优化稀疏注意力中的令牌丢弃机制 | AI生成和翻译

Home 2026.06

问题: 深入解释本文档中描述的技术——基于输出感知的令牌丢弃(OAM)、Blackwell 上注意力内核的动态负载均衡、以及 INT8 KV 缓存量化方案。

回答:


1. 核心问题:稀疏注意力与“令牌丢弃”的含义

在全注意力机制中,每个查询都关注所有键:复杂度 O(n²)。对于 128K 上下文,每层有 128K × 128K = 16B 对。这在推理时是不切实际的。

稀疏注意力通过仅计算子集令牌的注意力来解决这个问题。问题在于:你保留哪些令牌?


2. 传统令牌选择:Q×K 分数(问题所在)

经典的稀疏注意力根据注意力 logit 对令牌进行排序:

score(q, k_i) = q · k_i / sqrt(d_k)

然后你选择分数最高的 top-K 令牌。这很简单,但有一个根本缺陷

Q×K 分数是一种路由概率——它衡量查询和键之间的几何对齐程度。它并不衡量该令牌的对输出贡献了多少。

具体来说:注意力输出是

out = sum_i [ softmax(q · k_i / sqrt(d)) * v_i ]

一个具有高 q · k_i 分数的令牌可能有 v_i ≈ 0(微小的值向量)。你保留它却一无所获。相反,一个具有中等 q · k_i 的令牌可能有巨大的 |v_i|——丢弃它会破坏输出质量。


3. OAM:输出感知度量——解决方案

关键见解:根据每个令牌对输出的实际贡献(而不仅仅是键-查询对齐程度)来评分

令牌 i 对输出的贡献近似为:

contribution_i ≈ softmax(q · k_i) * |v_i|

或者在对数空间(数值更稳定):

log_contribution_i = (q · k_i) + log|v_i|

这就是文档中提到的基于 BetaMax 的对数评估。“BetaMax”可能指的是最大归一化的对数 softmax 近似——不是计算完整的 softmax(需要所有令牌),而是使用局部近似:

# OAM 分数的伪代码
def oam_score(q, k_i, v_i):
    routing_score = q @ k_i / sqrt(d_k)          # 标准注意力 logit
    value_magnitude = torch.log(v_i.norm() + eps) # |v_i| 贡献
    return routing_score + beta * value_magnitude  # beta 是学习或固定的权重

然后你丢弃具有最低 OAM 分数的令牌,而不是最低的 q·k 分数。

为什么这是有原则的: 你实际上是在计算 softmax 之前近似每个令牌输出贡献的上界。这与 Reformer(LSH 注意力)和 Mamba 的选择性状态空间背后的思想密切相关——选择器应该知道下游什么重要,而不仅仅是什么看起来很相似


4. “Slim”实现结果

在 128K 输入上实现 3.6 倍吞吐量是合理的:


5. 注意力内核的动态负载均衡(Blackwell / H20)

问题:变长序列导致的线程发散

在批量推理中,序列具有不同长度。当你将 KV 缓存分块到线程块(CTA/SM)时,一个序列可能是 128K 令牌,而另一个是 512。处理长序列的 CTA 要做 250 倍的工作——所有其他 CTA 都完成并闲置(warp 停顿)。

这就是注意力内核中的负载不均衡问题,从 FlashAttention-2 到 FlashAttention-3 的演变中广为人知。

Split-KV 方法

解决方案:将 KV 维度拆分到多个 CTA 上,然后进行归约

长度为 L 的序列 → 拆分为 N 个长度为 L/N 的块
每个块由一个 CTA 处理
最终输出 = 跨块的在线 softmax 归约(log-sum-exp 技巧)

这正是 FlashAttention-3 通过其“持久内核 + 工作窃取”调度器所做的,也是 vAttention/PagedAttention 变体所实现的。

归约仍然使用数值稳定的在线 softmax:

# 跨分片的在线 softmax
m_new = max(m_prev, m_chunk)
exp_sum = exp(m_prev - m_new) * exp_sum_prev + exp(m_chunk - m_new) * exp_sum_chunk
out = (exp(m_prev - m_new) * out_prev * exp_sum_prev + out_chunk * exp_sum_chunk) / exp_sum

结果:内核加速 1.5 倍,端到端提升 1%。端到端增益较小意味着注意力并非他们工作负载的主要瓶颈(在较短上下文中可能受计算限制,在较长上下文中受内存限制)。


6. INT8 KV 缓存量化:两种方案

为什么需要 INT8 KV 缓存?

KV 缓存是长上下文情况下的内存瓶颈。对于一个具有 128K 上下文的 70B 模型:

KV 缓存大小 = 2 * n_layers * n_heads * d_head * seq_len * bytes_per_element
             = 2 * 80 * 64 * 128 * 128000 * 2 (BF16)
             ≈ 210 GB  ← 无法放入一个 B200(192GB HBM3)

INT8 将其减半至约 105 GB。如果精度可接受,这是值得的。

逐张量静态量化(基线——问题所在)

q_val = round(val / scale)   # scale 是每个张量的单个浮点数
val_reconstructed = q_val * scale

整个 Q、K 或 V 张量使用一个 scale。速度快(一次乘法即可反量化),但对于长上下文来说是灾难性的

方案 A:细粒度 QK(最佳精度,有一定开销)

Q 和 K 的逐令牌分区量化:
  每个“分区”(例如 64 或 128 个令牌)有自己的 scale
  scale[i] = max(|k[i*chunk:(i+1)*chunk]|) / 127

V 使用逐张量(一个全局 scale)

为什么 Q 和 K 最敏感:softmax(QK^T) 是指数级的——Q 或 K 的微小误差会被放大。V 是线性进入的,因此更粗粒度的量化是可以容忍的。

方案 B:更粗粒度的 KV(接近零性能损失)

Q: 逐令牌分区(细粒度)
K: 逐张量(粗粒度)
V: 逐张量(粗粒度)

牺牲一些精度换取性能。两种方案“基本上是零损失”的,因为具有适当校准 scale 的 INT8 相对误差约为 0.4%——在大多数基准测试的噪声范围内。

精度/吞吐量权衡可视化

方案           | Q scale    | K scale    | V scale    | 精度     | 性能
----------------|------------|------------|------------|----------|------
基线 INT8      | 逐张量     | 逐张量     | 逐张量     | 降级     | 最快
方案 B         | 逐令牌     | 逐张量     | 逐张量     | 良好     | 接近基线
方案 A         | 逐令牌     | 逐令牌     | 逐张量     | 最佳     | 慢约 5-15%
BF16(完整)   | 不适用     | 不适用     | 不适用     | 完美     | 2 倍内存

7. 这三者如何协同工作

推理请求(128K 上下文)
         │
         ▼
OAM 令牌丢弃 ──► 减少需要关注的 KV 数量(3.6 倍计算减少)
         │
         ▼
INT8 KV 缓存 ──► KV 读取的内存带宽减半
         │
         ▼
Split-KV 负载均衡 ► 将剩余的 KV 工作均匀分配到各 SM 上
         │
         ▼
输出

它们是正交的优化,可以叠加:OAM 减少哪些令牌被处理,INT8 减少每次令牌读取的开销,Split-KV 减少不均衡造成的空闲时间


与你当前工作的联系

鉴于你正在 MI300X 上训练并研究 DeepSeek v4 MoE:


Back Donate