FlashAttention算法及其优势 | AI生成和翻译

Home 2025.11

FlashAttention 深度解析(Tri Dao 等人,2022)

简短总结:FlashAttention 通过分块计算注意力机制,在芯片内 SRAM(寄存器/共享内存)中完成全部操作,将整个注意力流程融合为单一内核,并采用数值正确的分块 softmax 累积技术。这种方法消除了在 GPU DRAM(HBM)中分配完整 \(N\times N\) 注意力矩阵的需求,从而将 HBM 流量和内存占用从 \(O(N^2)\) 显著降低至实际 \(O(N)\),在处理长序列时能为 GPU 带来显著的实时加速效果。citeturn0search0turn0search9


问题根源:标准注意力为何受 IO 限制

Transformer 自注意力机制(缩放点积)通常通过三个步骤实现:

  1. 计算得分矩阵 \(S = Q K^\top\)(尺寸 \(N\times N\));
  2. 计算行向 softmax \(P = \mathrm{softmax}(S)\);
  3. 计算输出 \(O = P V\)。

传统实现会将 \(S\)(及常包括 \(P\))显式存储在 GPU DRAM 中。对于序列长度 \(N\),这会占用 \(O(N^2)\) 内存并引发两个 IO 问题:

FlashAttention 将注意力重新定义为IO 问题而非单纯计算问题,致力于减少 HBM 访问次数。citeturn0search0


核心思想(高层视角)

  1. 矩阵分块:将 \(Q, K, V\) 矩阵划分为能放入芯片内 SRAM(共享内存/寄存器)的块。
  2. 分块处理注意力:对每个 \(Q\) 块与流式传输的 \(K,V\) 块集合,计算输出部分贡献并立即累积——永远不在 DRAM 中显式存储完整的 \(N\times N\) 得分矩阵。
  3. 全流程内核融合:内核将数据块加载至 SRAM,计算块间 \(QK^\top\),应用 softmax 逻辑并与 \(V\) 块相乘,写入部分输出——所有操作均无需将中间大矩阵往返 DRAM。内核融合减少了指令与内存开销。
  4. 分块数值稳定 softmax 累积:由于整行 softmax 需要全局最大值与求和值,FlashAttention 采用运行最大值/运行求和(类 log-sum-exp 方法)来精确稳定地合并来自多个 \(K\) 块的 softmax 贡献,而无需存储整行得分。
  5. 重计算反向传播:为避免存储大型中间结果供反向传播使用,在反向过程中对每个数据块重新计算前向注意力(以额外计算量换取显著减少的 DRAM IO)。由于 DRAM IO 占主导地位,节省的 DRAM IO 通常能带来净加速效果。citeturn0search2turn0search10

这些思想共同实现了内存占用降低与实时速度提升。citeturn0search0


分块算法逐步解析(前向传播)

考虑单注意力头,序列长度 \(N\),头维度 \(d\)。选择分块尺寸 \(B\),使得 \(B\times B\) 得分块及对应 \(Q\)、\(K\)、\(V\) 块能放入 SRAM。

对每个查询块 \(Q_{i}\)(行索引 \(iB:(i+1)B\)):

  1. 初始化输出累加器 \(O_i \leftarrow 0\)。
  2. 初始化运行归一化状态:每查询行的 row_max 设为 \(-\infty\),row_sum 设为 0。这些变量用于跨多个 K 块跟踪数值稳定的 softmax 分母。
  3. 对每个键/值块 \(K_{j}, V_{j}\)(列索引 \(jB:(j+1)B\)):
    • 将 \(Q_i\)、\(K_j\)、\(V_j\) 加载至 SRAM。
    • 计算原始得分块 \(S_{ij} = Q_i K_j^\top / \sqrt{d}\)(以向量化形式实现,形状 \(B\times B\))。
    • 对 \(S_{ij}\) 中的每一行,计算局部行最大值 \(m_{ij}\) 及指数化值 \(\exp(S_{ij} - m_{ij})\)。
    • 使用 log-sum-exp 技巧将该块的指数值合并至运行行归一化状态:
      • 令 \(M = \max(\text{row_max}, m_{ij})\)。
      • 更新 row_sum := row_sum · exp(row_max − M) + local_sum · exp(m_{ij} − M)。
      • 设置 row_max := \(M\)。
    • 使用适当缩放的指数值计算该块对累加器的贡献:累加 \(O_i \mathrel{+}= \text{(块内-softmax)} \times V_j\)(全部在 SRAM 内完成)。
  4. 流式处理完所有 K 块后,使用 row_sum 和 row_max 完成最终归一化,生成正确 softmax 输出;将 \(O_i\) 写入 DRAM。

关键点:没有任何 \(N\times N\) 矩阵被写入 DRAM;仅小型数据块与最终输出需要存储。采用运行最大值与求和的数值正确累积技术,使得每块 softmax 片段能精确组合成与整行 softmax 相同的结果。citeturn0search2turn0search10


内核融合与 SRAM 分块的实际优势

论文及后续研究的实证结果显示,在不同模型和序列长度下可实现数倍加速(如报告基准测试中的 2–7 倍),并显著降低峰值内存占用。citeturn0search0turn0search10


重要实现细节与权衡


FlashAttention 与近似长注意力方法对比

FlashAttention 保持精确的注意力语义(与完整注意力的数值结果相同,仅存在浮点舍入误差),而许多长注意力方法采用近似注意力(稀疏化、低秩、FAVOR+ 等),以质量换取内存/时间效率。FlashAttention 则在保持精确计算的前提下降低内存/IO 成本,因此模型质量不变而吞吐量/内存效率提升。这正是其广受青睐的原因:无需精度妥协,仅通过更优底层内核实现性能突破。citeturn0search0


实际可用性与生态整合

注意事项:“无需自定义内核”仅部分成立——FlashAttention 本身即是需由框架调用的自定义融合内核(代码库中的实现)。现代 PyTorch 版本可能内部搭载了类似融合内核或委托至供应商库,但核心思想仍需融合内核实现(无论是用 CUDA、Triton 还是供应商代码)。重要启示:模型使用者无需亲自编写这些内核——直接使用现成算子即可。citeturn0search9turn0search7


扩展与后续研究


FlashAttention 最适用场景(经验法则)


快速伪代码(概念性)

对每个查询块 Qi:
    row_max = -inf (每行)
    row_sum = 0    (每行)
    out_acc = 0    (B x d_v)
    对每个键/值块 Kj, Vj:
        将 Qi, Kj, Vj 加载至 SRAM
        S = Qi @ Kj.T / sqrt(d)
        m = 行最大值(S)
        expS = exp(S - m)      # 局部指数值
        local_sum = 行求和(expS)
        # 稳定合并归一化因子:
        M = max(row_max, m)
        row_sum = row_sum * exp(row_max - M) + local_sum * exp(m - M)
        row_max = M
        # 使用缩放后 expS 累积部分输出:
        out_acc += (expS / (合并后 row_sum)) @ Vj
    将 out_acc 写入 DRAM

(实际实现会存储更多元数据并精心安排精确缩放;此处为概念示意。)citeturn0search2


参考文献与延伸阅读


如需进一步了解,我可以:


Back

openai/gpt-5

Donate