Optimizing Sparse Attention with Token Dropping | Generated by AI

Home 2026.06

Question: Explain deeply the techniques described in this document — Output-Aware Token Dropping (OAM-based), Dynamic Load Balancing for Attention Kernels on Blackwell, and INT8 KV Cache Quantization schemes.

Answer:


1. The Core Problem: Sparse Attention and What “Token Dropping” Means

In full attention, every query attends to every key: O(n²) complexity. For 128K context, that’s 128K × 128K = 16B pairs per layer. Impractical at inference time.

Sparse attention fixes this by only computing attention for a subset of tokens. The question is: which tokens do you keep?


2. Traditional Token Selection: Q×K Score (The Problem)

Classic sparse attention ranks tokens by their attention logit:

score(q, k_i) = q · k_i / sqrt(d_k)

Then you take the top-K tokens by this score. Simple, but this has a fundamental flaw:

The Q×K score is a routing probability — it measures geometric alignment between query and key. It does not measure how much that token’s value contributes to the output.

Concretely: the attention output is

out = sum_i [ softmax(q · k_i / sqrt(d)) * v_i ]

A token with a high q · k_i score might have v_i ≈ 0 (tiny value vector). You kept it for nothing. Conversely, a token with moderate q · k_i might have a massive |v_i| — dropping it destroys output quality.


3. OAM: Output-Aware Metric — The Fix

The key insight: score each token by its actual contribution to the output, not just its key-query alignment.

The contribution of token i to the output is approximately:

contribution_i ≈ softmax(q · k_i) * |v_i|

Or in log-space (more numerically stable):

log_contribution_i = (q · k_i) + log|v_i|

This is the BetaMax-based log evaluation referenced in the doc. “BetaMax” likely refers to a max-normalized log-softmax approximation — instead of computing full softmax (which requires all tokens), you use a local approximation:

# Pseudocode for OAM score
def oam_score(q, k_i, v_i):
    routing_score = q @ k_i / sqrt(d_k)          # standard attention logit
    value_magnitude = torch.log(v_i.norm() + eps) # |v_i| contribution
    return routing_score + beta * value_magnitude  # beta is a learned or fixed weight

Then you drop tokens with the lowest OAM score instead of lowest q·k score.

Why this is principled: You’re approximating the upper bound of each token’s output contribution before actually computing softmax. This is closely related to the idea behind Reformer (LSH attention) and Mamba’s selective state spaces — the selector should know what matters downstream, not just what looks similar.


4. The “Slim” Implementation Result

The 3.6x throughput on 128K inputs makes sense:


5. Dynamic Load Balancing for Attention Kernels (Blackwell / H20)

The Problem: Thread Divergence from Variable Sequence Lengths

In batched inference, sequences have different lengths. When you tile the KV cache across thread blocks (CTA/SM), one sequence might be 128K tokens while another is 512. The CTA handling the long sequence does 250x more work — all other CTAs finish and sit idle (warp stall).

This is the load imbalance problem in attention kernels, well-known from FlashAttention-2 → FlashAttention-3 evolution.

Split-KV Approach

The fix: Split the KV dimension across multiple CTAs, then reduce:

Sequence of length L → split into N chunks of L/N
Each chunk handled by one CTA
Final output = online softmax reduction across chunks (log-sum-exp trick)

This is exactly what FlashAttention-3 does with its “persistent kernel + work stealing” scheduler, and what vAttention/PagedAttention variants implement.

The reduction still uses the numerically stable online softmax:

# Online softmax across splits
m_new = max(m_prev, m_chunk)
exp_sum = exp(m_prev - m_new) * exp_sum_prev + exp(m_chunk - m_new) * exp_sum_chunk
out = (exp(m_prev - m_new) * out_prev * exp_sum_prev + out_chunk * exp_sum_chunk) / exp_sum

Result: 1.5x kernel speedup, 1% end-to-end. The small end-to-end gain means attention wasn’t the dominant bottleneck for their workload (likely compute-bound at shorter contexts, memory-bound at longer ones).


6. INT8 KV Cache Quantization: Two Schemes

Why INT8 KV Cache?

KV cache is the memory bottleneck at long context. For a 70B model with 128K context:

KV cache size = 2 * n_layers * n_heads * d_head * seq_len * bytes_per_element
             = 2 * 80 * 64 * 128 * 128000 * 2 (BF16)
             ≈ 210 GB  ← doesn't fit on one B200 (192GB HBM3)

INT8 halves this to ~105 GB. Worth it if accuracy holds.

Per-Tensor Static Quantization (Baseline — The Problem)

q_val = round(val / scale)   # scale is a single float per tensor
val_reconstructed = q_val * scale

One scale for the entire Q, K, or V tensor. Fast (one multiply to dequantize), but catastrophic for long context:

Scheme A: Fine-Grained QK (Best Accuracy, Some Cost)

Per-token-partition quantization for Q and K:
  Each "partition" (e.g., 64 or 128 tokens) gets its own scale
  scale[i] = max(|k[i*chunk:(i+1)*chunk]|) / 127

V uses per-tensor (one global scale)

Why Q and K are most sensitive: the softmax(QK^T) is exponential — small errors in Q or K get amplified. V enters linearly, so coarser quantization is tolerable.

Scheme B: Coarser KV (Near-Zero Perf Loss)

Q: per-token-partition (fine)
K: per-tensor (coarse)
V: per-tensor (coarse)

Trades some accuracy for performance. Both schemes are “essentially lossless” because INT8 with properly calibrated scales has ~0.4% relative error — within noise for most benchmarks.

The Accuracy/Throughput Tradeoff Visualized

Scheme          | Q scale    | K scale    | V scale    | Accuracy | Perf
----------------|------------|------------|------------|----------|------
Baseline INT8   | per-tensor | per-tensor | per-tensor | degraded | fastest
Scheme B        | per-token  | per-tensor | per-tensor | good     | ~baseline
Scheme A        | per-token  | per-token  | per-tensor | best     | ~5-15% slower
BF16 (full)     | n/a        | n/a        | n/a        | perfect  | 2x memory

7. How These Three Fit Together

Inference request (128K context)
         │
         ▼
OAM Token Dropping ──► reduce effective KV to attend over (3.6x compute reduction)
         │
         ▼
INT8 KV Cache ──────► halve memory bandwidth for KV reads
         │
         ▼
Split-KV Load Balance ► evenly distribute remaining KV work across SMs
         │
         ▼
Output

They’re orthogonal optimizations that stack: OAM reduces which tokens you process, INT8 reduces how expensive each token read is, and Split-KV reduces idle time from imbalance.


Connections to What You’re Already Working On

Given you’re training on MI300X and studying DeepSeek v4 MoE:


Back Donate