
Building a PyTorch training loop is fairly straightforward, but getting everything in the right place and in the right order can feel surprisingly fragile. There are loads of moving parts and after the most basic errors are fixed, most of the other mistakes can be pretty hard to spot. Training runs will fail to converge, produce incorrect results, or consume excessive memory if lines are misplaced.
The sections below will go through each operation in sequence, explaining exactly how to write each section, and all the common mistakes to watch out for. Distributed training, FSDP, and multi-GPU setups are out of scope here, but we'll come back to that in a future essay. (The animation above was produced by running the loop on synthetic data and capturing the decision boundary at each epoch.)
The complete loop
Let's look, first of all, at the complete training loop. You don't need to understand or memorise it yet, just get a feel for the structure.
1import torch 2import torch.nn as nn 3from torch.utils.data import DataLoader, TensorDataset 4 5# --- data --- 6dataset = TensorDataset(X_train, y_train) 7loader = DataLoader(dataset, batch_size=64, shuffle=True) 8 9# --- model, loss, optimiser --- 10model = MLP(in_features=2, hidden=128, out_features=3) 11criterion = nn.CrossEntropyLoss() 12optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) 13scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=100) 14 15# --- loop --- 16for epoch in range(100): 17 model.train() 18 for X_batch, y_batch in loader: 19 optimiser.zero_grad() 20 logits = model(X_batch) 21 loss = criterion(logits, y_batch) 22 loss.backward() 23 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 24 optimiser.step() 25 scheduler.step() 26 27 model.eval() 28 with torch.no_grad(): 29 val_logits = model(X_val) 30 val_loss = criterion(val_logits, y_val)
Now let's go through each line and understand what it does, and how not to break it. We'll start with some of the common mistakes.
TL;DR Where the order really matters
Here are some of the most common failures, and how you can break the training loop by getting the placement a little bit wrong. The reason to memorise these is that none of them will raise an exception, over time you'll get a sense for what kind of errors to look for in your training runs, but for the first few times this crib sheet will help you out.
| Line | Wrong position | What breaks |
|---|---|---|
model.to(device) | After optimiser = ... | When a dtype conversion is combined (e.g. model.half()), nn.Module.to() allocates new nn.Parameter objects; the optimiser holds references to the discarded originals and applies updates to them instead. |
optimiser.zero_grad() | After loss.backward() | Gradients from multiple batches accumulate. Update uses their sum, not the current batch alone. |
clip_grad_norm_() | Before loss.backward() | .grad is empty. The call is a no-op. |
clip_grad_norm_() | After optimiser.step() | Clips gradients already applied. No effect. |
scheduler.step() | Inside batch loop | LR decays len(loader) times per epoch instead of once. |
Omit model.train() after model.eval() | — | Dropout disabled, BatchNorm frozen. The model trains in eval mode without error. |
Omit torch.no_grad() during validation | — | Autograd graph builds on every validation batch. Memory grows until OOM. |
Log loss instead of loss.item() | — | Pins the computation graph in memory for the duration of the logging call. |
Now let's go through each of these in detail.
The data
There are two parts to the data pipeline in PyTorch: the Dataset and the DataLoader. The Dataset is just a Python object that implements __len__ (how many elements are in the dataset) and __getitem__ (which, unsurprisingly, gets an item). It can be a simple wrapper around tensors, or it can load data from disk on demand.
The DataLoader wraps a dataset and produces batches.
Each pass through the full dataset is one epoch. With shuffle=True, examples are presented in a different order each epoch.
1dataset = TensorDataset(X_train, y_train) 2loader = DataLoader( 3 dataset, 4 batch_size=64, 5 shuffle=True, 6 num_workers=2, 7 pin_memory=True, 8 persistent_workers=True, 9)
TensorDataset pairs input and label tensors by index. Indexing with dataset[i] returns (X[i], y[i]). DataLoader calls __getitem__ repeatedly, collates the results into batches, and optionally hands work to background worker processes.
num_workers spawns separate processes that prefetch batches in parallel with GPU compute. The main process blocks on .next() only if a batch is not yet ready. Zero workers means the main process does all loading, which often bottlenecks GPU utilisation on data-heavy tasks. Two to four workers is practical, but the right number depends on CPU count and I/O speed.
pin_memory=True allocates batch tensors in pinned host memory. The GPU DMA engine can transfer directly from pinned memory without first copying through the kernel buffer, reducing host-to-device transfer time. It only helps when num_workers > 0 and you're transferring to CUDA.
persistent_workers=True keeps worker processes alive between epochs. Without it, workers are respawned at the start of each epoch, adding fork overhead that becomes measurable at large worker counts.
drop_last=True discards the final batch if it is smaller than batch_size. BatchNorm statistics computed from a batch of two or three samples are noisy so dropping the remainder avoids this. Small cost in terms of dropping the data, but it is often worth it for stability.
Smaller batches produce noisier gradient estimates, which acts as implicit regularisation. Larger batches use more GPU memory but allow more parallelism. One of the most important efficiencies is knowing that powers of two align with tensor core tile sizes (typically 16×16 or 8×16 depending on dtype), so making sure batch sizes and layer dimensions are set to multiples of 8 or 16 is a good idea.
.to(device) moves a tensor to the target device. For tensors, it is not in-place: it returns a new tensor and leaves the original unchanged. For example, X_batch.to('cuda') returns a new tensor on GPU; X_batch itself remains on CPU.
Reproducibility
Setting seeds before constructing the model and loader gives the same results on every run. This is essential in order to reproduce experiments, and make sure model behavior is deterministic. The main areas this impacts the model is in the data loader, and in the initialisation of the model weights.
1import random 2import numpy as np 3 4def set_seed(seed: int = 42): 5 torch.manual_seed(seed) 6 torch.cuda.manual_seed_all(seed) 7 np.random.seed(seed) 8 random.seed(seed) 9 torch.backends.cudnn.deterministic = True 10 torch.backends.cudnn.benchmark = False 11 12set_seed(42)
torch.manual_seed seeds the CPU generator. torch.cuda.manual_seed_all seeds every GPU. NumPy and Python's random are independent RNGs that PyTorch does not touch, be careful of this, you might need another random seed for them.
cudnn.deterministic = True forces cuDNN to use deterministic convolution algorithms. Some cuDNN kernels are non-deterministic by default for throughput. The deterministic alternatives are slightly slower, but practically it shouldn't matter much during development.
cudnn.benchmark = False must be paired with deterministic = True. When benchmark = True, cuDNN profiles several algorithms per input shape and picks the fastest, a process that itself varies between runs. Fixing it to False makes sure you always get the same results.
When num_workers > 0, each DataLoader worker has its own RNG state, seeded by the OS when it forks the processes. To make worker randomness reproducible you should pass a generator and a worker_init_fn:
1g = torch.Generator() 2g.manual_seed(42) 3 4loader = DataLoader( 5 dataset, 6 batch_size=64, 7 shuffle=True, 8 num_workers=2, 9 generator=g, 10 worker_init_fn=lambda worker_id: np.random.seed(42 + worker_id), 11)
The model
nn.Module provides parameter tracking, device movement, train/eval mode switching, and serialisation. Each instance needs an __init__ (to register all the submodules) and forward (to define the computation).
practice
ReLU Activation + BackwardEasyQ114
Implement a ReLU activation and its backward pass, in numpy, to see what nn.ReLU is actually doing.
practice
Custom Linear ModuleEasyQ328
Implement a custom nn.Linear as an nn.Module, registering weights and bias as parameters.
1class MLP(nn.Module): 2 def __init__(self, in_features, hidden, out_features): 3 super().__init__() 4 self.net = nn.Sequential( 5 nn.Linear(in_features, hidden), 6 nn.ReLU(), 7 nn.Linear(hidden, hidden), 8 nn.ReLU(), 9 nn.Linear(hidden, out_features), 10 ) 11 12 def forward(self, x): 13 return self.net(x) 14 15device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16model = MLP(in_features=2, hidden=128, out_features=3).to(device)
super().__init__() is required: it initialises the module registry. Assigning submodules or parameter tensors as attributes in __init__ registers them automatically. Unregistered plain attributes are excluded from parameter iteration, serialisation, and device movement.
.to(device) moves all registered parameters and buffers to the target device in-place. Call it before constructing the optimiser. For a device-only move, nn.Module.to() modifies each parameter's .data attribute in-place and the optimiser's references remain correct. When a dtype conversion is combined (e.g. .half().to(device)), nn.Module.to() allocates new nn.Parameter objects and replaces them in the module's internal registry; the optimiser, constructed before this, retains references to the originals and applies updates to them instead of the converted parameters.
◎ thinkWhy must the model be moved to the device before constructing the optimiser?
register_buffer is for tensors that should follow the module (move with .to(device), appear in state_dict) but are not trained parameters. BatchNorm's running_mean and running_var are buffers, as is the attention mask in a transformer.
model.parameters() returns all leaf tensors with requires_grad=True. model.named_parameters() returns the same with names, useful for excluding biases from weight decay or assigning different learning rates to different parts of the model.
Calling model(x) invokes __call__, not forward directly. This is a classic misconception for new users. That wrapper runs forward hooks, then the computation, then backward hooks. The distinction matters when using hooks for instrumentation or gradient modification.
torch.compile(model) (Only available in PyTorch 2.0+) traces the forward pass and emits optimised Triton/CUDA kernels via the backend. It fuses adjacent element-wise operations, reducing memory traffic. The first forward pass is slow (needs to run compilation) but subsequent ones are 10–30% faster on GPU, sometimes higher on inference-heavy workloads.
note
The PyTorch Parameter
There are two things we care about in the nn.Parameter class: the values stored in .data and the .grad attribute. The backward pass accumulates gradients additively into .grad; the optimiser reads .grad to compute the update and writes the result back to .data.
You can also use the parameters without gradients by setting requires_grad=False. This is useful for freezing layers, or for inference-only models, or non-trainable parameters.
Training mode
1
Sets every submodule to training mode. Two common layers change behaviour: Dropout applies random masking and rescales, and BatchNorm computes statistics from the current batch rather than its stored running estimates.
1model.train() # [1] 2for X_batch, y_batch in loader: 3 ...
model.train() sets a flag on the module and all sub-modules. Two common layers read it during their forward pass.
Dropout samples a Bernoulli mask with probability p of zeroing each element, then scales surviving activations by to preserve expected magnitude. In eval mode it acts as an identity function. The reason to invert dropout at train time is that it means you don't need to rescale at inference.
BatchNorm during training computes the mean and variance of the current batch, normalises against them, and updates its stored running_mean and running_var via exponential moving average. In eval mode it uses those stored estimates instead of the batch statistics.
LayerNorm, GELU, ReLU, and most other layers are unaffected by the training flag and behave identically in both modes. The flag usually only matters if your architecture contains Dropout or any BatchNorm variant.
Omitting model.train() after validation does not raise an exception. The model trains with dropout disabled and BatchNorm using frozen statistics, both of which alter the effective learning dynamics.
note
In eval mode the larger memory reduction comes from torch.no_grad(), which disables graph construction entirely: no grad_fn is attached to output tensors and no intermediate activations are stored for backward. Used together during validation, they roughly halve memory usage compared to a training-mode forward pass.
Clearing the gradient buffer
2
Resets every parameter's .grad before the next forward pass. PyTorch accumulates gradients additively. Without this call, gradients from the previous batch add to the current one.
1for X_batch, y_batch in loader: 2 optimiser.zero_grad() # [2] 3 ...
optimiser.zero_grad() resets every parameter's .grad attribute before the next forward pass. PyTorch accumulates gradients additively: each .backward() call adds to .grad, it does not replace it. This behaviour enables gradient accumulation: summing gradients over multiple micro-batches before stepping is equivalent to training with a proportionally larger batch.
1# gradient accumulation — effective batch size = batch_size × ACCUM_STEPS 2ACCUM_STEPS = 4 3for i, (X_batch, y_batch) in enumerate(loader): 4 logits = model(X_batch) 5 loss = criterion(logits, y_batch) / ACCUM_STEPS # scale to match mean reduction 6 loss.backward() # adds to .grad each time 7 if (i + 1) % ACCUM_STEPS == 0: 8 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 9 optimiser.step() 10 optimiser.zero_grad()
Accumulating gradients across multiple small batches and stepping once is mathematically equivalent to training with a batch four times larger, assuming the loss is mean-reduced. At scale, where a full batch does not fit in GPU memory, accumulation is the standard approach.
◎ thinkWhy does PyTorch accumulate gradients rather than replace them on each backward call?
zero_grad(set_to_none=True) sets .grad to None rather than filling it with zeros. It is slightly faster (skips the zero-fill pass) and uses less memory (None tensors are deallocated). This is the default in PyTorch 2.0 onwards. Be careful with custom backwards passes where you are reading .grad directly.
The forward pass
3
Runs forward() and builds the autograd computation graph. Every operation on a tensor with requires_grad=True is recorded: what the inputs were, what the output was, and how to compute the local vector-Jacobian product for backward.
1logits = model(X_batch) # [3]
During the forward pass, PyTorch builds a computation graph that records every operation on tensors with requires_grad=True. Each node in the graph represents an operation, storing the input tensors, the output tensor, and the function needed to compute gradients during the backward pass. This allows PyTorch to automatically compute gradients for all parameters with respect to a scalar loss.
Crucially this is a dynamic graph, it's constructed on-the-fly during the forward pass. This means that the graph can change from one forward pass to the next, allowing for more flexible model architectures, such as those with conditional branches or loops, and variable input sizes.
This graph is a directed acyclic graph (DAG) from the loss back to the leaf parameters. During the forward pass, activations from every layer are kept alive in memory for use in the backward pass. For a batch of size and a network with layers, the memory cost scales with in the naive case.
Gradient checkpointing (torch.utils.checkpoint.checkpoint) reduces this. Instead of storing all intermediate activations, it re-runs the forward pass of checkpointed segments during backward to recompute them on demand. Peak activation memory scales with the number of uncheckpointed segments rather than total depth, at the cost of roughly one additional forward pass per checkpointed segment.
1from torch.utils.checkpoint import checkpoint 2 3def forward(self, x): 4 x = checkpoint(self.block1, x, use_reentrant=False) 5 x = checkpoint(self.block2, x, use_reentrant=False) 6 return self.head(x)
Inside torch.no_grad(), no graph is constructed and no activations are stored. Memory usage drops roughly by half compared to a training-mode forward pass.
The loss
4
Computes the scalar loss. For CrossEntropyLoss, this is log-softmax followed by negative log-likelihood, averaged over the batch. The result is a scalar tensor with a grad_fn, still connected to the graph.
1loss = criterion(logits, y_batch) # [4]
practice
Cross-Entropy LossMediumQ18
Implement cross-entropy loss from scratch, in numpy, to see the log-sum-exp trick.
The Cross-Entropy loss is a standard choice for multi-class classification problems. It measures the difference between the predicted probability distribution (from the model's logits) and the true distribution (the one-hot encoded labels).
nn.CrossEntropyLoss is a composition of two operations: LogSoftmax followed by NLLLoss. The combined form is more numerically stable than computing them separately, because it avoids materialising the softmax probabilities and then taking their log.
The underlying computation uses the log-sum-exp trick to prevent overflow:
Subtracting before exponentiation keeps the values in a safe range. The loss for a single example is then:
and the batch loss is the mean over examples.
◎ thinkWhy does cross-entropy need the log-sum-exp trick rather than computing softmax directly?
Common arguments:
weight accepts a 1D tensor of per-class weights, applied to each sample's loss contribution. Use this for class imbalance i.e. upweight rare classes.
label_smoothing=0.1 distributes a fraction of the probability mass uniformly across all classes rather than concentrating it on the target. The effective target distribution becomes . It prevents overconfident pseudo-probabilities and is standard in modern image and language model training.
ignore_index=-100 masks positions with that label from the loss. Used in sequence modeling to exclude padding and masked tokens.
reduction='mean' divides by batch size. 'sum' does not. Switching between them shifts the effective loss scale and therefore the effective learning rate.
Logging loss with loss.item() extracts a Python float and detaches it from the graph. Logging the tensor directly pins the graph in memory for the duration of the logging call. If you're not careful, this can lead to runaway memory growth during training.
The backward pass
5
Traverses the computation graph from the scalar loss back to every leaf parameter. Populates .grad on each parameter via the chain rule. Does not modify weights.
backward() implements reverse-mode automatic differentiation (backpropagation). For a scalar output, a single backward pass computes gradients with respect to all parameters. Forward-mode differentiation requires separate passes, one per parameter. At the parameter counts typical of modern networks, reverse mode is the only practical option.
The backward pass walks the computation graph from the loss in reverse topological order. At each node, it applies the backward function (the vector-Jacobian product for that operation) and propagates the result to the node's inputs. At leaf parameters, the result accumulates into .grad.
As we said before, the .grad attribute accumulates additively. If .backward() is called without zeroing gradients first, the new gradients add to whatever was already in .grad. Gradient accumulation relies on this behaviour, and we use zero_grad() to reset it.
◎ thinkWhat does loss.backward() compute, and what does it not do?
retain_graph=True prevents the graph from being deallocated after backward. Normally the graph is freed immediately because it is assumed to be used once. You need retain_graph=True when calling backward multiple times through the same graph (for example, in a GAN where you backpropagate through a shared encoder for both the discriminator and generator losses).
create_graph=True allows differentiation through the backward pass itself. The backward computation becomes differentiable, enabling higher-order derivatives: Hessian-vector products, MAML-style meta-gradients, or second-order optimisers.
Gradient clipping
6
Computes the global L2 norm of all parameter gradients. If it exceeds max_norm, rescales all gradients proportionally. Called after backward(), before step().
1torch.nn.utils.clip_grad_norm_( # [6] 2 model.parameters(), max_norm=1.0 3)
clip_grad_norm_ computes the global norm across all parameters:
If , every gradient tensor is multiplied by . Relative direction is preserved; only the magnitude is bounded.
◎ thinkWhy does global norm clipping preserve gradient direction, while per-element clipping does not?
Gradient spikes are common in transformer training. Attention softmax can saturate, producing near-one-hot distributions, which propagates large gradient norms to earlier parameters via the chain rule.
Global norm clipping is standard practice: max_norm=1.0 is used in the original GPT-2 paper, most subsequent language model work, and many vision transformer papers. It is less commonly necessary for small MLPs on well-conditioned data.
practice
Gradient ClippingMediumQ338
Implement gradient clipping by global norm: compute the norm, scale if it exceeds the threshold.
clip_grad_value_ clips individual gradient components to rather than the global norm. It does not preserve gradient direction and is less commonly used.
Placing clipping before backward() has no effect (.grad is empty). Placing it after step() clips gradients that have already been applied to weights.
The weight update
7
Reads .grad on every parameter and applies the update rule. Moment estimates are updated. Weights change here and only here.
Parameter values change in step() and nowhere else. For plain SGD with momentum:
Adam maintains per-parameter estimates of the first moment (gradient mean) and second moment (gradient variance):
Both estimates are biased toward zero at initialisation (they start at zero, and early values of and underestimate the true moments). Bias correction accounts for this:
The weight update is then:
practice
Adam Optimizer StepMediumQ118
Implement an Adam optimiser step in numpy: bias correction and per-parameter adaptive rates.
practice
Adam Step (PyTorch)MediumQ335
Implement an Adam step using PyTorch: torch.optim.Adam under the hood.
practice
AdamW StepMediumQ36
Implement an AdamW step: decoupled weight decay applied directly to the weights, not through the gradient.
Default hyperparameters: , , . The per-parameter adaptive learning rate makes Adam robust to varying gradient scales across parameters, which matters in deep networks where early and late layers often have gradients of very different magnitudes.
AdamW decouples weight decay from the gradient update. Standard Adam with L2 regularisation adds a gradient term before the update, which then gets scaled by the adaptive step size. AdamW applies weight decay directly to the weights: . The two are not equivalent; AdamW gives more predictable effective regularisation and is now standard for transformer training.
◎ thinkWhat is the practical difference between Adam with L2 regularisation and AdamW?
fused=True (CUDA only): fuses the entire parameter update into a single CUDA kernel per parameter group, avoiding the Python loop and multiple kernel launches. It runs approximately 30–50% faster than the default at large parameter counts.
foreach=True: uses batched torch._foreach_* operations that process all parameters together in vectorised Python loops. It is intermediate between the default and fused in speed, and available on CPU.
The optimiser stores moment state for every parameter. For Adam on a 7B-parameter model, that is 7B float32 tensors for and 7B for , roughly 56GB of optimiser state at full precision. At that scale, training typically uses 8-bit optimisers (bitsandbytes) or shards state across devices (ZeRO).
Learning rate schedules
8
Updates the learning rate according to a schedule. Called once per epoch, after optimiser.step(), outside the batch loop.
The scheduler modifies the lr field of each parameter group in the optimiser. The most common mistake is calling it inside the batch loop:
1# wrong — lr decays len(loader) times per epoch instead of once 2for X_batch, y_batch in loader: 3 optimiser.step() 4 scheduler.step() 5 6# correct 7for X_batch, y_batch in loader: 8 optimiser.step() 9scheduler.step()
Cosine annealing decays the learning rate from eta_max to eta_min over T_max epochs:
In modern large model training, the standard schedule combines a linear warmup with cosine decay. A warmup phase (typically 1–5% of total steps) limits the update magnitude while parameters are far from their trained values. Cosine decay reduces the rate for the remaining steps.
1# linear warmup + cosine decay 2# (using HuggingFace's implementation as reference) 3from transformers import get_cosine_schedule_with_warmup 4 5scheduler = get_cosine_schedule_with_warmup( 6 optimiser, 7 num_warmup_steps=100, 8 num_training_steps=10_000, 9) 10scheduler.step() # called per step, not per epoch, with this scheduler
ReduceLROnPlateau is the exception to the no-argument rule. It takes a metric value and reduces the learning rate only when the metric has stopped improving for patience epochs. It must be called with the validation loss: scheduler.step(val_loss).
Validation
9
model.eval() changes layer behaviour. torch.no_grad() stops graph construction. They are independent operations; you need both for validation.
1model.eval() # [9] 2with torch.no_grad(): # [10] 3 val_logits = model(X_val) 4 val_loss = criterion(val_logits, y_val)
model.eval() and torch.no_grad() are independent operations.
model.eval() sets self.training = False on every module. BatchNorm switches from computing batch statistics to using its stored running_mean and running_var. Dropout switches from sampling Bernoulli masks to the identity function. No other standard layers are affected.
torch.no_grad() disables the construction of the autograd graph entirely. No grad_fn is attached to tensors produced inside the context; no intermediate activations are saved. This roughly halves memory usage compared to a training-mode forward pass and makes computation slightly faster (fewer bookkeeping operations per op).
The two are independent choices:
- Eval mode, graph enabled: valid for computing input gradients (saliency maps, adversarial examples).
- Train mode, no_grad: for forward-pass checks that don't require gradients.
- Both together: the standard validation pass.
◎ thinkIs model.eval() the same as torch.no_grad()?
torch.inference_mode() is a stricter form of no_grad(). Tensors created inside the context have is_inference() == True, which prevents them from participating in a backward pass even if they escape the context manager. It removes a few additional checks in the autograd engine and runs about 10% faster than no_grad().
Logging
Track training and validation metrics using loss.item(), not loss. Calling .item() extracts a Python float and detaches from the graph; holding a reference to the tensor keeps the full backward graph alive until the next call.
1for epoch in range(NUM_EPOCHS): 2 model.train() 3 running_loss = 0.0 4 for X_batch, y_batch in loader: 5 optimiser.zero_grad() 6 loss = criterion(model(X_batch), y_batch) 7 loss.backward() 8 optimiser.step() 9 running_loss += loss.item() 10 train_loss = running_loss / len(loader) 11 12 model.eval() 13 with torch.no_grad(): 14 val_logits = model(X_val) 15 val_loss = criterion(val_logits, y_val).item() 16 val_acc = (val_logits.argmax(1) == y_val).float().mean().item() 17 18 print(f'epoch {epoch:3d} train {train_loss:.4f} val {val_loss:.4f} acc {val_acc:.3f}')
Checkpointing
torch.save writes a checkpoint file; torch.load reads it back. Checkpoints allow a training run to survive crashes and preemption.
1# save 2torch.save({ 3 'epoch': epoch, 4 'model_state_dict': model.state_dict(), 5 'optimiser_state_dict': optimiser.state_dict(), 6 'scheduler_state_dict': scheduler.state_dict(), 7 'val_loss': val_loss, 8}, 'checkpoint.pt') 9 10# resume 11checkpoint = torch.load('checkpoint.pt', map_location=device) 12model.load_state_dict(checkpoint['model_state_dict']) 13optimiser.load_state_dict(checkpoint['optimiser_state_dict']) 14scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 15start_epoch = checkpoint['epoch'] + 1
state_dict() returns an ordered dictionary of parameter and buffer tensors. It does not include the class definition, so the model class must be defined in scope before calling load_state_dict.
The optimiser state must be saved alongside the model to resume training correctly. For Adam, the state includes per-parameter moment estimates and , which take several hundred steps to warm up from zero. Resuming without them is equivalent to restarting the optimiser cold, which typically produces a loss spike at the resume point.
map_location=device handles the common case where the checkpoint was saved on a different GPU than the one loading it.
The standard pattern saves only when validation loss improves, so the saved weights correspond to the best-generalising epoch rather than the final one:
1best_val_loss = float('inf') 2 3for epoch in range(NUM_EPOCHS): 4 # ... training loop ... 5 6 model.eval() 7 with torch.no_grad(): 8 val_loss = criterion(model(X_val), y_val).item() 9 10 if val_loss < best_val_loss: 11 best_val_loss = val_loss 12 torch.save(model.state_dict(), 'best_model.pt')
To restore the best model after training: model.load_state_dict(torch.load('best_model.pt', map_location=device)).
GPU efficiency
While we're working through the nitty-gritty of the training loop, it's worth spending a few minutes thinking about some of the basic GPU optimisation techniques. None of these exactly count as "melting the hardware" but they're almost all entirely free speed.
Device placement
Put the model and data on the same GPU, minimises the data transfer overhead.
1device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 2 3model = MLP(in_features=2, hidden=128, out_features=3).to(device) 4# construct optimiser AFTER moving model — it captures parameter references 5optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) 6 7# per-batch: move data to device 8for X_batch, y_batch in loader: 9 X_batch = X_batch.to(device, non_blocking=True) 10 y_batch = y_batch.to(device, non_blocking=True)
non_blocking=True initiates the host-to-device transfer asynchronously and returns immediately. The GPU can begin work from a previous batch while the new batch is still transferring. The overlap only helps when pin_memory=True in the DataLoader; unpinned memory cannot be transferred asynchronously.
Mixed precision
Modern GPUs have dedicated tensor cores that execute float16 (and bfloat16) matrix multiplications 4–8× faster than float32 on the same hardware. Mixed precision training keeps weights in float32 but runs the forward and backward passes in float16.
This makes the computation more efficient, but float16 has a much smaller dynamic range than float32, so small gradient values underflow to zero, producing incorrect updates. The compromise is to maintain the weights in float32, where accumulated rounding errors remain within the float32 dynamic range.
1scaler = torch.amp.GradScaler('cuda') 2 3for X_batch, y_batch in loader: 4 optimiser.zero_grad() 5 6 with torch.amp.autocast('cuda', dtype=torch.float16): 7 logits = model(X_batch) 8 loss = criterion(logits, y_batch) 9 10 scaler.scale(loss).backward() # backward in float16 11 scaler.unscale_(optimiser) # unscale before clipping 12 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 13 scaler.step(optimiser) # step only if no inf/nan 14 scaler.update() # adjust scale factor
torch.amp.autocast and torch.amp.GradScaler are the current API (PyTorch 2.0+). The older from torch.cuda.amp import autocast, GradScaler still works but is deprecated.
Why GradScaler: float16 has a much smaller dynamic range than float32. Small gradient values underflow to zero, producing incorrect updates. GradScaler multiplies the loss by a large scale factor (typically ) before backward, keeping gradient magnitudes in the float16 range. Before the optimiser step, gradients are divided back to their true scale. If an overflow is detected (inf or nan in any gradient), the step is skipped and the scale factor is reduced.
bfloat16 (dtype=torch.bfloat16) has the same exponent range as float32 but fewer mantissa bits, so it does not underflow and does not need GradScaler. It is the preferred dtype on recent hardware (A100, H100, TPUs).
◎ thinkWhy does bfloat16 not need GradScaler, but float16 does?
1with torch.amp.autocast('cuda', dtype=torch.bfloat16): 2 logits = model(X_batch) 3 loss = criterion(logits, y_batch) 4 5loss.backward() # no scaler needed 6optimiser.step()
DataLoader pipeline
For GPU training, the bottleneck is often not the GPU but the data pipeline feeding it. The DataLoader with num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2 keeps the GPU fed without stalling. prefetch_factor=2 (the default) means each worker pre-fetches two batches beyond what has been consumed.
torch.backends.cudnn.benchmark = True runs a short profiling pass on the first batch to determine the fastest convolution algorithm for your specific input shape. Subsequent batches use that cached choice. Do not use this if input shapes vary between batches; the profiling overhead is incurred again for each new shape.
Compilation
1model = torch.compile(model)
Available in PyTorch 2.0 and above, compile traces the forward pass and emits optimised Triton kernels via the Inductor backend. The main benefit is operation fusion: adjacent element-wise operations (activation + dropout + residual add, for example) are merged into a single kernel, reducing memory reads and writes. Typical speedups are 10–30% on GPU training, higher for inference workloads and architectures with many small ops.
mode='max-autotune' runs additional profiling to choose the best kernel for each operation. Compilation takes longer, but the resulting model is faster, and the cost amortises over long training runs.
The first forward pass triggers compilation and is slow. torch._dynamo.reset() clears the cache if you need to recompile (e.g., after changing the model structure).
The efficient loop
1device = torch.device('cuda') 2model = torch.compile(MLP(...).to(device)) 3 4torch.backends.cudnn.benchmark = True 5 6for epoch in range(NUM_EPOCHS): 7 model.train() 8 for X_batch, y_batch in loader: 9 X_batch = X_batch.to(device, non_blocking=True) 10 y_batch = y_batch.to(device, non_blocking=True) 11 12 optimiser.zero_grad(set_to_none=True) 13 14 with torch.amp.autocast('cuda', dtype=torch.bfloat16): 15 logits = model(X_batch) 16 loss = criterion(logits, y_batch) 17 18 loss.backward() # no scaler: bf16 doesn't underflow 19 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 20 optimiser.step() 21 22 scheduler.step() 23 24 model.eval() 25 with torch.inference_mode(): 26 val_logits = model(X_val.to(device)) 27 val_loss = criterion(val_logits, y_val.to(device))
practice
Train a ClassifierHardQ343
Put it all together: train a classifier on a synthetic dataset and hit a target accuracy.
· · ·
The complete annotated script
The same loop, with every line referenced.
1
Raw input tensor, shape (N, features). For the spiral dataset, N=2000, features=2.
2
Integer class indices, shape (N,). Not one-hot: CrossEntropyLoss expects indices directly.
3
TensorDataset pairs inputs and labels by index. __getitem__ returns (X[i], y[i]).
4
DataLoader handles batching, shuffling each epoch, and optional parallel prefetching via num_workers.
5
nn.Module subclass. Register submodules in __init__; define the computation in forward.
6
.to(device) moves all parameters and buffers. Do this before constructing the optimiser.
7
CrossEntropyLoss = LogSoftmax + NLLLoss. Pass raw logits, not softmax outputs.
8
Adam maintains per-parameter moment estimates. model.parameters() supplies the tensors to optimise.
9
Cosine annealing scheduler. scheduler.step() adjusts lr in the optimiser's param groups.
10
model.train(): dropout masks active, batchnorm uses batch statistics.
11
zero_grad(): clears .grad. PyTorch accumulates by default; without this, gradients compound across batches.
12
Forward pass: builds the autograd graph. Every op records inputs and its backward function.
13
Loss: scalar tensor, still in the graph. grad_fn is the entry point for backward.
14
backward(): reverse-mode AD. Populates .grad on every leaf parameter. Weights unchanged.
15
Gradient clipping: global norm rescale. Prevents spikes from destabilising training.
16
optimiser.step(): reads .grad, applies Adam update, updates moment estimates. Weights change.
17
scheduler.step(): adjusts lr. Once per epoch, after optimiser.step(), outside the batch loop.
18
model.eval(): dropout pass-through, batchnorm uses running statistics.
19
torch.no_grad() (or inference_mode()): no graph construction. Faster, lower memory.
1import torch 2import torch.nn as nn 3from torch.utils.data import DataLoader, TensorDataset 4 5# ── data ────────────────────────────────────────────────────────────────────── 6X_train = torch.randn(2000, 2) # [1] 7y_train = make_labels(X_train) # [2] 8 9dataset = TensorDataset(X_train, y_train) # [3] 10loader = DataLoader(dataset, batch_size=64, shuffle=True, # [4] 11 num_workers=2, pin_memory=True) 12 13# ── model ────────────────────────────────────────────────────────────────────── 14class MLP(nn.Module): # [5] 15 def __init__(self, in_features, hidden, out_features): 16 super().__init__() 17 self.net = nn.Sequential( 18 nn.Linear(in_features, hidden), nn.ReLU(), 19 nn.Linear(hidden, hidden), nn.ReLU(), 20 nn.Linear(hidden, out_features), 21 ) 22 def forward(self, x): 23 return self.net(x) 24 25model = MLP(in_features=2, hidden=128, out_features=3).to(device) # [6] 26 27# ── loss, optimiser, scheduler ──────────────────────────────────────────────── 28criterion = nn.CrossEntropyLoss() # [7] 29optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) # [8] 30scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( # [9] 31 optimiser, T_max=100 32) 33 34# ── training loop ───────────────────────────────────────────────────────────── 35for epoch in range(100): 36 37 model.train() # [10] 38 39 for X_batch, y_batch in loader: 40 41 optimiser.zero_grad() # [11] 42 43 logits = model(X_batch) # [12] 44 45 loss = criterion(logits, y_batch) # [13] 46 47 loss.backward() # [14] 48 49 torch.nn.utils.clip_grad_norm_( # [15] 50 model.parameters(), max_norm=1.0 51 ) 52 53 optimiser.step() # [16] 54 55 scheduler.step() # [17] 56 57 # ── validation ──────────────────────────────────────────────────────────── 58 model.eval() # [18] 59 with torch.no_grad(): # [19] 60 val_logits = model(X_val) 61 val_loss = criterion(val_logits, y_val) 62 val_acc = (val_logits.argmax(1) == y_val).float().mean()