KL Divergence

🟡 intermediate

KL Divergence: Measuring Probability Distribution Differences

What is KL Divergence?

KL divergence (Kullback-Leibler divergence) measures how different two probability distributions are. It answers the question: “If I expect distribution P but I actually see distribution Q, how surprised am I?”

It’s a distance-like measure between distributions, but it’s not symmetric: the “distance” from P to Q is usually different from the “distance” from Q to P.

Formal definition:

$$KL(P \parallel Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)}$$

Where:

  • $P(x)$ is the “true” or “reference” distribution
  • $Q(x)$ is the approximate or “model” distribution
  • The sum is over all possible outcomes $x$
  • We use natural log or log base 2 (see context for which)

Alternative form (numerically more stable):

$$KL(P \parallel Q) = \sum_{x} P(x) [\log P(x) - \log Q(x)]$$

Intuition: Surprise and Inefficiency

Imagine you’re a programmer designing a compression algorithm.

You believe the true message distribution is P (maybe 70% A’s, 20% B’s, 10% C’s). So you design a code optimized for this: A uses 1 bit, B uses 2 bits, C uses 3 bits.

But the actual distribution is Q (maybe 10% A’s, 30% B’s, 60% C’s). Now your code is inefficient! You’re using lots of bits for C (which is common) when you should have used few bits.

KL divergence measures this inefficiency: how much wasted communication you get by using the wrong distribution.

In machine learning: KL divergence measures how much extra information (wasted bits) you need if you approximate true distribution P with model distribution Q.

Worked Example: Two Coin Flips

True distribution P: Fair coin (50% heads, 50% tails) Model distribution Q: Biased coin (60% heads, 40% tails)

Calculate KL(P || Q):

$$KL(P \parallel Q) = P(\text{heads}) \log \frac{P(\text{heads})}{Q(\text{heads})} + P(\text{tails}) \log \frac{P(\text{tails})}{Q(\text{tails})}$$

$$KL(P \parallel Q) = 0.5 \log \frac{0.5}{0.6} + 0.5 \log \frac{0.5}{0.4}$$

Calculate each term:

  • $\log \frac{0.5}{0.6} = \log(0.833) \approx -0.1823$ (using natural log)
  • $\log \frac{0.5}{0.4} = \log(1.25) \approx 0.2231$ (using natural log)

$$KL(P \parallel Q) = 0.5 \times (-0.1823) + 0.5 \times 0.2231$$

$$KL(P \parallel Q) = -0.0912 + 0.1116 = 0.0204 \text{ nats}$$

Interpretation: If the true distribution is fair (P) but you model it as biased (Q), you waste about 0.0204 nats of information. This is a small difference because the distributions are similar.

Worked Example: Comparing Fair Coin vs Completely Biased Coin

True distribution P: Fair coin (50% heads, 50% tails)
Wrong model Q: Completely biased (99% heads, 1% tails)

Calculate KL(P || Q):

$$KL(P \parallel Q) = 0.5 \log \frac{0.5}{0.99} + 0.5 \log \frac{0.5}{0.01}$$

Calculate each term:

  • $\log \frac{0.5}{0.99} = \log(0.505) \approx -0.6826$ (natural log)
  • $\log \frac{0.5}{0.01} = \log(50) \approx 3.912$ (natural log)

$$KL(P \parallel Q) = 0.5 \times (-0.6826) + 0.5 \times 3.912$$

$$KL(P \parallel Q) = -0.3413 + 1.956 = 1.615 \text{ nats}$$

Interpretation: This is a large difference! If you use an almost-always-heads model to approximate a fair coin, you waste 1.615 nats. The distributions are very different.

Worked Example: Three Possible Outcomes

True distribution P:

  • Outcome A: P(A) = 0.5
  • Outcome B: P(B) = 0.3
  • Outcome C: P(C) = 0.2

Model distribution Q:

  • Outcome A: Q(A) = 0.4
  • Outcome B: Q(B) = 0.4
  • Outcome C: Q(C) = 0.2

Calculate KL(P || Q):

$$KL(P \parallel Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}$$

$$= P(A) \log \frac{P(A)}{Q(A)} + P(B) \log \frac{P(B)}{Q(B)} + P(C) \log \frac{P(C)}{Q(C)}$$

$$= 0.5 \log \frac{0.5}{0.4} + 0.3 \log \frac{0.3}{0.4} + 0.2 \log \frac{0.2}{0.2}$$

Calculate each term:

  • $0.5 \log \frac{0.5}{0.4} = 0.5 \log(1.25) = 0.5 \times 0.223 = 0.1115$
  • $0.3 \log \frac{0.3}{0.4} = 0.3 \log(0.75) = 0.3 \times (-0.288) = -0.0864$
  • $0.2 \log \frac{0.2}{0.2} = 0.2 \log(1) = 0.2 \times 0 = 0$

$$KL(P \parallel Q) = 0.1115 - 0.0864 + 0 = 0.0251 \text{ nats}$$

Interpretation: The model Q is a reasonable approximation of P, with a KL divergence of about 0.025 nats.

Key Properties

1. KL Divergence is Always Non-Negative

$$KL(P \parallel Q) \geq 0$$

Equality holds only when $P = Q$ (the distributions are identical).

Proof intuition: Gibbs’ inequality states that cross-entropy $H(P, Q) \geq H(P)$, where $H(P, Q) = -\sum P(x) \log Q(x)$. Rearranging gives KL ≥ 0.

2. KL Divergence is NOT Symmetric

$$KL(P \parallel Q) \neq KL(Q \parallel P)$$

This is fundamentally important. The divergence from P to Q is different from Q to P.

Example:

  • KL(fair coin || biased coin) = 0.0204 nats
  • KL(biased coin || fair coin) = ? (Different!)

Let’s calculate KL(Q || P) from our first example:

$$KL(Q \parallel P) = 0.6 \log \frac{0.6}{0.5} + 0.4 \log \frac{0.4}{0.5}$$

$$= 0.6 \log(1.2) + 0.4 \log(0.8)$$

$$= 0.6 \times 0.1823 + 0.4 \times (-0.2231)$$

$$= 0.1094 - 0.0892 = 0.0202 \text{ nats}$$

In this case, they’re close (both ~0.02), but in general, they can be very different.

3. Interpretation of the Direction

KL(P || Q): “I expected P, but got Q. How wrong was I?” — Penalizes Q for assigning low probability to events that actually happen in P.

KL(Q || P): “I have model Q, but truth is P. How wrong is Q?” — Penalizes Q for assigning high probability to events that don’t happen in P.

In machine learning practice:

  • Training: We often minimize KL(P_true || Q_model) = minimize surprise given true data
  • RLHF (Paper 15): We use KL(Q_RL || Q_SFT) as a penalty to prevent the RL model from diverging too far from the base model

4. Relationship to Entropy

KL divergence can be decomposed as:

$$KL(P \parallel Q) = H(P, Q) - H(P)$$

Where:

  • $H(P, Q) = -\sum P(x) \log Q(x)$ is cross-entropy
  • $H(P) = -\sum P(x) \log P(x)$ is entropy

Interpretation: KL divergence is the extra entropy (extra bits) you need if you use Q instead of P.

Real-World Example: Language Model Training

In Paper 15 (RLHF), we have:

  • P_true: The true distribution of good, helpful responses (from human preferences)
  • Q_base: The SFT model’s distribution (supervised fine-tuned on human examples)
  • Q_RL: The RL-optimized model’s distribution (after policy gradient updates)

We want to:

  1. Optimize Q_RL to maximize reward
  2. But stay close to Q_base (don’t forget general knowledge)

So we minimize:

$$\text{Loss} = -\text{Reward}(Q_{\text{RL}}) + \beta \cdot KL(Q_{\text{RL}} \parallel Q_{\text{base}})$$

The KL term (with coefficient β) prevents Q_RL from diverging too far from Q_base.

  • If β = 0: Q_RL optimizes only for reward (can get nonsense)
  • If β is too large: Q_RL stays close to Q_base but doesn’t improve much
  • β ≈ 0.01–0.1: Good balance in practice

Worked Example: RLHF-Like Scenario

Imagine:

  • Q_base: When asked “How do I bake a cake?”, the model outputs:

    • 40% fluent explanation
    • 30% confused rambling
    • 20% irrelevant text
    • 10% nonsensical
  • Q_RL: After RL training on human preferences, we want:

    • 90% fluent explanation
    • 5% confused rambling
    • 4% irrelevant text
    • 1% nonsensical

KL divergence penalty:

$$KL(Q_{\text{RL}} \parallel Q_{\text{base}}) = \sum P_{\text{RL}}(x) \log \frac{P_{\text{RL}}(x)}{P_{\text{base}}(x)}$$

This is large (~1.5 nats) because we’ve shifted probability mass significantly. The penalty term in the loss (β × KL) forces us to make a trade-off: achieve high reward while staying somewhat similar to the base model.

1. Jensen-Shannon Divergence

A symmetric version of KL divergence:

$$JS(P \parallel Q) = \frac{1}{2} KL(P \parallel Q) + \frac{1}{2} KL(Q \parallel P)$$

Used when you don’t want asymmetry.

2. Reverse KL Divergence

$$KL(Q \parallel P)$$

Used in variational inference and some RL algorithms. Penalizes Q for having probability where P doesn’t.

3. Wasserstein Distance

A different notion of distance between distributions. Less commonly used in LLMs but important in generative models.

Summary Table

ConceptFormulaInterpretation
KL Divergence$KL(P \parallel Q) = \sum P(x) \log \frac{P(x)}{Q(x)}$How different Q is from P
Range$KL \geq 0$; equals 0 only if $P = Q$Always non-negative
SymmetryNot symmetric$KL(P \parallel Q) \neq KL(Q \parallel P)$ in general
Measured inNats (base-e log) or Bits (base-2 log)Information content units
Related to Entropy$KL(P \parallel Q) = H(P, Q) - H(P)$Extra bits needed if using Q instead of P

Key Takeaways

  1. KL divergence measures distribution difference. Higher KL = more different distributions.
  2. It’s not symmetric. The direction matters: P to Q ≠ Q to P.
  3. Always non-negative, equals 0 only when distributions are identical.
  4. In RLHF (Paper 15), KL divergence is used as a penalty term to keep the RL-optimized model close to the base model.
  5. Interpreted as wasted information: If you use wrong distribution Q to approximate true P, you waste KL(P || Q) bits of information.

Further Reading


Used in: