PPO Models: Detailed Explanation of All Components

Overview

In PPO (Proximal Policy Optimization) used for RLHF, there are four key models/components that work together. This document explains each one in detail: what they are, their mathematical role, how they're used, and where they appear in the training pipeline.


Part 1: The Four Models in PPO/RLHF

Model 1: Policy Model ()

What it is:

  • The main model being trained
  • Generates responses/actions
  • Outputs probability distribution over actions
  • This is what we're optimizing

Mathematical role. is the probability of action given state .

In language models. is the probability of generating response given prompt .

Outputs:

  • Log probabilities:
  • Action probabilities:
  • Can also include value estimate (if using actor-critic architecture)

Where it's used:

  1. Generation: generate responses during training.
  2. Loss computation: compute policy gradient.
  3. Importance sampling: compute the ratio .

Mathematical formulation in PPO:

where

Key point. This is the model we're training. It learns to maximize reward, constrained by a KL penalty to stay close to the reference.


Model 2: Critic Model (Value Function )

What it is:

  • Estimates the value of a state
  • Predicts expected future return
  • Used to compute advantages
  • Can be a separate model or share parameters with the policy

Mathematical role.

In language models. is the expected reward for prompt .

Outputs:

  • Scalar value estimate .
  • Used to compute advantages .

Where it's used:

  1. Advantage computation: .
  2. Value loss: .
  3. Baseline: reduces variance in the policy gradient.

Mathematical formulation.

with

Value loss.

where is the actual return (discounted sum of rewards).

Key point. Estimates how good a state is, used to compute advantages (how much better than average), trained with MSE loss against actual returns.

Architecture options:

  1. Separate critic: independent model .
  2. Shared base: policy and critic share base layers, separate heads.
  3. Actor-critic: single model with policy and value heads.

Model 3: Reference Model ()

What it is:

  • Frozen copy of the policy before RL training
  • Used to compute the KL penalty
  • Prevents the policy from deviating too much
  • Typically the SFT (supervised fine-tuned) model

Mathematical role. is the (frozen) reference policy. For language models, is the reference model's probability of response .

Outputs:

  • Log probabilities .
  • Used to compute KL divergence.

Where it's used:

  1. KL penalty computation: .
  2. Importance sampling ratio: .
  3. Regularization: prevents policy collapse.

Mathematical formulation.

where

In the PPO loss.

with

Key point. Frozen (not trained); provides stability, prevents mode collapse, and ensures the policy doesn't forget SFT capabilities.

Why important:

  1. Prevents mode collapse: keeps the policy diverse.
  2. Prevents reward hacking: constrains the policy.
  3. Maintains capabilities: preserves SFT knowledge.
  4. Stability: prevents large policy changes.

Model 4: Reward Model ()

What it is:

  • Predicts a reward for a response
  • Trained on human preferences
  • Scores how good a response is
  • Used to compute rewards during RL training

Mathematical role. is the scalar reward for response to prompt . Higher means better response.

Where it's used:

  1. Reward computation: score generated responses.
  2. Return computation: .
  3. Advantage computation: .

Mathematical formulation.

Training (before RL). Bradley–Terry preference loss:

where is the chosen (winning) response, is the rejected (losing) response, and is the sigmoid function.

Key point. Trained separately before RL; captures human preferences; used to score responses during RL training; can be frozen or updated during RL.

Why important:

  1. Human preferences: encodes what humans want.
  2. Reward signal: provides the learning signal for the policy.
  3. Quality assessment: measures response quality.

Part 2: How They Work Together in PPO Training

Complete PPO Training Loop

Step 1 — Generate responses. Using policy model :

Step 2 — Score with reward model. Using reward model :

Step 3 — Get log probabilities.

Step 4 — Compute returns.

Step 5 — Compute values. Using critic model :

Step 6 — Compute advantages.

Step 7 — Compute PPO loss.

Step 8 — Update models.

  • Update policy : optimize .
  • Update critic : optimize .
  • Reference : frozen (no update).
  • Reward : typically frozen (can be updated).

Part 3: Mathematical Details for Each Model

Policy Model () — detailed mathematics

Forward pass. Input prompt , output response with probability . For each token:

Log probability.

Policy gradient.

where and is the advantage.

PPO clipping.

This prevents large policy updates, over-optimization, and training instability.


Critic Model () — detailed mathematics

Value function.

where is the discount factor, the reward at time , and the current policy.

Bellman equation.

Value loss.

Gradient.

Why a value function:

  1. Baseline: reduces variance in the policy gradient.
  2. Advantages: — how much better than average.
  3. Stability: more stable than raw returns.

Reference Model () — detailed mathematics

KL divergence.

In practice.

Properties.

  • (always non-negative).
  • iff .
  • Asymmetric: .

Why a KL penalty:

  1. Trust region: keeps the policy close to the reference.
  2. Prevents collapse: maintains diversity.
  3. Stability: prevents large changes.
  4. Capability preservation: keeps SFT knowledge.

Typical values. ; target KL nats per token. If KL is too high, increase ; if too low, decrease .


Reward Model () — detailed mathematics

Reward function. — maps (prompt, response) to a scalar reward.

Training objective (Bradley–Terry).

where is the chosen (winning) response, the rejected (losing) response, and the sigmoid function.

Interpretation.

the probability that the chosen response is preferred over the rejected one.

During RL. For a generated response :

used to compute returns and advantages .

Reward shaping (optional).

where is a KL penalty (can live in the reward or in the loss) and is a length penalty.


Part 4: Architecture Details

Policy Model Architecture

Option 1 — separate policy network:

class PolicyModel(nn.Module):
    def __init__(self):
        self.base = Transformer(...)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        hidden = self.base(x)
        logits = self.head(hidden)
        return logits

Option 2 — actor–critic (shared base):

class ActorCritic(nn.Module):
    def __init__(self):
        self.base = Transformer(...)              # shared
        self.policy_head = nn.Linear(d_model, vocab_size)
        self.value_head  = nn.Linear(d_model, 1)

    def forward(self, x):
        hidden = self.base(x)
        logits = self.policy_head(hidden)
        values = self.value_head(hidden)
        return logits, values

Critic Model Architecture

Option 1 — separate critic:

class CriticModel(nn.Module):
    def __init__(self):
        self.base = Transformer(...)
        self.head = nn.Linear(d_model, 1)

    def forward(self, x):
        hidden = self.base(x)
        value = self.head(hidden)
        return value

Option 2 — shared with policy (actor–critic): same as above, but shares the base with the policy.

Reference Model Architecture

Same as the policy model — a copy of the policy before RL training. Frozen (no gradients), used only for log-probability computation.

# Initialize reference model
reference_model = copy.deepcopy(policy_model)
reference_model.eval()  # freeze
for param in reference_model.parameters():
    param.requires_grad = False

Reward Model Architecture

class RewardModel(nn.Module):
    def __init__(self, base_model):
        self.base = base_model        # can use policy base
        self.head = nn.Linear(d_model, 1)

    def forward(self, x, y):
        # Concatenate prompt and response
        input_ids = concat(x, y)
        hidden = self.base(input_ids)
        # Use last token or mean pooling
        reward = self.head(hidden[-1])  # or mean(hidden)
        return reward

Part 5: Training Phases

Phase 1: Supervised Fine-Tuning (SFT)

Models used. Policy model (being trained).

Objective (standard language modeling loss).

Result. A policy model that can follow instructions; this becomes the reference model .


Phase 2: Reward Model Training

Models used. Reward model (being trained).

Data. Preference pairs .

Objective.

Result. A reward model that scores responses, trained to prefer chosen over rejected.


Phase 3: RL Optimization (PPO)

Models used.

  • Policy model (being trained).
  • Critic model (being trained).
  • Reference model (frozen).
  • Reward model (typically frozen).

Objective.

with

Training loop.

  1. Generate responses with .
  2. Score with .
  3. Get logprobs from and .
  4. Compute values with .
  5. Compute advantages.
  6. Update and .

Result. An aligned policy model, better at generating preferred responses.


Part 6: Summary Table

ModelRoleTrained?Used forMathematical form
Policy Generate responsesYesGeneration, loss
Critic Estimate state valueYesAdvantages
Reference RegularizationNo (frozen)KL penalty
Reward Score responsesBefore RLRewards

Key relationships.

  • Advantage: .
  • Ratio: .
  • KL: .
  • Reward: .

Training.

  • SFT: train .
  • Reward: train .
  • RL: train and ( and frozen).

Conclusion

Understanding these four models is crucial for PPO/RLHF:

  1. Policy model: what we're optimizing; generates responses.
  2. Critic model: estimates values; computes advantages.
  3. Reference model: provides stability; prevents collapse.
  4. Reward model: scores responses; provides the learning signal.

Each has a specific mathematical role and is used at different stages of training. Together, they enable stable and effective RLHF training.