🧠 All Things AI
Advanced

Multi-Head Self-Attention

The attention mechanism described in "Attention Is All You Need" (Vaswani et al., 2017) is powerful — but a single attention head computing scaled dot-product attention learns one type of relationship across the sequence per layer. Multi-head attention runs H copies of attention in parallel, each in a lower-dimensional subspace, allowing a single layer to simultaneously capture syntactic structure, semantic similarity, positional patterns, and long-range coreference. This is one of the key architectural decisions that makes transformers expressive at depth.

Why Multiple Heads

Consider a sentence like "The bank by the river had willow trees." The word "bank" is ambiguous — it could be financial or geographic. Resolving the ambiguity requires attending simultaneously to multiple signals: "river" (nearby, semantic), "trees" (nearby, semantic), and perhaps "the" (syntactic determiner pattern). A single attention head must learn one generalised relationship per layer. Multiple heads let the model maintain parallel, specialised "queries" about different aspects of the sequence at the same depth.

A key insight from the original paper: running H heads in a subspace of dimension dmodel/H has the same total computational cost as running a single head in the full dmodel space. The parallelism is free in terms of FLOPs, but provides much richer representational capacity.

Mechanics of Multi-Head Attention

Given input X of shape (seq_len × dmodel), multi-head attention proceeds as follows for each head h:

For each head h ∈ {1,...,H}: Qₕ = X · Wₕᴼ (seq_len × dₖ, dₖ = d_model / H) Kₕ = X · Wₕᴷ (seq_len × dₖ) Vₕ = X · Wₕⱽ (seq_len × dᵥ, dᵥ = d_model / H) headₕ = Attention(Qₕ, Kₕ, Vₕ) = softmax(QₕKₕᵀ / √dₖ) · Vₕ MultiHead(X) = Concat(head₁, ..., headₕ) · Wᴼ Wᴼ shape: (H·dᵥ × d_model) — projects concatenated heads back to d_model
Input X
X (seq × d_model)
Head Projections
Head 1: Q₁K₁V₁
Head 2: Q₂K₂V₂
Head H: QₕKₕVₕ
Attention per Head
Attn₁
Attn₂
AttnH
Concat + Project
Concat → Wᴼ → Output (seq × d_model)

Multi-head attention: H parallel attention operations concatenated and projected

Each head has its own learnable projection matrices WhQ, WhK, WhV of shape (dmodel × dk). These projections are the trainable parameters — the model learns which subspace each head should operate in. The output projection matrix WO maps the concatenated head outputs back to dmodel, mixing information across heads before it enters the residual stream.

What Different Heads Learn

Empirical studies of trained transformers have revealed consistent patterns in what individual attention heads specialise in. While not every head falls neatly into a single category, several functions appear reliably across model families and scales.

Syntactic Heads

Track grammatical dependencies: subject-verb agreement, noun-adjective relationships, prepositional phrase attachment. These heads show strong attention between syntactically related tokens even when they are far apart in the sequence. Found reliably in BERT's earlier layers.

Positional Heads

Attend to tokens at consistent offsets: always the previous token, always the next token, or always a fixed distance back. These implement local context windows within the global attention operation and are often found in the earliest layers.

Coreference Heads

Connect pronouns to their antecedents ("it" → "the animal"). These heads attend across long distances and require semantic understanding of noun phrase identity. Found in middle layers of larger models where sufficient contextual information has accumulated.

Semantic Heads

Attend to semantically related tokens regardless of syntactic relationship or distance: "Paris" attending to "France", verbs attending to their semantic arguments. More common in later layers where positional and syntactic structure has been encoded in earlier layers.

Induction Heads

One of the most significant mechanistic interpretability findings is the identification of induction heads (Olah et al., Anthropic, 2022). Induction heads are pairs of attention heads — typically one in an early layer and one in a later layer — that together implement a copying mechanism fundamental to in-context learning.

Example sequence: [A] [B] ... [A] [?] Previous-token head (layer L): attends to the token immediately before each position Induction head (layer L+1): for position [A], attends to where [A] appeared before; then reads the token that followed it in that context Result: predicts [B] at position [?] — the model has learned "copy the pattern from context"

Induction heads emerge reliably with scale — they appear around 2 layers in even small transformers — and their formation correlates with a sharp drop in training loss (the "phase transition" seen in loss curves during training). They are believed to be a core computational primitive underlying few-shot in-context learning: the ability to follow examples in the prompt without gradient updates.

Head Pruning

Michel et al. (2019) and subsequent work found that many attention heads can be removed from trained models with minimal performance degradation. At test time, for a typical BERT-base model (12 layers × 12 heads = 144 heads), as many as 20% of heads can be ablated without changing performance on GLUE. Some heads, when removed, actually improve performance — suggesting they were contributing noise rather than signal.

Head pruning is useful for inference compression: removing heads reduces the computational cost of the QKV projections and the attention computation proportionally. Models like DistilBERT (Sanh et al., 2019) use knowledge distillation combined with layer and head reduction to produce models 40% smaller with 97% of the original performance. The empirical redundancy in trained attention heads is one argument for architectures that learn to allocate capacity more efficiently.

GQA and MQA: Reducing KV Cache

In autoregressive inference, every previously generated token requires storing its key and value projections from every layer (the KV cache). With H heads, each layer stores 2 × H × dk floats per token per layer — this becomes the dominant memory cost at long context lengths, often exceeding the model's parameter memory for large batch sizes.

VariantQuery HeadsKV Head GroupsKV Cache SizeUsed By
MHA (standard)HH1× (baseline)GPT-3, original BERT
GQAHG (G < H)G/H ×Llama 3, Mistral, Gemma
MQAH11/H ×PaLM, Falcon, Starcoder

Grouped Query Attention (GQA) (Ainslie et al., 2023) divides the H query heads into G groups. Each group shares a single K and V head. Query heads within the same group attend to the same keys and values but use different query projections. Llama 3 70B uses 8 KV heads with 64 query heads (GQA ratio 8:1), reducing KV cache by 8× compared to standard MHA with negligible quality degradation.

Multi-Query Attention (MQA) takes this to the extreme: a single K and V head shared across all query heads. This yields the maximum KV cache reduction (H-fold) and the fastest inference, at the cost of a small quality drop that becomes more noticeable at very large model sizes. MQA is most effective when KV cache memory is the binding constraint — serving many concurrent users with long contexts.

Checklist: Do You Understand This?

  • Can you explain why using H heads in dimension dmodel/H has the same FLOPs as a single head in dmodel, and why this matters for the design choice?
  • Can you write out the multi-head attention computation — the per-head projections, the attention formula, and the concat + output projection?
  • Can you name four types of specialisation observed in empirical studies of attention heads and describe what each type attends to?
  • Can you explain what an induction head is, what two-head circuit implements it, and why it is significant for in-context learning?
  • Can you describe what GQA is, how it differs from standard MHA and from MQA, and explain why reducing KV heads reduces inference memory?
  • For a model with 32 query heads using GQA with 4 KV groups, what is the KV cache size reduction compared to standard MHA?