Gradient Flow Under Pressure
Estimated reading time: 18 minutes
Build the mental models that separate research engineers from ML practitioners.
Loading...
Loading...
Estimated reading time: 18 minutes
In this tutorial, you will trace how gradients break when stored in 16-bit floats, build a loss scaler that prevents the breakage, and learn when to reach for BF16 instead.
Mixed-precision training is how every large model ships today. The failure mode is simple: small gradients underflow to exactly zero in FP16, and layers stop learning. By the end of this tutorial you will be able to:
Every floating-point number is stored as three fields: a sign bit, exponent bits, and mantissa (fraction) bits.
Here are the concrete numbers you need to remember:
| Format | Exponent bits | Mantissa bits | Smallest positive normal | Largest value |
|---|---|---|---|---|
| FP32 | 8 | 23 | ~1.2 x 10⁻³⁸ | ~3.4 x 10³⁸ |
| FP16 | 5 | 10 | ~6.1 x 10⁻⁵ (normal), ~6 x 10⁻⁸ (subnormal) | 65,504 |
| BF16 | 8 | 7 | ~1.2 x 10⁻³⁸ | ~3.4 x 10³⁸ |
The FP16 floor (~6 x 10⁻⁸ including subnormals) is where gradients go to die. Compare that with BF16's floor (~1.2 x 10⁻³⁸), which matches FP32 and is effectively never a problem during training.
During backpropagation, the gradient at layer k is the product of per-layer Jacobian norms from the output back to layer k. If each layer attenuates the gradient by a factor a (typically 0.7-0.9 for well-initialized networks), the gradient at layer k is approximately a^k.
Worked example: With attenuation factor 0.8 and 50 layers, the gradient at the earliest layer is 0.8⁵⁰ = 1.4 x 10⁻⁵. This is above FP16's floor, so it survives. But with attenuation 0.7 and 50 layers: 0.7⁵⁰ = 1.8 x 10⁻⁸, which is below the FP16 floor. That gradient becomes zero.
Run this simulation to see underflow in action. Before you run it: predict at which layer the gradient will underflow with attenuation = 0.7.
Loss scaling prevents underflow by multiplying the loss (and therefore all gradients via the chain rule) by a large constant before the backward pass. After the backward pass, gradients are divided by the same constant to restore correct magnitudes. The key is that the multiplication happens while values are still in a representable range.
Worked example: A gradient of 1e-8 underflows in FP16 (below 6e-8). But if you multiply by 65,536 first, it becomes 6.5e-4 — well within FP16 range. After the backward pass, divide by 65,536 to recover the true gradient.
A fixed scale factor is fragile: too low and gradients still underflow; too high and they overflow. Dynamic loss scaling solves this by adapting the scale during training.
The algorithm:
This is exactly what PyTorch's torch.cuda.amp.GradScaler does internally.
BF16 keeps FP32's 8 exponent bits, giving it the same range (~1e-38 to ~3.4e38). This means gradients almost never underflow in BF16, and you do not need loss scaling at all.
The trade-off is precision: BF16 has only 7 mantissa bits (~2-3 significant digits) compared to FP16's 10 mantissa bits (~3-4 digits). For most training workloads, this precision loss has negligible effect on convergence.
| Format | Range | Precision | Loss scaling needed? | Hardware support |
|---|---|---|---|---|
| FP32 | ~1e-38 to ~3.4e38 | ~7 digits | No | All GPUs |
| FP16 | ~6e-8 to 65,504 | ~3-4 digits | Yes (recommended) | V100, all modern GPUs |
| BF16 | ~1e-38 to ~3.4e38 | ~2-3 digits | No | A100, H100, TPU v3+ |
This section walks through the symptoms you will see in practice when numerical precision goes wrong, and the first things to check.
Symptom 1: Loss plateaus early, gradient norms collapse to zero.
Symptom 2: Loss spikes to NaN or inf intermittently.
Symptom 3: Training converges but final quality is slightly worse than expected.
How does the precision problem change as you scale up?
| Scenario | What breaks | Why | Mitigation |
|---|---|---|---|
| Shallow net (10 layers), small batch | Nothing — FP16 is fine | Gradients stay in representable range | Standard mixed precision |
| Deep net (100+ layers) | Early-layer underflow in FP16 | 0.85¹⁰⁰ = 3.4e-7, near FP16 floor | BF16, or aggressive loss scaling |
| Large batch (8K+ samples) | Gradient averaging pushes magnitudes down | Mean of 8K gradients is ~90x smaller than single-sample gradient | Higher loss scale to compensate |
| Long training (weeks) | Accumulated FP16 rounding errors | Master weights drift from true value over millions of steps | FP32 master weights (standard practice) |
| Very large model (70B+) | Activation memory forces FP16/BF16 | Cannot afford FP32 for activations | BF16 compute + FP32 accumulation |
NVIDIA Mixed Precision (FP16 path):
GradScaler (starts at 65,536, halves on overflow)Google TPU / A100+ (BF16 path):
Choosing in practice: Use this rubric in a model review:
Test your understanding with these operational questions:
Estimate: A 96-layer Transformer has average per-layer gradient attenuation of 0.82. At which layer will FP16 gradients first underflow? (Hint: solve 0.82^k = 6e-8 for k.)
Calculate: You are training with loss scale = 32,768 in FP16. What is the smallest gradient that will survive without underflowing? (Hint: the gradient times the scale must exceed the FP16 floor.)
Diagnose: Your training loss has been flat for 2,000 steps. Gradient norms for layers 0-20 are ~1e-3, but layers 21-48 all show exactly 0.0. What is the most likely cause, and what is the cheapest fix?
Decide: You have a cluster of V100 GPUs (no BF16 support) and need to train a 30-layer model. FP16 with no loss scaling, FP16 with dynamic loss scaling, or FP32 — which do you pick and why?
Papers:
Open questions:
Next up: We explore how optimizers navigate loss landscapes — and why SGD's noise is a feature, not a bug.