FlashAttention Algorithm and Benefits | Generated by AI

Home 2025.11

Deep explanation of FlashAttention (Tri Dao et al., 2022)

Short summary up front: FlashAttention removes the need to allocate the full \(N\times N\) attention matrix in GPU DRAM (HBM) by computing attention block-by-block inside on-chip SRAM (registers / shared memory), fusing the whole attention pipeline into a single kernel and using numerically correct blockwise softmax accumulation. That drastically cuts HBM traffic and memory from \(O(N^2)\) to effectively \(O(N)\), and in practice gives large wall-clock speedups on GPUs for long sequences. citeturn0search0turn0search9


The problem: why standard attention is IO-bound

Transformer self-attention (scaled dot-product) is usually implemented with three steps:

  1. compute scores \(S = Q K^\top\) (size \(N\times N\));
  2. compute rowwise softmax \(P = \mathrm{softmax}(S)\);
  3. compute output \(O = P V\).

Naively you materialize \(S\) (and often \(P\)) in GPU DRAM. For sequence length \(N\) this uses \(O(N^2)\) memory and leads to two IO problems:

FlashAttention reframes attention as an IO problem, not just a FLOP problem, and targets reducing HBM accesses. citeturn0search0


Core ideas (high level)

  1. Tile the matrices \(Q, K, V\) into blocks that fit in on-chip SRAM (shared memory / registers).
  2. Process attention block-by-block: for a given \(Q\)-tile and a streaming set of \(K,V\)-tiles, compute the partial contributions to the output and immediately accumulate them — never materialize the full \(N\times N\) score matrix in DRAM.
  3. Fuse everything into one kernel: the kernel loads tiles into SRAM, computes \(QK^\top\) for that tile pair, applies softmax logic and multiplies by the \(V\)-tile, and writes partial outputs — all without round-trips of intermediate large matrices to DRAM. Kernel fusion reduces instruction and memory overhead.
  4. Blockwise numerically stable softmax accumulation: because softmax across the whole row needs the global max and sum, FlashAttention uses a running max / running sum (log-sum-exp style) to combine softmax contributions from multiple \(K\)-tiles exactly and stably without storing the whole row of scores.
  5. Backward via recomputation: instead of storing large intermediates for backward, recompute the forward attention for each block during the backward pass (trade extra FLOPs for much less DRAM IO). The saved DRAM IO usually yields net speedup since DRAM IO dominates. citeturn0search2turn0search10

These ideas together yield both memory reduction and wall-clock speed improvements. citeturn0search0


Blockwise algorithm — step by step (forward)

Consider a single attention head with sequence length \(N\) and head dim \(d\). Choose a tile size \(B\) so a \(B\times B\) scores block and the corresponding \(Q\), \(K\), \(V\) tiles fit in SRAM.

For each query tile \(Q_{i}\) (rows \(iB:(i+1)B\)):

  1. Initialize an output accumulator \(O_i \leftarrow 0\).
  2. Initialize running normalization state: row_max (per query row) to \(-\infty\), row_sum to 0. These track the numerically stable denom for softmax across multiple K-tiles.
  3. For each key/value tile \(K_{j}, V_{j}\) (columns \(jB:(j+1)B\)):
    • Load \(Q_i\), \(K_j\), \(V_j\) into SRAM.
    • Compute the tile of raw scores \(S_{ij} = Q_i K_j^\top / \sqrt{d}\) (shape \(B\times B\) in vectorized form).
    • For each row in \(S_{ij}\), compute the local row max \(m_{ij}\) and exponentiated values \(\exp(S_{ij} - m_{ij})\).
    • Merge this tile’s exponentials into the running row normalization using the log-sum-exp trick:
      • Let \(M = \max(\text{row_max}, m_{ij})\).
      • Update row_sum := row_sum · exp(row_max − M) + local_sum · exp(m_{ij} − M).
      • Set row_max := \(M\).
    • Compute the tile’s contribution to the accumulator with the appropriately scaled exponentials: accumulate \(O_i \mathrel{+}= \text{(tile-softmax)} \times V_j\). (All done inside SRAM.)
  4. After streaming all K-tiles, finalize normalization using row_sum and row_max to produce correct softmax outputs; write \(O_i\) to DRAM.

Key point: no \(N\times N\) matrix is ever written to DRAM; only small tiles and final outputs are. The numerically-correct accumulation using running max + sum is what lets the per-tile softmax pieces combine exactly into the same result as a full softmax over the row. citeturn0search2turn0search10


Why kernel fusion and SRAM tiling wins in practice

Empirical results from the paper and follow-ups show multiple× speedups (e.g., 2–7× in their reported benchmarks depending on model and seq length) and large reductions in peak memory. citeturn0search0turn0search10


Important implementation details & tradeoffs


FlashAttention vs. approximate long-attention methods

FlashAttention keeps exact attention semantics (same numerical result as full attention up to floating-point rounding), whereas many long-attention methods approximate attention (sparsity, low-rank, FAVOR+, etc.) and trade quality for memory/time. FlashAttention instead reduces memory/IO cost while preserving the exact computation, so model quality is unchanged while throughput/memory improve. That’s why it’s widely attractive: no accuracy tradeoff, just a better low-level kernel. citeturn0search0


Practical availability & ecosystem

Caveat: “No need for custom kernels” is only partly true — FlashAttention is a custom fused kernel (the work in the repo) that frameworks call. Modern PyTorch versions may internally ship comparable fused kernels or delegate to vendor libraries, but the core idea requires a fused kernel implementation (whether in CUDA, Triton, or vendor code). The important lesson: you (as a model user) don’t have to write those kernels yourself — use the provided operator. citeturn0search9turn0search7


Extensions and follow-ups


When FlashAttention helps most (rules of thumb)


Quick pseudocode (conceptual)

for each query tile Qi:
    row_max = -inf (per row)
    row_sum = 0    (per row)
    out_acc = 0    (B x d_v)
    for each key/value tile Kj, Vj:
        load Qi, Kj, Vj into SRAM
        S = Qi @ Kj.T / sqrt(d)
        m = max_row(S)
        expS = exp(S - m)      # local exponentials
        local_sum = sum_row(expS)
        # merge normalizers stably:
        M = max(row_max, m)
        row_sum = row_sum * exp(row_max - M) + local_sum * exp(m - M)
        row_max = M
        # accumulate partial output with scaled expS:
        out_acc += (expS / (row_sum after merge)) @ Vj
    write out_acc to DRAM

(Implementation stores more metadata and arranges exact scaling carefully; above is conceptual.) citeturn0search2



If you want, I can:


Back

openai/gpt-5

Donate