KV Cache: Detailed Explanation of How It Improves Inference
Overview
KV Cache is a critical optimization technique for autoregressive language model inference that dramatically reduces computational redundancy. This document provides a detailed explanation of exactly how KV cache improves over standard inference, with code-level comparisons showing what changes and why.
Part 1: The Problem with Standard Inference
Standard Inference Without KV Cache
In standard autoregressive generation, each new token is generated by processing the entire sequence from scratch. This means that for every new token, the model recomputes all the attention scores and key-value pairs for all previous tokens, even though these computations were already done in previous steps.
Example: Generating "The cat sat"
Step 1: Generate "The"
Input: [<start>]
Process: Compute Q, K, V for [<start>]
Output: "The"
Step 2: Generate "cat"
Input: [<start>, "The"]
Process: Compute Q, K, V for [<start>, "The"] ← Recomputes <start>!
Output: "cat"
Step 3: Generate "sat"
Input: [<start>, "The", "cat"]
Process: Compute Q, K, V for [<start>, "The", "cat"] ← Recomputes all previous!
Output: "sat"
The Redundancy Problem
The key insight is that when generating token tᵢ, we need to attend to all previous tokens [t₁, t₂, ..., tᵢ₋₁]. In standard inference, we recompute the keys and values for all these previous tokens, even though we already computed them when generating tᵢ₋₁.
Mathematical View:
For a sequence of length n, generating the i-th token requires:
- Computing Qᵢ (query for position i)
- Computing Kⱼ and Vⱼ for all j ≤ i (keys and values for all positions up to i)
- Computing attention: Attention(Qᵢ, [K₁, ..., Kᵢ], [V₁, ..., Vᵢ])
In standard inference, when generating token i+1:
- We recompute K₁, ..., Kᵢ and V₁, ..., Vᵢ (even though we computed them for token i)
- We only need to compute Kᵢ₊₁ and Vᵢ₊₁ (the new token)
Computational Cost:
For generating a sequence of length n:
- Without KV cache: O(n²d) total computation (quadratic in sequence length)
- With KV cache: O(n²d) total computation, but O(nd) per step (linear per step)
The key difference is that without KV cache, each step recomputes everything, leading to redundant computation. With KV cache, we reuse previous computations, making each step only compute the new token's contribution.
Part 2: How KV Cache Solves the Problem
The KV Cache Solution
KV cache stores the computed keys and values for all previous tokens, so we don't need to recompute them. When generating a new token, we only compute the keys and values for that new token, then concatenate them with the cached keys and values from previous tokens.
How It Works:
Step 1: Generate "The"
Input: [<start>]
Compute: Q₁, K₁, V₁ for <start>
Cache: KV_cache = [(K₁, V₁)]
Output: "The"
Step 2: Generate "cat"
Input: [<start>, "The"]
Compute: Q₂, K₂, V₂ for "The" ← Only compute for new token!
Retrieve: KV_cache = [(K₁, V₁)]
Combine: K = [K₁, K₂], V = [V₁, V₂]
Update cache: KV_cache = [(K₁, V₁), (K₂, V₂)]
Output: "cat"
Step 3: Generate "sat"
Input: [<start>, "The", "cat"]
Compute: Q₃, K₃, V₃ for "cat" ← Only compute for new token!
Retrieve: KV_cache = [(K₁, V₁), (K₂, V₂)]
Combine: K = [K₁, K₂, K₃], V = [V₁, V₂, V₃]
Update cache: KV_cache = [(K₁, V₁), (K₂, V₂), (K₃, V₃)]
Output: "sat"
Key Insight
The keys and values for a token depend only on that token's embedding and the model's weight matrices. They don't depend on future tokens. Therefore, once we compute Kᵢ and Vᵢ for token i, we can reuse them for all future tokens that need to attend to token i.
Mathematical Formulation:
For token at position i:
K_i = Embedding(token_i) @ W_k # Depends only on token_i
V_i = Embedding(token_i) @ W_v # Depends only on token_i
These are independent of future tokens, so they can be cached and reused.
Part 3: Code-Level Comparison
Standard Inference (Without KV Cache)
def generate_standard(model, prompt, max_length=100):
"""
Standard generation WITHOUT KV cache
Recomputes everything at each step
"""
generated = prompt.copy()
for step in range(max_length):
# At each step, process ENTIRE sequence from scratch
input_ids = torch.tensor([generated])
# Forward pass processes entire sequence
# This recomputes K and V for ALL previous tokens
outputs = model(input_ids)
# Get logits for last position
logits = outputs.logits[0, -1, :]
# Sample next token
next_token = sample(logits)
generated.append(next_token)
return generated
What Happens at Each Step:
Step 1 (generating token 1):
input_ids = [token_0]
# Model processes: [token_0]
# Computes: Q_0, K_0, V_0
# Attention: Attention(Q_0, [K_0], [V_0])
Step 2 (generating token 2):
input_ids = [token_0, token_1]
# Model processes: [token_0, token_1] ← REPROCESSES token_0!
# Computes: Q_0, K_0, V_0, Q_1, K_1, V_1 ← Recomputes K_0, V_0!
# Attention: Attention(Q_1, [K_0, K_1], [V_0, V_1])
Step 3 (generating token 3):
input_ids = [token_0, token_1, token_2]
# Model processes: [token_0, token_1, token_2] ← REPROCESSES all previous!
# Computes: Q_0, K_0, V_0, Q_1, K_1, V_1, Q_2, K_2, V_2 ← Recomputes all!
# Attention: Attention(Q_2, [K_0, K_1, K_2], [V_0, V_1, V_2])
Redundancy:
- Step 2: Recomputes K_0, V_0 (already computed in step 1)
- Step 3: Recomputes K_0, V_0, K_1, V_1 (already computed in step 2)
- Step n: Recomputes K_0 through K_{n-1} (all already computed)
KV Cache Inference (With KV Cache)
def generate_with_kv_cache(model, prompt, max_length=100):
"""
Generation WITH KV cache
Only computes K and V for new token, reuses cached ones
"""
generated = prompt.copy()
# Initialize KV cache (empty at start)
past_key_values = None
# Process prompt (if any)
if len(prompt) > 1:
input_ids = torch.tensor([prompt[:-1]]) # All but last token
outputs = model(input_ids, use_cache=True)
past_key_values = outputs.past_key_values # Cache K, V for prompt
generated = [prompt[-1]] # Start from last prompt token
# Generate new tokens
for step in range(max_length):
# Only process the CURRENT token (last in sequence)
input_ids = torch.tensor([[generated[-1]]]) # Only new token!
# Forward pass with cached K, V
# Model only computes K and V for the new token
# Uses cached K, V for all previous tokens
outputs = model(
input_ids,
past_key_values=past_key_values, # ← Reuse cached K, V
use_cache=True
)
# Update cache with new token's K, V
past_key_values = outputs.past_key_values # ← Updated cache
# Get logits for last position
logits = outputs.logits[0, -1, :]
# Sample next token
next_token = sample(logits)
generated.append(next_token)
return generated
What Happens at Each Step:
Step 1 (generating token 1):
input_ids = [token_0]
# Model processes: [token_0]
# Computes: Q_0, K_0, V_0
# Cache: past_key_values = [(K_0, V_0)]
# Attention: Attention(Q_0, [K_0], [V_0])
Step 2 (generating token 2):
input_ids = [token_1] # Only new token!
past_key_values = [(K_0, V_0)] # Cached from step 1
# Model processes: [token_1]
# Computes: Q_1, K_1, V_1 ← Only computes for new token!
# Retrieves: K_0, V_0 from cache ← Reuses cached!
# Cache: past_key_values = [(K_0, V_0), (K_1, V_1)] ← Updated
# Attention: Attention(Q_1, [K_0, K_1], [V_0, V_1]) ← Uses cached K_0, V_0
Step 3 (generating token 3):
input_ids = [token_2] # Only new token!
past_key_values = [(K_0, V_0), (K_1, V_1)] # Cached from step 2
# Model processes: [token_2]
# Computes: Q_2, K_2, V_2 ← Only computes for new token!
# Retrieves: K_0, V_0, K_1, V_1 from cache ← Reuses all cached!
# Cache: past_key_values = [(K_0, V_0), (K_1, V_1), (K_2, V_2)] ← Updated
# Attention: Attention(Q_2, [K_0, K_1, K_2], [V_0, V_1, V_2]) ← Uses all cached
Efficiency:
- Step 2: Only computes K_1, V_1 (reuses cached K_0, V_0)
- Step 3: Only computes K_2, V_2 (reuses cached K_0, V_0, K_1, V_1)
- Step n: Only computes K_{n-1}, V_{n-1} (reuses all previous cached)
Part 4: Detailed Code Implementation Comparison
Standard Attention (Without Cache)
def standard_attention_step(model, input_ids, layer_idx):
"""
Standard attention computation
Processes entire sequence, recomputes all K, V
"""
# Get embeddings for entire sequence
embeddings = model.embedding(input_ids) # Shape: (batch, seq_len, d_model)
# Process through layers up to current layer
hidden = embeddings
for i in range(layer_idx):
hidden = model.layers[i](hidden)
# Current layer: compute Q, K, V for ENTIRE sequence
Q = hidden @ model.layers[layer_idx].W_q # (batch, seq_len, d_k)
K = hidden @ model.layers[layer_idx].W_k # (batch, seq_len, d_k)
V = hidden @ model.layers[layer_idx].W_v # (batch, seq_len, d_v)
# Attention: all positions attend to all positions
scores = Q @ K.transpose(-2, -1) / sqrt(d_k) # (batch, seq_len, seq_len)
attention_weights = softmax(scores, dim=-1)
output = attention_weights @ V # (batch, seq_len, d_v)
return output
At each generation step:
- Input: Entire sequence [token_0, token_1, ..., token_{i-1}, token_i]
- Computes: K and V for ALL tokens (including recomputing previous ones)
- Complexity: O(i²d) for step i (quadratic in current sequence length)
KV Cache Attention (With Cache)
def kv_cache_attention_step(model, input_ids, past_key_values, layer_idx):
"""
KV cache attention computation
Only processes new token, reuses cached K, V
"""
# Get embeddings for ONLY the new token (last in sequence)
# input_ids shape: (batch, 1) - only new token!
embeddings = model.embedding(input_ids) # Shape: (batch, 1, d_model)
# Process through layers up to current layer
# Only process the new token through previous layers
hidden = embeddings
for i in range(layer_idx):
# For previous layers, we also only process new token
# But we need cached hidden states - this is simplified
hidden = model.layers[i](hidden, past_key_values[i] if past_key_values else None)
# Current layer: compute Q, K, V for ONLY the new token
Q = hidden @ model.layers[layer_idx].W_q # (batch, 1, d_k) ← Only new token!
K_new = hidden @ model.layers[layer_idx].W_k # (batch, 1, d_k) ← Only new token!
V_new = hidden @ model.layers[layer_idx].W_v # (batch, 1, d_v) ← Only new token!
# Retrieve cached K, V from previous tokens
if past_key_values and past_key_values[layer_idx] is not None:
K_past, V_past = past_key_values[layer_idx]
# K_past shape: (batch, past_len, d_k)
# V_past shape: (batch, past_len, d_v)
# Concatenate: cached + new
K = torch.cat([K_past, K_new], dim=1) # (batch, past_len + 1, d_k)
V = torch.cat([V_past, V_new], dim=1) # (batch, past_len + 1, d_v)
else:
# First token: no cache yet
K = K_new
V = V_new
# Attention: new token attends to all (cached + new)
# Q shape: (batch, 1, d_k) - only query for new token
# K shape: (batch, past_len + 1, d_k) - all keys
scores = Q @ K.transpose(-2, -1) / sqrt(d_k) # (batch, 1, past_len + 1)
attention_weights = softmax(scores, dim=-1)
output = attention_weights @ V # (batch, 1, d_v)
# Update cache: add new K, V to cache
new_cache = (K, V) # Store for next step
return output, new_cache
At each generation step:
- Input: Only new token [token_i]
- Computes: K and V for ONLY the new token
- Retrieves: Cached K, V for all previous tokens
- Complexity: O(id) for step i (linear in sequence length)
Key Differences:
-
Input Size:
- Standard: Entire sequence [token_0, ..., token_i] (length i+1)
- KV Cache: Only new token [token_i] (length 1)
-
K, V Computation:
- Standard: Computes K, V for all tokens (recomputes previous)
- KV Cache: Computes K, V only for new token (reuses cached)
-
Attention Computation:
- Standard: Q, K, V all have shape (batch, i+1, d_k)
- KV Cache: Q has shape (batch, 1, d_k), K, V have shape (batch, i+1, d_k)
-
Memory:
- Standard: No cache, recomputes everything
- KV Cache: Stores K, V for all previous tokens (memory trade-off)
Part 5: Computational Complexity Analysis
Standard Inference Complexity
For generating a sequence of length n:
Step 1:
- Process: 1 token
- Compute: Q₁, K₁, V₁
- Attention: O(1²d) = O(d)
Step 2:
- Process: 2 tokens
- Compute: Q₁, K₁, V₁, Q₂, K₂, V₂ (recomputes step 1)
- Attention: O(2²d) = O(4d)
Step 3:
- Process: 3 tokens
- Compute: Q₁, K₁, V₁, Q₂, K₂, V₂, Q₃, K₃, V₃ (recomputes steps 1-2)
- Attention: O(3²d) = O(9d)
Step i:
- Process: i tokens
- Compute: All Q, K, V up to i (recomputes all previous)
- Attention: O(i²d)
Total Complexity:
Total = O(d) + O(4d) + O(9d) + ... + O(n²d)
= O(d) × (1² + 2² + 3² + ... + n²)
= O(d) × n(n+1)(2n+1)/6
= O(n³d)
Per-step complexity: O(i²d) for step i (quadratic in current length)
KV Cache Inference Complexity
For generating a sequence of length n:
Step 1:
- Process: 1 token
- Compute: Q₁, K₁, V₁
- Cache: Store K₁, V₁
- Attention: O(1²d) = O(d)
Step 2:
- Process: 1 token (new)
- Compute: Q₂, K₂, V₂ (only new token!)
- Retrieve: K₁, V₁ from cache
- Cache: Store K₂, V₂ (update cache)
- Attention: O(1 × 2d) = O(2d) (1 query, 2 keys/values)
Step 3:
- Process: 1 token (new)
- Compute: Q₃, K₃, V₃ (only new token!)
- Retrieve: K₁, V₁, K₂, V₂ from cache
- Cache: Store K₃, V₃ (update cache)
- Attention: O(1 × 3d) = O(3d) (1 query, 3 keys/values)
Step i:
- Process: 1 token (new)
- Compute: Qᵢ, Kᵢ, Vᵢ (only new token!)
- Retrieve: K₁ through K_{i-1}, V₁ through V_{i-1} from cache
- Cache: Store Kᵢ, Vᵢ (update cache)
- Attention: O(1 × id) = O(id) (1 query, i keys/values)
Total Complexity:
Total = O(d) + O(2d) + O(3d) + ... + O(nd)
= O(d) × (1 + 2 + 3 + ... + n)
= O(d) × n(n+1)/2
= O(n²d)
Per-step complexity: O(id) for step i (linear in sequence length)
Comparison
| Aspect | Standard Inference | KV Cache Inference |
|---|---|---|
| Total Complexity | O(n³d) | O(n²d) |
| Per-Step Complexity | O(i²d) | O(id) |
| K, V Computation | Recomputes all | Only new token |
| Memory | O(1) | O(nd) (cache) |
| Speedup | 1x (baseline) | ~n× faster |
Key Insight:
- Standard inference: Quadratic per step, cubic total
- KV cache: Linear per step, quadratic total
- For long sequences (large n), KV cache provides significant speedup
Part 6: Memory Considerations
Memory Trade-off
KV cache trades computation for memory:
Standard Inference:
- Memory: O(1) - no cache needed
- Computation: High (recomputes everything)
KV Cache:
- Memory: O(nd) - stores K, V for all tokens
- Computation: Low (reuses cached values)
Memory Breakdown:
For a sequence of length n, model dimension d, num_heads h:
- K cache: n × d (or n × d/h per head)
- V cache: n × d (or n × d/h per head)
- Total per layer: 2nd
- Total for L layers: 2Lnd
Example:
- n = 2048 tokens
- d = 4096
- L = 32 layers
- Memory: 2 × 32 × 2048 × 4096 × 2 bytes (float16) ≈ 2 GB
Optimization:
- Can use quantization (int8, int4) to reduce memory
- Can use PagedAttention (vLLM) for efficient memory management
- Can limit cache size for very long sequences
Part 7: Practical Implementation Details
How KV Cache is Stored
Structure:
past_key_values = [
(K_layer_0, V_layer_0), # Cache for layer 0
(K_layer_1, V_layer_1), # Cache for layer 1
...
(K_layer_L, V_layer_L), # Cache for layer L
]
Shape Evolution:
After step 1:
K_layer_0: (batch, 1, d_k) # 1 token cached
V_layer_0: (batch, 1, d_v)
After step 2:
K_layer_0: (batch, 2, d_k) # 2 tokens cached
V_layer_0: (batch, 2, d_v)
After step i:
K_layer_0: (batch, i, d_k) # i tokens cached
V_layer_0: (batch, i, d_v)
Concatenation Operation
The key operation is concatenation:
# New token's K, V
K_new = compute_K(new_token) # (batch, 1, d_k)
V_new = compute_V(new_token) # (batch, 1, d_v)
# Cached K, V
K_past = past_key_values[layer][0] # (batch, past_len, d_k)
V_past = past_key_values[layer][1] # (batch, past_len, d_v)
# Concatenate
K = torch.cat([K_past, K_new], dim=1) # (batch, past_len + 1, d_k)
V = torch.cat([V_past, V_new], dim=1) # (batch, past_len + 1, d_v)
# Update cache
past_key_values[layer] = (K, V)
This concatenation is the core of KV cache:
- It combines cached (past) and new (current) keys/values
- Allows attention to attend to all tokens (past + current)
- Without recomputing past keys/values
Part 8: Summary
What KV Cache Does
- Stores computed keys and values for all previous tokens
- Reuses cached values instead of recomputing them
- Only computes K, V for new token at each step
- Concatenates cached + new for attention computation
How It Improves Over Standard Inference
-
Reduces computation:
- Standard: O(n³d) total, O(i²d) per step
- KV Cache: O(n²d) total, O(id) per step
- Speedup: ~n× for long sequences
-
Eliminates redundancy:
- Standard: Recomputes K, V for all previous tokens
- KV Cache: Computes K, V only once per token
-
Makes each step efficient:
- Standard: Processes entire sequence at each step
- KV Cache: Processes only new token at each step
Code Changes
Standard Inference:
input_ids = entire_sequence # All tokens
outputs = model(input_ids) # Processes all, recomputes all
KV Cache Inference:
input_ids = [new_token] # Only new token
outputs = model(input_ids, past_key_values=cache) # Uses cache, only computes new
cache = outputs.past_key_values # Update cache
Trade-offs
Benefits:
- Much faster inference (especially for long sequences)
- Reduces redundant computation
- Enables efficient generation
Costs:
- Requires memory to store cache (O(nd) per layer)
- Slightly more complex implementation
- Cache management overhead
Verdict:
- KV cache is essential for efficient LLM inference
- The memory cost is usually worth the speedup
- Used in all production LLM serving systems
Conclusion
KV cache is a fundamental optimization that eliminates redundant computation in autoregressive generation. By caching and reusing computed keys and values, it transforms inference from O(n³d) to O(n²d) complexity, providing significant speedups especially for long sequences. The key insight is that keys and values for a token depend only on that token, not on future tokens, so they can be computed once and reused for all future attention computations.