🧠 All Things AI
Advanced

Attention Mechanism

Attention is the single operation that made transformers possible. It lets every position in a sequence directly communicate with every other position — regardless of distance — in one parallel step. Before attention, sequence models had to pass information through recurrent bottlenecks, and long-range dependencies were routinely lost. Attention replaces that bottleneck with a differentiable, fully-connected lookup that scales gracefully with model size.

The Core Problem Attention Solves

Consider this sentence: "The animal didn't cross the street because it was too tired."

What does "it" refer to? A human reader immediately resolves it to "animal" — not "street". The word "tired" makes it unambiguous. But for a model to do the same, it must relate "it" to a word that appears five tokens earlier, bridging other words in between. This is the long-range dependency problem.

Recurrent neural networks (RNNs) process sequences token by token, threading a fixed-size hidden state through each step. Information about "animal" must survive through five state transitions to reach the position processing "it". In practice this doesn't reliably happen. Gradients vanish over long distances; early tokens are overwritten by later ones. LSTMs and GRUs mitigate this with gating, but the fundamental bottleneck remains: all history must be compressed into a single vector.

Attention eliminates the bottleneck entirely. Every position can directly attend to every other position in the same sequence. The model learns which positions to attend to by computing a similarity score between positions and using those scores as weights. For co-reference resolution ("it" → "animal"), the model simply learns that the pronoun position should attend strongly to its antecedent.

Queries, Keys, and Values

Attention is best understood as a soft, differentiable lookup into a key-value store.

In a conventional database, you issue an exact query, the system finds the matching key, and returns the associated value. Attention does something softer: the query is compared against all keys simultaneously, producing a similarity score for each. Those scores become weights, and the output is a weighted sum over all values — no hard selection, just a continuous blend.

Query (Q)

What this position is looking for. "I'm a pronoun — what noun am I referring to?"

Key (K)

What each position advertises about itself. "I'm a noun phrase in subject position."

Value (V)

What each position contributes when attended to. The actual representation passed forward.

Each token's embedding is projected into three separate spaces via learned weight matrices WQ, WK, and WV. A query vector for position i is computed as qi = xiWQ, and similarly for keys and values. These projections are what the model trains — it learns what queries and keys should look like such that useful positions score highly.

As a concrete example: imagine a machine translation decoder generating the French word "il" (he). It issues a query that, after training, encodes something like "I am a third-person singular masculine pronoun". The encoder's key for the source word "he" (or "the man") has a compatible representation, so the dot product is large. The corresponding value — which carries the semantic content of the source word — receives high weight in the output. The decoder effectively looks up the most relevant source context for each target word it generates.

Scaled Dot-Product Attention

Given a batch of queries packed into matrix Q (shape nƗdk), keys K (shape mƗdk), and values V (shape mƗdv), attention computes:

Attention(Q, K, V) = softmax(QKįµ€ / √dā‚–) Ā· V

Step by step:

  1. Score matrix: QKįµ€ produces an nƗm matrix. Entry (i,j) is the dot product between query i and key j — a scalar measuring their compatibility.
  2. Scale by √dk: Divide every score by the square root of the key dimension. Without this, dot products grow large as dk increases (since each dot product is a sum of dk terms). Large scores push softmax into a saturated, nearly one-hot regime where gradients become vanishingly small.
  3. Softmax: Applied row-wise to convert raw scores into a probability distribution over positions. Row i now sums to 1.
  4. Weighted sum: Multiply attention weights (nƗm) by V (mƗdv) to get the output (nƗdv). Each output position is a convex combination of all value vectors.

The scaling factor √dk was one of the key insights in the original "Attention Is All You Need" paper (Vaswani et al., 2017). Without it, training large models with high-dimensional projections becomes unstable. With it, attention gradients stay well-conditioned regardless of dk.

Attention Weights Visualisation

The nƗm softmax output is an attention weight matrix. Examining it reveals what the model has learned to attend to.

Common Patterns Observed

  • Diagonal: A position attending mostly to itself (identity pass-through)
  • Off-diagonal bands: Consistent offset attention (e.g. always attending one position back)
  • Induction heads: Attending to the previous occurrence of the same token — a copying mechanism
  • End-of-sequence peaks: Some heads aggregate information into the final [EOS] token

Interpretation Caution

  • High attention weight does not mean causal importance
  • Ablation studies show high-weight heads can often be removed with minimal loss
  • Attention weights and gradient-based importance measures frequently disagree
  • Visualising attention is useful for hypothesis generation, not ground truth

Anthropic's mechanistic interpretability research has documented specific head functions (e.g. induction heads that enable in-context learning), but the general lesson is that attention visualisation is a starting point for analysis, not a complete explanation of model behaviour.

Self-Attention vs Cross-Attention

The distinction is simply where Q, K, and V come from.

Self-Attention

Q, K, and V are all derived from the same sequence. Each position attends to all other positions (including itself) in the same sequence.

Used in: transformer encoder layers, transformer decoder layers (for the decoder's own generated tokens)

Cross-Attention

Q comes from one sequence (e.g. the decoder's current state), while K and V come from a different sequence (e.g. the encoder's output). The decoder attends to the encoder.

Used in: encoder-decoder models (translation, summarisation), diffusion model conditioning on text embeddings

Decoder-only language models (GPT, LLaMA, Gemini) use only self-attention. Encoder-decoder models (T5, Whisper, the original transformer) use both: self-attention within each stack and cross-attention connecting the two stacks.

Causal (Masked) Self-Attention

Autoregressive generation requires that predicting token at position i uses only tokens at positions 1 through iāˆ’1. If future tokens were visible during training, the model would simply copy them — the task would become trivial and the model would learn nothing useful.

The solution is the causal mask: before applying softmax, set the attention score to āˆ’āˆž for any position j > i. After softmax, eāˆ’āˆž = 0, so those positions contribute zero to the weighted sum. Concretely, this is implemented as:

scores = QKįµ€ / √dā‚– scores = scores + mask # mask[i,j] = 0 if j ≤ i, else āˆ’āˆž weights = softmax(scores) # future positions get weight ā‰ˆ 0 output = weights Ā· V

The mask is a lower-triangular matrix of zeros (allowed positions) and āˆ’āˆž (blocked positions). This is computed once and reused across all layers. Because the masking is applied before softmax rather than after, the operation remains differentiable everywhere — positions that are masked simply receive zero gradient through the attention weights.

During inference (token-by-token generation), only the current position needs to be scored against all past positions. This is why the KV cache (storing past K and V computations) is so important — it avoids recomputing all past tokens on every generation step.

Complexity

Standard attention has two dominant costs that scale with sequence length n:

ResourceCostBottleneck
ComputeO(n² Ā· d)QKįµ€ matrix multiplication
MemoryO(n²)Attention weight matrix must be materialised

The quadratic scaling in n is the principal limitation of transformers. For n = 8,192 tokens (a modest context window by 2025 standards), the attention matrix has 67 million entries per head per layer. At n = 128,000 (Gemini 1.5, Claude 3), that becomes over 16 billion entries per head — impractical to store naively in GPU memory.

Several approaches address this: sparse attention (attend only to a fixed pattern of positions), linear attention (approximate the softmax kernel to achieve O(n) complexity), sliding window attention (Mistral's approach), and — most importantly in practice — Flash Attention, which keeps the same mathematical result but reduces memory traffic dramatically.

Flash Attention

Flash Attention (Dao et al., 2022; v2 2023; v3 2024) is an IO-aware implementation of exact attention that achieves the same mathematical result as the standard algorithm while using far less memory bandwidth.

The key insight is that modern GPUs have a small, fast on-chip memory (SRAM, ~20 MB) and a large, slow off-chip memory (HBM / VRAM, 40–80 GB). Standard attention materialises the full nƗn attention matrix in HBM — writing and reading it multiple times (once for QKįµ€, once for softmax, once for the output). For long sequences this HBM traffic dominates runtime, not the actual arithmetic.

Flash Attention tiles the computation: it processes blocks of Q, K, and V that fit in SRAM, accumulating the softmax numerator and denominator incrementally using the log-sum-exp trick. The full nƗn matrix is never written to HBM. Only the final output (shape nƗdv) is written out.

Flash Attention Properties

  • Mathematically exact — not an approximation
  • 2–4Ɨ faster than PyTorch naive attention in practice
  • Memory footprint O(n) instead of O(n²)
  • Enables training on 4–8Ɨ longer sequences for the same GPU memory
  • Now standard in all major LLM training frameworks

Flash Attention v3 (2024)

  • Optimised for Hopper architecture (H100)
  • Overlaps compute and memory operations via warp specialisation
  • Supports FP8 for further throughput gains
  • ~2Ɨ throughput improvement over FA2 on H100

The practical impact of Flash Attention cannot be overstated. Models like Llama 3, Mistral, and Gemma are trained with it by default. Inference libraries like vLLM and TGI use it for production serving. Understanding that it is an algorithmic/hardware co-design — not a mathematical approximation — is important: you get the same model quality with dramatically better utilisation of the GPU's compute hierarchy.

Checklist: Do You Understand This?

  • Can you explain why dot products need to be scaled by √dk before softmax?
  • What is the difference between a query vector and a key vector — what role does each play in the attention score?
  • If attention is O(n²) in memory, roughly how many values must be stored for n=16,384 tokens across 32 heads?
  • Why does causal masking use āˆ’āˆž rather than 0 before the softmax?
  • What distinguishes self-attention from cross-attention — give one architecture that uses each?
  • Flash Attention achieves the same result as standard attention but is faster. What does it change, and what does it preserve?