KL Divergence
KL Divergence: Measuring Probability Distribution Differences
What is KL Divergence?
KL divergence (Kullback-Leibler divergence) measures how different two probability distributions are. It answers the question: “If I expect distribution P but I actually see distribution Q, how surprised am I?”
It’s a distance-like measure between distributions, but it’s not symmetric: the “distance” from P to Q is usually different from the “distance” from Q to P.
Formal definition:
$$KL(P \parallel Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)}$$
Where:
- $P(x)$ is the “true” or “reference” distribution
- $Q(x)$ is the approximate or “model” distribution
- The sum is over all possible outcomes $x$
- We use natural log or log base 2 (see context for which)
Alternative form (numerically more stable):
$$KL(P \parallel Q) = \sum_{x} P(x) [\log P(x) - \log Q(x)]$$
Intuition: Surprise and Inefficiency
Imagine you’re a programmer designing a compression algorithm.
You believe the true message distribution is P (maybe 70% A’s, 20% B’s, 10% C’s). So you design a code optimized for this: A uses 1 bit, B uses 2 bits, C uses 3 bits.
But the actual distribution is Q (maybe 10% A’s, 30% B’s, 60% C’s). Now your code is inefficient! You’re using lots of bits for C (which is common) when you should have used few bits.
KL divergence measures this inefficiency: how much wasted communication you get by using the wrong distribution.
In machine learning: KL divergence measures how much extra information (wasted bits) you need if you approximate true distribution P with model distribution Q.
Worked Example: Two Coin Flips
True distribution P: Fair coin (50% heads, 50% tails) Model distribution Q: Biased coin (60% heads, 40% tails)
Calculate KL(P || Q):
$$KL(P \parallel Q) = P(\text{heads}) \log \frac{P(\text{heads})}{Q(\text{heads})} + P(\text{tails}) \log \frac{P(\text{tails})}{Q(\text{tails})}$$
$$KL(P \parallel Q) = 0.5 \log \frac{0.5}{0.6} + 0.5 \log \frac{0.5}{0.4}$$
Calculate each term:
- $\log \frac{0.5}{0.6} = \log(0.833) \approx -0.1823$ (using natural log)
- $\log \frac{0.5}{0.4} = \log(1.25) \approx 0.2231$ (using natural log)
$$KL(P \parallel Q) = 0.5 \times (-0.1823) + 0.5 \times 0.2231$$
$$KL(P \parallel Q) = -0.0912 + 0.1116 = 0.0204 \text{ nats}$$
Interpretation: If the true distribution is fair (P) but you model it as biased (Q), you waste about 0.0204 nats of information. This is a small difference because the distributions are similar.
Worked Example: Comparing Fair Coin vs Completely Biased Coin
True distribution P: Fair coin (50% heads, 50% tails)
Wrong model Q: Completely biased (99% heads, 1% tails)
Calculate KL(P || Q):
$$KL(P \parallel Q) = 0.5 \log \frac{0.5}{0.99} + 0.5 \log \frac{0.5}{0.01}$$
Calculate each term:
- $\log \frac{0.5}{0.99} = \log(0.505) \approx -0.6826$ (natural log)
- $\log \frac{0.5}{0.01} = \log(50) \approx 3.912$ (natural log)
$$KL(P \parallel Q) = 0.5 \times (-0.6826) + 0.5 \times 3.912$$
$$KL(P \parallel Q) = -0.3413 + 1.956 = 1.615 \text{ nats}$$
Interpretation: This is a large difference! If you use an almost-always-heads model to approximate a fair coin, you waste 1.615 nats. The distributions are very different.
Worked Example: Three Possible Outcomes
True distribution P:
- Outcome A: P(A) = 0.5
- Outcome B: P(B) = 0.3
- Outcome C: P(C) = 0.2
Model distribution Q:
- Outcome A: Q(A) = 0.4
- Outcome B: Q(B) = 0.4
- Outcome C: Q(C) = 0.2
Calculate KL(P || Q):
$$KL(P \parallel Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}$$
$$= P(A) \log \frac{P(A)}{Q(A)} + P(B) \log \frac{P(B)}{Q(B)} + P(C) \log \frac{P(C)}{Q(C)}$$
$$= 0.5 \log \frac{0.5}{0.4} + 0.3 \log \frac{0.3}{0.4} + 0.2 \log \frac{0.2}{0.2}$$
Calculate each term:
- $0.5 \log \frac{0.5}{0.4} = 0.5 \log(1.25) = 0.5 \times 0.223 = 0.1115$
- $0.3 \log \frac{0.3}{0.4} = 0.3 \log(0.75) = 0.3 \times (-0.288) = -0.0864$
- $0.2 \log \frac{0.2}{0.2} = 0.2 \log(1) = 0.2 \times 0 = 0$
$$KL(P \parallel Q) = 0.1115 - 0.0864 + 0 = 0.0251 \text{ nats}$$
Interpretation: The model Q is a reasonable approximation of P, with a KL divergence of about 0.025 nats.
Key Properties
1. KL Divergence is Always Non-Negative
$$KL(P \parallel Q) \geq 0$$
Equality holds only when $P = Q$ (the distributions are identical).
Proof intuition: Gibbs’ inequality states that cross-entropy $H(P, Q) \geq H(P)$, where $H(P, Q) = -\sum P(x) \log Q(x)$. Rearranging gives KL ≥ 0.
2. KL Divergence is NOT Symmetric
$$KL(P \parallel Q) \neq KL(Q \parallel P)$$
This is fundamentally important. The divergence from P to Q is different from Q to P.
Example:
- KL(fair coin || biased coin) = 0.0204 nats
- KL(biased coin || fair coin) = ? (Different!)
Let’s calculate KL(Q || P) from our first example:
$$KL(Q \parallel P) = 0.6 \log \frac{0.6}{0.5} + 0.4 \log \frac{0.4}{0.5}$$
$$= 0.6 \log(1.2) + 0.4 \log(0.8)$$
$$= 0.6 \times 0.1823 + 0.4 \times (-0.2231)$$
$$= 0.1094 - 0.0892 = 0.0202 \text{ nats}$$
In this case, they’re close (both ~0.02), but in general, they can be very different.
3. Interpretation of the Direction
KL(P || Q): “I expected P, but got Q. How wrong was I?” — Penalizes Q for assigning low probability to events that actually happen in P.
KL(Q || P): “I have model Q, but truth is P. How wrong is Q?” — Penalizes Q for assigning high probability to events that don’t happen in P.
In machine learning practice:
- Training: We often minimize KL(P_true || Q_model) = minimize surprise given true data
- RLHF (Paper 15): We use KL(Q_RL || Q_SFT) as a penalty to prevent the RL model from diverging too far from the base model
4. Relationship to Entropy
KL divergence can be decomposed as:
$$KL(P \parallel Q) = H(P, Q) - H(P)$$
Where:
- $H(P, Q) = -\sum P(x) \log Q(x)$ is cross-entropy
- $H(P) = -\sum P(x) \log P(x)$ is entropy
Interpretation: KL divergence is the extra entropy (extra bits) you need if you use Q instead of P.
Real-World Example: Language Model Training
In Paper 15 (RLHF), we have:
- P_true: The true distribution of good, helpful responses (from human preferences)
- Q_base: The SFT model’s distribution (supervised fine-tuned on human examples)
- Q_RL: The RL-optimized model’s distribution (after policy gradient updates)
We want to:
- Optimize Q_RL to maximize reward
- But stay close to Q_base (don’t forget general knowledge)
So we minimize:
$$\text{Loss} = -\text{Reward}(Q_{\text{RL}}) + \beta \cdot KL(Q_{\text{RL}} \parallel Q_{\text{base}})$$
The KL term (with coefficient β) prevents Q_RL from diverging too far from Q_base.
- If β = 0: Q_RL optimizes only for reward (can get nonsense)
- If β is too large: Q_RL stays close to Q_base but doesn’t improve much
- β ≈ 0.01–0.1: Good balance in practice
Worked Example: RLHF-Like Scenario
Imagine:
-
Q_base: When asked “How do I bake a cake?”, the model outputs:
- 40% fluent explanation
- 30% confused rambling
- 20% irrelevant text
- 10% nonsensical
-
Q_RL: After RL training on human preferences, we want:
- 90% fluent explanation
- 5% confused rambling
- 4% irrelevant text
- 1% nonsensical
KL divergence penalty:
$$KL(Q_{\text{RL}} \parallel Q_{\text{base}}) = \sum P_{\text{RL}}(x) \log \frac{P_{\text{RL}}(x)}{P_{\text{base}}(x)}$$
This is large (~1.5 nats) because we’ve shifted probability mass significantly. The penalty term in the loss (β × KL) forces us to make a trade-off: achieve high reward while staying somewhat similar to the base model.
Variations and Related Measures
1. Jensen-Shannon Divergence
A symmetric version of KL divergence:
$$JS(P \parallel Q) = \frac{1}{2} KL(P \parallel Q) + \frac{1}{2} KL(Q \parallel P)$$
Used when you don’t want asymmetry.
2. Reverse KL Divergence
$$KL(Q \parallel P)$$
Used in variational inference and some RL algorithms. Penalizes Q for having probability where P doesn’t.
3. Wasserstein Distance
A different notion of distance between distributions. Less commonly used in LLMs but important in generative models.
Summary Table
| Concept | Formula | Interpretation |
|---|---|---|
| KL Divergence | $KL(P \parallel Q) = \sum P(x) \log \frac{P(x)}{Q(x)}$ | How different Q is from P |
| Range | $KL \geq 0$; equals 0 only if $P = Q$ | Always non-negative |
| Symmetry | Not symmetric | $KL(P \parallel Q) \neq KL(Q \parallel P)$ in general |
| Measured in | Nats (base-e log) or Bits (base-2 log) | Information content units |
| Related to Entropy | $KL(P \parallel Q) = H(P, Q) - H(P)$ | Extra bits needed if using Q instead of P |
Key Takeaways
- KL divergence measures distribution difference. Higher KL = more different distributions.
- It’s not symmetric. The direction matters: P to Q ≠ Q to P.
- Always non-negative, equals 0 only when distributions are identical.
- In RLHF (Paper 15), KL divergence is used as a penalty term to keep the RL-optimized model close to the base model.
- Interpreted as wasted information: If you use wrong distribution Q to approximate true P, you waste KL(P || Q) bits of information.
Further Reading
- Intuition on KL Divergence by Count Bayesie
- KL Divergence in Depth
- Information Theory Tutorial by James V. Stone
Used in:
- Paper 15: Training Language Models to Follow Instructions with Human Feedback — KL penalty in PPO loss
- [Paper 24: Language Model Alignment] — KL divergence in safety constraints