Section 06

The Code: Reward Model Training and RL Setup

Training Language Models to Follow Instructions with Human Feedback 2022

The Code: Reward Model Training and RL Setup

This section demonstrates the core components of RLHF: reward model training and the RL loss setup. This code runs on Google Colab with PyTorch and Transformers.

Code 1: Reward Model Training (Bradley-Terry)

# Reward Model Training with Bradley-Terry Loss
# Runs on Google Colab

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel

# Load a small pretrained model as the reward model base
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)

# Create reward head: [CLS] token → scalar reward
class RewardModel(nn.Module):
    def __init__(self, base_model, hidden_size=768):
        super().__init__()
        self.model = base_model
        self.reward_head = nn.Linear(hidden_size, 1)  # Output scalar
    
    def forward(self, input_ids, attention_mask):
        # Get [CLS] token representation
        outputs = self.model(input_ids, attention_mask)
        cls_hidden = outputs.last_hidden_state[:, 0, :]  # [batch, hidden]
        # Map to scalar reward
        reward = self.reward_head(cls_hidden).squeeze(-1)  # [batch]
        return reward

# Initialize model and optimizer
rm = RewardModel(base_model)
optimizer = optim.Adam(rm.parameters(), lr=5e-5)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy (sigmoid built-in)

# Mock comparison data: (prompt + response_A, prompt + response_B, label)
# In reality, these come from human annotations
training_data = [
    # (input_text_A, input_text_B, label_A_better)
    ("What is 2+2? Answer: 4", "What is 2+2? Answer: 5", 1),
    ("Explain gravity. It's a force. Explanation: ✓", 
     "Explain gravity. It's complicated.", 0),
    ("What is Python? It's a snake.", 
     "What is Python? A programming language.", 1),
]

# Training loop (simplified for demonstration)
rm.train()
for epoch in range(3):  # 3 epochs for this demo
    total_loss = 0
    for text_a, text_b, label_a_better in training_data:
        # Tokenize both responses
        tokens_a = tokenizer(text_a, return_tensors="pt", 
                            padding=True, truncation=True)
        tokens_b = tokenizer(text_b, return_tensors="pt", 
                            padding=True, truncation=True)
        
        # Forward pass: get rewards
        reward_a = rm(tokens_a["input_ids"], tokens_a["attention_mask"])
        reward_b = rm(tokens_b["input_ids"], tokens_b["attention_mask"])
        
        # Bradley-Terry: log(sigmoid(reward_winner - reward_loser))
        # This is equivalent to BCEWithLogitsLoss
        if label_a_better == 1:
            logits = reward_a - reward_b  # A better → positive
        else:
            logits = reward_b - reward_a  # B better → positive
        
        # Sigmoid of logits should be close to 1
        loss = criterion(logits.unsqueeze(-1), 
                        torch.ones(1, 1))  # Target: sigmoid = 1
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(training_data)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

# After training, evaluate on new responses
rm.eval()
with torch.no_grad():
    test_text = "Photosynthesis is the process where plants convert light to energy."
    test_tokens = tokenizer(test_text, return_tensors="pt", 
                           padding=True, truncation=True)
    reward_score = rm(test_tokens["input_ids"], 
                      test_tokens["attention_mask"])
    print(f"\nReward for test response: {reward_score.item():.4f}")

What this code does:

  1. Loads a pretrained language model (DistilBERT)
  2. Adds a scalar reward head on top (reward = MLP([CLS] token))
  3. Trains on preference pairs using Bradley-Terry loss
  4. Computes gradient to increase reward for preferred responses
  5. Outputs a scalar reward for any (prompt, response) pair

Key insight: The reward model is just a classifier trained on comparisons, reused with a small reward head.


Code 2: RL Loss with KL Penalty

# RL Objective with KL Divergence Penalty
# Demonstrating the PPO-style loss with KL constraint

import torch
import torch.nn.functional as F
from torch.distributions import Categorical

# Mock scenario:
# - RL policy (π_RL) generates a response
# - Reward model rates it
# - KL penalty keeps policy close to SFT baseline (π_SFT)

# Suppose we have log probabilities from both models
batch_size = 2
seq_len = 10

# Mock log probabilities over tokens (in practice, from model.logits)
log_probs_rl = torch.randn(batch_size, seq_len)  # π_RL log-probs
log_probs_sft = torch.randn(batch_size, seq_len)  # π_SFT log-probs

# Mock rewards from the reward model
rewards = torch.tensor([2.5, -0.3])  # One good response, one bad

# KL divergence penalty coefficient
beta = 0.02

# Compute KL divergence (per example)
# KL[π_RL || π_SFT] ≈ E_y[log π_RL(y) - log π_SFT(y)]
kl_divergence_per_token = log_probs_rl - log_probs_sft  # [batch, seq]
kl_per_example = kl_divergence_per_token.mean(dim=1)  # [batch]

# RL Loss: maximize reward, minimize KL divergence
# L = -reward + beta * KL
rl_loss = -rewards + beta * kl_per_example

print("Rewards from RM:", rewards.numpy())
print("KL divergence per example:", kl_per_example.detach().numpy())
print("RL Loss (raw):", rl_loss.detach().numpy())

# Average loss across batch
total_loss = rl_loss.mean()
print(f"\nTotal RL Loss: {total_loss.item():.4f}")

# In practice, this would be backpropagated:
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()

What this code does:

  1. Simulates log probabilities from RL policy and SFT baseline
  2. Computes KL divergence as the difference in log-probs
  3. Combines reward maximization and KL penalty
  4. Shows that loss = -reward + beta * KL (trade-off)

Key insight: The RL loss balances two objectives:

  • Numerator: Maximize reward (RL wants high-reward responses)
  • Denominator: Stay close to SFT (KL penalty prevents divergence)

Code 3: Full RLHF Training Loop (Simplified)

# Simplified RLHF Training Loop
# (In production, this would use PPO with advantage estimation)

import torch
import torch.optim as optim

# Pretend we have three components
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.tensor(0.5))
    
    def log_prob(self, x):
        # Dummy: output a learnable log-probability
        return -0.1 * ((x - self.weight)**2)  # Higher prob near weight

rl_policy = DummyModel()
sft_policy = DummyModel()
sft_policy.weight.data = torch.tensor(0.0)  # Baseline: weight=0

reward_model = lambda x: 2.0 * (x - 1.0)**2  # Reward for outputs near 1.0

optimizer = optim.Adam(rl_policy.parameters(), lr=0.1)
beta = 0.05

# RL training loop
for epoch in range(10):
    optimizer.zero_grad()
    
    # Sample from RL policy (for diversity)
    x_sampled = rl_policy.weight + torch.randn(5)  # Sample around mean
    
    # Compute rewards and log probabilities
    rewards = reward_model(x_sampled)  # [5]
    log_probs_rl = rl_policy.log_prob(x_sampled)  # [5]
    log_probs_sft = sft_policy.log_prob(x_sampled)  # [5]
    
    # KL divergence
    kl = (log_probs_rl - log_probs_sft).mean()
    
    # RL loss
    loss = -rewards.mean() + beta * kl
    
    # Backprop
    loss.backward()
    optimizer.step()
    
    if epoch % 3 == 0:
        print(f"Epoch {epoch}: Loss={loss.item():.4f}, "
              f"Reward={rewards.mean().item():.4f}, "
              f"KL={kl.item():.4f}, "
              f"Policy param={rl_policy.weight.item():.4f}")

print("\nAfter RL training:")
print(f"RL policy learned weight: {rl_policy.weight.item():.4f}")
print(f"SFT baseline weight: {sft_policy.weight.item():.4f}")
print("(RL policy moved toward reward peak at 1.0, but constrained by KL penalty)")

What this code does:

  1. Simulates an RL policy trying to maximize a reward
  2. Shows how KL penalty keeps the policy close to SFT baseline
  3. Demonstrates the trade-off: RL can improve, but not infinitely
  4. Shows that without KL penalty (beta=0), policy diverges completely

Key behavior:

  • With beta=0.05: Policy moves toward reward peak (~1.0) but stays close to baseline (~0.0)
  • With beta=0: Policy would move all the way to 1.0 (unconstrained)
  • With beta large: Policy would barely move (constrained too much)

Practical Notes on Implementation

1. Batch Size for RL

RL training is expensive because:

  • Generate response from policy: O(seq_len) tokens
  • Compute reward: O(seq_len) through RM
  • Compute log-probs: O(seq_len) through policy

Use smaller batch sizes (~16) for RL, larger (~128) for RM and SFT.

2. PPO Clipping (Not Shown)

The paper uses PPO, which clips gradients to prevent overshoots:

# Simplified PPO clip
ratio = exp(log_prob_new - log_prob_old)
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
loss = -torch.min(ratio * advantage, clipped_ratio * advantage)

This prevents one batch from causing huge policy updates.

3. Value Function for Advantage (Not Shown)

In practice, use:

advantage = reward - V(x)

where V(x) is a learned baseline that estimates expected reward given prompt x. This reduces variance in gradient estimates.

4. Data Flow in Production

SFT Model (trained)

    ├→ Generates responses on diverse prompts

    ├→ Human raters compare pairs (33k comparisons)

    └→ Reward Model (trained on comparisons)

         ├→ Scores RL rollouts (cheap, fast)

         └→ RL Training Loop (PPO)

              InstructGPT (aligned model)

Colab-Ready Code Summary

The three code blocks above are self-contained and run on free Google Colab:

  1. Code 1: Reward model training (~5 min)
  2. Code 2: RL loss computation (~1 min)
  3. Code 3: Full RL loop (~2 min)

For production systems:

  • Use HuggingFace’s trl library (Text RL) for PPO
  • Scale to 13B+ parameter models
  • Collect tens of thousands of human preference pairs
  • Train on 8× GPU setups

The paper’s key contribution isn’t new algorithms (SFT, RM, RL all existed before), but showing how to combine them effectively at scale.