Work at a Frontier Lab
CoursesProblemsBlogPapersLibrariesDiscussJobs

Track 0: Foundations

Build the mental models that separate research engineers from ML practitioners.

Memory & Compute
The Memory Wall15m
Gradient Flow Under Pressure18m
Optimizers
SGD & Momentum15m
Adam, Warmup & Scheduling18m
Gradient Mechanics
Backprop as Graph Transformation20m
Initialization & Residual Connections18m
Scaling Laws & μ-Transfer20m
Systems Thinking
Bandwidth & Profiling18m
The Debugging Flowchart22m

Track 0: Foundations

Build the mental models that separate research engineers from ML practitioners.

Memory & Compute
The Memory Wall15mGradient Flow Under Pressure18m
Optimizers
SGD & Momentum15mAdam, Warmup & Scheduling18m
Gradient Mechanics
Backprop as Graph Transformation20mInitialization & Residual Connections18mScaling Laws & μ-Transfer20m
Systems Thinking
Bandwidth & Profiling18mThe Debugging Flowchart22m

Loading...

Built with Next.js

PrivacyTermsContactPapersLibrariesJobsDiscuss|GitHub|Work at a Frontier Lab

Loading...

  1. Home
  2. /
  3. Track 0: Foundations
  4. /
  5. Adam, Warmup & Scheduling

Adam, Warmup & Scheduling

Estimated reading time: 18 minutes

Previous

←SGD & Momentum

Next

Backprop as Graph Transformation→

In this tutorial, you will implement AdamW from scratch, design a cosine learning rate schedule with warmup, and calculate the correct learning rate when scaling batch size.

By the end you will be able to:

  • Compute Adam's effective per-parameter learning rate from gradient statistics
  • Choose warmup length and decay schedule for a given training budget
  • Apply the linear scaling rule and identify when it breaks down
💡

Core Idea

Adam gives each parameter its own effective learning rate by dividing by the running RMS of its gradients. Parameters with noisy gradients get smaller updates; parameters with consistent gradients get larger updates. This is adaptive normalization, not just momentum.

The Adam Algorithm#

m = β₁ × m + (1 - β₁) × g           # First moment (mean of gradients)
v = β₂ × v + (1 - β₂) × g²          # Second moment (variance of gradients)
m̂ = m / (1 - β₁ᵗ)                   # Bias correction
v̂ = v / (1 - β₂ᵗ)                   # Bias correction
θ = θ - lr × m̂ / (√v̂ + ε)          # Normalized update

The key insight: dividing by √v normalizes each parameter's update by its gradient variance.

Loading diagram...

AdamW: Why Weight Decay Must Be Decoupled#

⚠️

The AdamW Fix

Original Adam applied weight decay as part of the gradient: g = ∇L + λθ

But Adam then scales this by 1/√v. So weight decay gets scaled differently for each parameter!

AdamW fixes this: apply weight decay after the adaptive scaling: θ = θ - lr × (m̂/√v̂ + λθ)

adamw_implementation.py
Loading editor...

Learning Rate Warmup#

💡

Why Warmup?

At random initialization, gradients point in essentially random directions. Taking large steps amplifies this randomness, causing loss spikes or divergence.

Warmup lets the gradients "settle down" before taking big steps. It's like warming up an engine before driving fast.

lr_schedule.py
Loading editor...
1

Step 1: Warmup phase (0 → 1000 steps). Learning rate increases linearly from 0 to max_lr. This lets gradients stabilize before we take big steps.

Batch Size and Learning Rate Coupling#

⚠️

Linear Scaling Rule

When you increase batch size, gradient noise decreases. You can take bigger steps:

lr_new = lr_base × (batch_new / batch_base)

This works up to a point (~32K batch size for ImageNet), then breaks down.

batch_lr_scaling.py
Loading editor...

Scale Thought Experiment#

ScaleWhat BreaksMitigation
Small modelsNothing—Adam just worksDefault hyperparameters
Large batch trainingLR scaling breaks above 32K batchLARS, LAMB optimizers
Very deep transformersWarmup needs to be longer2000-4000 warmup steps
Long training runsAdam's memory overhead adds up8-bit Adam, Adafactor

Production Reality#

OpenAI (GPT-4 and Predecessors):

  • AdamW with β₁=0.9, β₂=0.95 (slightly lower β₂ for stability)
  • Cosine decay with warmup
  • Gradient clipping at 1.0

8-bit Adam (Memory Optimization):

  • Quantize optimizer states to 8-bit
  • 75% memory reduction for optimizer states
  • Minimal impact on training quality

Break It: Common Optimizer Failures#

❌

Common Mistakes

  1. Forgetting bias correction — Without the 1 - beta^t correction, early updates are biased toward zero. This causes slow starts or requires longer warmup to compensate.

  2. Applying weight decay inside the gradient — Standard L2 regularization gets scaled by Adam's 1/sqrt(v). Embedding parameters (large v) get almost no decay. Use AdamW, not Adam + L2.

  3. Scaling LR without adjusting warmup — When doubling batch size and LR, you also need to scale warmup steps. Otherwise the larger LR hits before gradients stabilize.

Checkpoint Questions#

Each question requires calculation or diagnosis, not just recall.

  1. A model has a parameter whose gradient has been [0.1, 0.1, 0.1] for 100 steps. Another parameter has gradient [1.0, -1.0, 1.0, -1.0, ...] alternating. Using Adam defaults (beta2=0.999), compute the approximate effective learning rate for each. Which parameter gets larger updates and why?
  2. You are training a 7B model with base config batch=256, lr=3e-4, warmup=2000 steps. You scale to batch=2048. Compute the new learning rate using the linear scaling rule. Should warmup steps change?
  3. A training run shows loss NaN at step 50 (before warmup ends). The learning rate at step 50 is 5e-6. Identify the most likely cause and the first two fixes to try.

Research Hooks#

Papers:

  1. "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, 2019) — The AdamW paper explaining why weight decay matters for transformers
  2. "On the Variance of the Adaptive Learning Rate and Beyond" (Liu et al., 2020) — RAdam adds warmup to Adam's variance estimate

Open Questions:

  • Can we automatically tune warmup length based on model architecture?
  • Is there an optimizer that combines Adam's adaptivity with SGD's generalization?

Next up: Backprop isn't "chain rule backwards"—it's graph traversal. Understanding the computation graph lets you catch bugs that waste $2M training runs.