Understanding KV Cache in LLM Inference | Generated by AI
Question: Explain the KV cache in detail.
Answer:
Detailed Explanation of the KV Cache in Large Language Models
The KV cache (Key-Value cache) is a critical optimization technique used during inference in transformer-based large language models (LLMs), particularly in autoregressive decoding (e.g., models like GPT). It dramatically reduces computational redundancy when generating text token by token, improving speed and efficiency while trading off increased memory usage.
Background: Self-Attention in Transformers
In the transformer architecture, the self-attention mechanism allows the model to weigh the importance of different tokens in a sequence. For each token, three vectors are computed from its hidden representation:
- Query (Q): Represents what the current token is “looking for.”
- Key (K): Represents what other tokens “offer.”
- Value (V): The actual information from other tokens.
Attention scores are calculated as:
\[ \text{Attention} = \softmax\left(\frac{Q K^T}{\sqrt{d_k}}\right) V \]
Where \(d_k\) is the dimension of the keys.
In multi-head attention, this is done in parallel across multiple heads.
During training, the entire sequence is processed in parallel, so all Q, K, V are computed at once.
During inference (text generation), tokens are produced autoregressively: one at a time. The input grows by one token each step.
The Problem Without KV Cache
Without caching:
- To generate the next token, the model must reprocess the entire sequence (prompt + all previously generated tokens).
- This means recomputing K and V for all previous tokens every time.
- Computational cost for attention scales quadratically with sequence length (O(n²) per generation step overall).
- For a sequence of length n, generating the (n+1)th token requires recomputing attention over all n+1 tokens.
This leads to high latency, especially for long contexts or outputs.
How KV Cache Works
The key insight: In autoregressive generation, the K and V vectors for previous tokens do not change when generating a new token. Only the Query (Q) for the new token needs to interact with all previous K and V.
The KV cache solves this by:
- During the prefill phase (processing the initial prompt):
- Compute Q, K, V for all prompt tokens in parallel.
- Store the K and V tensors in the cache.
- During the decode phase (generating new tokens one by one):
- For the new token:
- Compute only its Q, K, and V.
- Retrieve all previous K and V from the cache.
- Concatenate the new K and V to the cache.
- Compute attention using the new Q against the full (cached + new) K, then multiply by V.
- Update the cache by appending the new K and V.
- For the new token:
This changes the attention computation from quadratic to linear in sequence length during decoding: each new token adds O(n) operations instead of O(n²).
Dimensions of the KV Cache
Typically, the cache has shape:
- [batch_size, num_heads (or num_kv_heads), sequence_length, head_dim] for both K and V separately.
Memory usage grows linearly with sequence length, batch size, and model size. For large models (e.g., 70B parameters) and long contexts (e.g., 128k tokens), the KV cache can dominate GPU memory.
Benefits
- Speed: Reduces redundant computations; generation becomes memory-bound rather than compute-bound for long sequences.
- Lower latency: Essential for real-time applications like chatbots.
- Scalability: Enables longer context windows.
Drawbacks and Optimizations
- Memory bottleneck: KV cache size can exceed model weights for very long sequences or large batches.
- Common optimizations:
- Quantization: Reduce precision (e.g., to INT8 or FP8) for KV cache.
- PagedAttention: Virtual memory-like paging to avoid fragmentation (used in vLLM).
- Grouped-Query Attention (GQA) or Multi-Query Attention (MQA): Share K/V across heads to reduce cache size.
- Eviction/Compression: Remove less important entries (e.g., StreamingLLM keeps “sink” tokens).
- Offloading: Move parts to CPU/disk for extreme lengths.
In frameworks like Hugging Face Transformers, KV caching is enabled by default during generation (via use_cache=True).
The KV cache is what makes efficient LLM inference possible at scale, transforming potentially prohibitive quadratic costs into manageable linear ones.
References:
- KV Caching Explained: Optimizing Transformer Inference Efficiency
- Transformers Key-Value Caching Explained
- Understanding and Coding the KV Cache in LLMs from Scratch
- Mastering LLM Techniques: Inference Optimization (NVIDIA)
- Hugging Face Documentation on KV Cache