机器学习中的KL散度 | AI生成和翻译
问题: 什么是KL散度?
答案:
KL散度(Kullback-Leibler散度)衡量一个概率分布与另一个概率分布的差异程度。它是现代大语言模型训练的基石——出现在交叉熵损失、RLHF、DPO和VAE中。
定义
\[D_{KL}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}\]对于连续分布:
\[D_{KL}(P \| Q) = \int P(x) \log \frac{P(x)}{Q(x)} dx\]可以理解为:“当真实分布为P时,使用Q会多花费多少额外比特?”
直观理解
可以将其视为使用错误分布时的惊讶惩罚:
- 如果Q完全等于P → KL = 0
- 如果Q在P高概率的区域分配低概率 → KL值会急剧增大
- 非对称性:$D_{KL}(P | Q) \neq D_{KL}(Q | P)$
与交叉熵的联系
这是理解大语言模型训练的关键:
\[H(P, Q) = H(P) + D_{KL}(P \| Q)\]交叉熵 = 真实分布的熵 + KL散度。由于$H(P)$是常数(你无法控制数据),最小化交叉熵损失 = 最小化模型Q与数据分布P之间的KL散度。
这就是语言模型训练有效的原因。
前向KL与反向KL
| $D_{KL}(P | Q)$ | $D_{KL}(Q | P)$ | |
|---|---|---|
| 名称 | 前向KL / “I-投影” | 反向KL / “M-投影” |
| 行为 | Q扩散以覆盖P的所有区域(”均值寻求”) | Q收缩到P的单一模式上(”模式寻求”) |
| 应用场景 | 最大似然估计 / 交叉熵训练 | VAE编码器、RL PPO KL惩罚 |
在RLHF / PPO中的应用
PPO中的KL项防止策略偏离参考模型过远:
\[\mathcal{L} = \mathbb{E}\left[r(x, y)\right] - \beta \cdot D_{KL}(\pi_\theta \| \pi_\text{ref})\]β控制奖励最大化与保持接近SFT模型之间的权衡。太低 → 奖励作弊。太高 → 无法学习。
在DPO中的应用
DPO隐式优化KL,无需显式的RL循环——带有KL约束的PPO目标的闭式解正是DPO损失所计算的。
最小化Python实现
import numpy as np
def kl_divergence(p, q):
# P是真实分布,Q是近似分布
p, q = np.array(p), np.array(q)
# 裁剪以避免log(0)
return np.sum(p * np.log(p / (q + 1e-10) + 1e-10))
p = [0.4, 0.4, 0.2] # 真实分布
q = [0.3, 0.4, 0.3] # 模型的分布
print(kl_divergence(p, q)) # ~0.026 nats
print(kl_divergence(q, p)) # ~0.027 nats — 不同!
在PyTorch中的实际使用
import torch
import torch.nn.functional as F
# 两者都处于对数空间以保证数值稳定性
log_q = F.log_softmax(logits, dim=-1)
p = target_distribution # 软标签
# PyTorch的KLDivLoss期望输入为对数概率
loss = F.kl_div(log_q, p, reduction='batchmean')
注意:PyTorch的kl_div以log(Q)作为输入,P作为目标——容易出错。
为什么是非对称的
\[D_{KL}(P \| Q) = 0.026, \quad D_{KL}(Q \| P) = 0.027\]这里差异很小,但当分布的支持集不重叠时,结果可能是灾难性的。在VAE中,反向KL $D_{KL}(Q | P)$ 推动编码器后验分布趋向先验分布——它会忽略Q=0的区域,导致某些架构中出现后验坍塌。
实用总结
| 场景 | 使用哪种KL | 原因 |
|---|---|---|
| 大语言模型交叉熵损失 | $D_{KL}(P_{data} | P_{model})$ | 最大似然估计 = 前向KL |
| PPO惩罚项 | $D_{KL}(\pi_\theta | \pi_{ref})$ | 反向KL,防止漂移 |
| VAE | $D_{KL}(q_\phi(z|x) | p(z))$ | 反向KL,正则化潜在空间 |
| 知识蒸馏 | $D_{KL}(P_{teacher} | P_{student})$ | 前向KL,匹配教师模型 |