Attention Mechanism Complexity: Detailed Analysis
Overview
This document provides a comprehensive analysis of the computational complexity of different attention mechanisms, explaining why standard attention is O(n²d), what this means in practice, and how various attention variants achieve different complexity trade-offs.
Standard Self-Attention: O(n²d) Complexity
Mathematical Formulation
Standard self-attention computes attention as follows:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Where:
- Q (Query), K (Key), V (Value) are matrices of shape (n, d)
- n is the sequence length
- d is the model dimension
- d_k is the dimension of keys/queries (typically d_k = d)
Step-by-Step Complexity Analysis
Step 1: Compute QK^T
- Q shape: (n, d)
- K^T shape: (d, n)
- Result QK^T shape: (n, n)
- Operations: For each of n² output elements, compute dot product of d-dimensional vectors
- Complexity: O(n²d)
Step 2: Scale by √d_k
- Element-wise division of (n, n) matrix by scalar
- Complexity: O(n²)
Step 3: Apply Softmax
- Softmax over each row of (n, n) matrix
- For each row: compute exp for n elements, sum, divide
- Complexity: O(n²)
Step 4: Multiply by V
- Attention weights shape: (n, n)
- V shape: (n, d)
- Result shape: (n, d)
- Operations: For each of n output positions, compute weighted sum of n vectors of dimension d
- Complexity: O(n²d)
Total Complexity: O(n²d)
The dominant terms are steps 1 and 4, both O(n²d), making the overall complexity O(n²d).
Space Complexity: O(n²)
The attention score matrix QK^T has shape (n, n), requiring O(n²) memory to store. For a sequence of length 10,000, this requires storing 100 million attention scores. For sequences of length 100,000, this requires 10 billion attention scores, which can be prohibitively expensive.
Why Quadratic in Sequence Length?
The quadratic dependency on sequence length comes from the fact that attention computes pairwise relationships between all tokens. For each token, the model considers its relationship with every other token (or every previous token in causal attention). This "all-pairs" computation inherently requires O(n²) operations.
This is both a strength and a weakness: it allows the model to capture long-range dependencies and complex relationships, but it becomes computationally expensive for very long sequences. This is why GPT models typically have maximum sequence lengths (e.g., 2048 or 4096 tokens) and why processing longer sequences requires specialized techniques.
Multi-Head Attention Complexity
How Multi-Head Works
Multi-head attention splits the model dimension d into h heads, each with dimension d/h:
MultiHead(Q, K, V) = Concat(head₁, ..., headₕ)W^O
where headᵢ = Attention(QWᵢ^Q, KWᵢ^K, VWᵢ^V)
Each head processes smaller vectors (dimension d/h) but computes attention over the same sequence length n.
Complexity Analysis
Per Head:
- Q, K, V projections: O(nd²/h) each (since d → d/h)
- Attention computation: O(n²d/h) (same as standard attention but with d/h)
- Total per head: O(n²d/h)
All Heads:
- h heads × O(n²d/h) = O(n²d)
- Output projection: O(nd²)
Total Complexity: O(n²d + nd²)
For typical cases where n and d are similar (e.g., n=1024, d=768), the n²d term dominates, so complexity is effectively O(n²d), the same as single-head attention.
Parallelization Benefit
While the total computation is the same, multi-head attention can be parallelized across heads, potentially reducing wall-clock time on parallel hardware. However, the fundamental complexity remains O(n²d).
Linear Attention: O(nd²) Complexity
Reformulation
Linear attention reformulates the attention computation to avoid the quadratic dependency:
Standard: Attention = softmax(QK^T)V
Linear: Attention = Q(K^T V) / normalization
The key insight is using the associative property: (QK^T)V = Q(K^T V).
Complexity Analysis
Step 1: Compute K^T V
- K^T shape: (d, n)
- V shape: (n, d)
- Result K^T V shape: (d, d)
- Complexity: O(nd²)
Step 2: Multiply Q by (K^T V)
- Q shape: (n, d)
- (K^T V) shape: (d, d)
- Result shape: (n, d)
- Complexity: O(nd²)
Total Complexity: O(nd²)
This is linear in sequence length n, but quadratic in model dimension d. For sequences where n >> d, this provides significant speedup. However, for typical cases where n and d are similar, the benefit may be limited.
Trade-offs
Linear attention requires using a different similarity function (often a kernel function like exp or ReLU) instead of the standard dot product. This can affect the model's expressiveness. The trade-off is between computational efficiency (linear attention is faster for long sequences) and model capacity (standard attention may be more expressive).
Sparse Attention: Reduced Complexity
Local Attention: O(nk d)
Local attention only computes attention within a sliding window of size k:
For each position i, attend only to positions [i-k, i+k]
Complexity:
- For each of n positions, attend to k positions
- Each attention computation: O(kd)
- Total: O(nk d)
This is linear in sequence length when k is fixed. However, it cannot capture long-range dependencies beyond the window size.
Strided Attention: O(nk d)
Strided attention attends to every k-th token plus local neighbors:
For position i, attend to: [i-k, i-k+1, ..., i-1, i, i+1, ..., i+k]
plus [0, k, 2k, 3k, ...] (global tokens)
Complexity:
- Each position attends to ~2k + n/k positions
- Total: O(n(2k + n/k)d) ≈ O(n²d/k) when k is small
- With careful design: O(nk d)
Longformer Pattern: O(n√n d)
The Longformer uses a combination of local attention and global attention:
- Local attention: Each token attends to w tokens on each side (window size w)
- Global attention: A few special tokens attend to all positions, and all positions attend to these global tokens
Complexity:
- Local attention: O(nw d) where w is window size
- Global attention: O(ng d) where g is number of global tokens
- If g = √n and w = √n: O(n√n d)
This pattern maintains the ability to capture long-range dependencies (through global tokens) while reducing overall complexity.
BigBird Pattern: O(n log n d) or O(n√n d)
BigBird uses a similar pattern with:
- Local attention (sliding window)
- Global attention (a few tokens attend to all)
- Random attention (each token attends to r random positions)
Complexity:
- Local: O(nw d)
- Global: O(ng d)
- Random: O(nr d)
- With optimal parameters: O(n log n d) or O(n√n d)
Flash Attention: O(n²d) Time, O(n) Space
The Memory Problem
Standard attention requires storing the full attention matrix (n, n), which has O(n²) space complexity. For long sequences, this can exceed GPU memory limits, even though the computation itself is feasible.
Flash Attention Solution
Flash Attention computes attention in blocks without materializing the full attention matrix:
- Block-wise computation: Divide Q, K, V into blocks
- Incremental softmax: Compute softmax incrementally across blocks
- Online attention: Compute output directly without storing intermediate attention scores
Time Complexity: Still O(n²d) (same computation, just organized differently)
Space Complexity: O(n) instead of O(n²) (only need to store output, not full attention matrix)
How It Works
Flash Attention processes the sequence in blocks. For each block of queries, it:
- Loads corresponding blocks of keys and values
- Computes attention scores for this block
- Updates the output incrementally
- Discards the attention scores (doesn't store them)
This requires recomputing some attention scores multiple times (trading computation for memory), but enables processing much longer sequences on the same hardware.
Complexity Comparison Table
| Attention Type | Time Complexity | Space Complexity | Use Case |
|---|---|---|---|
| Standard | O(n²d) | O(n²) | Short sequences, maximum quality |
| Multi-Head | O(n²d) | O(n²) | Standard GPT, parallelizable |
| Linear | O(nd²) | O(nd) | Very long sequences (n >> d) |
| Local | O(nk d) | O(nk) | Fixed context window |
| Sparse (Longformer) | O(n√n d) | O(n√n) | Long sequences with structure |
| Sparse (BigBird) | O(n log n d) | O(n log n) | Very long sequences |
| Flash | O(n²d) | O(n) | Memory-constrained training |
Practical Implications
When n and d are Similar (e.g., n=1024, d=768)
- Standard attention: O(1024² × 768) ≈ O(800M) operations
- Linear attention: O(1024 × 768²) ≈ O(600M) operations
- Verdict: Similar cost, standard attention preferred for expressiveness
When n >> d (e.g., n=100,000, d=768)
- Standard attention: O(100,000² × 768) ≈ O(7.7T) operations
- Linear attention: O(100,000 × 768²) ≈ O(59B) operations
- Verdict: Linear attention is ~130x faster
Memory Considerations
For n=10,000, d=768:
- Standard attention: 10,000² × 4 bytes = 400 MB for attention matrix
- Flash attention: ~30 MB (just output)
- Verdict: Flash attention enables 10x longer sequences on same GPU
Summary
Standard self-attention has O(n²d) time complexity and O(n²) space complexity due to computing pairwise relationships between all tokens. Multi-head attention maintains the same complexity but enables parallelization. Linear attention reduces to O(nd²) by reformulating the computation, providing speedup when n >> d. Sparse attention patterns reduce complexity to O(n√n d) or O(n log n d) by only computing attention for a subset of token pairs. Flash Attention maintains O(n²d) time but reduces space to O(n) by computing in blocks without storing the full attention matrix. The choice of attention mechanism depends on the sequence length, model dimension, available memory, and quality requirements.