24th April 2023
ChatGPT’s ability to generate coherent and fluent text responses to an assortment of natural language instructions has revolutionised the field of natural language processing (NLP), and has captured the attention of the world. This type of large language model (LLM) has been trained under a new training paradigm.
Traditionally, large-scale pretraining would be performed with internet data on the language model, and then fine-tuning the model on a downstream task would result in huge performance gains for that specific task. However, when we want to build a model that generalises across hundreds of different natural language tasks, the data required for SFT is practically impossible to acquire.
Which brings us to the topic of this article, "reinforcement learning from human feedback" (RLHF) with "proximal policy optimization" (PPO). This new training paradigm moves beyond the boundaries of task-specific supervised training; allowing us to build generalised LLMs that achieve extremely impressive results on a whole range of natural language tasks.
To understand RLHF and PPO in the context of instruction fine-tuning of LLMs we need to understand the basic concepts of reinforcement learning (RL). In this post, I aim to give an intuitive overview of how RL works, proximal gradient methods, and human feedback in the context of LLM instruction fine-tuning.
This is by no means a comprehensive overview of the field of RL, for example, Q-learning, Bellman equations and other terms and concepts are beyond the scope of this post. I’ve included the bare minimum that I believe is necessary to understand RLHF, omitting mathematical details and derivations.
The papers that I have based my post on are:
InstructGPT is the model that OpenAI designed to generate text responses to general instruction prompts, wherein ChatGPT builds on top of this by providing a chatbot user interface. Thus, throughout this post, I will be referring to InstructGPT and its paper, though the design and training procedure for ChatGPT-like LLMs are practically the same.
Photo credits: HuggingFace
The InstructGPT paper describes the 3-step method used to train Instruct-GPT in this way:
Step 1 is the regular SFT that we’re used to. This is a technique used to improve a LLM’s performance on a specific task or domain by fine-tuning its pre-trained weights on a smaller labelled dataset to that used during pre-training. This is an effective way to improve a model’s performance of a specific task such as text classification, sentiment analysis, and named entity recognition.
The SFT dataset contained about 13k training prompts, but it is not task specific. The SFT dataset is populated with natural language input prompts written by labellers as various instructions ranging from writing poetry, to answering questions about physics. In addition to this, each instruction prompt requires a corresponding ground truth answer from human writers.
But for training an LLM to generalise against a large variety of instruction prompts requires much more than 13k data samples. But getting human writers to come up with new prompt instructions and high-quality answers is expensive, time-consuming and extremely difficult.
Hence the need for steps 2 and 3 to accelerate the training process and overcome the need for further SFT data. But what is a reward model (RM), policy optimization, PPO? Why do we need humans in the loop? To fully understand the intuition behind all this, we need to understand the basics of reinforcement learning.
Reinforcement learning (RL) is a field of machine learning that aims to train agents to make good decisions in a specific environment through trial and error by receiving feedback in the form of rewards.
One of the most prominent examples of RL’s successes is the AlphaGo program developed by DeepMind around 2015: a computer program that was designed to play the board game Go, which ended up defeating the world champion Go player in a series of high-profile matches (link here https://www.deepmind.com/research/highlighted-research/alphago).
In order to understand RL, we need to familiarise ourselves with the terminology used widely across RL literature:
Photo credits: Medium
This figure illustrates the high-level RL pattern that typical RL systems follow. The agent is placed in an environment with a current state, which then decides on the most appropriate action to take based on its policy. Every action will cause a change in the current state, and the agent will receive a reward signal indicating whether the new state is good or bad. The goal of the agent is to learn a policy that maximises the expected cumulative reward for a given session.
In all cases of RL, our agent is usually some kind of computer system or neural network model, that is tasked with optimizing its policy (e.g., model weights) its environment to receive the highest reward possible.
To get a better grasp of how RL works, let's consider the example of an RL agent we want to teach to play the classic game of Super Mario Bros:
Now that we have a basic understanding of how this works in the context of a video game, let’s consider how RL works with training LLMs to output text based on a input prompt consisting of general natural language tasks:
To successfully train an agent via RL, we need a suitable RM that returns a reward signal based on the actions chosen by the agent. The reward is communicated to the agent, who in turn uses this information to updates its policy with the goal of maximising the expected cumulative reward for the next session.
If we return to our Super Mario Bros analogy, a simple RM (but by no means sufficient in reality) can be the scoring mechanism in the video game itself (e.g., the reward signal is the video game score for the actions taken so far in a particular level). Here, the agent’s objective is to update its current policy such that it learns the appropriate actions in each state that will lead to a higher score in the video game.
However, for LLMs, the RM itself is not so straight forward. We don’t have a scoring mechanism to tell us whether a LLM’s response to a prompt input is good or not. Designing an appropriate RM is not a trivial task as we want a RM that is informative and captures the desired behaviour of the LLM, but also easy to compute and does not lead to unintended consequences and loopholes.
Usually, in traditional supervised fine-tuning (SFT) of language models, the loss function tells us how well the model is performing, wherein the loss function calculated the discrepancy between the model’s predictions and the ground truth answers. But the problem is, we don’t have any more ground truth answers to natural language instructions anymore – nor are we looking to perform simple SFT.
So where do we find this RM that will appropriately score the text outputs of our LLM for any given instruction prompt? This is where training a RM from human feedback comes in.
Photo credits: HuggingFace
When training an RM for RL training on LLMs, a RM dataset needs to be carefully curated by human annotators. These annotators are used to judge the overall quality of the generated text outputs from various LLMs by ranking them from the best to worst response.
One may initially think that humans should apply a scalar score directly to each piece of text to train a RM, but this is difficult to do in practice. The differing values of given by different people cause these scores to be uncalibrated and noisy. Instead, rankings are used to compare the outputs of multiple LLMs and create a much better regularized dataset.
From a practical perspective, it is also much easier to rank outputs, rather than hand-write the responses for varied text prompts. This way, OpenAI was able to collect RM dataset has 33k training prompts for the RM dataset, compared to the 13k training prompts for the SFT dataset.
Photo credits: InstructGPT paper
Another advantages of training a RM from human feedback is that the RM itself can effectively encodes the preferences of the human labellers, e.g., is this harmful, is this safe, is this funny, etc. These are very subjective values that can’t be encoded in objective loss functions – but through gathering tagged data from human labellers indicating whether the output is "harmful" or not (through human labellers with a Likert scale, see here), we can train a RM to learn how to response to the question "how harmful is this response?".
Photo credits: InstructGPT paper
This RM data is then used to perform SFT on another LLM, which effectively learns which text generated output to an instruction prompt a human labeller would prefer. Therefore, the RM learns to imitate how a human would rate an response generated to an instruction prompt. The RM will indicate whether this generated response is “good” according to the ranking preferences of human labellers - this will be presented as a reward score.
So, we now have a RM trained on human feedback that can signal to our LLM whether the text generated response to an input prompt is good or bad. But remember that in RL, the aim is for our LLM to learn the most optimal policy that maximises the cumulative reward for a given set of actions (the generated text response). How do we get the agent to update its policy from the reward signals?
Photo credits: OpenAI - Spinning Up
In RL, there are many methods/algorithms for teaching agents to perform better, each with their own strengths and weaknesses, and their own use-cases. But for this article, we will look exclusively at a class of RL algorithms called “policy gradient methods” – and specifically that of proximal policy optimization (PPO) later. This is most relevant, as we can model the LLM training problem as a model-free RL problem – and this is also the approach used by OpenAI on InstructGPT.
Policy gradient methods aim to optimize the policy directly. For our LLMs, the policy is effectively the LLM itself, which decides what words to output based on the current state of the conversation or input prompt. Thus, the policy can be a parameterized function, wherein the parameters are the LLM weights. The optimization takes place through calculating the “policy gradients” with respect to an “objective function”. Let’s talk about these 2 new concepts.
The objective function specifies the goal that the agent is trying to achieve (i.e., maximise the expected cumulative reward for the text generated by an LLM based on the input prompts), which is often defined as a function of the policy and RM.
The main difference between the various policy gradient methods (VPG, TRPO, DPG, PPO, etc.) is how they’ve defined their objective function – each having their own advantages and disadvantages. But they all use the objective function to optimize the policy of the agent (by updating the policy parameters, i.e., the model weights).
Let us consider the simplest case of a policy gradient method here, the “vanilla policy gradient” (VPG), wherein the objective function can be defined as such (for the full details and derivations, please check here):
Note: the trajectory is the sequence of actions taken by the agent in response to a given state throughout the episode. For example, the word generated based on a given text sequence, then the next word generated based on the updated text sequence, and so on, until the completed text output is generated (i.e., the episode is completed).
In essence, VPG aims to maximise a very simple objective function, which is effectively the expected return from a RM for a given trajectory taken by the agent (i.e., the reward score generated from the RM for the text generated by the LLM for a given input prompt).
Here we want to use the objective function to update the policy via gradient ascent (as we want to maximise the score, unlike gradient descent when we want to minimize the score from a loss function). We can do this by optimizing the policy parameters (our model weights) by calculating the “policy gradients”.
The policy gradient is the partial derivative of the objective function with respect to the policy parameters, which is used to update the parameters (model weights), and thus update the policy of the agent (our LLM). This should look very familiar to you, it's basically gradient descent, except that in gradient ascent, the sign in front of the gradient term is positive!.
But note that the reward score is not differentiable with respect to any learnable parameters because the RM simple returns a ranking score, but not the ground truth answer (since we don't have it here, like with the SFT problem). How do we calculate the policy gradient if we can't retrieve the partial derivatives of the objective function (which is a function of the reward)?
Luckily, in RL, a policy gradient can be defined in other ways, for example: In VPG, the policy gradient can defined as such (the proof for this is beyond the scope of this article, but can be found here):
The mathematical details behind this policy gradient are irrelevant here, the point is that all policy gradient methods share the same characteristic; they update the policy parameters (LLM weights) by calculating the policy gradients and performing gradient ascent. The distinguishing feature between different policy gradient methods is how the objective function is defined, which in turn affects how we calculate the policy gradient to update our policy.
During training, policy gradients are calculated via backpropagation and used to update the model weights via a typical optimization algorithm such as stochastic gradient ascent, and this process is repeated iteratively until we have updated the model weights sufficiently. In other words, the LLMs weights are slowly adjusted such that it is incentivized to generate a response to an instruction prompt such that it maximises the objective function (which should indirectly reflect a “good” natural language response, and conversely discouraged to generate bad responses).
Photo credits: StackOverflow
Note: The main difference between an objective function and a loss function is that the former is used to optimize the performance (or reward received) by the LLM on a particular task, while the latter is used to measure the error of the LLMs predictions against a ground truth values during training. However, both aim to optimize their respective functions through gradient ascent/descent. Here, though I use term objective function, the subtle differences aren’t that important, and you can think of the objective function as a loss function for a RL problem.
As mentioned above, VPG is a simple algorithm, wherein the objective function is purely the reward signal for a LLM response calculated by our RM. This works fairly well in theory, but VPG suffers from several problems that can make it difficult to train our model and lead to suboptimal performance.
These problems include:
PPO is the algorithm used by OpenAI to optimize the policy in InstructGPT, and therefore this is the one we’ll look at here. According to the PPO paper, there are two variations of PPO:
In practice, both variations of PPO should ensure that the LLM generates reasonably coherent text responses. Without clipping the objective function, or the KL divergence, the optimized policy can move to a policy space that tricks the RM into giving a high reward despite generating incomprehensible text.
The PPO variation that InstructGPT adopts here is the one that makes use of the “KL penalty coefficient”, though it is not clear why they have opted for this variation when the “clipped surrogate objective” variation appears to perform better (at least for simulated robotic locomotion and Atari game playing, see paper).
I’ve placed the segment from the InstructGPT paper that defines their objective function below:
Photo credits: InstructGPT paper
The point being that with PPO, we make use of an objective function that addresses the problems of high variance, limited exploration, limited sample efficiency and inconsistent policy updates associated with VPG. Here are some ways PPO addresses these problems:
Photo credits: Reddit
Note that during RL training with PPO, we have 3 models. The frozen LLM and the LLM being trained to ensure that we can calculate the KL divergence term in the PPO objective function. And of course, the RM itself, whose weights are also frozen to ensure the ranking scores are consistent during training.
In practice, having 3 LLM models in one training environment resource heavy. For example, a full-precision 20B LLM will require 80GB GPU memory, 3 times this is 240GB (notwithstanding the training data and other considerations). There are ways to make this work in practice, such an 8-bit model instead, running on multiple GPUs (with data parallelism, model parallelism, etc.), and more. These practical considersations should not be ignored when performing such an compute intensive operation.
So now we can see how we can optimize a policy against the RM using PPO in RL. In step 3, a PPO dataset which includes 31k instruction prompts (no ground truth labels) are passed to the LLM to generate an output response. The LLM can then learn to update its model weights via PPO in order to optimimize the objective function (which is a function of the reward model).
But let’s consider some of the key advantages of RLHF and why this has led to such amazing results, allowing the LLM to perform well across other tasks that it has never seen before.
Caveat: although RLHF has generated phenomenal results, work on other LLMs like Alpaca, Vicuna, etc. have shown that we can achieve close to ChatGPT-like performances without RLHF and with SFT alone. This indicates that we can think of RLHF has a method that makes up for the lack of SFT data. If we had 100k worth of perfectly curated instruction prompt data and their ground truth labels for SFT, I suspect we wouldn’t have to go through the process of RLHF and still achieve ChatGPT-like performances, or even better.
NLP is moving very quickly right now, and every week there are new models, papers and results that are shifting the landscape. As more research comes out, RLHF may continue to be a dominating training method for LLM, or it could also fall off to the side-lines, with a new training paradigm to replace it. Regardless of that, there is no denying that it has caused a massive shift in how we consider training LLMs and aligning them to human preferences. Understanding RLHF will only help with following the exciting research and literature coming our way.
To learn more about RL in detail: