FlashAttention Solves Transformers Memory Bottleneck | Generated by AI

Home 2025.10

Background: The Challenges of Standard Attention in Transformers

Transformers, the backbone of models like GPT and BERT, rely heavily on the self-attention mechanism to process sequences. In self-attention, given input sequences represented as query matrix \(Q\), key matrix \(K\), and value matrix \(V\) (each of shape \(N \times d\), where \(N\) is the sequence length and \(d\) is the head dimension, typically 64-128), the attention output \(O\) is computed as:

\[ S = QK^T \in \mathbb{R}^{N \times N}, \quad P = \softmax(S) \in \mathbb{R}^{N \times N}, \quad O = PV \in \mathbb{R}^{N \times d}, \]

where \(\softmax\) is applied row-wise, and \(S\) is often scaled by \(\tau = 1 / \sqrt{d}\) for stability. Additional operations like causal masking (for autoregressive models) and dropout are common.

This formulation is elegant but computationally expensive. The intermediate matrices \(S\) and \(P\) are \(N \times N\), leading to quadratic time and memory complexity \(O(N^2)\) in sequence length \(N\). For long contexts (e.g., \(N = 4096\) in GPT-2 or up to 128k in modern LLMs), this becomes a severe bottleneck:

FlashAttention (introduced in 2022 by Tri Dao et al.) addresses these by rethinking the algorithm to be I/O-aware, leveraging GPU memory hierarchy (fast SRAM ~20 MB vs. slow HBM) without approximations.

Core Ideas: Tiling, Kernel Fusion, and Online Softmax

FlashAttention computes exact attention (no approximations) by:

  1. Tiling: Instead of materializing the full \(N \times N\) matrices, it divides \(Q, K, V\) into small blocks that fit in SRAM. \(Q\) is split into \(T_r = \lceil N / B_r \rceil\) row-blocks of size \(B_r \times d\) (e.g., \(B_r \approx 64-256\)), and \(K, V\) into \(T_c = \lceil N / B_c \rceil\) column-blocks of size \(B_c \times d\) (e.g., \(B_c \approx 128-1024\)). Block sizes are chosen dynamically based on SRAM capacity \(M\) (e.g., \(B_c \approx M / (4d)\)) to maximize reuse.

  2. Kernel Fusion: All operations (matmul for \(S\), masking, softmax, dropout, matmul for \(O\)) are fused into a single CUDA kernel. This avoids writing intermediates to HBM, reducing I/O by ~50-70%. The kernel loads blocks from HBM to SRAM, computes on-chip, and writes only partial sums back—e.g., one HBM read/write per block instead of per element.

  3. Online Softmax with Statistics: Softmax can’t be computed partially without the full row, so FlashAttention uses an associative decomposition for incremental computation. For a row split into blocks \(x = [x^{(1)}; x^{(2)}]\), track running statistics:

    • Row-max \(m_i = \max_j S_{ij}\),
    • Row-sum of exponentials \(\ell_i = \sum_j \exp(S_{ij} - m_i)\).

    Updating for a new block \(x^{(t)}\) with local stats \(\tilde{m}t, \tilde{\ell}_t\): \[ m_i^{\new} = \max(m_i, \tilde{m}_t), \quad \ell_i^{\new} = e^{m_i - m_i^{\new}} \ell_i + e^{\tilde{m}_t - m_i^{\new}} \tilde{\ell}_t. \] The partial softmax is then \(\tilde{P}{ij} = \exp(S_{ij} - m_i^{\new})\), and output accumulates as \(O_i \leftarrow \frac{\ell_i}{\ell_i^{\new}} e^{m_i - m_i^{\new}} O_i + \frac{\tilde{\ell}t}{\ell_i^{\new}} e^{\tilde{m}_t - m_i^{\new}} \tilde{P}{ij} V_j\).

    This is numerically stable (matches fused softmax) and exact, as proven inductively: after all blocks, \(O = \softmax(S) V\).

These ideas reduce memory to \(O(N)\) (inputs + output + \(O(N)\) stats like \(m, \ell\)) and HBM accesses to \(O(N^2 d / M)\)—sub-quadratic, as each \(K/V\) element is read once, and \(Q/O\) is read \(T_c \approx N d / M\) times.

Forward Pass: Block-by-Block Computation

The forward pass (pseudocode in the paper’s Algorithm 2) iterates over column-blocks of \(K, V\):

This fuses everything: total FLOPs remain \(O(N^2 d)\), but I/O drops dramatically (e.g., 9x fewer accesses than standard). For causal attention, masking is cheap (vectorized). Dropout uses a shared RNG state \(R\) saved for backward.

Backward Pass: Gradient Computation via Recomputation

The backward pass (Algorithm 4) is trickier, as gradients depend on \(P\):

\[ dP = dO \cdot V^T, \quad dS = P \odot (dP - \rowsum(dO \odot O)), \quad dQ = dS \cdot K, \quad dK = Q^T \cdot dS, \quad dV = P^T \cdot dO. \]

Storing \(P\) would be \(O(N^2)\), so FlashAttention recomputes blocks on-the-fly (selective recomputation, like checkpointing but tiled):

This uses another \(O(N^2 d)\) FLOPs but only \(O(N)\) extra memory (no \(P\) storage). Total forward + backward: ~2-3x the FLOPs of standard but 2-4x faster due to I/O savings.

I/O-Awareness and GPU Optimizations

GPUs have a hierarchy: registers/SRAM (fast, small) » HBM (slow, large). Standard attention thrashes HBM with \(\Theta(N^2)\) accesses per pass. FlashAttention’s tiling ensures:

Empirical: On A100, HBM stalls dominate runtime; FlashAttention reduces them by 50-80%, hitting compute-bound regime. It supports block-sparsity (skip zero-mask blocks) for even more gains (2-4x over dense).

Benefits: Speed, Memory, and Downstream Impact

PyTorch Integration: Seamless Usage

PyTorch 2.0+ integrates FlashAttention natively via torch.nn.functional.scaled_dot_product_attention(q, k, v), which auto-dispatches to the fused kernel if conditions met (e.g., contiguous tensors, supported dtypes). No custom CUDA needed:

import torch
from torch.nn.functional import scaled_dot_product_attention

q, k, v = torch.randn(1, 8, 1024, 64, device='cuda')  # batch, heads, seq, dim
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
    out = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

This handles masking/dropout; falls back to math kernel if unsupported. For custom models, it’s a drop-in replacement for manual attention loops, yielding 2-3x speedups out-of-box.

FlashAttention revolutionized long-context training, powering models like Llama-2 (4k→70B) and beyond.

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
PyTorch 2.2: FlashAttention-v2 Integration
GitHub: Dao-AILab/flash-attention


Back

x-ai/grok-4-fast

Donate