RLHF Pipeline: Complete Process Explanation (Interview Style)
The Three-Stage RLHF Pipeline
Stage 1: Supervised Fine-Tuning (SFT)
The first stage of the RLHF pipeline is supervised fine-tuning, which transforms a base language model into a model that can follow instructions and generate appropriate responses. This stage is fundamentally about teaching the model to understand and respond to human instructions, which is a capability that base language models trained on general text data often lack. The training data for this stage consists of high-quality human-written demonstrations, where each example includes a prompt or instruction and a corresponding high-quality response written by a human expert.
During supervised fine-tuning, the model learns through standard maximum likelihood estimation, where it tries to predict the next token in the response given the prompt and the previous tokens. The loss function is simply the cross-entropy loss between the model's predicted token probabilities and the actual tokens in the human-written response. This is a straightforward supervised learning problem, but it's crucial because it establishes the baseline capabilities that the model will need in subsequent stages. The model learns not just to generate coherent text, but specifically to generate text that follows instructions and provides helpful, relevant responses.
The quality of the supervised fine-tuning stage is critical because the model produced here will serve as the reference model in later stages, and it also provides the foundation for the reward model training. If the SFT model is poor, the entire RLHF pipeline will struggle. Therefore, significant effort goes into curating high-quality demonstration data, ensuring diversity in the types of instructions and responses, and training the model for sufficient epochs to achieve good performance. The resulting model should be able to generate reasonable responses to a wide variety of prompts, even if those responses might not always align perfectly with human preferences in terms of style, tone, or specific content choices.
Stage 2: Reward Model Training
The second stage of RLHF involves training a reward model that can score how good a response is according to human preferences. This is a critical component because the reward model will provide the learning signal for the reinforcement learning stage. Unlike supervised fine-tuning, which uses demonstration data, reward model training uses preference data, where humans have indicated which of two responses they prefer for a given prompt. This preference data is more scalable to collect than demonstrations because it's easier for humans to compare two responses and indicate a preference than it is to write a high-quality response from scratch.
The reward model is typically built on top of the SFT model, using its representations as a starting point. The architecture usually involves taking the final hidden states from the SFT model and passing them through a linear layer that outputs a single scalar score. During training, the reward model learns to assign higher scores to responses that humans prefer and lower scores to responses they don't prefer. The training objective is a binary classification problem: given a prompt and two responses, the reward model should predict that the preferred response has a higher score than the non-preferred response.
The mathematical formulation uses a sigmoid function to convert the difference in reward scores into a probability that the first response is preferred. The loss function is the negative log probability of the correct preference, which encourages the reward model to assign higher scores to preferred responses. This training process continues until the reward model can reliably rank responses according to human preferences. The quality of the reward model is crucial because any biases or errors in the reward model will be learned by the policy during reinforcement learning, potentially leading to undesirable behaviors.
One challenge in reward model training is ensuring that the model generalizes well to responses it hasn't seen during training. The reward model needs to be able to score any response, including those generated by the policy during RL training, which may have different characteristics than the responses in the training data. This requires careful data collection that covers a diverse range of prompts and response types, and it may also require periodic retraining of the reward model as the policy evolves and generates different types of responses.
Stage 3: Reinforcement Learning Optimization (PPO)
The third and final stage of RLHF uses reinforcement learning, specifically Proximal Policy Optimization, to optimize the policy model to maximize the reward predicted by the reward model while staying close to the reference model. This stage is where the actual alignment happens, as the policy learns to generate responses that score highly according to the reward model, which represents human preferences. The process begins by generating responses using the current policy model for a batch of prompts. These responses are then scored by the reward model to obtain reward signals.
The PPO algorithm uses several key components to ensure stable and effective training. First, it uses a critic model, also called a value function, to estimate the expected future return from each state. This value function helps reduce the variance in the policy gradient estimates by providing a baseline. The advantage of each action is computed as the difference between the actual return and the value function's prediction, which tells us how much better or worse each action was compared to what we expected on average.
The policy update in PPO uses importance sampling to allow reuse of data collected with an older version of the policy. The importance sampling ratio compares the probability of actions under the current policy to their probability under the policy that collected the data. PPO clips this ratio to prevent the policy from making large updates that could destabilize training. The clipped objective ensures that the policy improves when advantages are positive and reliable, but prevents it from making harmful updates when the ratio suggests the policies have diverged too much.
A crucial component of PPO in RLHF is the KL divergence penalty, which prevents the policy from deviating too far from the reference model. This penalty is computed by comparing the log probabilities of responses under the current policy and the reference model. By penalizing large KL divergences, we ensure that the policy maintains the capabilities learned during supervised fine-tuning while still being able to adapt to human preferences. The strength of this penalty is controlled by a hyperparameter that must be carefully tuned to balance between learning from rewards and maintaining capabilities.
The complete training loop involves generating responses, scoring them with the reward model, computing advantages using the critic model, and updating both the policy and critic models using the PPO objective with the KL penalty. This process is repeated for many iterations, with the policy gradually improving its ability to generate responses that humans prefer while maintaining its core language modeling capabilities. The result is a model that is both capable and aligned with human values and preferences.
Challenges and Solutions in RLHF
One of the major challenges in RLHF is reward hacking, where the policy learns to exploit the reward model in unintended ways rather than actually improving response quality. For example, the policy might learn to generate very long responses if the reward model happens to favor length, or it might learn to use certain phrases that the reward model associates with high scores, even if those phrases don't actually improve the response quality. The KL penalty helps mitigate this by keeping the policy close to the reference model, which was trained to generate good responses through supervised learning. Additionally, careful design of the reward model and regular monitoring of policy behavior can help detect and prevent reward hacking.
Another challenge is distribution shift, where the policy generates responses that are different from the distribution of responses the reward model was trained on. This can cause the reward model to make inaccurate predictions, leading to poor learning signals. This is addressed by periodically retraining the reward model on data that includes responses from the current policy, ensuring that the reward model stays calibrated to the policy's output distribution. Some implementations also use on-policy data collection, where the reward model is continuously updated as new policy responses are generated.
Mode collapse is another concern, where the policy converges to generating a limited set of responses rather than maintaining diversity. This can happen if the reward model consistently favors certain response patterns, causing the policy to over-optimize for those patterns. The KL penalty helps here as well, by encouraging the policy to maintain the diversity of the reference model. Additionally, some implementations include an entropy bonus that explicitly encourages diversity in the policy's output distribution.
The computational cost of RLHF is significant, as it requires training multiple models (SFT, reward model, and policy with critic), collecting large amounts of human preference data, and running many iterations of the RL training loop. However, the benefits in terms of model alignment and performance often justify this cost. The key is to design an efficient pipeline that maximizes the value of human feedback while minimizing computational overhead.
Evaluation and Iteration
Evaluating the success of RLHF is challenging because there's no single metric that captures all aspects of alignment. Common evaluation approaches include human evaluation, where humans rate the quality of model responses, automated metrics that measure specific aspects like helpfulness or harmlessness, and red teaming, where experts try to find ways to make the model produce undesirable outputs. The evaluation process is iterative, with the results informing adjustments to the training process, reward model design, and hyperparameters.
The RLHF pipeline is not a one-time process but rather an iterative cycle of improvement. As the model is deployed and used, new data is collected, preferences may evolve, and new issues may be discovered. This requires ongoing monitoring, evaluation, and potentially retraining of various components. The ability to iterate and improve is crucial for maintaining alignment as the model is used in production and as our understanding of what constitutes good alignment evolves.