优化稀疏注意力中的令牌丢弃机制 | AI生成和翻译
问题: 深入解释本文档中描述的技术——基于输出感知的令牌丢弃(OAM)、Blackwell 上注意力内核的动态负载均衡、以及 INT8 KV 缓存量化方案。
回答:
1. 核心问题:稀疏注意力与“令牌丢弃”的含义
在全注意力机制中,每个查询都关注所有键:复杂度 O(n²)。对于 128K 上下文,每层有 128K × 128K = 16B 对。这在推理时是不切实际的。
稀疏注意力通过仅计算子集令牌的注意力来解决这个问题。问题在于:你保留哪些令牌?
2. 传统令牌选择:Q×K 分数(问题所在)
经典的稀疏注意力根据注意力 logit 对令牌进行排序:
score(q, k_i) = q · k_i / sqrt(d_k)
然后你选择分数最高的 top-K 令牌。这很简单,但有一个根本缺陷:
Q×K 分数是一种路由概率——它衡量查询和键之间的几何对齐程度。它并不衡量该令牌的值对输出贡献了多少。
具体来说:注意力输出是
out = sum_i [ softmax(q · k_i / sqrt(d)) * v_i ]
一个具有高 q · k_i 分数的令牌可能有 v_i ≈ 0(微小的值向量)。你保留它却一无所获。相反,一个具有中等 q · k_i 的令牌可能有巨大的 |v_i|——丢弃它会破坏输出质量。
3. OAM:输出感知度量——解决方案
关键见解:根据每个令牌对输出的实际贡献(而不仅仅是键-查询对齐程度)来评分。
令牌 i 对输出的贡献近似为:
contribution_i ≈ softmax(q · k_i) * |v_i|
或者在对数空间(数值更稳定):
log_contribution_i = (q · k_i) + log|v_i|
这就是文档中提到的基于 BetaMax 的对数评估。“BetaMax”可能指的是最大归一化的对数 softmax 近似——不是计算完整的 softmax(需要所有令牌),而是使用局部近似:
# OAM 分数的伪代码
def oam_score(q, k_i, v_i):
routing_score = q @ k_i / sqrt(d_k) # 标准注意力 logit
value_magnitude = torch.log(v_i.norm() + eps) # |v_i| 贡献
return routing_score + beta * value_magnitude # beta 是学习或固定的权重
然后你丢弃具有最低 OAM 分数的令牌,而不是最低的 q·k 分数。
为什么这是有原则的: 你实际上是在计算 softmax 之前近似每个令牌输出贡献的上界。这与 Reformer(LSH 注意力)和 Mamba 的选择性状态空间背后的思想密切相关——选择器应该知道下游什么重要,而不仅仅是什么看起来很相似。
4. “Slim”实现结果
在 128K 输入上实现 3.6 倍吞吐量是合理的:
- 在 128K 序列长度下,全注意力 = 每头 128K² 次运算
- 如果 OAM 丢弃 90% 的令牌 → 每个查询的有效上下文 = 12.8K
- 计算减少:注意力部分约 10 倍,实际由于开销会更少
- 3.6 倍的端到端提升是现实的,因为注意力仅是瓶颈之一(FFN、内存带宽也很重要)
5. 注意力内核的动态负载均衡(Blackwell / H20)
问题:变长序列导致的线程发散
在批量推理中,序列具有不同长度。当你将 KV 缓存分块到线程块(CTA/SM)时,一个序列可能是 128K 令牌,而另一个是 512。处理长序列的 CTA 要做 250 倍的工作——所有其他 CTA 都完成并闲置(warp 停顿)。
这就是注意力内核中的负载不均衡问题,从 FlashAttention-2 到 FlashAttention-3 的演变中广为人知。
Split-KV 方法
解决方案:将 KV 维度拆分到多个 CTA 上,然后进行归约:
长度为 L 的序列 → 拆分为 N 个长度为 L/N 的块
每个块由一个 CTA 处理
最终输出 = 跨块的在线 softmax 归约(log-sum-exp 技巧)
这正是 FlashAttention-3 通过其“持久内核 + 工作窃取”调度器所做的,也是 vAttention/PagedAttention 变体所实现的。
归约仍然使用数值稳定的在线 softmax:
# 跨分片的在线 softmax
m_new = max(m_prev, m_chunk)
exp_sum = exp(m_prev - m_new) * exp_sum_prev + exp(m_chunk - m_new) * exp_sum_chunk
out = (exp(m_prev - m_new) * out_prev * exp_sum_prev + out_chunk * exp_sum_chunk) / exp_sum
结果:内核加速 1.5 倍,端到端提升 1%。端到端增益较小意味着注意力并非他们工作负载的主要瓶颈(在较短上下文中可能受计算限制,在较长上下文中受内存限制)。
6. INT8 KV 缓存量化:两种方案
为什么需要 INT8 KV 缓存?
KV 缓存是长上下文情况下的内存瓶颈。对于一个具有 128K 上下文的 70B 模型:
KV 缓存大小 = 2 * n_layers * n_heads * d_head * seq_len * bytes_per_element
= 2 * 80 * 64 * 128 * 128000 * 2 (BF16)
≈ 210 GB ← 无法放入一个 B200(192GB HBM3)
INT8 将其减半至约 105 GB。如果精度可接受,这是值得的。
逐张量静态量化(基线——问题所在)
q_val = round(val / scale) # scale 是每个张量的单个浮点数
val_reconstructed = q_val * scale
整个 Q、K 或 V 张量使用一个 scale。速度快(一次乘法即可反量化),但对于长上下文来说是灾难性的:
- 在 128K 令牌下,K/V 值的范围跨越多个数量级
- 一个单一的 scale 无法同时准确表示小值和大值
- 量化误差在 80 层中累积 → 智能体任务失败(它们维护长时间的历史)
方案 A:细粒度 QK(最佳精度,有一定开销)
Q 和 K 的逐令牌分区量化:
每个“分区”(例如 64 或 128 个令牌)有自己的 scale
scale[i] = max(|k[i*chunk:(i+1)*chunk]|) / 127
V 使用逐张量(一个全局 scale)
为什么 Q 和 K 最敏感:softmax(QK^T) 是指数级的——Q 或 K 的微小误差会被放大。V 是线性进入的,因此更粗粒度的量化是可以容忍的。
方案 B:更粗粒度的 KV(接近零性能损失)
Q: 逐令牌分区(细粒度)
K: 逐张量(粗粒度)
V: 逐张量(粗粒度)
牺牲一些精度换取性能。两种方案“基本上是零损失”的,因为具有适当校准 scale 的 INT8 相对误差约为 0.4%——在大多数基准测试的噪声范围内。
精度/吞吐量权衡可视化
方案 | Q scale | K scale | V scale | 精度 | 性能
----------------|------------|------------|------------|----------|------
基线 INT8 | 逐张量 | 逐张量 | 逐张量 | 降级 | 最快
方案 B | 逐令牌 | 逐张量 | 逐张量 | 良好 | 接近基线
方案 A | 逐令牌 | 逐令牌 | 逐张量 | 最佳 | 慢约 5-15%
BF16(完整) | 不适用 | 不适用 | 不适用 | 完美 | 2 倍内存
7. 这三者如何协同工作
推理请求(128K 上下文)
│
▼
OAM 令牌丢弃 ──► 减少需要关注的 KV 数量(3.6 倍计算减少)
│
▼
INT8 KV 缓存 ──► KV 读取的内存带宽减半
│
▼
Split-KV 负载均衡 ► 将剩余的 KV 工作均匀分配到各 SM 上
│
▼
输出
它们是正交的优化,可以叠加:OAM 减少哪些令牌被处理,INT8 减少每次令牌读取的开销,Split-KV 减少不均衡造成的空闲时间。
与你当前工作的联系
鉴于你正在 MI300X 上训练并研究 DeepSeek v4 MoE:
- OAM 与 MoE 推理直接相关:在 MoE 中,令牌路由也使用不考虑输出大小的分数。同样的论点也适用——你可以根据专家输出范数(而非仅门控分数)来加权专家路由。这是一个开放的研究方向。
- INT8 KV 量化:当你在 RTX 4070(12GB)上运行 128K 上下文推理时,这就是它 OOM 的原因——未量化的 KV。尝试在
llama.cpp中使用-cache-type-k q8_0 -cache-type-v q8_0来获得本地模型上的方案 B 等效实现。 - Split-KV:FlashAttention-3 的
tile_scheduler正是做这个的。值得将 FA3 论文与此一同阅读。