PPO Models: Detailed Explanation of All Components
Overview
In PPO (Proximal Policy Optimization) used for RLHF, there are four key models/components that work together. This document explains each one in detail: what they are, their mathematical role, how they're used, and where they appear in the training pipeline.
Part 1: The Four Models in PPO/RLHF
Model 1: Policy Model ()
What it is:
- The main model being trained
- Generates responses/actions
- Outputs probability distribution over actions
- This is what we're optimizing
Mathematical role. is the probability of action given state .
In language models. is the probability of generating response given prompt .
Outputs:
- Log probabilities:
- Action probabilities:
- Can also include value estimate (if using actor-critic architecture)
Where it's used:
- Generation: generate responses during training.
- Loss computation: compute policy gradient.
- Importance sampling: compute the ratio .
Mathematical formulation in PPO:
where
Key point. This is the model we're training. It learns to maximize reward, constrained by a KL penalty to stay close to the reference.
Model 2: Critic Model (Value Function )
What it is:
- Estimates the value of a state
- Predicts expected future return
- Used to compute advantages
- Can be a separate model or share parameters with the policy
Mathematical role.
In language models. is the expected reward for prompt .
Outputs:
- Scalar value estimate .
- Used to compute advantages .
Where it's used:
- Advantage computation: .
- Value loss: .
- Baseline: reduces variance in the policy gradient.
Mathematical formulation.
with
Value loss.
where is the actual return (discounted sum of rewards).
Key point. Estimates how good a state is, used to compute advantages (how much better than average), trained with MSE loss against actual returns.
Architecture options:
- Separate critic: independent model .
- Shared base: policy and critic share base layers, separate heads.
- Actor-critic: single model with policy and value heads.
Model 3: Reference Model ()
What it is:
- Frozen copy of the policy before RL training
- Used to compute the KL penalty
- Prevents the policy from deviating too much
- Typically the SFT (supervised fine-tuned) model
Mathematical role. is the (frozen) reference policy. For language models, is the reference model's probability of response .
Outputs:
- Log probabilities .
- Used to compute KL divergence.
Where it's used:
- KL penalty computation: .
- Importance sampling ratio: .
- Regularization: prevents policy collapse.
Mathematical formulation.
where
In the PPO loss.
with
Key point. Frozen (not trained); provides stability, prevents mode collapse, and ensures the policy doesn't forget SFT capabilities.
Why important:
- Prevents mode collapse: keeps the policy diverse.
- Prevents reward hacking: constrains the policy.
- Maintains capabilities: preserves SFT knowledge.
- Stability: prevents large policy changes.
Model 4: Reward Model ()
What it is:
- Predicts a reward for a response
- Trained on human preferences
- Scores how good a response is
- Used to compute rewards during RL training
Mathematical role. is the scalar reward for response to prompt . Higher means better response.
Where it's used:
- Reward computation: score generated responses.
- Return computation: .
- Advantage computation: .
Mathematical formulation.
Training (before RL). Bradley–Terry preference loss:
where is the chosen (winning) response, is the rejected (losing) response, and is the sigmoid function.
Key point. Trained separately before RL; captures human preferences; used to score responses during RL training; can be frozen or updated during RL.
Why important:
- Human preferences: encodes what humans want.
- Reward signal: provides the learning signal for the policy.
- Quality assessment: measures response quality.
Part 2: How They Work Together in PPO Training
Complete PPO Training Loop
Step 1 — Generate responses. Using policy model :
Step 2 — Score with reward model. Using reward model :
Step 3 — Get log probabilities.
Step 4 — Compute returns.
Step 5 — Compute values. Using critic model :
Step 6 — Compute advantages.
Step 7 — Compute PPO loss.
Step 8 — Update models.
- Update policy : optimize .
- Update critic : optimize .
- Reference : frozen (no update).
- Reward : typically frozen (can be updated).
Part 3: Mathematical Details for Each Model
Policy Model () — detailed mathematics
Forward pass. Input prompt , output response with probability . For each token:
Log probability.
Policy gradient.
where and is the advantage.
PPO clipping.
This prevents large policy updates, over-optimization, and training instability.
Critic Model () — detailed mathematics
Value function.
where is the discount factor, the reward at time , and the current policy.
Bellman equation.
Value loss.
Gradient.
Why a value function:
- Baseline: reduces variance in the policy gradient.
- Advantages: — how much better than average.
- Stability: more stable than raw returns.
Reference Model () — detailed mathematics
KL divergence.
In practice.
Properties.
- (always non-negative).
- iff .
- Asymmetric: .
Why a KL penalty:
- Trust region: keeps the policy close to the reference.
- Prevents collapse: maintains diversity.
- Stability: prevents large changes.
- Capability preservation: keeps SFT knowledge.
Typical values. ; target KL nats per token. If KL is too high, increase ; if too low, decrease .
Reward Model () — detailed mathematics
Reward function. — maps (prompt, response) to a scalar reward.
Training objective (Bradley–Terry).
where is the chosen (winning) response, the rejected (losing) response, and the sigmoid function.
Interpretation.
the probability that the chosen response is preferred over the rejected one.
During RL. For a generated response :
used to compute returns and advantages .
Reward shaping (optional).
where is a KL penalty (can live in the reward or in the loss) and is a length penalty.
Part 4: Architecture Details
Policy Model Architecture
Option 1 — separate policy network:
class PolicyModel(nn.Module):
def __init__(self):
self.base = Transformer(...)
self.head = nn.Linear(d_model, vocab_size)
def forward(self, x):
hidden = self.base(x)
logits = self.head(hidden)
return logits
Option 2 — actor–critic (shared base):
class ActorCritic(nn.Module):
def __init__(self):
self.base = Transformer(...) # shared
self.policy_head = nn.Linear(d_model, vocab_size)
self.value_head = nn.Linear(d_model, 1)
def forward(self, x):
hidden = self.base(x)
logits = self.policy_head(hidden)
values = self.value_head(hidden)
return logits, values
Critic Model Architecture
Option 1 — separate critic:
class CriticModel(nn.Module):
def __init__(self):
self.base = Transformer(...)
self.head = nn.Linear(d_model, 1)
def forward(self, x):
hidden = self.base(x)
value = self.head(hidden)
return value
Option 2 — shared with policy (actor–critic): same as above, but shares the base with the policy.
Reference Model Architecture
Same as the policy model — a copy of the policy before RL training. Frozen (no gradients), used only for log-probability computation.
# Initialize reference model
reference_model = copy.deepcopy(policy_model)
reference_model.eval() # freeze
for param in reference_model.parameters():
param.requires_grad = False
Reward Model Architecture
class RewardModel(nn.Module):
def __init__(self, base_model):
self.base = base_model # can use policy base
self.head = nn.Linear(d_model, 1)
def forward(self, x, y):
# Concatenate prompt and response
input_ids = concat(x, y)
hidden = self.base(input_ids)
# Use last token or mean pooling
reward = self.head(hidden[-1]) # or mean(hidden)
return reward
Part 5: Training Phases
Phase 1: Supervised Fine-Tuning (SFT)
Models used. Policy model (being trained).
Objective (standard language modeling loss).
Result. A policy model that can follow instructions; this becomes the reference model .
Phase 2: Reward Model Training
Models used. Reward model (being trained).
Data. Preference pairs .
Objective.
Result. A reward model that scores responses, trained to prefer chosen over rejected.
Phase 3: RL Optimization (PPO)
Models used.
- Policy model (being trained).
- Critic model (being trained).
- Reference model (frozen).
- Reward model (typically frozen).
Objective.
with
Training loop.
- Generate responses with .
- Score with .
- Get logprobs from and .
- Compute values with .
- Compute advantages.
- Update and .
Result. An aligned policy model, better at generating preferred responses.
Part 6: Summary Table
| Model | Role | Trained? | Used for | Mathematical form |
|---|---|---|---|---|
| Policy | Generate responses | Yes | Generation, loss | |
| Critic | Estimate state value | Yes | Advantages | |
| Reference | Regularization | No (frozen) | KL penalty | |
| Reward | Score responses | Before RL | Rewards |
Key relationships.
- Advantage: .
- Ratio: .
- KL: .
- Reward: .
Training.
- SFT: train .
- Reward: train .
- RL: train and ( and frozen).
Conclusion
Understanding these four models is crucial for PPO/RLHF:
- Policy model: what we're optimizing; generates responses.
- Critic model: estimates values; computes advantages.
- Reference model: provides stability; prevents collapse.
- Reward model: scores responses; provides the learning signal.
Each has a specific mathematical role and is used at different stages of training. Together, they enable stable and effective RLHF training.