KV Cache: 2 and Layers Multiplier | Generated by AI

Home 2026.06

Question: Why is there a 2 multiplier in the KV cache formula, and why multiply by num_layers?

Answer:

The 2 and num_layers are two separate multipliers, not a combined factor. Let me break it down:

KV cache size = batch_size × seq_len × hidden_dim × 2 × num_layers × bytes
                                                     ↑                  ↑
                                              K + V tensors      one per layer

The 2: Key and Value Tensors

In multi-head attention, you have Q, K, V:

Q = input @ W_q  # Query: computed fresh each step
K = input @ W_k  # Key: reused across all future steps
V = input @ W_v  # Value: reused across all future steps

During inference, you only cache K and V, not Q. Why?

The × num_layers: Every Attention Layer Stores Its Own Cache

A transformer model has many layers (e.g., 40 for 13B), and each layer has independent K and V caches:

Layer 1: K_cache_1 (batch, seq_len, hidden_dim)
         V_cache_1 (batch, seq_len, hidden_dim)
Layer 2: K_cache_2 (batch, seq_len, hidden_dim)
         V_cache_2 (batch, seq_len, hidden_dim)
...
Layer 40: K_cache_40 (batch, seq_len, hidden_dim)
          V_cache_40 (batch, seq_len, hidden_dim)

Total = 2 * 40 = 80 cache tensors.

Concrete Example

For a 13B model (typically ~5120 hidden_dim, ~40 layers):

batch_size = 32
seq_len = 4096
hidden_dim = 5120
num_layers = 40
bytes = 2  # fp16

# Per layer: batch × seq × hidden = 32 × 4096 × 5120 = 671M elements
# K cache per layer: 671M × 2 bytes = 1.3 GB
# V cache per layer: 671M × 2 bytes = 1.3 GB
# Both per layer: 2.6 GB
# Total × 40 layers: 104 GB

So when you see 52 GB (my earlier example), it’s actually the full KV cache across all layers for one batch.

Why This Matters

This is why KV cache becomes the hard limit on batch size during inference:

max_batch_size  available_memory / (2 × num_layers × seq_len × hidden_dim × bytes)

On MI300X (192 GB):

max_bsz ≈ 192GB / (2 × 40 × 4096 × 5120 × 2 bytes)
        ≈ 192GB / 3.3GB per token position
        ≈ ~58 batch size at full sequence

This is why MQA (Multi-Query Attention) and GQA (Grouped-Query Attention) help—they reduce hidden_dim in the cache:

# Standard MHA: K and V are (batch, seq, hidden_dim)
# GQA: K and V are (batch, seq, hidden_dim // num_groups)
# MQA: K and V are (batch, seq, hidden_dim // num_heads)  [extreme case]

Reducing the cache size by 8-12x lets you fit much larger batches.


Back Donate