🧠 All Things AI
Advanced

Backpropagation

Backpropagation is the algorithm that makes it practical to train deep neural networks. It computes the gradient of the loss function with respect to every parameter in the network, enabling gradient descent to update all weights simultaneously. Despite its reputation for complexity, backpropagation is not a special neural-network trick — it is the chain rule of calculus applied systematically to a directed acyclic computation graph. Understanding it precisely demystifies training and explains many of the practical difficulties that arise in deep networks.

What Backpropagation Does

Training a neural network requires answering a specific question: for the current set of weights, how should each weight change in order to reduce the loss? The answer is the gradient — a vector of partial derivatives, one per parameter, indicating the direction and magnitude of steepest ascent in loss. Gradient descent moves in the opposite direction.

A network with millions of parameters could have its gradients computed numerically — perturb each weight slightly, observe the change in loss, divide. But this requires two forward passes per parameter, making it completely impractical at scale. Backpropagation computes all gradients in a single backward pass through the network, at roughly the same computational cost as the forward pass. For a network with N parameters, backpropagation is O(N) rather than O(N²), which is the difference between feasible and impossible.

Computation Graphs

The forward pass of any neural network can be represented as a directed acyclic graph (DAG). Each node represents an operation — addition, multiplication, activation function, matrix product. Each directed edge carries a value forward from one operation to the next. The final node produces the scalar loss.

Consider a simple example: z = (x + y) · w, then L = z². The computation graph has four nodes: one for addition, one for multiplication, one for squaring, and one for the result L. Values flow from left to right during the forward pass. During the backward pass, gradients flow from right to left — from the loss back through squaring, back through multiplication, back through addition, and finally to the original inputs x, y, and w.

Forward: x=2, y=3, w=4 s = x + y = 5 z = s · w = 20 L = z² = 400 Backward: ∂L/∂z = 2z = 40 ∂L/∂w = ∂L/∂z · ∂z/∂w = 40 · s = 40 · 5 = 200 ∂L/∂s = ∂L/∂z · ∂z/∂s = 40 · w = 40 · 4 = 160 ∂L/∂x = ∂L/∂s · ∂s/∂x = 160 · 1 = 160 ∂L/∂y = ∂L/∂s · ∂s/∂y = 160 · 1 = 160

Each node needs to know only two things: its local gradient (how its output changes with respect to its inputs) and the upstream gradient (how the loss changes with respect to its output). It multiplies these together to produce the gradient it passes further upstream. This locality is what makes the algorithm efficient — each operation only needs to implement its own backward computation without knowing anything about the rest of the network.

The Chain Rule in Computation Graphs

The chain rule states that if y = f(g(x)), then dy/dx = (dy/dg) · (dg/dx). In a computation graph with multiple paths from an intermediate node to the loss, the gradient contribution from each path is summed. This is the generalisation that handles arbitrary graph topologies, not just sequential chains.

For a concrete neural network example, consider a single linear layer followed by a squared loss. Let y = Wx + b, L = ‖y − t‖² where t is the target.

Forward pass: y = W · x + b (linear transformation) e = y − t (error) L = eᵀe (squared loss) Backward pass: ∂L/∂e = 2e (gradient of squared loss) ∂L/∂y = ∂L/∂e · 1 = 2e ∂L/∂W = ∂L/∂y · xᵀ = 2e · xᵀ (outer product) ∂L/∂b = ∂L/∂y = 2e ∂L/∂x = Wᵀ · ∂L/∂y = 2Wᵀe

The gradient with respect to W is an outer product of the upstream gradient and the input x. This pattern — the gradient of a weight matrix is the outer product of the upstream gradient and the layer's input — appears throughout deep learning and is worth memorising. It explains why the forward pass activations must be stored in memory: the backward pass needs them to compute weight gradients.

Vanishing and Exploding Gradients

In a deep network with L layers, the gradient of the loss with respect to the parameters in the first layer involves a product of L Jacobian matrices — one per layer, accumulated via the chain rule. This product can behave badly in two directions.

Vanishing Gradients

If the Jacobian matrices have singular values less than 1 — as happens when using sigmoid or tanh activations, which saturate and produce near-zero derivatives — the product of many such matrices shrinks exponentially toward zero. Early layers receive essentially no gradient signal. Their weights stop updating. The network effectively only trains its final few layers.

Exploding Gradients

If the Jacobian matrices have singular values greater than 1 — as happens with large initial weights — the product grows exponentially. Gradients become enormous numbers, weight updates overshoot dramatically, and loss diverges to infinity. More common in recurrent networks due to long sequences of multiplications through the same weight matrix.

These problems were why deep networks were largely intractable before roughly 2010. Networks deeper than about five layers simply refused to train effectively with the tools available at the time. The solution required several advances working together.

Solutions to Gradient Problems

Three techniques largely resolved vanishing and exploding gradients for different network types:

ReLU Activation

ReLU(x) = max(0, x) has a gradient of exactly 1 for all positive inputs. This means the gradient does not shrink as it passes through a ReLU node (when that neuron is active). Replacing sigmoid/tanh hidden activations with ReLU was one of the key practical steps that made training deep networks feasible.

Residual Connections

Introduced in ResNet (He et al., 2015), residual connections add the input of a block directly to its output: y = F(x) + x. During backpropagation, this creates a gradient highway: ∂L/∂x = ∂L/∂y · (∂F/∂x + I). The identity term I ensures that gradient flows directly from any layer back to earlier layers regardless of how many residual blocks are in between. This is why transformers and modern LLMs use residual connections throughout.

Gradient Clipping

Before applying the weight update, compute the global L2 norm of all gradients. If it exceeds a threshold, scale all gradients down so the norm equals the threshold. This prevents exploding gradients without eliminating the gradient signal entirely. Commonly used in RNN training and still applied in large LLM training runs as a safety measure.

Practical Backpropagation: Automatic Differentiation

Modern deep learning frameworks implement backpropagation via automatic differentiation (autograd). You write the forward pass as normal Python code. The framework records the operations performed — building the computation graph dynamically — and provides a mechanism to reverse through that graph and accumulate gradients.

# PyTorch autograd example import torch x = torch.tensor([2.0], requires_grad=True) W = torch.tensor([[3.0]], requires_grad=True) y = W @ x.unsqueeze(0).T # forward: matrix multiply L = (y ** 2).sum() # forward: squared loss L.backward() # backward: compute all gradients print(W.grad) # ∂L/∂W — available automatically print(x.grad) # ∂L/∂x — available automatically

The requires_grad=True flag tells the framework to track operations involving this tensor. After calling .backward(), gradients accumulate in the .grad attribute of every leaf tensor that requires a gradient. Calling optimizer.step() then updates the weights using those stored gradients.

An important memory implication: all intermediate activations computed during the forward pass must be retained until the backward pass completes, because weight gradients depend on the layer's input (the outer product pattern described earlier). This is why training requires substantially more memory than inference — during inference, activations can be discarded immediately. For very long sequences or very deep networks, gradient checkpointing trades compute for memory: intermediate activations are discarded during the forward pass and recomputed on demand during the backward pass, roughly doubling forward-pass compute but reducing peak memory by a factor proportional to the recompute frequency.

Numerical Gradient Checking

When implementing a custom layer or loss function, bugs in the backward pass are common and easy to miss — the forward pass runs, numbers are produced, but the gradients are silently wrong. Numerical gradient checking is the standard debugging technique.

Numerical gradient approximation (centred finite differences): ∂L/∂θᵢ ≈ (L(θ + εeᵢ) − L(θ − εeᵢ)) / (2ε) where eᵢ is the unit vector in direction i, ε ≈ 1e-5 Check: |analytical − numerical| / max(|analytical|, |numerical|) < 1e-5

For each parameter θᵢ, you run the forward pass twice with a small perturbation in opposite directions and approximate the gradient by the finite difference. Compare this to the analytical gradient produced by your backward implementation. If they agree to several decimal places, the backward pass is correct. If not, there is a bug.

The cost is O(N) forward passes where N is the number of parameters — completely impractical for a network with millions of weights, but perfectly acceptable for verifying a small test case. The standard workflow is to write a minimal version of your layer with a handful of parameters, verify the gradients match numerically, then deploy the full-scale version with confidence.

Checklist: Do You Understand This?

  • Can you explain why backpropagation is O(N) in the number of parameters rather than O(N²), and why this matters for practical training?
  • Can you draw the computation graph for a two-node expression like L = (Wx + b)² and derive ∂L/∂W by hand using the chain rule?
  • Can you explain what causes vanishing gradients in sigmoid networks and describe the specific mathematical property that produces the problem?
  • Can you explain what a residual connection does to gradient flow and write the backpropagation equation through a residual block?
  • Can you describe what requires_grad=True and .backward() do in PyTorch and explain what memory cost they impose?
  • Can you describe the numerical gradient check procedure and explain when you would use it and why it cannot be used on full-scale networks?