Multi-Head Self-Attention
The attention mechanism described in "Attention Is All You Need" (Vaswani et al., 2017) is powerful — but a single attention head computing scaled dot-product attention learns one type of relationship across the sequence per layer. Multi-head attention runs H copies of attention in parallel, each in a lower-dimensional subspace, allowing a single layer to simultaneously capture syntactic structure, semantic similarity, positional patterns, and long-range coreference. This is one of the key architectural decisions that makes transformers expressive at depth.
Why Multiple Heads
Consider a sentence like "The bank by the river had willow trees." The word "bank" is ambiguous — it could be financial or geographic. Resolving the ambiguity requires attending simultaneously to multiple signals: "river" (nearby, semantic), "trees" (nearby, semantic), and perhaps "the" (syntactic determiner pattern). A single attention head must learn one generalised relationship per layer. Multiple heads let the model maintain parallel, specialised "queries" about different aspects of the sequence at the same depth.
A key insight from the original paper: running H heads in a subspace of dimension dmodel/H has the same total computational cost as running a single head in the full dmodel space. The parallelism is free in terms of FLOPs, but provides much richer representational capacity.
Mechanics of Multi-Head Attention
Given input X of shape (seq_len × dmodel), multi-head attention proceeds as follows for each head h:
Multi-head attention: H parallel attention operations concatenated and projected
Each head has its own learnable projection matrices WhQ, WhK, WhV of shape (dmodel × dk). These projections are the trainable parameters — the model learns which subspace each head should operate in. The output projection matrix WO maps the concatenated head outputs back to dmodel, mixing information across heads before it enters the residual stream.
What Different Heads Learn
Empirical studies of trained transformers have revealed consistent patterns in what individual attention heads specialise in. While not every head falls neatly into a single category, several functions appear reliably across model families and scales.
Track grammatical dependencies: subject-verb agreement, noun-adjective relationships, prepositional phrase attachment. These heads show strong attention between syntactically related tokens even when they are far apart in the sequence. Found reliably in BERT's earlier layers.
Attend to tokens at consistent offsets: always the previous token, always the next token, or always a fixed distance back. These implement local context windows within the global attention operation and are often found in the earliest layers.
Connect pronouns to their antecedents ("it" → "the animal"). These heads attend across long distances and require semantic understanding of noun phrase identity. Found in middle layers of larger models where sufficient contextual information has accumulated.
Attend to semantically related tokens regardless of syntactic relationship or distance: "Paris" attending to "France", verbs attending to their semantic arguments. More common in later layers where positional and syntactic structure has been encoded in earlier layers.
Induction Heads
One of the most significant mechanistic interpretability findings is the identification of induction heads (Olah et al., Anthropic, 2022). Induction heads are pairs of attention heads — typically one in an early layer and one in a later layer — that together implement a copying mechanism fundamental to in-context learning.
Induction heads emerge reliably with scale — they appear around 2 layers in even small transformers — and their formation correlates with a sharp drop in training loss (the "phase transition" seen in loss curves during training). They are believed to be a core computational primitive underlying few-shot in-context learning: the ability to follow examples in the prompt without gradient updates.
Head Pruning
Michel et al. (2019) and subsequent work found that many attention heads can be removed from trained models with minimal performance degradation. At test time, for a typical BERT-base model (12 layers × 12 heads = 144 heads), as many as 20% of heads can be ablated without changing performance on GLUE. Some heads, when removed, actually improve performance — suggesting they were contributing noise rather than signal.
Head pruning is useful for inference compression: removing heads reduces the computational cost of the QKV projections and the attention computation proportionally. Models like DistilBERT (Sanh et al., 2019) use knowledge distillation combined with layer and head reduction to produce models 40% smaller with 97% of the original performance. The empirical redundancy in trained attention heads is one argument for architectures that learn to allocate capacity more efficiently.
GQA and MQA: Reducing KV Cache
In autoregressive inference, every previously generated token requires storing its key and value projections from every layer (the KV cache). With H heads, each layer stores 2 × H × dk floats per token per layer — this becomes the dominant memory cost at long context lengths, often exceeding the model's parameter memory for large batch sizes.
| Variant | Query Heads | KV Head Groups | KV Cache Size | Used By |
|---|---|---|---|---|
| MHA (standard) | H | H | 1× (baseline) | GPT-3, original BERT |
| GQA | H | G (G < H) | G/H × | Llama 3, Mistral, Gemma |
| MQA | H | 1 | 1/H × | PaLM, Falcon, Starcoder |
Grouped Query Attention (GQA) (Ainslie et al., 2023) divides the H query heads into G groups. Each group shares a single K and V head. Query heads within the same group attend to the same keys and values but use different query projections. Llama 3 70B uses 8 KV heads with 64 query heads (GQA ratio 8:1), reducing KV cache by 8× compared to standard MHA with negligible quality degradation.
Multi-Query Attention (MQA) takes this to the extreme: a single K and V head shared across all query heads. This yields the maximum KV cache reduction (H-fold) and the fastest inference, at the cost of a small quality drop that becomes more noticeable at very large model sizes. MQA is most effective when KV cache memory is the binding constraint — serving many concurrent users with long contexts.
Checklist: Do You Understand This?
- Can you explain why using H heads in dimension dmodel/H has the same FLOPs as a single head in dmodel, and why this matters for the design choice?
- Can you write out the multi-head attention computation — the per-head projections, the attention formula, and the concat + output projection?
- Can you name four types of specialisation observed in empirical studies of attention heads and describe what each type attends to?
- Can you explain what an induction head is, what two-head circuit implements it, and why it is significant for in-context learning?
- Can you describe what GQA is, how it differs from standard MHA and from MQA, and explain why reducing KV heads reduces inference memory?
- For a model with 32 query heads using GQA with 4 KV groups, what is the KV cache size reduction compared to standard MHA?