Skills Development PyTorch Training Recipes

PyTorch Training Recipes

v20260331
ml-training-recipes
Battle-tested PyTorch training recipes covering loops, optimizer choices, learning-rate schedules, mixed precision, debugging, and experimentation guidance for LLMs, vision, diffusion, biomedical, and scientific models.
Get Skill
472 downloads
Overview

ML Training Recipes

Battle-tested patterns for PyTorch training across domains. Drawn from production codebases (Karpathy's autoresearch/nanochat, torchvision, HuggingFace) and modern training practice.

Reference files (read when needed)

  • references/architecture.md — Transformer/LLM architecture code patterns, weight init
  • references/optimizers.md — Muon, AdamW hybrid, per-group LR, compiled optimizer steps
  • references/domain-specific.md — Vision, diffusion, contrastive, distributed, checkpointing, data loading
  • references/scaling-and-selection.md — Scaling laws, compute budget tables, decision trees, DGX Spark
  • references/biomedical.md — Drug discovery, protein models, medical imaging, genomics, clinical NLP
  • references/experiment-loop.md — Autonomous experiment loop (autoresearch keep/discard/revert)

Architecture Selection

Pick the right model by data type and data scale:

Data Type < 10K samples 10K-100K > 100K
Images Pretrained CNN + fine-tune Fine-tune ViT or CNN ViT from scratch
Text (gen) Few-shot prompting Fine-tune GPT/LLaMA (LoRA) Pretrain from scratch
Tabular XGBoost/LightGBM Still XGBoost Neural viable
Audio Pretrained Whisper Fine-tune AST Train from scratch
Molecules Pretrained GNN Fine-tune molecular LM Train GNN from scratch
Proteins ESM-2 embeddings + head Fine-tune ESM-2 Train protein LM
Medical img Pretrained CNN nnU-Net (auto-config) Swin-UNETR / MedSAM

Key principle: architecture matters less than training recipe at equal compute. A well-tuned ResNet beats a poorly-tuned ViT (ref: "ResNet Strikes Back", Wightman 2021).

For biomedical domains, see references/biomedical.md. For sequence model selection and compute planning, see references/scaling-and-selection.md.


Scaling Laws

Chinchilla rule (Hoffmann et al., 2022)

Compute-optimal training: ~20 tokens per parameter.

Model Size Compute-Optimal Inference-Optimal (100×)
125M 2.5B tokens 12.5B tokens
1B 20B tokens 100B tokens
7B 140B tokens 700B tokens

FLOPs ≈ 6 × N × D (N=params, D=tokens). Data repetition limit: ~4 epochs before diminishing returns.


Training Loop

import gc, time, torch

torch.manual_seed(42)
torch.set_float32_matmul_precision("high")  # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0

while not done:
    t0 = time.time()
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        (loss / grad_accum_steps).backward()
        x, y = next(train_loader)

    update_lr(optimizer, progress)
    optimizer.step()
    model.zero_grad(set_to_none=True)  # frees memory vs zeroing

    if loss.item() > 100:  # fast-fail on divergence
        print("FAIL: loss exploded"); exit(1)

    torch.cuda.synchronize()
    if step == 0:
        gc.collect(); gc.freeze(); gc.disable()  # avoid ~500ms GC stalls
    step += 1

Key principles

  • Gradient clipping: clip_grad_norm_(params, 1.0) — near-universal for Transformers. Exception: Muon optimizer normalizes updates via orthogonalization, so clipping is optional.
  • Tensor Core alignment: batch size, hidden dims should be multiples of 8 (bf16) or 64 (A100).
  • Time-based budgets make experiments comparable across hardware.
  • cudnn.benchmark = True for fixed-size vision inputs.

Optimizer Configuration

Modern LLM training uses different optimizers per parameter group:

Parameter Type Optimizer LR (base) Weight Decay
2D weight matrices Muon 0.04 0.2
Token embeddings AdamW 0.6 × scale 0.0
Unembedding (lm_head) AdamW 0.004 × scale 0.0
Per-layer scalars AdamW 0.005 × scale 0.0

LR scaling by dimension: lr * (d_model / 768)^(-0.5) — keeps dynamics stable across sizes.

Rules of thumb

  • Embeddings need higher LR (sparse updates). Never weight-decay embeddings.
  • Weight decay scheduling: linearly decay WD to 0 over training.
  • AdamW defaults: β1=0.9, β2=0.95, eps=1e-10 (not default 1e-8 — prevents stale updates in bf16).

For Muon details (polar express orthogonalization, NorMuon), see references/optimizers.md.


Learning Rate Scheduling

Time-based (autoresearch style)

def get_lr_multiplier(progress):  # progress = elapsed_time / time_budget
    if progress < warmup_ratio:
        return progress / warmup_ratio
    elif progress < 1.0 - warmdown_ratio:
        return 1.0
    else:
        cooldown = (1.0 - progress) / warmdown_ratio
        return cooldown + (1 - cooldown) * final_lr_frac

Cosine decay

def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

WSD (Warmup-Stable-Decay): gaining traction — easier to resume training mid-run.

Guidance

  • Warmup: 1-5% of training. Zero warmup valid with Muon (autoresearch uses WARMUP_RATIO=0.0).
  • Warmdown: 30-50% of training in LR decay. Matters more than warmup for final quality.
  • Final LR: 0 or ~10% of peak. Zero is simpler.

Mixed Precision & Compilation

import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"  # before torch import

import torch
torch.set_float32_matmul_precision("high")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
model = torch.compile(model, dynamic=False)
  • bf16 (Ampere+): same exponent as fp32, no loss scaling needed. Preferred over fp16.
  • fp16: needs GradScaler. Use only on V100 or older.
  • dynamic=False enables max optimization. Add fullgraph=True if no graph breaks.
  • First steps are slow (JIT) — exclude from timing.

Memory & Performance

Meta device init (large models)

with torch.device("meta"):
    model = GPT(config)          # zero memory
model.to_empty(device="cuda")
model.init_weights()

MFU (Model FLOPs Utilization)

achieved_flops = model_flops_per_token * batch_tokens / step_time
mfu = achieved_flops / gpu_peak_flops
# H100 SXM: 989.5 TFLOPS | A100: 312 | RTX 4090: 165

Good targets: >30% decent, >40% good, >50% excellent (single-GPU).

OOM solutions (in order)

  1. Reduce DEVICE_BATCH_SIZE, increase grad_accum_steps
  2. PYTORCH_ALLOC_CONF=expandable_segments:True
  3. model.zero_grad(set_to_none=True)
  4. Meta device init → to_empty
  5. Activation checkpointing: torch.utils.checkpoint.checkpoint()
  6. 8-bit optimizer (bitsandbytes): ~30% savings on optimizer states

Hyperparameter Search

Priority order (tune first → last)

  1. Learning rate — most impactful. Always tune first.
  2. Batch size — largest that fits. Speed knob, not quality knob.
  3. Weight decay — 0.01-0.1 for AdamW.
  4. Warmup steps — 1-5% of training.

The 2025 default recipe

Setting Value
Optimizer AdamW (β1=0.9, β2=0.95, eps=1e-10)
Weight decay 0.1
LR schedule Cosine decay or WSD
Peak LR 3e-4 (scale down for larger models)
Precision bf16
Grad clipping max_norm=1.0
Normalization RMSNorm (pre-norm)
Activation SwiGLU
Position encoding RoPE
Attention Flash Attention, optionally GQA

Debugging Checklist

Karpathy's recipe (still canonical)

  1. Become one with the data — visualize, check distributions, verify labels
  2. Get end-to-end running first — verify on a trivial case
  3. Overfit one batch — if you can't, you have a bug
  4. Then regularize — add regularization only after overfitting works
  5. Tune hyperparameters — start with known defaults

Loss exploding / NaN

  1. Reduce LR (3-10× smaller)
  2. Add gradient clipping: clip_grad_norm_(params, 1.0)
  3. Check for inf/nan in inputs
  4. Add logit soft capping: softcap * tanh(logits / softcap)
  5. Add QK-norm in attention
  6. Verify weight init (zero-init output projections?)
  7. Check loss reduction with gradient accumulation (loss / grad_accum_steps)

Slow training / Low MFU

  1. Verify torch.compile is active
  2. Check torch.set_float32_matmul_precision("high")
  3. Pin memory + non_blocking transfers
  4. Profile with torch.profiler
  5. GC stalls? gc.freeze(); gc.disable()
  6. Tensor Core alignment: dims multiples of 8/64

Loss plateau / Slow convergence

  1. LR too low — try 2-5× larger
  2. Warmup too long
  3. Weight decay too high
  4. Verify LR schedule is actually applied (print each step)
  5. Model too small for task

Silent failures

  1. Data leakage between train/val
  2. Wrong preprocessing at inference — augmentation mismatch
  3. Label errors — use cleanlab to detect
  4. Shuffling bugs — correlated batches
  5. Tokenizer mismatch with pretrained model

What to monitor

  • Gradient norms — spike precedes loss spike
  • Per-layer activation stats — reveals exploding/vanishing
  • Dead neurons — >50% zero ReLU = dying ReLU problem
  • Learning rate — verify schedule applied (common silent bug)

Experiment Management

Track experiments in TSV for easy comparison:

commit  val_bpb  memory_gb  status   description
a1b2c3d 0.9979   44.0       keep     baseline
b2c3d4e 0.9932   44.2       keep     increase matrix LR to 0.04
c3d4e5f 1.0050   44.0       discard  switch to GeLU (worse)

Simplicity criterion: all else equal, simpler is better. Removing something and getting equal results is a great outcome. For systematic agent-driven experimentation, see references/experiment-loop.md.

Evaluation metrics by domain

Domain Primary Metric Notes
LLM BPB (bits per byte) Vocab-size-independent
Classification Accuracy / F1 Macro-F1 for imbalanced
Segmentation mIoU / Dice Per-class IoU reveals weak spots
Generation FID Needs >10k samples
Regression RMSE / MAE Log-transform skewed targets
Info
Category Development
Name ml-training-recipes
Version v20260331
Size 39.07KB
Updated At 2026-04-02
Language