Table of Contents
Open Table of Contents
- Back to Backprop
- From Derivatives to Gradients and Jacobians
- The Mechanics of Backpropagation
- Backprop by hand
- Softmax, LogSoftmax and LogSumExp
- Building a Simplified Tensor and Autograd Engine (aka “Micrograd”)
- Challenges in Practice
- Modern Tweaks to Backpropagation
- Why Revisit Backpropagation?
- Conclusion
Back to Backprop
In the rapidly advancing field of deep learning, it’s easy to get swept up in the excitement of state-of-the-art architectures and massive datasets. Yet, beneath the surface of every breakthrough is the workhorse of gradient-based optimization: backpropagation. For seasoned practitioners, revisiting this foundational algorithm isn’t just a nostalgic exercise—it’s an opportunity to sharpen our intuition and improve our practice.
This post explores backpropagation from a practitioner’s lens. We’ll begin with a concise refresher, implementing the forward and backward passes for foundational layer types. We’ll demonstrate a full training loop for a simple neural network, using our own custom implementation and mini-autograd engine. Along the way, we’ll dive into real-world challenges, explore modern tweaks. Finally, we’ll culminate in building a simplified PyTorch-like tensor and autograd engine.
Here are some quick links to some awesome resources on backpropagation:
- Derivatives, Backpropagation, and Vectorization - Justin Johnson
- cs224n Lecture 3: Neural net learning: Gradients by hand (matrix calculus) and algorithmically (the backpropagation algorithm)
- cs231n BackProp Notes- Andrej Karpathy
- Computing Neural Network Gradients - Kevin Clark
From Derivatives to Gradients and Jacobians
For completelness sake, I think it is worth just restating some of the basic definitions and concepts that underpin backpropagation.
Derivatives
Derivatives are a way to measure change in the scalar case:
Derivative Definition
Given a function , the derivative of at a point is defined as:
Put differently, they tell us how much a function changes as the input changes by a small amount : .
The Chain Rule
The chain rule tells us how to compute the derivative of a composition of functions. Suppose we have functions and In other words and . Then, the chain rule states:
In modern neural networks, it turns out that backpropagation is essentially the chain rule applied to a composition of vector/tensor-valued functions.
Gradients: Vector in -> Scalar out
Gradients are the generalization of derivatives to the multivariate case, where we have a scalar-valued function of multiple variables. The gradient of a function is a vector of partial derivatives, one for each input variable.
Gradient Definition
Given a function , the gradient of at a point is defined as:
It can also be viewed as a vector of partial derivatives:
When we consider the relationship
it is important to note that are all vectors, now. is computed by taking the dot product, producing a scalar, which matches the output signature of the gradient.
Jacobians: Vector in -> Vector out
Now we consider a function . The Jacobian of at a point is an matrix of partial derivatives defined such that the entry of the Jacobian is . Written out in other forms:
Ways to remember the formula:
- Outputs are rows (M) (outputs vary down a column)
- Inputs are columns (N) (inputs vary across a row)
- Each row is the gradient of the corresponding output with respect to all the inputs.
- The Jacobian is the transpose of the gradient of the function (i.e., when so , the Jacobian is the transpose of the gradient, .
Generalized Jacobians: Tensor in -> Tensor out
From scalars to vectors (1-D array) to matrices(2-D array), we can generalize the concept of Jacobians to higher-order tensors, which we can view as -dimensional arrays.
Commonly, in neural networks, it is very common to think of tensor operations in terms of their “shapes”, which are the dimensions of the tensor, i.e., the tuple returned by tensor.shape
in numpy/PyTorch. You will frequently see code where comments are tracking the shape of the tensor, e.g., # [B, N, D] x [B, D, M] -> [B, N, M]
. This is a shorthand and has even been formalized in the Einstein Notation used in PyTorch, with einsum
and the python package einops
.
` to indicate a 2-D tensor with rows and columns.
. or example, a tensor of shape is a 2-D tensor with rows and columns.
Now, suppose we had a function . The Jacobian of at a point is a tensor of shape . The generalized Jacobian is another tensor with shape:
The Mechanics of Backpropagation
At its core, backpropagation efficiently computes gradients for training neural networks by leveraging the chain rule. Here’s how it works:
- Forward Pass: Inputs flow through the network, producing intermediate activations and final outputs.
- Backward Pass: Gradients are calculated layer by layer in reverse, propagating partial derivatives back to earlier layers.
Suppose we have a function as a single node in a computational graph. The output is passed through additional functions (nodes in the graph) before producing a final scalar loss . The chain rule allows us to compute the gradient of with respect to and by multiplying the gradients of each intermediate function:
where represents the chain of gradients from the loss to the output , known as the upstream gradient. The goal is to compute the downstream gradients:
[downstream gradient] = [upstream gradient] * [local gradient]
Backprop by hand
Let’s walk through a simple example to illustrate backpropagation. We’ll start use a 1-layer neural network. Let’s assume vectorized/batched computation starting with input , where is the number of samples and is the input dimension.
Linear Layer:
A linear layer consists of weights , bias .
- Forward Pass: with .
- Gradient Identities:
- Backward Pass: , , .
Derivation: Linear Layer Gradients
We have .
Shapes:
For simplicity, let’s consider the first row of and or just (we’ll drop the subscript ). Then we have and .
Note: counter to normal convention, we use to index the columns of so that we can use to index into the output .
What is ?
Notice that , or . In other words, 1 if and 0 otherwise. Thus,
By the same logic, we can write this as
For the general case, the same pattern holds. See Computing the Jacobian of a Matrix Product for more details.
This is a bit more complicated because the full Jacobian is a 3D tensor with shape (or for the general case). To simplify the notation, we’ll consider the gradient of with respect to a single element . This is just a vector!
We have
and
Note that if and and 0 otherwise. This is just . So when , the derivative is 0. Otherwise, the only nonzero element of the sum is when , so we just get . Thus, .
Finally, let’s compute the full gradient. Assume we have some upstream gradient . Then, we have
From above, we know that , so
where is the th element of the upstream gradient , which in this case is technically matrix. Similarly, we have been considering as a matrix, so inspecting the dimensions, we see that the full gradient is a matrix given by
because a matrix times a matrix gives a matrix, as expected.
In code:
class Linear:
def __init__(self, in_features: int, out_features: int):
self.weight = torch.empty(in_features, out_features)
self.bias = torch.zeros(out_features)
self.x: torch.Tensor | None = None
init_uniform_(self.weight)
init_uniform_(self.bias)
def parameters(self):
return [self.weight, self.bias]
def forward(self, x: torch.Tensor):
"""Computes the forward pass for a linear layer.
Args:
x (Tensor): Input data, of shape (N, in_features) where N is the batch size and
in_features is the number of input features.
"""
self.x = x
return x @ self.weight + self.bias # @ is torch.matmul
def backward(self, dout: torch.Tensor):
"""Computes the backward pass for a linear layer.
Args:
dout (Tensor): Upstream derivative, of shape (N, out_features).
Notes:
------------
z = xW + b
dz/dx = W.T
dz/dW = x.T
dz/db = I_{out}
chain rule
dL/dx = dout * dz/dx
= dout * W.T [N, out] x [out, in]
dL/dW = dout * dz/dW (need to swap to make shapes work out)
= x.T * dout [in, N] x [N, out]
dL/db = dout * dz/db
= dout * I_{out} [N, out] x [out, out]
"""
if self.x is None:
raise ValueError("No cache found. Run forward pass first.")
dx = dout.mm(self.weight.T) # [N, in_features]
self.weight.grad = self.cache.T.mm(dout) # [in_features, out_featues]
self.bias.grad = dout.sum(dim=0)
return dx
ReLU Activation:
To start with, recall that ReLU(x) = max(x, 0). This means that
which is also .
def relu(x: torch.Tensor) -> torch.Tensor:
"""Computes max(0, x) element-wise."""
return x.clamp(min=0)
class ReLU:
def parameters(self):
return []
def forward(self, x: torch.Tensor):
"""Computes max(0, x) element-wise."""
self.x = x
return relu(x)
def backward(self, dout: torch.Tensor):
"""Computes the backward pass for a layer of rectified linear units (ReLUs).
ReLU(x) = max(0, x)
ReLU'(x) = 1 if x > 0 else 0
out = ReLU(x)
dL/dx = dout * ReLU'(x)
= dout * 1[x > 0]
= dout[x < 0] = 0
= dout * 1[x > 0]
Args:
dout (Tensor): Upstream derivative, of shape (N, M).
"""
return dout * (self.x > 0).float()
__call__ = forward
Tanh Activation:
To start with, recall that . This means that the derivative is:
class TanhLayer:
def forward(self, x):
self.output = torch.tanh(x)
return self.output
def backward(self, dout: torch.Tensor):
return dout * (1 - self.output**2)
def parameters(self):
return []
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Softmax, LogSoftmax and LogSumExp
These functions commonly arise in the context of neural networks, especially in the output layer. They occur when dealing with categorical and multinomial probability distributions, converting raw logits into a probably distrubution over classes, as well as in attention mechanisms in transformers These functions are all closely related, involving sums of exponentials, and have nice relationships with each other.
Definitions
Given a vector , we define the following functions:
So we see that the softmax and logsoftmax are both functions that take in a vector and return a vector, while the logsumexp is a scalar valued function. The logsoftmax is just the log of the softmax, which also happens to be the input minus its logsumexp.
And in code:
def softmax(z):
"""Compuze softmax values for each sets of scores in z.
Args:
z (Tensor): input tensor
Notes:
The definition of the softmax:
softmax(z) = exp(z) / sum(exp(z))
can be numerical unstable, so we use the following definition:
softmax(z) = exp(z) / sum(exp(z - max(z)))
which multiplies the numerator and denominator by a constant factor.
It is usually chosen to be -max(z) producing negative to zero range
to avoid numerical overflow.
"""
e_z = torch.exp(z - torch.max(z))
return e_z / e_z.sum()
def logsumexp(s: torch.Tensor, dim=-1):
"""Numerically stable log(sum(exp(s)))"""
max_val = s.max(dim=dim, keepdim=True)[0]
return torch.log(torch.sum(torch.exp(s - max_val), dim=dim, keepdim=True)) + max_val
def log_softmax(z: torch.Tensor, dim=-1):
"""Return logprobs"""
return z - logsumexp(z, dim=dim)
class SoftmaxLayer:
def forward(self, x):
# Save input for backward pass
self.input = x
# For numerical stability, subtract max value before exponential
# Shape: (batch_size, num_classes)
shifted_x = x - torch.max(x, dim=1, keepdim=True)[0]
# Compute exponentials
# Shape: (batch_size, num_classes)
exp_x = torch.exp(shifted_x)
# Compute sum of exponentials for normalization
# Shape: (batch_size, 1)
sum_exp = torch.sum(exp_x, dim=1, keepdim=True)
# Compute softmax output
# Shape: (batch_size, num_classes)
self.output = exp_x / sum_exp
return self.output
def backward(self, grad_output):
# grad_output shape: (batch_size, num_classes)
# output shape: (batch_size, num_classes)
# The Jacobian of softmax for each sample is:
# J[i,j] = output[i] * (δ[i,j] - output[j])
# where δ[i,j] is 1 if i=j and 0 otherwise
# Compute using the compact form:
# grad = output * (grad_output - (output · grad_output))
# Shape: (batch_size, 1)
sum_grad = torch.sum(grad_output * self.output, dim=1, keepdim=True) # s · ∇_s L
# Shape: (batch_size, num_classes)
return self.output * (grad_output - sum_grad) # s ⊙ (∇_s L - (s · ∇_s L))
def __call__(self, x):
return self.forward(x)
Sigmoid:
While the softmax function is used for multi-class classification, the sigmoid function is used for binary classification. The primary difference is that the softmax enforces that elements of particular dimension sum to 1 (producing a probability distribution), while the sigmoid function squashes the output to the range , making it suitable for binary classification.
Definition
The sigmoid function, also known as the logistic function, maps any real number to a value between 0 and 1. It’s particularly useful for binary classification tasks since its output can be interpreted as a probability. The sigmoid function is defined as:
Its derivative has a particularly nice form:
Which makes it efficient to compute during backpropagation. And in code:
def sigmoid(x: torch.Tensor) -> torch.Tensor:
return 1 / (1 + torch.exp(-x))
class SigmoidLayer:
def forward(self, x: torch.Tensor):
self.out = sigmoid(x)
return self.out
def backward(self, dout: torch.Tensor):
"""
sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
"""
return dout * self.out * (1 - self.out)
Typically, in frameworks like PyTorch, models output logits
(unnormalized scores) and it usually the role of the crtiterion (e.g. CrossEntropyLoss
) to implicitly convert the logits to probabilities using the softmax function.
Losses:
Cross Entropy:
Gradient of cross entropy loss w.r.t. logits (i.e., , what is ):
class CrossEntropyLoss:
"""Computes the forward pass for the cross-entropy loss.
CE(s, y) = -SUM_i y_i * log(softmax(s)_i)
torch.mean(-torch.sum(targets * torch.log(torch.softmax(inputs, dim=1)), dim=1)))
Args:
x (Tensor): Input data, of shape (N, C) where N is the batch size and
C is the number of classes.
y (Tensor): Ground truth labels, of shape (N,).
"""
def forward(self, input: torch.Tensor, target: torch.Tensor):
self.input = input
self.target = target
indices = torch.arange(target.shape[0])
return -log_softmax(input)[indices, target].mean()
def backward(self):
"""Compute the backward pass for the cross-entropy loss.
dL(x, y)/dx = softmax(x) - y
where x is the unnormalized score, y is the one-hot encoded target label.
dL(p, y)/dx = p - y (p is predicted prob)
Returns:
dx (Tensor): Gradient of the loss with respect to the input x.
"""
N = self.target.shape[0]
dx = F.softmax(self.inputs)
dev = self.input.device
indices = torch.arange(N, dtype=self.input.dtype, device=device)
dx[indices, self.target] =- 1
dx /= N
return dx
__call__ = forward
import torch
class BCELoss:
"""
Need to apply sigmoid activation to inputs before passing it here
"""
def forward(self, probs, target):
"""
BCELoss(x, y) = -[y * log(x) + (1 - y) * log(1 - x)]
Args:
probs (Tensor): Predictions from the model
target (Tensor): Ground truth labels
Returns:
Binary cross-entropy loss averaged over batch
"""
self.probs = probs
self.target = target
return -torch.mean(
target * torch.log(probs) + (1 - target) * torch.log(1 - probs)
)
def backward(self):
"""Compute the gradient of binary cross-entropy loss.
dBCE(p, y)/dp = (p - y) / (p * (1 - p))
"""
return (self.probs - self.target) / (self.probs * (1 - self.probs))
import torch
class BCEWithLogitsLoss:
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""
input: (*) shape
target: (*) shape
"""
# Use logsumexp trick
# BCE(x, y) = (1 - y)*x + log(1 + e^{-x})
# = (1 - y)*x + log (e^0 + e^{-x})
zeros = torch.zeros_like(input)
m = torch.maximum(zeros, -input)
def lse(a, b):
return m + torch.log(torch.exp(a - m) + torch.exp(b - m))
return (1 - target) * input + lse(zeros, -input)
Mean Squared Error (MSE) loss measures the average squared difference between predictions and true values. For a batch of N samples:
The gradient with respect to the predictions x is:
import torch
class MSELoss:
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Compute mean squared error loss.
Args:
y_pred: Model predictions
y_true: Ground truth values
Returns:
Mean squared error averaged over batch
"""
self.diff = input - target
return torch.mean(self.diff**2)
def backward(self, dout: torch.Tensor):
"""Compute gradient of MSE loss.
Returns:
Gradient with respect to predictions: 2/N * (pred - true)
"""
N = self.diff.shape[0]
return 2 * self.diff / N
The gradient with respect to the predictions x is:
import torch
class L1Loss:
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Compute L1 loss."""
self.diff = input - target
return torch.mean(torch.abs(self.diff))
def backward(self):
"""Compute gradient of L1 loss.
Returns:
Gradient with respect to predictions: 1/N * sign(pred - true)
"""
return torch.sign(self.diff) / self.diff.shape[0]
Pytorch’s CrossEntropyLoss
combines the softmax and negative log likelihood loss into a single operation. For a complete understanding, we can break it down into two parts:
-
LogSoftmax: Take input and compute the log softmax:
To compute the gradient, differentiate with respect to , considering two cases:
- When :
- When : So in vectorized form: In Code:
import torch
import torch.nn.functional as F
class LogSoftmaxLayer:
def forward(self, x):
# x: (N, C)
# store input and output for backward
self.x = x
self.z = F.log_softmax(x, dim=1)
return self.z
def backward(self, dz):
# dz: gradient w.r.t. output z = log_softmax(x)
# We know dL/dx_i = dz_i - softmax(x)_i * sum_j dz_j
# But since softmax(x) = exp(x)/sum exp(x), we can quickly compute it:
# s = F.softmax(self.x, dim=1)
# or
s = torch.exp(self.z)
# Sum over classes: sum_j dz_j
sum_dz = torch.sum(dz, dim=1, keepdim=True)
return dz - s * sum_dz # dx
-
Negative Log Likelihood Loss: The negative log likelihood loss is defined as:
The gradient with respect to the log probabilities is:
import torch
class NLLLoss:
def forward(self, log_probs: torch.Tensor, target: torch.Tensor):
"""Compute negative log likelihood loss.
NLL = -1/N * sum_i log_probs[i, target[i]]
Args:
log_probs: Log probabilities, shape (N, C)
target: Ground truth values, shape (N,)
Returns:
Negative log likelihood loss
"""
self.log_probs = log_probs
self.target = target
N = target.shape[0]
return -log_probs[torch.arange(N), target].mean()
def backward(self):
"""Compute gradient of negative log likelihood loss.
dL/dlogp = -1/N * one_hot(target)
Returns:
Gradient with respect to log probabilities: -1
"""
N = self.target.shape[0]
grad = torch.zeros_like(self.log_probs)
grad[torch.arange(N), self.target] = -1.0
grad /= N
return grad
All together:
Evaluating the summation, minding that when and otherwise:
Finally, we know that is a one-hot encoded vector, is zero everywhere except at the correct class, so the sum collapses to:
Evaluating Gradients: Numerical Gradient Checking
While we’ve derived the gradients for common operations, it’s crucial to validate our implementations. One way to do this is by comparing the analytical gradients with numerical approximations. Numerical gradient checking is a technique to verify the correctness of our gradients by approximating them using finite differences.
The central difference formula approximates the gradient of a function at a point by evaluating the function at two nearby points and :
By comparing the analytical gradients with numerical approximations, we can catch errors in our implementations and ensure that our models are learning correctly.
Here’s a simple implementation of gradient checking for any layer implementing the Layer
protocol (which you’ll see in the next section):
import torch
from activations import ReLU, SigmoidLayer, SoftmaxLayer, TanhLayer
from layers import Layer, Linear
def eval_numerical_gradients(
layer: Layer,
x: torch.Tensor,
epsilon: float = 1e-6,
tolerance: float = 1e-5,
):
print(f"Evaluating grads for: {type(layer).__name__}")
# Store original parameters
original_params = [p.data.clone() for p in layer.parameters()]
# Compute the analytical gradients
output = layer.forward(x)
dout = torch.rand_like(output)
dx = layer.backward(dout)
analytical_grads = [
p.grad.data.clone() for p in layer.parameters() if p.grad is not None
]
analytical_input_grad = dx.clone()
# Compute numerical gradients
numerical_grads = []
for i, params in enumerate(layer.parameters()):
param_grad = torch.zeros_like(params)
flat_params = params.flatten()
flat_grads = param_grad.flatten()
for i in range(params.numel()):
orig_val = flat_params.data[i].clone()
# Compute f(x + epsilon)
flat_params.data[i] = orig_val + epsilon
output_plus = layer.forward(x)
f_plus = (output_plus * dout).sum()
# Compute f(x - epsilon)
flat_params.data[i] = orig_val - epsilon
output_minus = layer.forward(x)
f_minus = (output_minus * dout).sum()
# Compute numerical gradient
flat_grads.data[i] = (f_plus - f_minus) / (2 * epsilon)
# Restore original value
flat_params.data[i] = orig_val
numerical_grads.append(param_grad)
numerical_input_grad = torch.zeros_like(dx)
flat_x = x.flatten()
flat_input_grads = numerical_input_grad.flatten()
for i in range(x.numel()):
orig_val = flat_x.data[i].clone()
# Compute f(x + epsilon)
flat_x.data[i] = orig_val + epsilon
output_plus = layer.forward(x)
f_plus = (output_plus * dout).sum()
# Compute f(x - epsilon)
flat_x.data[i] = orig_val - epsilon
output_minus = layer.forward(x)
f_minus = (output_minus * dout).sum()
# Compute numerical gradient
flat_input_grads.data[i] = (f_plus - f_minus) / (2 * epsilon)
flat_x.data[i] = orig_val
# Compare analytical and numerical gradients
for i, (analytical, numerical) in enumerate(zip(analytical_grads, numerical_grads)):
diff = torch.abs(analytical - numerical).max().item()
print(f"Max difference for parameter {i}: {diff}")
print(f"Relative error: {relative_error(analytical, numerical)}")
assert (
diff < tolerance
), f"Gradient check failed for parameter {i}. Max difference: {diff}"
# Compare analytical and numerical gradients for input
input_diff = torch.abs(analytical_input_grad - numerical_input_grad).max().item()
print(f"Max difference for input: {input_diff}")
print(
f"Relative error for input: {relative_error(analytical_input_grad, numerical_input_grad)}"
)
assert (
input_diff < tolerance
), f"Gradient check failed for input. Max difference: {input_diff}"
# Restore original parameters
for param, original in zip(layer.parameters(), original_params):
param.data.copy_(original)
print("Gradient check passed!")
def relative_error(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-11):
return (torch.abs(a - b) / (torch.abs(a) + torch.abs(b) + eps)).mean()
if __name__ == "__main__":
# For the Linear layer
dtype = torch.float64
linear = Linear(10, 5).to(dtype=dtype)
x_linear = torch.randn(32, 10, dtype=dtype)
eval_numerical_gradients(linear, x_linear)
relu = ReLU()
x_relu = torch.randn(32, 10, dtype=dtype)
eval_numerical_gradients(relu, x_relu)
x = torch.randn(32, 10, dtype=dtype)
eval_numerical_gradients(SigmoidLayer(), x)
x = torch.randn(32, 10, dtype=dtype)
eval_numerical_gradients(SoftmaxLayer(), x)
x = torch.randn(32, 10, dtype=dtype)
eval_numerical_gradients(TanhLayer(), x)
Putting it all together: Training a small MLP on MNIST
Now that we have all the building blocks in place, let’s put them together to train a simple multi-layer perceptron (MLP) on the MNIST dataset. This classic computer vision task involves classifying 28x28 grayscale images of handwritten digits (0-9).
Our MLP will consist of multiple linear layers with ReLU activations between them, followed by a final linear layer that outputs logits for each of the 10 digit classes. We’ll use cross-entropy loss and stochastic gradient descent (SGD) for optimization.
The training loop follows these key steps:
- Forward pass through the network to get predictions
- Calculate loss using cross-entropy
- Backward pass to compute gradients
- Update parameters using SGD
- Repeat!
This end-to-end example demonstrates how the individual components we’ve built - layers, activations, loss functions, optimizers - work together during training. It also showcases common training practices like:
- Using mini-batches for efficiency
- Tracking metrics like accuracy and loss
- Validating on a test set
- Logging progress during training
While modern deep learning frameworks abstract away much of this complexity, understanding these fundamentals is invaluable for debugging, optimization, and developing new techniques.
Let’s examine the code and dissect what’s happening at each step:
First, let’s examine how we’ve structured our custom layers to mirror PyTorch’s Module interface. Our layers accept input tensors in their forward methods, maintain their own parameters and gradients, and properly chain the backward pass. Just like PyTorch modules, they:
- Have forward() and backward() methods for the computation graph
- Track parameters that need gradients
- Support function-like calling through call
- Maintain state between forward and backward passes
This design pattern makes it easy to compose layers into networks and perform automatic differentiation in a familiar way. The Layer protocol we defined above formalizes this interface.
Layer and Optim Protocols
class Layer(Protocol):
def forward(self, x: torch.Tensor) -> torch.Tensor: ...
def backward(self, dout: torch.Tensor) -> torch.Tensor: ...
def parameters(self) -> list[torch.Tensor]: ...
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class Optimizer(Protocol):
params: list[torch.Tensor]
def __init__(self, params: list[torch.Tensor], lr: float):
self.params = params
def step(self): ...
def zero_grad(self): ...
As you’ll see above, I’ve also defined Optimizers with a Protocol interface to establish a standard contract for optimizer types. Following these protocols helps maintain consistency and type safety in our implementations. When we extend this to new optimizers, they must provide step()
and zero_grad()
methods that match this interface, while maintaining the expected parameter update behavior.
Network Architecture
In our simple network architecture, we have a feedforward neural network (also called a multilayer perceptron or MLP) that takes 784-dimensional input (flattened 28x28 MNIST images) and outputs 10-dimensional logits (one for each digit class). The network consists of:
- Input layer: Flattens 28x28 images into 784-dimensional vectors
- Three hidden layers with ReLU activations between them, each with size
hidden_size
(default 1024) - Output layer: Linear layer producing 10 logits
- Cross entropy loss for training
The network uses standard linear layers with learned weight matrices and bias vectors, with ReLU nonlinearities between each layer to allow learning of non-linear functions.
Each linear layer performs an affine transformation: output = input @ weight + bias The ReLU activation applies an elementwise nonlinearity: output = max(input, 0)
This relatively simple architecture is still capable of achieving ~98% accuracy on MNIST with proper training. The multiple hidden layers allow the network to learn hierarchical features from the input digits.
import torch.nn as nn
from activations import ReLU
from layers import Layer, Linear
from losses import CrossEntropyLoss
class Net:
def __init__(self, hidden_size=1024, num_classes=10):
self.layers: list[Layer] = [
Linear(784, hidden_size),
ReLU(),
Linear(hidden_size, hidden_size),
ReLU(),
Linear(hidden_size, hidden_size),
ReLU(),
Linear(hidden_size, num_classes),
]
self.loss = CrossEntropyLoss()
def forward(self, x):
x = x.view(x.shape[0], -1)
for layer in self.layers:
x = layer(x)
return x
def backward(self, dout):
for layer in reversed(self.layers):
dout = layer.backward(dout)
return dout
def parameters(self):
return [param for layer in self.layers for param in layer.parameters()]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class PytorchNet(nn.Module):
def __init__(self, hidden_size=1024, num_classes=10):
super().__init__()
self.layers = nn.Sequential(
*[
nn.Linear(784, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_classes),
]
)
def forward(self, x):
x = x.view(x.shape[0], -1)
return self.layers(x)
Optimizers
Let’s look at different optimizers and their implementations:
class SGD:
params: list[torch.Tensor]
def __init__(self, params: list[torch.Tensor], lr: float = 0.01):
self.params = params
self.lr = lr
def step(self):
if self.lr <= 0:
raise ValueError("Learning rate must be positive")
if self.params and self.params[0].grad is None:
raise ValueError("Parameters have no gradients. Call backward first.")
for param in self.params:
param -= self.lr * param.grad
def zero_grad(self):
for param in self.params:
param.grad = torch.zeros_like(param)
Momentum helps accelerate gradients in the right direction by accumulating a velocity vector in directions of persistent reduction in the objective:
class SGDMomentum:
def __init__(self, params: list[torch.Tensor], lr: float = 0.01, momentum: float = 0.9):
self.params = params
self.lr = lr
self.momentum = momentum
self.velocities = [torch.zeros_like(param) for param in self.params]
def step(self):
"""
dW = momentum * dW - lr * dL/dW (perform grad update on velocity weight)
W = W + dW
"""
if self.lr <= 0:
raise ValueError("Learning rate must be positive")
if self.params and self.params[0].grad is None:
raise ValueError("Parameters have no gradients. Call backward first.")
for i, param in enumerate(self.params):
self.velocity[i] = self.momentum * self.velocity[i] - self.lr * param.grad
param.grad += self.velocity[i]
RMSprop maintains per-parameter learning rates adapted based on the root mean square of recent gradients. In other words, it divides the learning rate for a given weight by a running average of the magnitudes of recent gradients for that weight:
class RMSprop:
def __init__(
self,
params: list[torch.Tensor],
lr=0.01,
alpha=0.99,
momentum: float = 0,
eps=1e-8,
):
self.params = params
self.lr = lr
self.alpha = alpha
self.eps = eps
self.square_avg = [torch.zeros_like(p) for p in params]
self.momentum = momentum
if momentum > 0:
self.buffers = [torch.zeros_like(p) for p in params]
def step(self):
"""
v = alpha * v + (1 - alpha) * (dL/dW)^2
W = W - lr * dL/dW / sqrt(v + eps)
"""
if self.lr <= 0:
raise ValueError("Learning rate must be positive")
if self.params and self.params[0].grad is None:
raise ValueError("Parameters have no gradients. Call backward first.")
for i, param in enumerate(self.params):
grad: torch.Tensor = param.grad # type: ignore
self.square_avg[i] = (
self.alpha * self.square_avg[i] + (1 - self.alpha) * grad**2
)
if self.momentum > 0:
self.buffers[i] = self.momentum * self.buffers[i] + self.lr * grad / (
torch.sqrt(self.square_avg[i]) + self.eps
)
self.params[i] -= self.lr * self.buffers[i]
else:
self.params[i] -= (
self.lr * grad / (torch.sqrt(self.square_avg[i]) + self.eps)
)
Adam combines the best of RMSprop and momentum to achieve faster convergence. It uses the first and second moments of the gradients to adapt the learning rate for each parameter:
class Adam:
def __init__(
self,
params: list[torch.Tensor],
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
):
self.params = params
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.m = [torch.zeros_like(p) for p in params] # First moment
self.v = [torch.zeros_like(p) for p in params] # Second moment
self.t = 0
self.weight_decay = weight_decay
def step(self):
self.t += 1
for i, (param, m, v) in enumerate(zip(self.params, self.m, self.v)):
grad: torch.Tensor = param.grad # type: ignore
if self.weight_decay > 0:
grad += self.weight_decay * param
# Update biased first moment estimate
self.m[i] = self.beta1 * m[i] + (1 - self.beta1) * grad
# Update biased second raw moment estimate
self.v[i] = self.beta2 * v[i] + (1 - self.beta2) * grad**2
# Bias correction
m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)
# Update parameters
param -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
def zero_grad(self):
for param in self.params:
param.grad = torch.zeros_like(param)
While Adam
is the de-facto optimizer for most deep learning tasks, it’s always good to experiment with different optimizers to see what works best for your specific problem. For example, for this problem, I found that RMSprop
with momentum converges to nearly 100% accuracy on MNIST in just a few epochs.
Training loop:
The training loop employs several best practices for modern deep learning:
- DataLoader with batching and shuffling for efficient training
- Using AverageMeter class for tracking running statistics
- Regular logging of train/val metrics at configurable intervals
- Clean organization with separate model, criterion, optimizer
- Proper handling of train vs eval modes
- Time tracking for performance analysis
- Configurable hyperparameters like batch size, learning rate
- Progress tracking with global step counter
- Memory efficient by avoiding unnecessary tensor allocation
The validation loop mirrors the training structure but disables gradients for efficiency. Overall, the code follows a modular design that would be familiar to PyTorch users while implementing core functionality from scratch.
Full Training loop
import time
import torch
from data_utils import get_data
from losses import CrossEntropyLoss
from net import Net
from optimizers import get_optimizer
from torch_utils import AverageMeter, Logger, accuracy
def train(epochs=20, batch_size=512, lr=0.01, log_interval=10, val_interval=1):
data_iter = get_data(batchsize=batch_size, debug=False)
net = Net()
criterion = CrossEntropyLoss()
optimizer = get_optimizer("rmsprop", net.parameters(), lr=lr)
logger = Logger()
start_time = time.time()
global_step = 0
for epoch in range(epochs):
loss_meter = AverageMeter()
acc_meter = AverageMeter()
batch_time = AverageMeter()
for batch_idx, (x, y) in enumerate(data_iter("train")):
batch_start = time.time()
out = net(x)
loss = criterion(out, y)
optimizer.zero_grad()
dout = criterion.backward()
net.backward(dout)
optimizer.step()
loss_meter.update(loss.item())
acc_meter.update(accuracy(out, y))
batch_time.update(time.time() - batch_start)
global_step += 1
if (batch_idx + 1) % log_interval == 0:
logger.log(
{
"train_loss": loss_meter.avg,
"train_acc": acc_meter.avg,
"batch_time": batch_time.avg,
},
global_step,
)
# Validation
if (epoch + 1) % val_interval == 0:
val_loss_meter = AverageMeter()
val_acc_meter = AverageMeter()
val_batch_time = AverageMeter()
with torch.no_grad():
for x, y in data_iter("test"):
batch_start = time.time()
out = net(x)
loss = criterion(out, y)
val_loss_meter.update(loss.item())
val_acc_meter.update(accuracy(out, y))
val_batch_time.update(time.time() - batch_start)
logger.log(
{
"val_loss": val_loss_meter.avg,
"val_acc": val_acc_meter.avg,
"val_batch_time": val_batch_time.avg,
},
global_step,
)
end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds")
train()
Building a Simplified Tensor and Autograd Engine (aka “Micrograd”)
Now that we’ve covered the fundamentals of backpropagation and built a basic neural network training loop from scratch, let’s dive deeper by implementing a minimal version of an automatic differentiation engine - similar to PyTorch’s autograd but significantly simplified. This exercise, sometimes called “building micrograd,” will help cement our understanding of how modern deep learning frameworks work under the hood. Unlike before, where we manually had to pass gradients through the network, this engine will automatically compute gradients for us.
The key insight of automatic differentiation is that we can automatically compute gradients through arbitrary computation graphs by tracking operations and their corresponding derivative rules. Our implementation will support basic tensor operations and automatic gradient computation through the chain rule - providing a foundation similar to what powers frameworks like PyTorch and TensorFlow. We are truly lucky to have access to powerful deep learning libraries like PyTorch, TensorFlow, and JAX, which hide much of the complexity of automatic differentiation and GPU acceleration behind user-friendly APIs. However, building a simplified version of these libraries from scratch can deepen our understanding of how they work under the hood.
Tensor Class
We start by defining a Tensor
class that wraps a torch.Tensor
or np.ndarray
and tracks gradients. Our Tensor
class should support operations like addition, multiplication, and backpropagation. The goal is to create a basic computational graph that can compute gradients using reverse-mode automatic differentiation (backpropagation).
from typing import Callable, TypeAlias, Union
import numpy as np
import torch
TensorData = torch.Tensor | np.ndarray | float | int | list[float | int] # Underlying data type
Operand: TypeAlias = Union["Tensor", TensorData] # Input type for operations
class Tensor:
def __init__(
self,
data: TensorData,
requires_grad: bool = True,
_children: tuple["Tensor", ...] = (),
_op: str | None = None,
):
self.data = self._to_torch(data) # Convert underlying data to torch.Tensor
self.requires_grad = requires_grad # Flag to track gradients
self.grad = torch.zeros_like(self.data) if requires_grad else None
self._backward: Callable | None = None # Function to compute gradients
self._prev = set(_children) # Parents in the computation graph
self._op = _op # Operation (name) that produced this tensor
# (..., more methods to come)
Above, we’ve defined the Tensor
class with the following key attributes:
data
: The underlying tensor data (converted to atorch.Tensor
)requires_grad
: A flag to track gradientsgrad
: The gradient of the tensor (initialized to zeros ifrequires_grad
)_backward
: A function to compute gradients (to be defined later) - each operation will define its own_backward
function
We’ve also included a _prev
set to track the parents of the current tensor in the computation graph. This set will be used to construct the computation graph during the forward pass and perform the backward pass during gradient computation.
Finally, _to_tensor
is a helper method to convert input data to a torch.Tensor
for consistency.
Tracking Operations
Now let’s implement the basic operations that our Tensor class should support. Each operation needs to:
- Compute the forward result
- Define a backward function that specifies how to compute gradients
- Track the computation graph by storing the input tensors as parents of the output tensor
Here are the key operations we’ll implement:
- Basic arithmetic (+, -, *, /, @)
- Elementwise functions (ReLU, sigmoid, tanh)
- Reduction operations (sum, mean)
- Matrix operations (matmul)
Implementing Basic Arithmetic Operations
Let’s examine how multiplication works as an archetypal example of implementing tensor operations. When multiplying two tensors and , three key components work together:
- Forward Pass: Computes the result and stores it in a new tensor
- Gradient Rules: Implements the partial derivatives needed for backpropagation:
- (gradient with respect to first input)
- (gradient with respect to second input)
- Graph Construction: Records the computational history by:
- Creating a new tensor to hold the result
- Storing references to the input tensors and
- Defining a backward function that knows how to compute gradients using the stored inputs
This three-part structure creates a computational graph that enables automatic differentiation through:
- Forward propagation that builds the computation graph by tracking dependencies
- Backward traversal through the graph starting from the final output
- Gradient accumulation at each node according to the chain rule of calculus
Each operation in our tensor library will follow this same pattern, just with different forward computations and gradient formulas. The consistent structure allows automatic differentiation to work uniformly across all operations.
Let’s look at an example implementing these operations:
def __mul__(self, other: Operand) -> "Tensor":
"""Implement multiplication operation.
Args:
self: current tensor (`x`)
other: Second operand for multiplication (`y`)
Returns:
Tensor: Result of the multiplication operation (`z = x * y`)
"""
other = self._to_tensor(other)
requires_grad = self.requires_grad or other.requires_grad
data = self.data * other.data # Forward pass: z = x * y
out = Tensor(
data, # Result of the operation
requires_grad, # Flag to track gradients
(self, other), # Parents in the computation graph (x, y)
_op="mul" # Operation name
)
def _backward():
"""Compute gradients for the multiplication operation."""
if self.requires_grad:
# Gradient with respect to x: dz/dx = y (stored in self.grad (x.grad))
self.grad += (other.data * out.grad).type_as(self.grad)
if other.requires_grad:
# Gradient with respect to y: dz/dy = x (stored in other.grad (y.grad))
other.grad += (self.data * out.grad).type_as(other.grad)
out._backward = _backward
return out
The backward method
The backward
method is where the magic of automatic differentiation happens. This method orchestrates the computation of gradients through the entire computational graph. When called on a tensor, it:
- Initializes the gradient (typically 1.0 for scalar outputs)
- Performs a topological sort of the computation graph
- Applies the chain rule in reverse order through the sorted nodes
The topological sort is crucial because it ensures we compute gradients in the correct order - from the output back through to the inputs. Let’s dive into how this process works in detail with the implementation below:
def backward(self, gradient: Operand | None = None):
if gradient is None:
if self.data.ndim == 0 or self.data.numel() == 1:
gradient = torch.tensor(1.0, requires_grad=False)
else:
raise RuntimeError(
"grad can be implicitly created only for scalar outputs"
)
# Initialize gradient accumulation
if self.grad is None:
self.grad = torch.zeros_like(
self.data, requires_grad=False, dtype=get_dtype(gradient)
)
self.grad = self.grad.type(get_dtype(gradient)) + gradient
# topilogical sort for backward pass
topo: list["Tensor"] = []
visited = set()
def build_topo(v: "Tensor"):
"""Topological sort helper function.
Does a depth-first search on the computation graph to build a topological ordering.
"""
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
# Apply the chain rule to nodes in reverse order
for v in reversed(topo):
if v._backward is not None:
v._backward()
Understanding how the computational graph is built
Let’s break down the topological sort step by step, as it’s crucial for ensuring we compute gradients in the right order. Let’s use a concrete example:
# Let's say we have this computation:
x = Tensor(2.0, requires_grad=True) # Let's call this node x
y = x * 3 # Let's call this node y
z = y + 1 # Let's call this node z
z.backward() # We start backward from z
When we call z.backward()
, here’s what happens:
- First, let’s understand what we’re trying to solve:
- We need to compute gradients in reverse order of the computation
- depends on , which depends on
- We must calculate and
- The
build_topo
function does depth-first search:
topo = []
visited = set()
def build_topo(v: Tensor):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
- Let’s trace
build_topo(self)
for our example:
# Starting with node z
build_topo(z):
visited = {} # Empty set initially
z not in visited → True
Add z to visited = {z}
Look at z._prev = {y} # y was an input to z
# Recursively call for y
build_topo(y):
y not in visited → True
Add y to visited = {z, y}
Look at y._prev = {x} # x was an input to y
# Recursively call for x
build_topo(x):
x not in visited → True
Add x to visited = {z, y, x}
x._prev is empty # x had no inputs
Append x to topo → [x]
Append y to topo → [x, y]
Append z to topo → [x, y, z]
- Finally, we reverse the list and apply backward:
for v in reversed(topo):
if v_backward:
v._backward()
This processes the gradients in reverse order of computation:
- First computes
z._backward()
- Then
y._backward()
- Finally
x._backward()
The backward functions we defined for each operation accumulate the gradients in the .grad
attribute of each tensor according to the chain rule.
Getting back to our example:
x = Tensor(2.0, requires_grad=True)
y = x * 3
z = y + 1
z.backward()
The backward pass computes:
- dz/dy = 1 (derivative of addition)
- dz/dx = dz/dy * dy/dz = 1 * 3 = 3 (chain rule)
After z.backward() completes:
- z.grad = 1.0 (initialized gradient)
- y.grad = 1.0 (dz/dy)
- x.grad = 3.0 (dz/dx)
This matches our expectation since z = (x * 3) + 1, so dz/dx = 3.
Implementing other operations
At the heart of any modern neural network is the matrix multiplication operation. Let’s implement the matmul
operation, which computes the matrix product of two tensors. This operation is crucial for linear layers in neural networks and is the backbone of deep learning computations.
def __matmul__(self, other: Operand) -> "Tensor":
other = self._to_tensor(other)
requires_grad = self.requires_grad or other.requires_grad
data = self.data @ other.data
out = Tensor(data, requires_grad, (self, other), _op="matmul")
def _backward():
if self.requires_grad:
self.grad += out.grad @ other.data.T
if other.requires_grad:
other.grad += self.data.T @ out.grad
out._backward = _backward
return out
def __rmatmul__(self, other):
return self.__matmul__(other)
Implementing Other Operations
def __add__(self, other: Operand) -> "Tensor":
other = self._to_tensor(other)
requires_grad = self.requires_grad or other.requires_grad
data = self.data + other.data
out = Tensor(data, requires_grad, (self, other), _op="add")
def _backward():
if self.requires_grad:
self.grad += out.grad # type: ignore
if other.requires_grad:
out.grad += out.grad # type: ignore
out._backward = _backward
return out
def sum(self):
out = Tensor(
self.data.sum(),
requires_grad=self.requires_grad,
_children=(self,),
_op="sum",
)
def _backward():
if self.requires_grad:
self.grad += torch.ones_like(self.data, requires_grad=False) * out.grad
out._backward = _backward
return out
def __radd__(self, other):
return self.__add__(other)
def __rmul__(self, other):
return self.__mul__(other)
def __pow__(self, other: Operand) -> "Tensor":
other = self._to_tensor(other)
requires_grad = self.requires_grad or other.requires_grad
data = self.data**other.data
out = Tensor(data, requires_grad, (self, other), _op="pow")
def _backward():
if self.requires_grad:
# power rule: d/dx x^n = n * x^(n-1)
self.grad += other.data * self.data ** (other.data - 1) * out.grad
if other.requires_grad:
# d/dx a^x = a^x * ln(a)
other.grad += (self.data**other.data) * torch.log(self.data) * out.grad
out._backward = _backward
return out
def __neg__(self) -> "Tensor":
return self * -1
def __sub__(self, other: Operand) -> "Tensor":
return self + -(self._to_tensor(other))
def __rsub__(self, other: Operand) -> "Tensor":
return self._to_tensor(other) + (-self)
def __truediv__(self, other: Operand) -> "Tensor":
other = self._to_tensor(other)
requires_grad = self.requires_grad or other.requires_grad
data = self.data / other.data
out = Tensor(data, requires_grad, (self, other), _op="div")
def _backward():
if self.requires_grad:
# d/dx (a / b) = 1 / b
self.grad += (1 / other.data) * out.grad
if other.requires_grad:
# d/db (a / b) = -a / b^2
other.grad += (-self.data / other.data**2) * out.grad
out._backward = _backward
return out
def __rtruediv__(self, other: Operand) -> "Tensor":
return self._to_tensor(other) / self
def exp(self) -> "Tensor":
out = Tensor(
torch.exp(self.data),
self.requires_grad,
(self,),
_op="exp",
)
def _backward():
if self.requires_grad:
assert self.grad is not None
self.grad += (out.data * out.grad).type_as(self.grad)
out._backward = _backward
return out
def log(self) -> "Tensor":
out = Tensor(
torch.log(self.data),
self.requires_grad,
(self,),
_op="log",
)
def _backward():
if self.requires_grad:
assert self.grad is not None
self.grad += ((1 / self.data) * out.grad).type_as(self.grad)
out._backward = _backward
return out
def sigmoid(self) -> "Tensor":
out = Tensor(
torch.sigmoid(self.data),
self.requires_grad,
(self,),
_op="sigmoid",
)
def _backward():
if self.requires_grad:
assert self.grad is not None
self.grad += (out.data * (1 - out.data) * out.grad).type_as(self.grad)
out._backward = _backward
return out
def tanh(self) -> "Tensor":
out = Tensor(
torch.tanh(self.data),
self.requires_grad,
(self,),
_op="tanh",
)
def _backward():
if self.requires_grad:
assert self.grad is not None
self.grad += (1 - out.data**2) * out.grad
out._backward = _backward
return out
def sum_dim(self, dim: int, keepdim: bool = False) -> "Tensor":
data = self.data.sum(dim=dim, keepdim=keepdim)
out = Tensor(data, self.requires_grad, (self,), _op="sum_dim")
def _backward():
if self.requires_grad:
grad = out.grad
assert grad is not None
if not keepdim:
shape = list(self.data.shape)
shape[dim] = 1
grad = grad.view(shape)
self.grad += grad.expand_as(self.data)
out._backward = _backward
return out
def mean(self) -> "Tensor":
data = self.data.mean()
out = Tensor(data, self.requires_grad, (self,), _op="mean")
def _backward():
if self.requires_grad:
assert out.grad is not None
assert self.grad is not None
self.grad += (out.grad / self.data.numel()).type_as(self.grad)
out._backward = _backward
return out
Challenges in Practice
Despite its elegance, backpropagation faces practical hurdles in real-world scenarios:
-
Vanishing/Exploding Gradients: Deep networks can suffer from gradients shrinking to zero or growing too large.
- Solution: Use proper activation functions (e.g., ReLU), initialization strategies (e.g., Xavier or He), and normalization techniques.
-
Numerical Precision Errors: Tiny gradients may be lost in floating-point arithmetic.
- Solution: Use mixed-precision training and gradient clipping.
-
Computational Overheads: Large networks are expensive to train.
- Solution: Use mini-batching, efficient optimizers, and gradient checkpointing.
Modern Tweaks to Backpropagation
Over the years, practitioners have developed techniques to make backpropagation more robust:
Gradient Monitoring & Clipping
Prevent exploding gradients by capping their magnitude
Gradient Monitoring & Clipping
import time
from collections import defaultdict
from typing import Iterable
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from data_utils import get_data
from net import PytorchNet as Net
from torch_utils import (
AverageMeter,
Logger,
accuracy,
)
def tensor_norm(tensor: torch.Tensor, p: int = 2):
return torch.sqrt(tensor.pow(p).sum())
def compute_total_norm(params: Iterable[torch.Tensor], norm_type: float = 2):
return torch.norm(
torch.stack(
[
torch.norm(p.grad.detach(), norm_type)
for p in params
if p.grad is not None
]
),
p=norm_type,
)
def clip_grad_norm_(params: Iterable[torch.Tensor], max_norm: float):
"""
Clips gradient norm of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
params: List of parameters with gradients
max_norm: Max norm of the gradients
Returns:
Total norm of the parameters (viewed as a single vector).
"""
# First compute total norm of gradients
if isinstance(params, torch.Tensor):
params = [params]
params = [p for p in params if p.grad is not None]
if not params:
return torch.tensor(0.0)
total_norm = compute_total_norm(params)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: # Only clip if total_norm > max_grad_norm
for param in params:
if param.grad is not None:
param.grad.data.mul_(clip_coef.to(param.grad.device))
return total_norm
class GradientMonitor:
def __init__(self):
self.grad_stats = defaultdict(
lambda: {"mean": [], "std": [], "norm": [], "histogram": []}
)
self.param_stats = defaultdict(lambda: {"mean": [], "std": [], "norm": []})
self.total_norm_stats = {"before_clip": [], "after_clip": []}
def update(self, model: nn.Module):
total_norm_before = compute_total_norm(model.parameters())
self.total_norm_stats["before_clip"].append(total_norm_before)
for name, param in model.named_parameters():
if param.grad is not None:
grad = param.grad
self.grad_stats[name]["mean"].append(grad.mean().item())
self.grad_stats[name]["std"].append(grad.std().item())
self.grad_stats[name]["norm"].append(grad.norm().item())
self.grad_stats[name]["histogram"].append(
grad.cpu().histogram(bins=30).hist
)
self.param_stats[name]["mean"].append(param.data.mean().item())
self.param_stats[name]["std"].append(param.data.std().item())
self.param_stats[name]["norm"].append(param.data.norm().item())
def update_after_clip(self, model: nn.Module):
# Compute total norm after clipping
total_norm_after = compute_total_norm(model.parameters())
self.total_norm_stats["after_clip"].append(total_norm_after)
def log_stats(self, logger: Logger, global_step: int, log_interval):
for name in self.grad_stats.keys():
logger.log(
{
f"grad_{name}_mean": np.mean(
self.grad_stats[name]["mean"][-log_interval:]
),
f"grad_{name}_std": np.mean(
self.grad_stats[name]["std"][-log_interval:]
),
f"grad_{name}_norm": np.mean(
self.grad_stats[name]["norm"][-log_interval:]
),
f"param_{name}_mean": np.mean(
self.param_stats[name]["mean"][-log_interval:]
),
f"param_{name}_std": np.mean(
self.param_stats[name]["std"][-log_interval:]
),
f"param_{name}_norm": np.mean(
self.param_stats[name]["norm"][-log_interval:]
),
},
global_step,
)
logger.log(
{
"total_grad_norm_before_clip": np.mean(
self.total_norm_stats["before_clip"][-log_interval:]
),
"total_grad_norm_after_clip": np.mean(
self.total_norm_stats["after_clip"][-log_interval:]
),
},
global_step,
)
def visualize_stats(self):
self._visualize_stats(self.grad_stats, "gradient")
self._visualize_stats(self.param_stats, "parameter")
self._visualize_grad_histograms()
def visualize_total_norm(self):
plt.figure(figsize=(12, 6))
plt.plot(self.total_norm_stats["before_clip"], label="Before Clip")
plt.plot(self.total_norm_stats["after_clip"], label="After Clip")
plt.xlabel("Steps")
plt.ylabel("Total Gradient Norm")
plt.title("Total Gradient Norm Before and After Clipping")
plt.legend()
plt.tight_layout()
plt.savefig("total_gradient_norm.png")
plt.close()
def _visualize_stats(self, stats, stat_type):
fig, axs = plt.subplots(3, 1, figsize=(12, 18))
for name, data in stats.items():
axs[0].plot(data["mean"], label=f"{name}_mean")
axs[1].plot(data["std"], label=f"{name}_std")
axs[2].plot(data["norm"], label=f"{name}_norm")
axs[0].set_title(f"{stat_type.capitalize()} Mean")
axs[1].set_title(f"{stat_type.capitalize()} Std")
axs[2].set_title(f"{stat_type.capitalize()} Norm")
for ax in axs:
ax.legend()
ax.set_xlabel("Steps")
ax.set_ylabel("Value")
plt.tight_layout()
plt.savefig(f"{stat_type}_statistics.png")
plt.close()
def _visualize_grad_histograms(self):
num_layers = len(self.grad_stats)
fig, axs = plt.subplots(num_layers, 1, figsize=(12, 4 * num_layers))
for idx, (name, data) in enumerate(self.grad_stats.items()):
histograms = np.array(data["histogram"])
im = axs[idx].imshow(histograms.T, aspect="auto", cmap="viridis")
axs[idx].set_title(f"Gradient Histogram: {name}")
axs[idx].set_xlabel("Steps")
axs[idx].set_ylabel("Bins")
plt.colorbar(im, ax=axs[idx])
plt.tight_layout()
plt.savefig("gradient_histograms.png")
plt.close()
def train(epochs=20, batch_size=512, lr=0.01, log_interval=10, val_interval=1):
data_iter = get_data(batchsize=batch_size, debug=False)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
logger = Logger()
gradient_monitor = GradientMonitor()
start_time = time.time()
global_step = 0
for epoch in range(epochs):
loss_meter = AverageMeter()
acc_meter = AverageMeter()
batch_time = AverageMeter()
for batch_idx, (x, y) in enumerate(data_iter("train")):
batch_start = time.time()
out = net(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
gradient_monitor.update(net)
# Gradient clipping (optional, but good to show in interviews)
clip_grad_norm_(net.parameters(), max_norm=1.0)
gradient_monitor.update_after_clip(net)
optimizer.step()
loss_meter.update(loss.item())
acc_meter.update(accuracy(out, y))
batch_time.update(time.time() - batch_start)
global_step += 1
if (batch_idx + 1) % log_interval == 0:
logger.log(
{
"train_loss": loss_meter.avg,
"train_acc": acc_meter.avg,
"batch_time": batch_time.avg,
},
global_step,
)
# gradient_monitor.log_stats(logger, global_step, log_interval)
# Validation
if (epoch + 1) % val_interval == 0:
val_loss_meter = AverageMeter()
val_acc_meter = AverageMeter()
val_batch_time = AverageMeter()
with torch.no_grad():
for x, y in data_iter("test"):
batch_start = time.time()
out = net(x)
loss = criterion(out, y)
val_loss_meter.update(loss.item())
val_acc_meter.update(accuracy(out, y))
val_batch_time.update(time.time() - batch_start)
logger.log(
{
"val_loss": val_loss_meter.avg,
"val_acc": val_acc_meter.avg,
"val_batch_time": val_batch_time.avg,
},
global_step,
)
end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds")
gradient_monitor.visualize_stats()
gradient_monitor.visualize_total_norm()
train(epochs=2)
- Weight Initialization: Properly initialize weights to avoid unstable gradients (e.g., Xavier or He initialization).
- Normalization: Apply BatchNorm or LayerNorm to stabilize gradient flow.
Gradient accumulation
Gradient Accumulation
import time
from contextlib import contextmanager
import torch
from configs import TrainingConfig
from data_utils import get_data
from layers import Optimizer
from losses import CrossEntropyLoss
from net import Net
from optimizers import get_optimizer
from torch_utils import AverageMeter, Logger, accuracy, clip_grad_norm_
class GradientAccumulator:
def __init__(
self,
model: Net,
optimizer: Optimizer,
criterion: CrossEntropyLoss,
config: TrainingConfig,
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.config = config
# Track accumulation state
self.current_step = 0
self.loss_accumulator = 0.0
# Scaling factors for loss
self.accumulation_scale = 1.0 / self.config.accumulation_steps
# Monitoring metrics
self.max_gradient_seen = 0.0
self.min_gradient_seen = float("inf")
self.gradient_norms = []
@property
def accumulation_steps(self):
return self.config.accumulation_steps
@contextmanager
def accumulation_context(self, global_step: int | None = None):
curr_step = global_step or self.current_step
yield (curr_step + 1) % self.config.accumulation_steps == 0
self.current_step += 1
def backward(self, output: torch.Tensor, target: torch.Tensor):
"""
Perform backward pass with proper scaling
The loss is scaled by 1/accumulation_steps to maintain
equivalent gradients when accumulating over multiple steps.
"""
# First compute the raw loss
raw_loss = self.criterion(output, target)
# Scale the loss
scaled_loss = raw_loss * self.accumulation_scale
self.loss_accumulator += raw_loss.item() # Store unscaled loss for logging
# Get gradients from criterion with scaling
dout = self.criterion.backward() * self.accumulation_scale
self.model.backward(dout)
return scaled_loss
def step(self):
if (self.current_step + 1) % self.accumulation_steps == 0:
if self.config.max_grad_norm > 0:
self._clip_gradients()
self.optimizer.step()
self.optimizer.zero_grad()
# Get average loss over accumulation steps
avg_loss = self.loss_accumulator / self.config.accumulation_steps
self.loss_accumulator = 0.0
return avg_loss
def _clip_gradients(self):
"""Clip gradients and collect statistics"""
# Calculate gradient norm
grad_norm = clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
# Update statistics
self.max_gradient_seen = max(self.max_gradient_seen, grad_norm) # type: ignore
self.min_gradient_seen = min(self.min_gradient_seen, grad_norm) # type: ignore
self.gradient_norms.append(grad_norm)
return grad_norm
def train(config: TrainingConfig):
data_iter = get_data(batchsize=config.batch_size, debug=False)
net = Net()
criterion = CrossEntropyLoss()
optimizer = get_optimizer("rmsprop", net.parameters(), lr=config.lr)
logger = Logger()
accumulator = GradientAccumulator(net, optimizer, criterion, config)
start_time = time.time()
global_step = 0
for epoch in range(config.epochs):
loss_meter = AverageMeter()
acc_meter = AverageMeter()
batch_time = AverageMeter()
for batch_idx, (x, y) in enumerate(data_iter("train")):
batch_start = time.time()
with accumulator.accumulation_context() as is_last_step:
out = net(x)
loss = accumulator.backward(out, y)
if is_last_step:
accumulator.step()
loss_meter.update(loss.item())
acc_meter.update(accuracy(out, y))
batch_time.update(time.time() - batch_start)
global_step += 1
if (batch_idx + 1) % config.log_interval == 0:
logger.log(
{
"train_loss": loss_meter.avg,
"train_acc": acc_meter.avg,
"batch_time": batch_time.avg,
},
global_step,
)
# Validation
if (epoch + 1) % config.val_interval == 0:
val_loss_meter = AverageMeter()
val_acc_meter = AverageMeter()
val_batch_time = AverageMeter()
with torch.no_grad():
for x, y in data_iter("test"):
batch_start = time.time()
out = net(x)
loss = criterion(out, y)
val_loss_meter.update(loss.item())
val_acc_meter.update(accuracy(out, y))
val_batch_time.update(time.time() - batch_start)
logger.log(
{
"val_loss": val_loss_meter.avg,
"val_acc": val_acc_meter.avg,
"val_batch_time": val_batch_time.avg,
},
global_step,
)
end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds")
def test_gradient_accumulation():
# Create small test data
batch_size = 4
input_dim = 784
num_classes = 10
# Create model and accumulator
net = Net()
criterion = CrossEntropyLoss()
optimizer = get_optimizer("rmsprop", net.parameters(), lr=0.01)
config = TrainingConfig(accumulation_steps=4)
accumulator = GradientAccumulator(net, optimizer, criterion, config)
# Run forward/backward with accumulation
for i in range(8): # 2 complete accumulation cycles
x = torch.randn(batch_size, input_dim)
y = torch.randint(0, num_classes, (batch_size,))
with accumulator.accumulation_context() as is_last_step:
out = net(x)
loss = accumulator.backward(out, y)
if is_last_step:
# Check gradients before and after clipping
params = net.parameters()
grad_norm_before = sum(p.grad.norm().item() ** 2 for p in params) ** 0.5 # type: ignore
accumulator._clip_gradients()
grad_norm_after = sum(p.grad.norm().item() ** 2 for p in params) ** 0.5 # type: ignore
print(f"Step {i+1}:")
print(f" Grad norm before clipping: {grad_norm_before:.4f}")
print(f" Grad norm after clipping: {grad_norm_after:.4f}")
accumulator.step()
return net
test_gradient_accumulation()
train(TrainingConfig(accumulation_steps=4))
Why Revisit Backpropagation?
Understanding backprop deeply allows us to:
- Debug and fine-tune models effectively.
- Innovate on optimization techniques and architectures.
- Appreciate the mathematical and computational elegance that powers modern AI.
By building our own tools and revisiting foundational concepts, we bridge the gap between theory and practice.
Conclusion
“Back to Backprop” isn’t just a nod to the algorithm’s history—it’s a call to master the fundamentals that underpin our most advanced systems. Whether you’re troubleshooting a vanishing gradient, designing a new architecture, or simply exploring the beauty of gradient-based learning, backpropagation remains a cornerstone of machine learning.
Happy coding, and here’s to gradients that flow smoothly!