Topic 5: Attention Mechanisms

🔥 For interviews, read these first:

  • ATTENTION_DEEP_DIVE.md — frontier-lab interview deep dive: MHA → MQA → GQA → MLA hierarchy, sliding window receptive-field math, sparse attention, linear attention (Performer, RWKV, SSM connection), induction heads, attention sinks.
  • INTERVIEW_GRILL.md — 50 active-recall questions.

See also 04_transformers/TRANSFORMERS_DEEP_DIVE.md for foundational scaled-dot-product attention.

What You'll Learn

This topic teaches you different attention mechanisms:

  • Self-attention
  • Cross-attention
  • Scaled dot-product attention
  • Sparse attention
  • Longformer/BigBird attention
  • What problems each solves

Why We Need This

Interview Importance

  • Common question: "Explain different attention mechanisms"
  • Problem-solving: Know which attention to use when
  • Understanding: Deep understanding of attention

Real-World Application

  • Long context: Sparse attention for long sequences
  • Efficiency: Different attentions have different costs
  • Specialized tasks: Different tasks need different attention

Industry Use Cases

1. Self-Attention

Use Case: BERT, GPT

  • Language understanding
  • Text generation
  • Standard transformer attention

2. Sparse Attention

Use Case: Longformer, BigBird

  • Long documents
  • Efficient long-context processing
  • Reduces quadratic complexity

3. Cross-Attention

Use Case: Encoder-decoder models

  • Translation
  • Question answering
  • Cross-modal tasks

Core Intuition

Different attention mechanisms exist because "let every token attend to every token" is not always the right answer.

The right attention pattern depends on the task:

  • do you need bidirectional context?
  • do you need causality?
  • do you need to connect two sequences?
  • do you need long-context efficiency?

Self-Attention

Self-attention is the default when tokens in one sequence need to interact with each other.

Use it when:

  • a sequence needs internal contextualization
  • every token may depend on far-away tokens

Causal Attention

Causal attention is self-attention with a future mask.

It is used when:

  • you are generating left-to-right
  • the model must not cheat by looking ahead

Cross-Attention

Cross-attention is used when one sequence should read from another sequence.

Classic example:

  • decoder queries
  • encoder keys and values

That lets the decoder decide which encoded information matters at each step.

Sparse Attention

Sparse attention changes the connectivity pattern so not every token attends to every token.

This matters because full attention becomes expensive for long sequences.

Technical Details Interviewers Often Want

Causal Mask Orientation

A very common interview bug is using the wrong triangular mask.

Correct intuition:

  • token i can attend to tokens <= i
  • token i cannot attend to tokens > i

So the mask must keep the lower triangle.

Why Sparse Attention Helps

Sparse attention reduces the number of token-to-token interactions.

The exact complexity depends on the sparsity pattern, but the main idea is:

  • spend compute only where useful structure is expected

Examples:

  • local window attention
  • global tokens
  • block or pattern-based sparsity

Cross-Attention Shape Logic

The most important shape fact is:

  • the query length and key/value length do not need to be the same

That is why cross-attention works naturally across:

  • encoder vs decoder sequence lengths
  • text vs image patches
  • question vs context

Common Failure Modes

  • wrong mask orientation in causal attention
  • misunderstanding cross-attention as if it were ordinary self-attention
  • assuming sparse attention is automatically better than full attention
  • forgetting that sparse patterns can lose useful long-range interactions
  • softmax applied on the wrong axis

Edge Cases and Follow-Up Questions

  1. Why is lower-triangular masking correct for causal attention?
  2. When would sparse attention hurt quality?
  3. Why is cross-attention useful in multimodal systems?
  4. What happens if important dependencies fall outside the sparse pattern?
  5. Why is full attention still often preferred when context length is manageable?

What to Practice Saying Out Loud

  1. The difference between self-attention and cross-attention
  2. Why causal attention is essential for autoregressive generation
  3. Why sparse attention is a compute trade-off, not a free improvement

Industry-Standard Boilerplate Code

Self-Attention (Standard)

"""
Self-Attention
Standard attention used in transformers
"""
import numpy as np

def self_attention(Q, K, V, d_k, mask=None):
    """Standard self-attention"""
    scores = Q @ K.T / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)
    attention_weights = softmax(scores)
    return attention_weights @ V

Causal Attention (GPT-style)

"""
Causal Attention
Masks future positions (for autoregressive generation)
"""
def causal_attention(Q, K, V, d_k):
    """Causal attention with lower triangular mask"""
    seq_len = Q.shape[0]
    # Create lower triangular mask
    mask = np.tril(np.ones((seq_len, seq_len)))
    return self_attention(Q, K, V, d_k, mask=mask)

What This Code Does:

Step 1: Get sequence length

seq_len = Q.shape[0]  # Number of tokens

Step 2: Create lower triangular mask

mask = np.tril(np.ones((seq_len, seq_len)))

What happens:

  • np.ones((seq_len, seq_len)) creates matrix of all 1s
  • np.tril() keeps only lower triangular part (sets upper to 0)
  • Result: Lower triangular matrix where:
    • mask[i, j] = 1 if j ≤ i (can attend to past/current)
    • mask[i, j] = 0 if j > i (cannot attend to future)

Example for seq_len=4:

[[1, 0, 0, 0],   ← Position 0: can only see itself
 [1, 1, 0, 0],   ← Position 1: can see 0, 1
 [1, 1, 1, 0],   ← Position 2: can see 0, 1, 2
 [1, 1, 1, 1]]   ← Position 3: can see all (0, 1, 2, 3)

Step 3: Apply mask in attention

return self_attention(Q, K, V, d_k, mask=mask)

Inside self_attention:

  • Computes attention scores: scores = Q @ K.T / √d_k
  • Applies mask: scores[mask == 0] = -∞ (future positions)
  • After softmax: Future positions get 0 attention weight
  • Result: Each position only attends to past and current tokens

Why Lower Triangular?

  • Lower triangular = can attend to positions ≤ current (past + current)
  • Upper triangular = wrong (would allow future, block past)
  • This enforces causal constraint for autoregressive generation

See causal_attention_detailed.md for complete explanation!

Sparse Attention (Longformer-style)

"""
Sparse Attention
Only attends to local + global positions
Reduces O(n²) to O(n)
"""
def sparse_attention(Q, K, V, d_k, window_size=512, global_indices=None):
    """
    Sparse attention: local window + global tokens
    
    Args:
        window_size: Local attention window
        global_indices: Positions that attend to all (e.g., [CLS] token)
    """
    seq_len = Q.shape[0]
    scores = Q @ K.T / np.sqrt(d_k)
    
    # Create sparse mask
    mask = np.zeros((seq_len, seq_len))
    
    # Local attention (sliding window)
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2)
        mask[i, start:end] = 1
    
    # Global attention
    if global_indices:
        for idx in global_indices:
            mask[idx, :] = 1  # Attend to all
            mask[:, idx] = 1  # All attend to this
    
    # Apply mask
    scores = np.where(mask == 1, scores, -1e9)
    attention_weights = softmax(scores)
    return attention_weights @ V

Cross-Attention

"""
Cross-Attention
Query from one sequence, Key/Value from another
Used in encoder-decoder architectures
"""
def cross_attention(Q_encoder, K_decoder, V_decoder, d_k):
    """
    Cross-attention: Q from encoder, K/V from decoder
    
    Args:
        Q_encoder: Queries from encoder (encoder_len, d_k)
        K_decoder: Keys from decoder (decoder_len, d_k)
        V_decoder: Values from decoder (decoder_len, d_v)
    """
    scores = Q_encoder @ K_decoder.T / np.sqrt(d_k)
    attention_weights = softmax(scores)
    return attention_weights @ V_decoder

What Problems They Solve

Self-Attention

  • Problem: Need to relate all positions
  • Solution: Every position attends to every position
  • Cost: O(n²)

Causal Attention

  • Problem: Autoregressive generation (can't see future)
  • Solution: Mask future positions
  • Use: GPT, language models

Sparse Attention

  • Problem: O(n²) too expensive for long sequences
  • Solution: Only attend to local + few global positions
  • Use: Longformer, BigBird, long documents

Cross-Attention

  • Problem: Need to relate two sequences
  • Solution: Query from one, Key/Value from other
  • Use: Translation, encoder-decoder

Theory

Attention Complexity

Detailed Analysis: See attention_complexity.md for complete complexity breakdown!

TypeTime ComplexitySpace ComplexityUse Case
Self-attentionO(n²d)O(n²)Standard transformers
Multi-headO(n²d)O(n²)GPT, parallelizable
LinearO(nd²)O(nd)Very long sequences (n >> d)
Sparse (Longformer)O(n√n d)O(n√n)Long sequences
Sparse (BigBird)O(n log n d)O(n log n)Very long sequences
Flash AttentionO(n²d)O(n)Memory-constrained training

Key Insight: Standard attention is O(n²d) because it computes pairwise relationships between all n tokens, with each computation involving d-dimensional vectors. The n² term comes from the attention matrix (n×n), and the d term comes from the vector dimension.

When to Use Which

  • Self-attention: Standard, short sequences
  • Causal attention: Autoregressive generation
  • Sparse attention: Long sequences (>2048 tokens)
  • Cross-attention: Encoder-decoder tasks

Exercises

  1. Implement causal mask
  2. Implement sparse attention
  3. Compare attention patterns
  4. Measure computational cost

Causal Attention: Detailed Explanation

New Comprehensive Guide:

  • causal_attention_detailed.md: Complete theoretical explanation

    • Why we need causal attention (autoregressive constraint)
    • How causal attention works (lower triangular mask)
    • Step-by-step code explanation
    • Visual examples
    • Why lower triangular (not upper)
    • Comparison with/without mask
    • Common mistakes and pitfalls
    • Advanced topics
  • causal_attention_code.py: Complete implementation with visualization

    • Step-by-step visualization
    • Comparison with/without mask
    • Explanation of lower triangular
    • Interactive examples

Key Concepts:

  • Lower triangular mask: np.tril(np.ones((seq_len, seq_len)))
  • Sets future positions to -∞ in attention scores
  • After softmax: Future positions get 0 attention weight
  • Enforces: Each position can only see past and current tokens
  • Critical for autoregressive models like GPT

Advanced Attention Mechanisms

New Comprehensive Content:

  • advanced_attention_mechanisms.md: Complete theoretical guide

    • Multi-Head Attention (MHA) - baseline
    • Multi-Query Attention (MQA) - shares K, V across all heads
    • Group Query Attention (GQA) - shares K, V within groups
    • Paged Attention - memory-efficient cache management
    • Detailed comparisons and trade-offs
    • Real-world usage and examples
  • advanced_attention_code.py: Complete implementations

    • MultiQueryAttention class
    • GroupQueryAttention class
    • PagedKVCache class (conceptual)
    • Comparison utilities
    • Memory analysis

Key Concepts:

Multi-Query Attention (MQA):

  • Shares K, V across all heads
  • KV Cache: seq_len × (d_k + d_v) (not per head!)
  • Reduction: num_heads× (e.g., 32× for 32 heads)

Group Query Attention (GQA):

  • Shares K, V within groups of heads
  • KV Cache: num_groups × seq_len × (d_k + d_v)
  • Reduction: (num_heads / num_groups)× (e.g., 4× for 32 heads, 8 groups)
  • Recommended for production (best balance)

Paged Attention:

  • Manages KV cache in non-contiguous pages
  • Eliminates memory fragmentation
  • 95%+ memory utilization (vs ~70% standard)
  • Core of vLLM's efficiency

Note on "Multi-Head Latent Attention":

  • Not a standard term in literature
  • Related concepts exist (latent variables, low-rank attention)
  • Mostly research topics, not widely deployed
  • Production systems typically use GQA, MQA, or standard MHA

Next Steps

  • Topic 6: LLM inference techniques
  • Topic 7: LLM problem solving