KV Cache: Detailed Explanation of How It Improves Inference

Overview

KV Cache is a critical optimization technique for autoregressive language model inference that dramatically reduces computational redundancy. This document provides a detailed explanation of exactly how KV cache improves over standard inference, with code-level comparisons showing what changes and why.


Part 1: The Problem with Standard Inference

Standard Inference Without KV Cache

In standard autoregressive generation, each new token is generated by processing the entire sequence from scratch. This means that for every new token, the model recomputes all the attention scores and key-value pairs for all previous tokens, even though these computations were already done in previous steps.

Example: Generating "The cat sat"

Step 1: Generate "The"

Input: [<start>]
Process: Compute Q, K, V for [<start>]
Output: "The"

Step 2: Generate "cat"

Input: [<start>, "The"]
Process: Compute Q, K, V for [<start>, "The"]  ← Recomputes <start>!
Output: "cat"

Step 3: Generate "sat"

Input: [<start>, "The", "cat"]
Process: Compute Q, K, V for [<start>, "The", "cat"]  ← Recomputes all previous!
Output: "sat"

The Redundancy Problem

The key insight is that when generating token tᵢ, we need to attend to all previous tokens [t₁, t₂, ..., tᵢ₋₁]. In standard inference, we recompute the keys and values for all these previous tokens, even though we already computed them when generating tᵢ₋₁.

Mathematical View:

For a sequence of length n, generating the i-th token requires:

  • Computing Qᵢ (query for position i)
  • Computing Kⱼ and Vⱼ for all j ≤ i (keys and values for all positions up to i)
  • Computing attention: Attention(Qᵢ, [K₁, ..., Kᵢ], [V₁, ..., Vᵢ])

In standard inference, when generating token i+1:

  • We recompute K₁, ..., Kᵢ and V₁, ..., Vᵢ (even though we computed them for token i)
  • We only need to compute Kᵢ₊₁ and Vᵢ₊₁ (the new token)

Computational Cost:

For generating a sequence of length n:

  • Without KV cache: O(n²d) total computation (quadratic in sequence length)
  • With KV cache: O(n²d) total computation, but O(nd) per step (linear per step)

The key difference is that without KV cache, each step recomputes everything, leading to redundant computation. With KV cache, we reuse previous computations, making each step only compute the new token's contribution.


Part 2: How KV Cache Solves the Problem

The KV Cache Solution

KV cache stores the computed keys and values for all previous tokens, so we don't need to recompute them. When generating a new token, we only compute the keys and values for that new token, then concatenate them with the cached keys and values from previous tokens.

How It Works:

Step 1: Generate "The"

Input: [<start>]
Compute: Q₁, K₁, V₁ for <start>
Cache: KV_cache = [(K₁, V₁)]
Output: "The"

Step 2: Generate "cat"

Input: [<start>, "The"]
Compute: Q₂, K₂, V₂ for "The"  ← Only compute for new token!
Retrieve: KV_cache = [(K₁, V₁)]
Combine: K = [K₁, K₂], V = [V₁, V₂]
Update cache: KV_cache = [(K₁, V₁), (K₂, V₂)]
Output: "cat"

Step 3: Generate "sat"

Input: [<start>, "The", "cat"]
Compute: Q₃, K₃, V₃ for "cat"  ← Only compute for new token!
Retrieve: KV_cache = [(K₁, V₁), (K₂, V₂)]
Combine: K = [K₁, K₂, K₃], V = [V₁, V₂, V₃]
Update cache: KV_cache = [(K₁, V₁), (K₂, V₂), (K₃, V₃)]
Output: "sat"

Key Insight

The keys and values for a token depend only on that token's embedding and the model's weight matrices. They don't depend on future tokens. Therefore, once we compute Kᵢ and Vᵢ for token i, we can reuse them for all future tokens that need to attend to token i.

Mathematical Formulation:

For token at position i:

K_i = Embedding(token_i) @ W_k  # Depends only on token_i
V_i = Embedding(token_i) @ W_v  # Depends only on token_i

These are independent of future tokens, so they can be cached and reused.


Part 3: Code-Level Comparison

Standard Inference (Without KV Cache)

def generate_standard(model, prompt, max_length=100):
    """
    Standard generation WITHOUT KV cache
    Recomputes everything at each step
    """
    generated = prompt.copy()
    
    for step in range(max_length):
        # At each step, process ENTIRE sequence from scratch
        input_ids = torch.tensor([generated])
        
        # Forward pass processes entire sequence
        # This recomputes K and V for ALL previous tokens
        outputs = model(input_ids)
        
        # Get logits for last position
        logits = outputs.logits[0, -1, :]
        
        # Sample next token
        next_token = sample(logits)
        generated.append(next_token)
    
    return generated

What Happens at Each Step:

Step 1 (generating token 1):

input_ids = [token_0]
# Model processes: [token_0]
# Computes: Q_0, K_0, V_0
# Attention: Attention(Q_0, [K_0], [V_0])

Step 2 (generating token 2):

input_ids = [token_0, token_1]
# Model processes: [token_0, token_1]  ← REPROCESSES token_0!
# Computes: Q_0, K_0, V_0, Q_1, K_1, V_1  ← Recomputes K_0, V_0!
# Attention: Attention(Q_1, [K_0, K_1], [V_0, V_1])

Step 3 (generating token 3):

input_ids = [token_0, token_1, token_2]
# Model processes: [token_0, token_1, token_2]  ← REPROCESSES all previous!
# Computes: Q_0, K_0, V_0, Q_1, K_1, V_1, Q_2, K_2, V_2  ← Recomputes all!
# Attention: Attention(Q_2, [K_0, K_1, K_2], [V_0, V_1, V_2])

Redundancy:

  • Step 2: Recomputes K_0, V_0 (already computed in step 1)
  • Step 3: Recomputes K_0, V_0, K_1, V_1 (already computed in step 2)
  • Step n: Recomputes K_0 through K_{n-1} (all already computed)

KV Cache Inference (With KV Cache)

def generate_with_kv_cache(model, prompt, max_length=100):
    """
    Generation WITH KV cache
    Only computes K and V for new token, reuses cached ones
    """
    generated = prompt.copy()
    
    # Initialize KV cache (empty at start)
    past_key_values = None
    
    # Process prompt (if any)
    if len(prompt) > 1:
        input_ids = torch.tensor([prompt[:-1]])  # All but last token
        outputs = model(input_ids, use_cache=True)
        past_key_values = outputs.past_key_values  # Cache K, V for prompt
        generated = [prompt[-1]]  # Start from last prompt token
    
    # Generate new tokens
    for step in range(max_length):
        # Only process the CURRENT token (last in sequence)
        input_ids = torch.tensor([[generated[-1]]])  # Only new token!
        
        # Forward pass with cached K, V
        # Model only computes K and V for the new token
        # Uses cached K, V for all previous tokens
        outputs = model(
            input_ids,
            past_key_values=past_key_values,  # ← Reuse cached K, V
            use_cache=True
        )
        
        # Update cache with new token's K, V
        past_key_values = outputs.past_key_values  # ← Updated cache
        
        # Get logits for last position
        logits = outputs.logits[0, -1, :]
        
        # Sample next token
        next_token = sample(logits)
        generated.append(next_token)
    
    return generated

What Happens at Each Step:

Step 1 (generating token 1):

input_ids = [token_0]
# Model processes: [token_0]
# Computes: Q_0, K_0, V_0
# Cache: past_key_values = [(K_0, V_0)]
# Attention: Attention(Q_0, [K_0], [V_0])

Step 2 (generating token 2):

input_ids = [token_1]  # Only new token!
past_key_values = [(K_0, V_0)]  # Cached from step 1
# Model processes: [token_1]
# Computes: Q_1, K_1, V_1  ← Only computes for new token!
# Retrieves: K_0, V_0 from cache  ← Reuses cached!
# Cache: past_key_values = [(K_0, V_0), (K_1, V_1)]  ← Updated
# Attention: Attention(Q_1, [K_0, K_1], [V_0, V_1])  ← Uses cached K_0, V_0

Step 3 (generating token 3):

input_ids = [token_2]  # Only new token!
past_key_values = [(K_0, V_0), (K_1, V_1)]  # Cached from step 2
# Model processes: [token_2]
# Computes: Q_2, K_2, V_2  ← Only computes for new token!
# Retrieves: K_0, V_0, K_1, V_1 from cache  ← Reuses all cached!
# Cache: past_key_values = [(K_0, V_0), (K_1, V_1), (K_2, V_2)]  ← Updated
# Attention: Attention(Q_2, [K_0, K_1, K_2], [V_0, V_1, V_2])  ← Uses all cached

Efficiency:

  • Step 2: Only computes K_1, V_1 (reuses cached K_0, V_0)
  • Step 3: Only computes K_2, V_2 (reuses cached K_0, V_0, K_1, V_1)
  • Step n: Only computes K_{n-1}, V_{n-1} (reuses all previous cached)

Part 4: Detailed Code Implementation Comparison

Standard Attention (Without Cache)

def standard_attention_step(model, input_ids, layer_idx):
    """
    Standard attention computation
    Processes entire sequence, recomputes all K, V
    """
    # Get embeddings for entire sequence
    embeddings = model.embedding(input_ids)  # Shape: (batch, seq_len, d_model)
    
    # Process through layers up to current layer
    hidden = embeddings
    for i in range(layer_idx):
        hidden = model.layers[i](hidden)
    
    # Current layer: compute Q, K, V for ENTIRE sequence
    Q = hidden @ model.layers[layer_idx].W_q  # (batch, seq_len, d_k)
    K = hidden @ model.layers[layer_idx].W_k  # (batch, seq_len, d_k)
    V = hidden @ model.layers[layer_idx].W_v  # (batch, seq_len, d_v)
    
    # Attention: all positions attend to all positions
    scores = Q @ K.transpose(-2, -1) / sqrt(d_k)  # (batch, seq_len, seq_len)
    attention_weights = softmax(scores, dim=-1)
    output = attention_weights @ V  # (batch, seq_len, d_v)
    
    return output

At each generation step:

  • Input: Entire sequence [token_0, token_1, ..., token_{i-1}, token_i]
  • Computes: K and V for ALL tokens (including recomputing previous ones)
  • Complexity: O(i²d) for step i (quadratic in current sequence length)

KV Cache Attention (With Cache)

def kv_cache_attention_step(model, input_ids, past_key_values, layer_idx):
    """
    KV cache attention computation
    Only processes new token, reuses cached K, V
    """
    # Get embeddings for ONLY the new token (last in sequence)
    # input_ids shape: (batch, 1) - only new token!
    embeddings = model.embedding(input_ids)  # Shape: (batch, 1, d_model)
    
    # Process through layers up to current layer
    # Only process the new token through previous layers
    hidden = embeddings
    for i in range(layer_idx):
        # For previous layers, we also only process new token
        # But we need cached hidden states - this is simplified
        hidden = model.layers[i](hidden, past_key_values[i] if past_key_values else None)
    
    # Current layer: compute Q, K, V for ONLY the new token
    Q = hidden @ model.layers[layer_idx].W_q  # (batch, 1, d_k) ← Only new token!
    K_new = hidden @ model.layers[layer_idx].W_k  # (batch, 1, d_k) ← Only new token!
    V_new = hidden @ model.layers[layer_idx].W_v  # (batch, 1, d_v) ← Only new token!
    
    # Retrieve cached K, V from previous tokens
    if past_key_values and past_key_values[layer_idx] is not None:
        K_past, V_past = past_key_values[layer_idx]
        # K_past shape: (batch, past_len, d_k)
        # V_past shape: (batch, past_len, d_v)
        
        # Concatenate: cached + new
        K = torch.cat([K_past, K_new], dim=1)  # (batch, past_len + 1, d_k)
        V = torch.cat([V_past, V_new], dim=1)  # (batch, past_len + 1, d_v)
    else:
        # First token: no cache yet
        K = K_new
        V = V_new
    
    # Attention: new token attends to all (cached + new)
    # Q shape: (batch, 1, d_k) - only query for new token
    # K shape: (batch, past_len + 1, d_k) - all keys
    scores = Q @ K.transpose(-2, -1) / sqrt(d_k)  # (batch, 1, past_len + 1)
    attention_weights = softmax(scores, dim=-1)
    output = attention_weights @ V  # (batch, 1, d_v)
    
    # Update cache: add new K, V to cache
    new_cache = (K, V)  # Store for next step
    
    return output, new_cache

At each generation step:

  • Input: Only new token [token_i]
  • Computes: K and V for ONLY the new token
  • Retrieves: Cached K, V for all previous tokens
  • Complexity: O(id) for step i (linear in sequence length)

Key Differences:

  1. Input Size:

    • Standard: Entire sequence [token_0, ..., token_i] (length i+1)
    • KV Cache: Only new token [token_i] (length 1)
  2. K, V Computation:

    • Standard: Computes K, V for all tokens (recomputes previous)
    • KV Cache: Computes K, V only for new token (reuses cached)
  3. Attention Computation:

    • Standard: Q, K, V all have shape (batch, i+1, d_k)
    • KV Cache: Q has shape (batch, 1, d_k), K, V have shape (batch, i+1, d_k)
  4. Memory:

    • Standard: No cache, recomputes everything
    • KV Cache: Stores K, V for all previous tokens (memory trade-off)

Part 5: Computational Complexity Analysis

Standard Inference Complexity

For generating a sequence of length n:

Step 1:

  • Process: 1 token
  • Compute: Q₁, K₁, V₁
  • Attention: O(1²d) = O(d)

Step 2:

  • Process: 2 tokens
  • Compute: Q₁, K₁, V₁, Q₂, K₂, V₂ (recomputes step 1)
  • Attention: O(2²d) = O(4d)

Step 3:

  • Process: 3 tokens
  • Compute: Q₁, K₁, V₁, Q₂, K₂, V₂, Q₃, K₃, V₃ (recomputes steps 1-2)
  • Attention: O(3²d) = O(9d)

Step i:

  • Process: i tokens
  • Compute: All Q, K, V up to i (recomputes all previous)
  • Attention: O(i²d)

Total Complexity:

Total = O(d) + O(4d) + O(9d) + ... + O(n²d)
     = O(d) × (1² + 2² + 3² + ... + n²)
     = O(d) × n(n+1)(2n+1)/6
     = O(n³d)

Per-step complexity: O(i²d) for step i (quadratic in current length)

KV Cache Inference Complexity

For generating a sequence of length n:

Step 1:

  • Process: 1 token
  • Compute: Q₁, K₁, V₁
  • Cache: Store K₁, V₁
  • Attention: O(1²d) = O(d)

Step 2:

  • Process: 1 token (new)
  • Compute: Q₂, K₂, V₂ (only new token!)
  • Retrieve: K₁, V₁ from cache
  • Cache: Store K₂, V₂ (update cache)
  • Attention: O(1 × 2d) = O(2d) (1 query, 2 keys/values)

Step 3:

  • Process: 1 token (new)
  • Compute: Q₃, K₃, V₃ (only new token!)
  • Retrieve: K₁, V₁, K₂, V₂ from cache
  • Cache: Store K₃, V₃ (update cache)
  • Attention: O(1 × 3d) = O(3d) (1 query, 3 keys/values)

Step i:

  • Process: 1 token (new)
  • Compute: Qᵢ, Kᵢ, Vᵢ (only new token!)
  • Retrieve: K₁ through K_{i-1}, V₁ through V_{i-1} from cache
  • Cache: Store Kᵢ, Vᵢ (update cache)
  • Attention: O(1 × id) = O(id) (1 query, i keys/values)

Total Complexity:

Total = O(d) + O(2d) + O(3d) + ... + O(nd)
     = O(d) × (1 + 2 + 3 + ... + n)
     = O(d) × n(n+1)/2
     = O(n²d)

Per-step complexity: O(id) for step i (linear in sequence length)

Comparison

AspectStandard InferenceKV Cache Inference
Total ComplexityO(n³d)O(n²d)
Per-Step ComplexityO(i²d)O(id)
K, V ComputationRecomputes allOnly new token
MemoryO(1)O(nd) (cache)
Speedup1x (baseline)~n× faster

Key Insight:

  • Standard inference: Quadratic per step, cubic total
  • KV cache: Linear per step, quadratic total
  • For long sequences (large n), KV cache provides significant speedup

Part 6: Memory Considerations

Memory Trade-off

KV cache trades computation for memory:

Standard Inference:

  • Memory: O(1) - no cache needed
  • Computation: High (recomputes everything)

KV Cache:

  • Memory: O(nd) - stores K, V for all tokens
  • Computation: Low (reuses cached values)

Memory Breakdown:

For a sequence of length n, model dimension d, num_heads h:

  • K cache: n × d (or n × d/h per head)
  • V cache: n × d (or n × d/h per head)
  • Total per layer: 2nd
  • Total for L layers: 2Lnd

Example:

  • n = 2048 tokens
  • d = 4096
  • L = 32 layers
  • Memory: 2 × 32 × 2048 × 4096 × 2 bytes (float16) ≈ 2 GB

Optimization:

  • Can use quantization (int8, int4) to reduce memory
  • Can use PagedAttention (vLLM) for efficient memory management
  • Can limit cache size for very long sequences

Part 7: Practical Implementation Details

How KV Cache is Stored

Structure:

past_key_values = [
    (K_layer_0, V_layer_0),  # Cache for layer 0
    (K_layer_1, V_layer_1),  # Cache for layer 1
    ...
    (K_layer_L, V_layer_L),  # Cache for layer L
]

Shape Evolution:

After step 1:

K_layer_0: (batch, 1, d_k)  # 1 token cached
V_layer_0: (batch, 1, d_v)

After step 2:

K_layer_0: (batch, 2, d_k)  # 2 tokens cached
V_layer_0: (batch, 2, d_v)

After step i:

K_layer_0: (batch, i, d_k)  # i tokens cached
V_layer_0: (batch, i, d_v)

Concatenation Operation

The key operation is concatenation:

# New token's K, V
K_new = compute_K(new_token)  # (batch, 1, d_k)
V_new = compute_V(new_token)  # (batch, 1, d_v)

# Cached K, V
K_past = past_key_values[layer][0]  # (batch, past_len, d_k)
V_past = past_key_values[layer][1]  # (batch, past_len, d_v)

# Concatenate
K = torch.cat([K_past, K_new], dim=1)  # (batch, past_len + 1, d_k)
V = torch.cat([V_past, V_new], dim=1)  # (batch, past_len + 1, d_v)

# Update cache
past_key_values[layer] = (K, V)

This concatenation is the core of KV cache:

  • It combines cached (past) and new (current) keys/values
  • Allows attention to attend to all tokens (past + current)
  • Without recomputing past keys/values

Part 8: Summary

What KV Cache Does

  1. Stores computed keys and values for all previous tokens
  2. Reuses cached values instead of recomputing them
  3. Only computes K, V for new token at each step
  4. Concatenates cached + new for attention computation

How It Improves Over Standard Inference

  1. Reduces computation:

    • Standard: O(n³d) total, O(i²d) per step
    • KV Cache: O(n²d) total, O(id) per step
    • Speedup: ~n× for long sequences
  2. Eliminates redundancy:

    • Standard: Recomputes K, V for all previous tokens
    • KV Cache: Computes K, V only once per token
  3. Makes each step efficient:

    • Standard: Processes entire sequence at each step
    • KV Cache: Processes only new token at each step

Code Changes

Standard Inference:

input_ids = entire_sequence  # All tokens
outputs = model(input_ids)  # Processes all, recomputes all

KV Cache Inference:

input_ids = [new_token]  # Only new token
outputs = model(input_ids, past_key_values=cache)  # Uses cache, only computes new
cache = outputs.past_key_values  # Update cache

Trade-offs

Benefits:

  • Much faster inference (especially for long sequences)
  • Reduces redundant computation
  • Enables efficient generation

Costs:

  • Requires memory to store cache (O(nd) per layer)
  • Slightly more complex implementation
  • Cache management overhead

Verdict:

  • KV cache is essential for efficient LLM inference
  • The memory cost is usually worth the speedup
  • Used in all production LLM serving systems

Conclusion

KV cache is a fundamental optimization that eliminates redundant computation in autoregressive generation. By caching and reusing computed keys and values, it transforms inference from O(n³d) to O(n²d) complexity, providing significant speedups especially for long sequences. The key insight is that keys and values for a token depend only on that token, not on future tokens, so they can be computed once and reused for all future attention computations.