Backprop as Graph Transformation
Estimated reading time: 20 minutes
Build the mental models that separate research engineers from ML practitioners.
Loading...
Loading...
Estimated reading time: 20 minutes
In this tutorial, you will trace gradient flow through a computation graph, implement a mini-autograd engine, and diagnose common gradient bugs (accumulation leaks, accidental detachment, missing requires_grad).
By the end you will be able to:
.grad is unexpectedly None or staleDuring forward pass, PyTorch builds a directed acyclic graph (DAG):
Let's implement a complete mini-autograd to see how it all fits together:
| Scale | What Breaks | Mitigation |
|---|---|---|
| Small models | Graph memory is negligible | Standard autograd |
| Large models | Activation memory for backward | Gradient checkpointing |
| Very long sequences | Quadratic attention activations | FlashAttention, chunked backward |
| Multi-GPU | Graph needs to sync across devices | Distributed autograd, tensor parallelism |
Gradient Checkpointing: Instead of storing all activations, recompute them during backward:
Debugging Tools:
tensor.register_hook(fn) — Inspect gradients during backwardtorch.autograd.grad() — Compute gradients without .backward()torch.autograd.set_detect_anomaly(True) — Find NaN/Inf sourcesDiagnosis checklist when a parameter is not learning:
param.requires_grad is Trueparam.grad is not None.detach() or .data on any tensor upstream of the parameteroptimizer.zero_grad() is called before each backwardEach question requires tracing, calculating, or diagnosing — not just recall.
loss = sum(relu(x @ W1) @ W2), list every tensor that receives a gradient during .backward(). If you insert h = h.detach() after the relu, which tensors lose their gradients?.backward() twice without zeroing gradients. If the true gradient for step 1 is [0.5, -0.3] and for step 2 is [0.1, 0.2], what value does .grad hold after both calls? What effective learning rate multiplier does this create?Papers:
Open Questions:
Next up: Initialization determines whether training starts or stalls. We'll derive why random scale must depend on layer width.