KL Divergence in Machine Learning | Generated by AI

Home 2026.06

Question: What is KL divergence?

Answer:

KL divergence (Kullback-Leibler divergence) measures how much one probability distribution differs from another. It’s the backbone of modern LLM training — shows up in cross-entropy loss, RLHF, DPO, and VAEs.

Definition

\[D_{KL}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}\]

For continuous distributions:

\[D_{KL}(P \| Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx\]

Read as: “how much extra bits does Q cost you, when the true distribution is P?”

Intuition

Think of it as the surprise penalty when you use the wrong distribution:

Connection to Cross-Entropy

This is the key insight for LLM training:

\[H(P, Q) = H(P) + D_{KL}(P \| Q)\]

Cross-entropy = entropy of true distribution + KL divergence. Since $H(P)$ is constant (you don’t control the data), minimizing cross-entropy loss = minimizing KL divergence from model Q to data distribution P.

That’s why language model training works.

Forward vs Reverse KL

  $D_{KL}(P | Q)$ $D_{KL}(Q | P)$
Name Forward KL / “I-projection” Reverse KL / “M-projection”
Behavior Q spreads to cover all of P (“mean-seeking”) Q collapses onto one mode of P (“mode-seeking”)
Used in MLE / cross-entropy training VAE encoder, RL PPO KL penalty

In RLHF / PPO

The KL term in PPO prevents the policy from drifting too far from the reference model:

\[\mathcal{L} = \mathbb{E}\left[r(x, y)\right] - \beta \cdot D_{KL}(\pi_\theta \| \pi_\text{ref})\]

β controls the tradeoff between reward maximization and staying close to the SFT model. Too low → reward hacking. Too high → no learning.

In DPO

DPO implicitly optimizes KL without an explicit RL loop — the closed-form solution to the PPO objective with KL constraint is exactly what DPO’s loss computes.

Minimal Python

import numpy as np

def kl_divergence(p, q):
    # P is true, Q is approximate
    p, q = np.array(p), np.array(q)
    # Clip to avoid log(0)
    return np.sum(p * np.log(p / (q + 1e-10) + 1e-10))

p = [0.4, 0.4, 0.2]   # true distribution
q = [0.3, 0.4, 0.3]   # model's distribution

print(kl_divergence(p, q))   # ~0.026 nats
print(kl_divergence(q, p))   # ~0.027 nats — different!

In PyTorch (what you actually use)

import torch
import torch.nn.functional as F

# Both in log-space for numerical stability
log_q = F.log_softmax(logits, dim=-1)
p = target_distribution  # soft labels

# PyTorch's KLDivLoss expects log-probabilities for input
loss = F.kl_div(log_q, p, reduction='batchmean')

Note: PyTorch’s kl_div takes log(Q) as input and P as target — easy gotcha.

Why It’s Non-Symmetric

\[D_{KL}(P \| Q) = 0.026, \quad D_{KL}(Q \| P) = 0.027\]

Small here but catastrophic when distributions have disjoint support. In VAEs, the reverse KL $D_{KL}(Q | P)$ pushes the encoder posterior toward the prior — it ignores regions where Q=0, leading to posterior collapse in some architectures.

Practical Summary

Context Which KL Why
LLM cross-entropy loss $D_{KL}(P_{data} | P_{model})$ MLE = forward KL
PPO penalty $D_{KL}(\pi_\theta | \pi_{ref})$ Reverse, prevents drift
VAE $D_{KL}(q_\phi(z|x) | p(z))$ Reverse, regularizes latent
Distillation $D_{KL}(P_{teacher} | P_{student})$ Forward, match teacher

Back Donate