Work at a Frontier Lab
CoursesProblemsBlogPapersLibrariesDiscussJobs

Track 0: Foundations

Build the mental models that separate research engineers from ML practitioners.

Memory & Compute
The Memory Wall15m
Gradient Flow Under Pressure18m
Optimizers
SGD & Momentum15m
Adam, Warmup & Scheduling18m
Gradient Mechanics
Backprop as Graph Transformation20m
Initialization & Residual Connections18m
Scaling Laws & μ-Transfer20m
Systems Thinking
Bandwidth & Profiling18m
The Debugging Flowchart22m

Track 0: Foundations

Build the mental models that separate research engineers from ML practitioners.

Memory & Compute
The Memory Wall15mGradient Flow Under Pressure18m
Optimizers
SGD & Momentum15mAdam, Warmup & Scheduling18m
Gradient Mechanics
Backprop as Graph Transformation20mInitialization & Residual Connections18mScaling Laws & μ-Transfer20m
Systems Thinking
Bandwidth & Profiling18mThe Debugging Flowchart22m

Loading...

Built with Next.js

PrivacyTermsContactPapersLibrariesJobsDiscuss|GitHub|Work at a Frontier Lab

Loading...

  1. Home
  2. /
  3. Track 0: Foundations
  4. /
  5. Backprop as Graph Transformation

Backprop as Graph Transformation

Estimated reading time: 20 minutes

Previous

←Adam, Warmup & Scheduling

Next

Initialization & Residual Connections→

In this tutorial, you will trace gradient flow through a computation graph, implement a mini-autograd engine, and diagnose common gradient bugs (accumulation leaks, accidental detachment, missing requires_grad).

By the end you will be able to:

  • Predict which parameters receive gradients in a given graph and which do not
  • Verify autograd correctness using finite-difference gradient checks
  • Identify the root cause when a parameter's .grad is unexpectedly None or stale
💡

Core Idea

Every tensor knows two things: how it was created (its grad_fn) and who its parents are (the input tensors). The backward pass is a reverse traversal of this DAG, accumulating gradients with += at each node.

The Computation Graph#

During forward pass, PyTorch builds a directed acyclic graph (DAG):

Loading diagram...

Gradient Accumulation: += Not =#

⚠️

Accumulation, Not Overwrite

.backward() ADDS to .grad, it doesn't overwrite!

If you forget to zero gradients, you'll have gradients from previous batches contaminating current gradients. The model will learn garbage.

gradient_accumulation.py
Loading editor...

Detachment: Breaking the Graph#

💡

When to Detach

Use .detach() when you want the value of a tensor but not its gradient connection.

Common use cases:

  • REINFORCE reward baseline (shouldn't affect policy gradients)
  • Target networks in RL (shouldn't backprop through)
  • Logging/monitoring (don't want side effects)
detachment_demo.py
Loading editor...
1

Step 1: Normal flow. Without detachment, gradients flow through the entire graph. Both w1 and w2 receive gradients.

Build Your Own Autograd#

Let's implement a complete mini-autograd to see how it all fits together:

mini_autograd.py
Loading editor...

Scale Thought Experiment#

ScaleWhat BreaksMitigation
Small modelsGraph memory is negligibleStandard autograd
Large modelsActivation memory for backwardGradient checkpointing
Very long sequencesQuadratic attention activationsFlashAttention, chunked backward
Multi-GPUGraph needs to sync across devicesDistributed autograd, tensor parallelism

Production Reality#

Gradient Checkpointing: Instead of storing all activations, recompute them during backward:

  • 30-40% memory savings
  • 20-30% compute overhead
  • Used in all large model training (Megatron-LM, DeepSpeed, FSDP)

Debugging Tools:

  • tensor.register_hook(fn) — Inspect gradients during backward
  • torch.autograd.grad() — Compute gradients without .backward()
  • torch.autograd.set_detect_anomaly(True) — Find NaN/Inf sources

Break It: Gradient Flow Bugs#

❌

Common Mistakes

  1. Stale gradients from missing zero_grad() — After two .backward() calls without zeroing, .grad contains the sum of both batches. The model oscillates or diverges because updates use contaminated gradients.

  2. Accidental .detach() or .data access — Using tensor.data bypasses autograd silently. The parameter stops learning but no error is raised. Check by asserting .grad is not None after a backward pass.

  3. Missing requires_grad=True — A newly created parameter defaults to requires_grad=False unless wrapped in nn.Parameter. The optimizer receives zero gradients and the parameter never updates.

Diagnosis checklist when a parameter is not learning:

  1. Verify param.requires_grad is True
  2. Run one forward + backward, then check param.grad is not None
  3. Search for .detach() or .data on any tensor upstream of the parameter
  4. Check that optimizer.zero_grad() is called before each backward

Checkpoint Questions#

Each question requires tracing, calculating, or diagnosing — not just recall.

  1. Given the graph loss = sum(relu(x @ W1) @ W2), list every tensor that receives a gradient during .backward(). If you insert h = h.detach() after the relu, which tensors lose their gradients?
  2. You call .backward() twice without zeroing gradients. If the true gradient for step 1 is [0.5, -0.3] and for step 2 is [0.1, 0.2], what value does .grad hold after both calls? What effective learning rate multiplier does this create?
  3. A 32-layer transformer uses gradient checkpointing on every other layer. Estimate the activation memory savings as a fraction of the no-checkpointing baseline, and the approximate compute overhead percentage.

Research Hooks#

Papers:

  1. "Training Deep Nets with Sublinear Memory Cost" (Chen et al., 2016) — The original gradient checkpointing paper
  2. "Automatic Differentiation in Machine Learning: a Survey" (Baydin et al., 2018) — Comprehensive overview of autograd techniques

Open Questions:

  • Can we learn optimal checkpointing strategies for a given model architecture?
  • How do we efficiently compute second-order gradients (Hessians) for large models?

Next up: Initialization determines whether training starts or stalls. We'll derive why random scale must depend on layer width.