Skip to content

Back to Backprop

Published: at 06:42 AM in 29 min readSuggest Changes

Table of Contents

Open Table of Contents

Back to Backprop

In the rapidly advancing field of deep learning, it’s easy to get swept up in the excitement of state-of-the-art architectures and massive datasets. Yet, beneath the surface of every breakthrough is the workhorse of gradient-based optimization: backpropagation. For seasoned practitioners, revisiting this foundational algorithm isn’t just a nostalgic exercise—it’s an opportunity to sharpen our intuition and improve our practice.

This post explores backpropagation from a practitioner’s lens. We’ll begin with a concise refresher, implementing the forward and backward passes for foundational layer types. We’ll demonstrate a full training loop for a simple neural network, using our own custom implementation and mini-autograd engine. Along the way, we’ll dive into real-world challenges, explore modern tweaks. Finally, we’ll culminate in building a simplified PyTorch-like tensor and autograd engine.

Here are some quick links to some awesome resources on backpropagation:

From Derivatives to Gradients and Jacobians

For completelness sake, I think it is worth just restating some of the basic definitions and concepts that underpin backpropagation.

Derivatives

Derivatives are a way to measure change in the scalar case:

Derivative Definition

Given a function f:RRf: \mathbb{R} \rightarrow \mathbb{R}, the derivative of ff at a point xRx\in \mathbb{R} is defined as:

f(x)=limh0f(x+h)f(x)hf'(x) = \lim_{h \to 0} \frac{f(x + h) - f(x)}{h}

Put differently, they tell us how much a function ff changes as the input xx changes by a small amount ϵ\epsilon: f(x+ϵ)f(x)+ϵf(x)f(x + \epsilon) \approx f(x) + \epsilon f'(x).

The Chain Rule

The chain rule tells us how to compute the derivative of a composition of functions. Suppose we have functions f,g:RRf,g: \mathbb{R} \rightarrow \mathbb{R} and h=fgh = f \circ g In other words y=f(x)y = f(x) and z=g(y)z = g(y). Then, the chain rule states:

zx=zyyx\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \frac{\partial y}{\partial x}

In modern neural networks, it turns out that backpropagation is essentially the chain rule applied to a composition of vector/tensor-valued functions.

Gradients: Vector in -> Scalar out

Gradients are the generalization of derivatives to the multivariate case, where we have a scalar-valued function of multiple variables. The gradient of a function f:RNRf: \mathbb{R}^N \rightarrow \mathbb{R} is a vector of partial derivatives, one for each input variable.

Gradient Definition

Given a function f:RNR\mathbf{f}: \mathbb{R}^N \rightarrow \mathbb{R}, the gradient of ff at a point xRN\mathbf{x} \in \mathbb{R}^N is defined as:

xf(x)limh0f(x+h)f(x)h\nabla_\mathbf{x} f(\mathbf{x}) \lim_{h \to 0} \frac{f(\mathbf{x} + h) - f(\mathbf{x})}{ ||h||}

It can also be viewed as a vector of partial derivatives:

f(x)x=[f(x)x1,f(x)x2,,f(x)xN]=[f(x)xi]i=1N\nabla f(x)_x = \left[ \frac{\partial \mathbf{f(x)}}{\partial x_1}, \frac{\partial \mathbf{f(x)}}{\partial x_2}, \ldots, \frac{\partial \mathbf{f(x)}}{\partial x_N} \right] = \left[ \frac{\partial \mathbf{f(x)}}{\partial x_i} \right]_{i=1}^N

When we consider the relationship

xx+Δxf(x)f(x)+f(x)Δxx \rightarrow x + \Delta x \Rightarrow f(x) \approx f(x) + \nabla f(x) \Delta x

it is important to note that x,Δx,f(x)x, \Delta x, \nabla f(x) are all vectors, now. f(x)xΔ˙x\nabla f(x)_x \dot \Delta x is computed by taking the dot product, producing a scalar, which matches the output signature of the gradient.

Jacobians: Vector in -> Vector out

Now we consider a function f:RNRM\mathbf{f}: \mathbb{R}^N \rightarrow \mathbb{R}^M. The Jacobian of f\mathbf{f} at a point xRN\mathbf{x} \in \mathbb{R}^N is an M×NM \times N matrix of partial derivatives defined such that the (i,j)th(i, j)-th entry of the Jacobian is f(x)ixj\boxed{\frac{\partial \mathbf{f(x)}_i}{\partial x_j}}. Written out in other forms:

Jf=[f1x1,f1x2,,f1xN]=[Tf1TfM]=[f1x1f1x2f1xNf2x1f2x2f2xNfMx1fMx2fMxN]\mathbf{J_f} = \left[ \frac{\partial \mathbf{f}_1}{\partial x_1}, \frac{\partial \mathbf{f_1}}{\partial x_2}, \ldots, \frac{\partial \mathbf{f_1}}{\partial x_N} \right] = \begin{bmatrix} \nabla^T f_1\\ \vdots \\ \nabla^T f_M \end{bmatrix} = \begin{bmatrix} \frac{\partial \mathbf{f}_1}{\partial x_1} & \frac{\partial \mathbf{f_1}}{\partial x_2} & \ldots & \frac{\partial \mathbf{f_1}}{\partial x_N} \\ \frac{\partial \mathbf{f}_2}{\partial x_1} & \frac{\partial \mathbf{f}_2}{\partial x_2} & \ldots & \frac{\partial \mathbf{f}_2}{\partial x_N} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \mathbf{f}_M}{\partial x_1} & \frac{\partial \mathbf{f}_M}{\partial x_2} & \ldots & \frac{\partial \mathbf{f}_M}{\partial x_N} \\ \end{bmatrix}

Ways to remember the formula:

Generalized Jacobians: Tensor in -> Tensor out

From scalars to vectors (1-D array) to matrices(2-D array), we can generalize the concept of Jacobians to higher-order tensors, which we can view as DD-dimensional arrays.

Commonly, in neural networks, it is very common to think of tensor operations in terms of their “shapes”, which are the dimensions of the tensor, i.e., the tuple returned by tensor.shape in numpy/PyTorch. You will frequently see code where comments are tracking the shape of the tensor, e.g., # [B, N, D] x [B, D, M] -> [B, N, M]. This is a shorthand and has even been formalized in the Einstein Notation used in PyTorch, with einsum and the python package einops. ` to indicate a 2-D tensor with NN rows and DD columns. . or example, a tensor of shape (N,D)(N, D) is a 2-D tensor with NN rows and DD columns.

Now, suppose we had a function f:RN1×N2××NDxRM1×M2××MDyf: \mathbb{R}^{N_1 \times N_2 \times \cdots \times N_{D_x}} \rightarrow \mathbb{R}^{M_1 \times M_2 \times \cdots \times M_{D_y}}. The Jacobian of ff at a point XRN1×N2××NDx\mathbf{X} \in \mathbb{R}^{N_1 \times N_2 \times \cdots \times N_{D_x}} is a tensor of shape (M1×M2××MDy×N1×N2××NDx)(M_1 \times M_2 \times \cdots \times M_{D_y} \times N_1 \times N_2 \times \cdots \times N_{D_x}). The generalized Jacobian is another tensor with shape:

(M1×M2××MDy)×(N1×N2××NDx)(M_1 \times M_2 \times \cdots \times M_{D_y}) \times (N_1 \times N_2 \times \cdots \times N_{D_x})

The Mechanics of Backpropagation

At its core, backpropagation efficiently computes gradients for training neural networks by leveraging the chain rule. Here’s how it works:

  1. Forward Pass: Inputs flow through the network, producing intermediate activations and final outputs.
  2. Backward Pass: Gradients are calculated layer by layer in reverse, propagating partial derivatives back to earlier layers.

Suppose we have a function z=xWz = xW as a single node in a computational graph. The output zz is passed through additional functions (nodes in the graph) before producing a final scalar loss LL. The chain rule allows us to compute the gradient of LL with respect to xx and WW by multiplying the gradients of each intermediate function:

Lx=szzX and LW=szzW\frac{\partial L}{\partial x} = \dots \frac{\partial s}{\partial z}\frac{\partial z}{\partial X} \quad \text{ and } \quad \frac{\partial L}{\partial W} = \dots \frac{\partial s}{\partial z}\frac{\partial z}{\partial W}

where sz\dots \frac{\partial s}{\partial z} represents the chain of gradients from the loss LL to the output zz, known as the upstream gradient. The goal is to compute the downstream gradients:

[downstream gradient] = [upstream gradient] * [local gradient]

Backprop by hand

Let’s walk through a simple example to illustrate backpropagation. We’ll start use a 1-layer neural network. Let’s assume vectorized/batched computation starting with input xRN×Dx \in \mathbb{R}^{N \times D}, where NN is the number of samples and DD is the input dimension.

x=inputRN×Dh=xW+bRN×Ha=ReLU(h)z=aU+b2y^=softmax(z)L=cross-entropy(y^,y)\begin{align*} x &= \text{input} \in \mathbb{R}^{N \times D} \\ h &= xW + b \in \mathbb{R}^{N \times H} \\ a &= \text{ReLU}(h) \\ z &= aU + b_2 \\ \hat{y} &= \text{softmax}(z) \\ L &= \text{cross-entropy}(\hat{y}, y) \\ \end{align*}

Linear Layer:

A linear layer consists of weights WRD×HW \in \mathbb{R}^{D \times H}, bias bRHb \in \mathbb{R}^{H}.

x(xW+b)=WT\frac{\partial}{\partial x}(xW + b) = W^T
W(xW+b)=XT\frac{\partial}{\partial W}(xW+b) = X^T
b(xW+b)=I\frac{\partial}{\partial b}(xW + b) = I
Derivation: Linear Layer Gradients

We have z=xWz = xW.

Shapes:

  • xRN×Dx \in \mathbb{R}^{N \times D}
  • WRD×HW \in \mathbb{R}^{D \times H}
  • zRN×Hz \in \mathbb{R}^{N \times H}

For simplicity, let’s consider the first row of xx and zz or just N=1N=1 (we’ll drop the subscript 11). Then we have xRDx \in \mathbb{R}^{D} and zRHz \in \mathbb{R}^{H}.

zi=kDxkWkiz_{i} = \sum_{k}^{D} x_{k}W_{ki}

Note: counter to normal convention, we use ii to index the columns of WW so that we can use ii to index into the output zz.

What is zx\frac{\partial z}{\partial x}?

(zx)ij=zixj=xjkDxkWki=kDxjxkWki(\frac{\partial z}{\partial x})_{ij} = \frac{\partial z_i}{\partial x_j} = \frac{\partial}{\partial x_j} \sum_{k}^{D} x_{k}W_{ki} = \sum_{k}^{D} \frac{\partial}{\partial x_j} x_{k}W_{ki}

Notice that xkxj=δkj\frac{\partial x_{k}}{\partial x_{j}} = \delta_{kj}, or 1[k=j]\mathbb{1}[k = j]. In other words, 1 if k=jk=j and 0 otherwise. Thus,

(zx)ij=kD1[k=j]Wki=Wji(\frac{\partial z}{\partial x})_{ij} = \sum_{k}^{D} \mathbb{1}[k =j] \cdot W_{ki} = W_{ji}

By the same logic, we can write this as zx=WT\boxed{\frac{\partial z}{\partial x} = W^T}

For the general case, the same pattern holds. See Computing the Jacobian of a Matrix Product for more details.

In code:

class Linear:
    def __init__(self, in_features: int, out_features: int):
        self.weight = torch.empty(in_features, out_features)
        self.bias = torch.zeros(out_features)
        self.x: torch.Tensor | None = None
        init_uniform_(self.weight)
        init_uniform_(self.bias)

    def parameters(self):
        return [self.weight, self.bias]

    def forward(self, x: torch.Tensor):
        """Computes the forward pass for a linear layer.

        Args:
            x (Tensor): Input data, of shape (N, in_features) where N is the batch size and
                in_features is the number of input features.
        """
        self.x = x
        return x @ self.weight + self.bias  # @ is torch.matmul

    def backward(self, dout: torch.Tensor):
        """Computes the backward pass for a linear layer.

        Args:
            dout (Tensor): Upstream derivative, of shape (N, out_features).

        Notes:
        ------------
        z = xW + b
        dz/dx = W.T
        dz/dW = x.T
        dz/db = I_{out}

        chain rule
        dL/dx = dout * dz/dx
              = dout * W.T  [N, out] x [out, in]
        dL/dW = dout * dz/dW (need to swap to make shapes work out)
              = x.T  * dout [in, N] x [N, out]
        dL/db = dout * dz/db
              = dout * I_{out} [N, out] x [out, out]
        """
        if self.x is None:
            raise ValueError("No cache found. Run forward pass first.")
        dx = dout.mm(self.weight.T)  # [N, in_features]
        self.weight.grad = self.cache.T.mm(dout)  # [in_features, out_featues]
        self.bias.grad = dout.sum(dim=0)
        return dx

ReLU Activation:

To start with, recall that ReLU(x) = max(x, 0). This means that

ReLU(x)=1 if x>0 else 0=1[x>0]\text{ReLU}'(x) = 1 \text{ if } x > 0 \text{ else } 0 = 1[x > 0]

which is also sgn(ReLU(x))\text{sgn}(\text{ReLU}(x)).

def relu(x: torch.Tensor) -> torch.Tensor:
	"""Computes max(0, x) element-wise."""
	return x.clamp(min=0)

class ReLU:
    def parameters(self):
        return []

    def forward(self, x: torch.Tensor):
        """Computes max(0, x) element-wise."""
        self.x = x
        return relu(x)

    def backward(self, dout: torch.Tensor):
        """Computes the backward pass for a layer of rectified linear units (ReLUs).

        ReLU(x) = max(0, x)
        ReLU'(x) = 1 if x > 0 else 0
        out = ReLU(x)
        dL/dx = dout * ReLU'(x)
              = dout * 1[x > 0]
              = dout[x < 0] = 0
              = dout * 1[x > 0]

        Args:
            dout (Tensor): Upstream derivative, of shape (N, M).

        """
        return dout * (self.x > 0).float()

    __call__ = forward

Tanh Activation:

To start with, recall that Tanh(x)=e2x1e2x+1\text{Tanh}(x) = \frac{e^{2x} - 1}{e^{2x} + 1}. This means that the derivative is:

Tanh(x)=1Tanh2(x)\text{Tanh}'(x) = 1 - \text{Tanh}^2(x)
class TanhLayer:
    def forward(self, x):
        self.output = torch.tanh(x)
        return self.output

    def backward(self, dout: torch.Tensor):
        return dout * (1 - self.output**2)

    def parameters(self):
        return []

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

Softmax, LogSoftmax and LogSumExp

These functions commonly arise in the context of neural networks, especially in the output layer. They occur when dealing with categorical and multinomial probability distributions, converting raw logits into a probably distrubution over classes, as well as in attention mechanisms in transformers These functions are all closely related, involving sums of exponentials, and have nice relationships with each other.

Definitions

Given a vector xRn\mathbf{x} \in \mathbb{R}^n, we define the following functions:

s=softmax(x):=exp(x)iexp(xi)\mathbf{s} = \text{softmax}(\mathbf{x}) := \frac{\exp(\mathbf{x})}{\sum_{i} \exp(x_i)}
sixj=si(δijsj) or Js=diag(s)ssT\frac{\partial s_i}{\partial x_j} = s_i(\delta_{ij} - s_j) \text{ or } \mathbf{J}_{\mathbf{s}} = \text{diag}({\mathbf{s}}) - \mathbf{s} \mathbf{s}^T
=logsumexp(x):=log(iexp(xi))\ell = \text{logsumexp}(\mathbf{x}) := \log \left( \sum_{i} \exp(x_i) \right)
xj=sj or J=diag(softmax(x))\frac{\partial \ell}{\partial x_j} = s_j \text{ or } \mathbf{J}_{\ell} = \text{diag}(\text{softmax}(\mathbf{x}))
s~=logsoftmax(x):=xlogsumexp(x)\mathbf{\tilde{s}} = \text{logsoftmax}(\mathbf{x}) := \mathbf{x} - \text{logsumexp}(\mathbf{x})
s~ixj=δijsj or Jlogsoftmax=I1softmax(x)T\frac{\partial \tilde{s}_i}{\partial x_j} = \delta_{ij} - s_j \text{ or } \mathbf{J}_{\text{logsoftmax}} = \mathbf{I} - \mathbf{1}\cdot \text{softmax}(\mathbf{x})^T

So we see that the softmax and logsoftmax are both functions that take in a vector and return a vector, while the logsumexp is a scalar valued function. The logsoftmax is just the log of the softmax, which also happens to be the input minus its logsumexp.

And in code:

def softmax(z):
    """Compuze softmax values for each sets of scores in z.

    Args:
        z (Tensor): input tensor

    Notes:
        The definition of the softmax:
            softmax(z) = exp(z) / sum(exp(z))
        can be numerical unstable, so we use the following definition:
            softmax(z) = exp(z) / sum(exp(z - max(z)))
        which multiplies the numerator and denominator by a constant factor.
        It is usually chosen to be -max(z) producing negative to zero range
        to avoid numerical overflow.
    """
    e_z = torch.exp(z - torch.max(z))
    return e_z / e_z.sum()

def logsumexp(s: torch.Tensor, dim=-1):
	"""Numerically stable log(sum(exp(s)))"""
	max_val = s.max(dim=dim, keepdim=True)[0]
	return torch.log(torch.sum(torch.exp(s - max_val), dim=dim, keepdim=True)) + max_val

def log_softmax(z: torch.Tensor, dim=-1):
	"""Return logprobs"""
	return z - logsumexp(z, dim=dim)

class SoftmaxLayer:
    def forward(self, x):
        # Save input for backward pass
        self.input = x

        # For numerical stability, subtract max value before exponential
        # Shape: (batch_size, num_classes)
        shifted_x = x - torch.max(x, dim=1, keepdim=True)[0]

        # Compute exponentials
        # Shape: (batch_size, num_classes)
        exp_x = torch.exp(shifted_x)

        # Compute sum of exponentials for normalization
        # Shape: (batch_size, 1)
        sum_exp = torch.sum(exp_x, dim=1, keepdim=True)

        # Compute softmax output
        # Shape: (batch_size, num_classes)
        self.output = exp_x / sum_exp

        return self.output

    def backward(self, grad_output):
        # grad_output shape: (batch_size, num_classes)
        # output shape: (batch_size, num_classes)

        # The Jacobian of softmax for each sample is:
        # J[i,j] = output[i] * (δ[i,j] - output[j])
        # where δ[i,j] is 1 if i=j and 0 otherwise

        # Compute using the compact form:
        # grad = output * (grad_output - (output · grad_output))

        # Shape: (batch_size, 1)
        sum_grad = torch.sum(grad_output * self.output, dim=1, keepdim=True) # s · ∇_s L

        # Shape: (batch_size, num_classes)
        return self.output * (grad_output - sum_grad) # s ⊙ (∇_s L - (s · ∇_s L))

    def __call__(self, x):
        return self.forward(x)

Sigmoid:

While the softmax function is used for multi-class classification, the sigmoid function is used for binary classification. The primary difference is that the softmax enforces that elements of particular dimension sum to 1 (producing a probability distribution), while the sigmoid function squashes the output to the range (0,1)(0, 1), making it suitable for binary classification.

Definition

The sigmoid function, also known as the logistic function, maps any real number to a value between 0 and 1. It’s particularly useful for binary classification tasks since its output can be interpreted as a probability. The sigmoid function is defined as:

σ(x)=11+ex\sigma(x) = \frac{1}{1 + e^{-x}}

Its derivative has a particularly nice form:

σ(x)=σ(x)(1σ(x))\sigma'(x) = \sigma(x)(1 - \sigma(x))

Which makes it efficient to compute during backpropagation. And in code:

def sigmoid(x: torch.Tensor) -> torch.Tensor:
	return 1 / (1 + torch.exp(-x))

class SigmoidLayer:

	def forward(self, x: torch.Tensor):
		self.out = sigmoid(x)
		return self.out

	def backward(self, dout: torch.Tensor):
		"""
		sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
		"""
		return dout * self.out * (1 - self.out)

Typically, in frameworks like PyTorch, models output logits (unnormalized scores) and it usually the role of the crtiterion (e.g. CrossEntropyLoss) to implicitly convert the logits to probabilities using the softmax function.

Losses:

Cross Entropy: H(p,q)=xXp(x)logq(x) H(p, q) = - \underset{x \in X}{\sum}p(x)\log q(x)

Gradient of cross entropy loss w.r.t. logits (i.e., y^=softmax(z)\hat{y} = \text{softmax}(z), L=H(y,y^)L = H(y, \hat{y})what is Lz\frac{\partial L}{\partial z}):

Lz=y^y\boxed{\frac{\partial L}{\partial z} = \hat{y} - y}
class CrossEntropyLoss:
	"""Computes the forward pass for the cross-entropy loss.

    CE(s, y) = -SUM_i y_i * log(softmax(s)_i)
    torch.mean(-torch.sum(targets * torch.log(torch.softmax(inputs, dim=1)), dim=1)))

    Args:
        x (Tensor): Input data, of shape (N, C) where N is the batch size and
            C is the number of classes.
        y (Tensor): Ground truth labels, of shape (N,).
    """

	def forward(self, input: torch.Tensor, target: torch.Tensor):
		self.input = input
		self.target = target
		indices = torch.arange(target.shape[0])
		return -log_softmax(input)[indices, target].mean()

	def backward(self):
		"""Compute the backward pass for the cross-entropy loss.

		dL(x, y)/dx = softmax(x) - y
		where x is the unnormalized score, y is the one-hot encoded target label.

        dL(p, y)/dx = p - y  (p is predicted prob)

        Returns:
            dx (Tensor): Gradient of the loss with respect to the input x.
        """
		N = self.target.shape[0]
		dx = F.softmax(self.inputs)
		dev = self.input.device
		indices = torch.arange(N, dtype=self.input.dtype, device=device)
		dx[indices, self.target] =- 1
		dx /= N
		return dx

	__call__ = forward

Evaluating Gradients: Numerical Gradient Checking

While we’ve derived the gradients for common operations, it’s crucial to validate our implementations. One way to do this is by comparing the analytical gradients with numerical approximations. Numerical gradient checking is a technique to verify the correctness of our gradients by approximating them using finite differences.

The central difference formula approximates the gradient of a function ff at a point xx by evaluating the function at two nearby points x+ϵx + \epsilon and xϵx - \epsilon:

fxf(x+ϵ)f(xϵ)2ϵ\frac{\partial f}{\partial x} \approx \frac{f(x + \epsilon) - f(x - \epsilon)}{2\epsilon}

By comparing the analytical gradients with numerical approximations, we can catch errors in our implementations and ensure that our models are learning correctly.

Here’s a simple implementation of gradient checking for any layer implementing the Layer protocol (which you’ll see in the next section):

import torch
from activations import ReLU, SigmoidLayer, SoftmaxLayer, TanhLayer
from layers import Layer, Linear


def eval_numerical_gradients(
    layer: Layer,
    x: torch.Tensor,
    epsilon: float = 1e-6,
    tolerance: float = 1e-5,
):
    print(f"Evaluating grads for: {type(layer).__name__}")
    # Store original parameters
    original_params = [p.data.clone() for p in layer.parameters()]

    # Compute the analytical gradients
    output = layer.forward(x)
    dout = torch.rand_like(output)
    dx = layer.backward(dout)

    analytical_grads = [
        p.grad.data.clone() for p in layer.parameters() if p.grad is not None
    ]
    analytical_input_grad = dx.clone()

    # Compute numerical gradients
    numerical_grads = []
    for i, params in enumerate(layer.parameters()):
        param_grad = torch.zeros_like(params)

        flat_params = params.flatten()
        flat_grads = param_grad.flatten()
        for i in range(params.numel()):
            orig_val = flat_params.data[i].clone()

            # Compute f(x + epsilon)
            flat_params.data[i] = orig_val + epsilon
            output_plus = layer.forward(x)
            f_plus = (output_plus * dout).sum()

            # Compute f(x - epsilon)
            flat_params.data[i] = orig_val - epsilon
            output_minus = layer.forward(x)
            f_minus = (output_minus * dout).sum()

            # Compute numerical gradient
            flat_grads.data[i] = (f_plus - f_minus) / (2 * epsilon)

            # Restore original value
            flat_params.data[i] = orig_val

        numerical_grads.append(param_grad)

    numerical_input_grad = torch.zeros_like(dx)
    flat_x = x.flatten()
    flat_input_grads = numerical_input_grad.flatten()
    for i in range(x.numel()):
        orig_val = flat_x.data[i].clone()

        # Compute f(x + epsilon)
        flat_x.data[i] = orig_val + epsilon
        output_plus = layer.forward(x)
        f_plus = (output_plus * dout).sum()

        # Compute f(x - epsilon)
        flat_x.data[i] = orig_val - epsilon
        output_minus = layer.forward(x)
        f_minus = (output_minus * dout).sum()

        # Compute numerical gradient
        flat_input_grads.data[i] = (f_plus - f_minus) / (2 * epsilon)
        flat_x.data[i] = orig_val

    # Compare analytical and numerical gradients
    for i, (analytical, numerical) in enumerate(zip(analytical_grads, numerical_grads)):
        diff = torch.abs(analytical - numerical).max().item()
        print(f"Max difference for parameter {i}: {diff}")
        print(f"Relative error: {relative_error(analytical, numerical)}")
        assert (
            diff < tolerance
        ), f"Gradient check failed for parameter {i}. Max difference: {diff}"
    # Compare analytical and numerical gradients for input

    input_diff = torch.abs(analytical_input_grad - numerical_input_grad).max().item()
    print(f"Max difference for input: {input_diff}")
    print(
        f"Relative error for input: {relative_error(analytical_input_grad, numerical_input_grad)}"
    )
    assert (
        input_diff < tolerance
    ), f"Gradient check failed for input. Max difference: {input_diff}"

    # Restore original parameters
    for param, original in zip(layer.parameters(), original_params):
        param.data.copy_(original)

    print("Gradient check passed!")


def relative_error(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-11):
    return (torch.abs(a - b) / (torch.abs(a) + torch.abs(b) + eps)).mean()


if __name__ == "__main__":
    # For the Linear layer
    dtype = torch.float64
    linear = Linear(10, 5).to(dtype=dtype)
    x_linear = torch.randn(32, 10, dtype=dtype)
    eval_numerical_gradients(linear, x_linear)

    relu = ReLU()
    x_relu = torch.randn(32, 10, dtype=dtype)
    eval_numerical_gradients(relu, x_relu)

    x = torch.randn(32, 10, dtype=dtype)
    eval_numerical_gradients(SigmoidLayer(), x)

    x = torch.randn(32, 10, dtype=dtype)
    eval_numerical_gradients(SoftmaxLayer(), x)

    x = torch.randn(32, 10, dtype=dtype)
    eval_numerical_gradients(TanhLayer(), x)

Putting it all together: Training a small MLP on MNIST

Now that we have all the building blocks in place, let’s put them together to train a simple multi-layer perceptron (MLP) on the MNIST dataset. This classic computer vision task involves classifying 28x28 grayscale images of handwritten digits (0-9).

Our MLP will consist of multiple linear layers with ReLU activations between them, followed by a final linear layer that outputs logits for each of the 10 digit classes. We’ll use cross-entropy loss and stochastic gradient descent (SGD) for optimization.

The training loop follows these key steps:

  1. Forward pass through the network to get predictions
  2. Calculate loss using cross-entropy
  3. Backward pass to compute gradients
  4. Update parameters using SGD
  5. Repeat!

This end-to-end example demonstrates how the individual components we’ve built - layers, activations, loss functions, optimizers - work together during training. It also showcases common training practices like:

While modern deep learning frameworks abstract away much of this complexity, understanding these fundamentals is invaluable for debugging, optimization, and developing new techniques.

Let’s examine the code and dissect what’s happening at each step:

First, let’s examine how we’ve structured our custom layers to mirror PyTorch’s Module interface. Our layers accept input tensors in their forward methods, maintain their own parameters and gradients, and properly chain the backward pass. Just like PyTorch modules, they:

This design pattern makes it easy to compose layers into networks and perform automatic differentiation in a familiar way. The Layer protocol we defined above formalizes this interface.

Layer and Optim Protocols
class Layer(Protocol):
    def forward(self, x: torch.Tensor) -> torch.Tensor: ...
    def backward(self, dout: torch.Tensor) -> torch.Tensor: ...
    def parameters(self) -> list[torch.Tensor]: ...
    def __call__(self, x: torch.Tensor) -> torch.Tensor: ...

class Optimizer(Protocol):
    params: list[torch.Tensor]

    def __init__(self, params: list[torch.Tensor], lr: float):
        self.params = params

    def step(self): ...

    def zero_grad(self): ...

As you’ll see above, I’ve also defined Optimizers with a Protocol interface to establish a standard contract for optimizer types. Following these protocols helps maintain consistency and type safety in our implementations. When we extend this to new optimizers, they must provide step() and zero_grad() methods that match this interface, while maintaining the expected parameter update behavior.

Network Architecture

In our simple network architecture, we have a feedforward neural network (also called a multilayer perceptron or MLP) that takes 784-dimensional input (flattened 28x28 MNIST images) and outputs 10-dimensional logits (one for each digit class). The network consists of:

The network uses standard linear layers with learned weight matrices and bias vectors, with ReLU nonlinearities between each layer to allow learning of non-linear functions.

Each linear layer performs an affine transformation: output = input @ weight + bias The ReLU activation applies an elementwise nonlinearity: output = max(input, 0)

This relatively simple architecture is still capable of achieving ~98% accuracy on MNIST with proper training. The multiple hidden layers allow the network to learn hierarchical features from the input digits.

import torch.nn as nn
from activations import ReLU
from layers import Layer, Linear
from losses import CrossEntropyLoss


class Net:
    def __init__(self, hidden_size=1024, num_classes=10):
        self.layers: list[Layer] = [
            Linear(784, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            ReLU(),
            Linear(hidden_size, hidden_size),
            ReLU(),
            Linear(hidden_size, num_classes),
        ]
        self.loss = CrossEntropyLoss()

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        for layer in self.layers:
            x = layer(x)
        return x

    def backward(self, dout):
        for layer in reversed(self.layers):
            dout = layer.backward(dout)
        return dout

    def parameters(self):
        return [param for layer in self.layers for param in layer.parameters()]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


class PytorchNet(nn.Module):
    def __init__(self, hidden_size=1024, num_classes=10):
        super().__init__()
        self.layers = nn.Sequential(
            *[
                nn.Linear(784, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, num_classes),
            ]
        )

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        return self.layers(x)

Optimizers

Let’s look at different optimizers and their implementations:

class SGD:

	params: list[torch.Tensor]

	def __init__(self, params: list[torch.Tensor], lr: float = 0.01):
		self.params = params
		self.lr = lr

	def step(self):
		if self.lr <= 0:
            raise ValueError("Learning rate must be positive")
        if self.params and self.params[0].grad is None:
            raise ValueError("Parameters have no gradients. Call backward first.")
		for param in self.params:
			param -= self.lr * param.grad

	def zero_grad(self):
		for param in self.params:
			param.grad = torch.zeros_like(param)

While Adam is the de-facto optimizer for most deep learning tasks, it’s always good to experiment with different optimizers to see what works best for your specific problem. For example, for this problem, I found that RMSprop with momentum converges to nearly 100% accuracy on MNIST in just a few epochs.

Training loop:

The training loop employs several best practices for modern deep learning:

The validation loop mirrors the training structure but disables gradients for efficiency. Overall, the code follows a modular design that would be familiar to PyTorch users while implementing core functionality from scratch.

Full Training loop
import time

import torch
from data_utils import get_data
from losses import CrossEntropyLoss
from net import Net
from optimizers import get_optimizer
from torch_utils import AverageMeter, Logger, accuracy


def train(epochs=20, batch_size=512, lr=0.01, log_interval=10, val_interval=1):
    data_iter = get_data(batchsize=batch_size, debug=False)
    net = Net()
    criterion = CrossEntropyLoss()
    optimizer = get_optimizer("rmsprop", net.parameters(), lr=lr)
    logger = Logger()

    start_time = time.time()
    global_step = 0

    for epoch in range(epochs):
        loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        batch_time = AverageMeter()

        for batch_idx, (x, y) in enumerate(data_iter("train")):
            batch_start = time.time()

            out = net(x)
            loss = criterion(out, y)

            optimizer.zero_grad()
            dout = criterion.backward()
            net.backward(dout)
            optimizer.step()

            loss_meter.update(loss.item())
            acc_meter.update(accuracy(out, y))
            batch_time.update(time.time() - batch_start)

            global_step += 1

            if (batch_idx + 1) % log_interval == 0:
                logger.log(
                    {
                        "train_loss": loss_meter.avg,
                        "train_acc": acc_meter.avg,
                        "batch_time": batch_time.avg,
                    },
                    global_step,
                )

        # Validation
        if (epoch + 1) % val_interval == 0:
            val_loss_meter = AverageMeter()
            val_acc_meter = AverageMeter()
            val_batch_time = AverageMeter()

            with torch.no_grad():
                for x, y in data_iter("test"):
                    batch_start = time.time()
                    out = net(x)
                    loss = criterion(out, y)

                    val_loss_meter.update(loss.item())
                    val_acc_meter.update(accuracy(out, y))
                    val_batch_time.update(time.time() - batch_start)

            logger.log(
                {
                    "val_loss": val_loss_meter.avg,
                    "val_acc": val_acc_meter.avg,
                    "val_batch_time": val_batch_time.avg,
                },
                global_step,
            )

    end_time = time.time()
    print(f"Training completed in {end_time - start_time:.2f} seconds")


train()

Building a Simplified Tensor and Autograd Engine (aka “Micrograd”)

Now that we’ve covered the fundamentals of backpropagation and built a basic neural network training loop from scratch, let’s dive deeper by implementing a minimal version of an automatic differentiation engine - similar to PyTorch’s autograd but significantly simplified. This exercise, sometimes called “building micrograd,” will help cement our understanding of how modern deep learning frameworks work under the hood. Unlike before, where we manually had to pass gradients through the network, this engine will automatically compute gradients for us.

The key insight of automatic differentiation is that we can automatically compute gradients through arbitrary computation graphs by tracking operations and their corresponding derivative rules. Our implementation will support basic tensor operations and automatic gradient computation through the chain rule - providing a foundation similar to what powers frameworks like PyTorch and TensorFlow. We are truly lucky to have access to powerful deep learning libraries like PyTorch, TensorFlow, and JAX, which hide much of the complexity of automatic differentiation and GPU acceleration behind user-friendly APIs. However, building a simplified version of these libraries from scratch can deepen our understanding of how they work under the hood.

Tensor Class

We start by defining a Tensor class that wraps a torch.Tensor or np.ndarray and tracks gradients. Our Tensor class should support operations like addition, multiplication, and backpropagation. The goal is to create a basic computational graph that can compute gradients using reverse-mode automatic differentiation (backpropagation).

from typing import Callable, TypeAlias, Union

import numpy as np
import torch

TensorData = torch.Tensor | np.ndarray | float | int | list[float | int]  # Underlying data type
Operand: TypeAlias = Union["Tensor", TensorData]  # Input type for operations

class Tensor:
    def __init__(
        self,
        data: TensorData,
        requires_grad: bool = True,
        _children: tuple["Tensor", ...] = (),
        _op: str | None = None,
    ):
        self.data = self._to_torch(data)  # Convert underlying data to torch.Tensor
        self.requires_grad = requires_grad  # Flag to track gradients
        self.grad = torch.zeros_like(self.data) if requires_grad else None
        self._backward: Callable | None = None  # Function to compute gradients
        self._prev = set(_children)  # Parents in the computation graph
        self._op = _op   # Operation (name) that produced this tensor

# (..., more methods to come)

Above, we’ve defined the Tensor class with the following key attributes:

We’ve also included a _prev set to track the parents of the current tensor in the computation graph. This set will be used to construct the computation graph during the forward pass and perform the backward pass during gradient computation. Finally, _to_tensor is a helper method to convert input data to a torch.Tensor for consistency.

Tracking Operations

Now let’s implement the basic operations that our Tensor class should support. Each operation needs to:

  1. Compute the forward result
  2. Define a backward function that specifies how to compute gradients
  3. Track the computation graph by storing the input tensors as parents of the output tensor

Here are the key operations we’ll implement:

Implementing Basic Arithmetic Operations

Let’s examine how multiplication works as an archetypal example of implementing tensor operations. When multiplying two tensors xx and yy, three key components work together:

  1. Forward Pass: Computes the result z=x×yz = x \times y and stores it in a new tensor
  2. Gradient Rules: Implements the partial derivatives needed for backpropagation:
    • zx=y\frac{\partial z}{\partial x} = y (gradient with respect to first input)
    • zy=x\frac{\partial z}{\partial y} = x (gradient with respect to second input)
  3. Graph Construction: Records the computational history by:
    • Creating a new tensor to hold the result
    • Storing references to the input tensors xx and yy
    • Defining a backward function that knows how to compute gradients using the stored inputs

This three-part structure creates a computational graph that enables automatic differentiation through:

  1. Forward propagation that builds the computation graph by tracking dependencies
  2. Backward traversal through the graph starting from the final output
  3. Gradient accumulation at each node according to the chain rule of calculus

Each operation in our tensor library will follow this same pattern, just with different forward computations and gradient formulas. The consistent structure allows automatic differentiation to work uniformly across all operations.

Let’s look at an example implementing these operations:

def __mul__(self, other: Operand) -> "Tensor":
    """Implement multiplication operation.

    Args:
        self: current tensor (`x`)
        other: Second operand for multiplication (`y`)

    Returns:
        Tensor: Result of the multiplication operation (`z = x * y`)
    """
    other = self._to_tensor(other)
    requires_grad = self.requires_grad or other.requires_grad
    data = self.data * other.data # Forward pass: z = x * y
    out = Tensor(
        data, # Result of the operation
        requires_grad, # Flag to track gradients
        (self, other), # Parents in the computation graph (x, y)
        _op="mul"  # Operation name
    )

    def _backward():
        """Compute gradients for the multiplication operation."""
        if self.requires_grad:
            # Gradient with respect to x: dz/dx = y (stored in self.grad (x.grad))
            self.grad += (other.data * out.grad).type_as(self.grad)
        if other.requires_grad:
            # Gradient with respect to y: dz/dy = x (stored in other.grad (y.grad))
            other.grad += (self.data * out.grad).type_as(other.grad)

    out._backward = _backward
    return out

The backward method

The backward method is where the magic of automatic differentiation happens. This method orchestrates the computation of gradients through the entire computational graph. When called on a tensor, it:

  1. Initializes the gradient (typically 1.0 for scalar outputs)
  2. Performs a topological sort of the computation graph
  3. Applies the chain rule in reverse order through the sorted nodes

The topological sort is crucial because it ensures we compute gradients in the correct order - from the output back through to the inputs. Let’s dive into how this process works in detail with the implementation below:


   def backward(self, gradient: Operand | None = None):
        if gradient is None:
            if self.data.ndim == 0 or self.data.numel() == 1:
                gradient = torch.tensor(1.0, requires_grad=False)
            else:
                raise RuntimeError(
                    "grad can be implicitly created only for scalar outputs"
                )
        # Initialize gradient accumulation
        if self.grad is None:
            self.grad = torch.zeros_like(
                self.data, requires_grad=False, dtype=get_dtype(gradient)
            )

        self.grad = self.grad.type(get_dtype(gradient)) + gradient

        # topilogical sort for backward pass
        topo: list["Tensor"] = []
        visited = set()

        def build_topo(v: "Tensor"):
            """Topological sort helper function.
            Does a depth-first search on the computation graph to build a topological ordering.
            """
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)

        build_topo(self)

        # Apply the chain rule to nodes in reverse order
        for v in reversed(topo):
            if v._backward is not None:
                v._backward()

Understanding how the computational graph is built

Let’s break down the topological sort step by step, as it’s crucial for ensuring we compute gradients in the right order. Let’s use a concrete example:

# Let's say we have this computation:
x = Tensor(2.0, requires_grad=True)  # Let's call this node x
y = x * 3                            # Let's call this node y
z = y + 1                            # Let's call this node z
z.backward()                         # We start backward from z

When we call z.backward(), here’s what happens:

  1. First, let’s understand what we’re trying to solve:
    • We need to compute gradients in reverse order of the computation
    • zz depends on yy, which depends on xx
    • We must calculate zy\frac{\partial z}{\partial y} and zx\frac{\partial z}{\partial x}
  2. The build_topo function does depth-first search:
topo = []
visited = set()
def build_topo(v: Tensor):
    if v not in visited:
        visited.add(v)
        for child in v._prev:
            build_topo(child)
        topo.append(v)
build_topo(self)
  1. Let’s trace build_topo(self) for our example:
# Starting with node z
build_topo(z):
    visited = {}                   # Empty set initially
    z not in visited → True
    Add z to visited = {z}
    Look at z._prev = {y}         # y was an input to z

    # Recursively call for y
    build_topo(y):
        y not in visited → True
        Add y to visited = {z, y}
        Look at y._prev = {x}     # x was an input to y

        # Recursively call for x
        build_topo(x):
            x not in visited → True
            Add x to visited = {z, y, x}
            x._prev is empty      # x had no inputs
            Append x to topo → [x]

        Append y to topo → [x, y]

    Append z to topo → [x, y, z]
  1. Finally, we reverse the list and apply backward:
for v in reversed(topo):
    if v_backward:
        v._backward()

This processes the gradients in reverse order of computation:

  1. First computes z._backward()
  2. Then y._backward()
  3. Finally x._backward()

The backward functions we defined for each operation accumulate the gradients in the .grad attribute of each tensor according to the chain rule.

Getting back to our example:

x = Tensor(2.0, requires_grad=True)
y = x * 3
z = y + 1
z.backward()

The backward pass computes:

  1. dz/dy = 1 (derivative of addition)
  2. dz/dx = dz/dy * dy/dz = 1 * 3 = 3 (chain rule)

After z.backward() completes:

This matches our expectation since z = (x * 3) + 1, so dz/dx = 3.

Implementing other operations

At the heart of any modern neural network is the matrix multiplication operation. Let’s implement the matmul operation, which computes the matrix product of two tensors. This operation is crucial for linear layers in neural networks and is the backbone of deep learning computations.

    def __matmul__(self, other: Operand) -> "Tensor":
        other = self._to_tensor(other)
        requires_grad = self.requires_grad or other.requires_grad
        data = self.data @ other.data
        out = Tensor(data, requires_grad, (self, other), _op="matmul")

        def _backward():
            if self.requires_grad:
                self.grad += out.grad @ other.data.T
            if other.requires_grad:
                other.grad += self.data.T @ out.grad

        out._backward = _backward
        return out

    def __rmatmul__(self, other):
        return self.__matmul__(other)
Implementing Other Operations
    def __add__(self, other: Operand) -> "Tensor":
        other = self._to_tensor(other)
        requires_grad = self.requires_grad or other.requires_grad
        data = self.data + other.data
        out = Tensor(data, requires_grad, (self, other), _op="add")

        def _backward():
            if self.requires_grad:
                self.grad += out.grad  # type: ignore
            if other.requires_grad:
                out.grad += out.grad  # type: ignore

        out._backward = _backward
        return out

    def sum(self):
        out = Tensor(
            self.data.sum(),
            requires_grad=self.requires_grad,
            _children=(self,),
            _op="sum",
        )

        def _backward():
            if self.requires_grad:
                self.grad += torch.ones_like(self.data, requires_grad=False) * out.grad

        out._backward = _backward
        return out

    def __radd__(self, other):
        return self.__add__(other)

    def __rmul__(self, other):
        return self.__mul__(other)


    def __pow__(self, other: Operand) -> "Tensor":
        other = self._to_tensor(other)
        requires_grad = self.requires_grad or other.requires_grad
        data = self.data**other.data
        out = Tensor(data, requires_grad, (self, other), _op="pow")

        def _backward():
            if self.requires_grad:
                # power rule: d/dx x^n = n * x^(n-1)
                self.grad += other.data * self.data ** (other.data - 1) * out.grad
            if other.requires_grad:
                # d/dx a^x = a^x * ln(a)
                other.grad += (self.data**other.data) * torch.log(self.data) * out.grad

        out._backward = _backward
        return out

    def __neg__(self) -> "Tensor":
        return self * -1

    def __sub__(self, other: Operand) -> "Tensor":
        return self + -(self._to_tensor(other))

    def __rsub__(self, other: Operand) -> "Tensor":
        return self._to_tensor(other) + (-self)

    def __truediv__(self, other: Operand) -> "Tensor":
        other = self._to_tensor(other)
        requires_grad = self.requires_grad or other.requires_grad
        data = self.data / other.data
        out = Tensor(data, requires_grad, (self, other), _op="div")

        def _backward():
            if self.requires_grad:
                # d/dx (a / b) = 1 / b
                self.grad += (1 / other.data) * out.grad
            if other.requires_grad:
                # d/db (a / b) = -a / b^2
                other.grad += (-self.data / other.data**2) * out.grad

        out._backward = _backward
        return out

    def __rtruediv__(self, other: Operand) -> "Tensor":
        return self._to_tensor(other) / self

    def exp(self) -> "Tensor":
        out = Tensor(
            torch.exp(self.data),
            self.requires_grad,
            (self,),
            _op="exp",
        )

        def _backward():
            if self.requires_grad:
                assert self.grad is not None
                self.grad += (out.data * out.grad).type_as(self.grad)

        out._backward = _backward
        return out

    def log(self) -> "Tensor":
        out = Tensor(
            torch.log(self.data),
            self.requires_grad,
            (self,),
            _op="log",
        )

        def _backward():
            if self.requires_grad:
                assert self.grad is not None
                self.grad += ((1 / self.data) * out.grad).type_as(self.grad)

        out._backward = _backward
        return out

    def sigmoid(self) -> "Tensor":
        out = Tensor(
            torch.sigmoid(self.data),
            self.requires_grad,
            (self,),
            _op="sigmoid",
        )

        def _backward():
            if self.requires_grad:
                assert self.grad is not None
                self.grad += (out.data * (1 - out.data) * out.grad).type_as(self.grad)

        out._backward = _backward
        return out

    def tanh(self) -> "Tensor":
        out = Tensor(
            torch.tanh(self.data),
            self.requires_grad,
            (self,),
            _op="tanh",
        )

        def _backward():
            if self.requires_grad:
                assert self.grad is not None
                self.grad += (1 - out.data**2) * out.grad

        out._backward = _backward
        return out

    def sum_dim(self, dim: int, keepdim: bool = False) -> "Tensor":
        data = self.data.sum(dim=dim, keepdim=keepdim)
        out = Tensor(data, self.requires_grad, (self,), _op="sum_dim")

        def _backward():
            if self.requires_grad:
                grad = out.grad
                assert grad is not None
                if not keepdim:
                    shape = list(self.data.shape)
                    shape[dim] = 1
                    grad = grad.view(shape)
                self.grad += grad.expand_as(self.data)

        out._backward = _backward
        return out

    def mean(self) -> "Tensor":
        data = self.data.mean()
        out = Tensor(data, self.requires_grad, (self,), _op="mean")

        def _backward():
            if self.requires_grad:
                assert out.grad is not None
                assert self.grad is not None
                self.grad += (out.grad / self.data.numel()).type_as(self.grad)

        out._backward = _backward
        return out

Challenges in Practice

Despite its elegance, backpropagation faces practical hurdles in real-world scenarios:


Modern Tweaks to Backpropagation

Over the years, practitioners have developed techniques to make backpropagation more robust:

Gradient Monitoring & Clipping

Prevent exploding gradients by capping their magnitude

Gradient Monitoring & Clipping
import time
from collections import defaultdict
from typing import Iterable

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from data_utils import get_data
from net import PytorchNet as Net
from torch_utils import (
    AverageMeter,
    Logger,
    accuracy,
)


def tensor_norm(tensor: torch.Tensor, p: int = 2):
    return torch.sqrt(tensor.pow(p).sum())


def compute_total_norm(params: Iterable[torch.Tensor], norm_type: float = 2):
    return torch.norm(
        torch.stack(
            [
                torch.norm(p.grad.detach(), norm_type)
                for p in params
                if p.grad is not None
            ]
        ),
        p=norm_type,
    )


def clip_grad_norm_(params: Iterable[torch.Tensor], max_norm: float):
    """
    Clips gradient norm of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        params: List of parameters with gradients
        max_norm: Max norm of the gradients

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    # First compute total norm of gradients
    if isinstance(params, torch.Tensor):
        params = [params]

    params = [p for p in params if p.grad is not None]
    if not params:
        return torch.tensor(0.0)

    total_norm = compute_total_norm(params)

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:  # Only clip if total_norm > max_grad_norm
        for param in params:
            if param.grad is not None:
                param.grad.data.mul_(clip_coef.to(param.grad.device))

    return total_norm


class GradientMonitor:
    def __init__(self):
        self.grad_stats = defaultdict(
            lambda: {"mean": [], "std": [], "norm": [], "histogram": []}
        )
        self.param_stats = defaultdict(lambda: {"mean": [], "std": [], "norm": []})
        self.total_norm_stats = {"before_clip": [], "after_clip": []}

    def update(self, model: nn.Module):
        total_norm_before = compute_total_norm(model.parameters())
        self.total_norm_stats["before_clip"].append(total_norm_before)
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad
                self.grad_stats[name]["mean"].append(grad.mean().item())
                self.grad_stats[name]["std"].append(grad.std().item())
                self.grad_stats[name]["norm"].append(grad.norm().item())
                self.grad_stats[name]["histogram"].append(
                    grad.cpu().histogram(bins=30).hist
                )

                self.param_stats[name]["mean"].append(param.data.mean().item())
                self.param_stats[name]["std"].append(param.data.std().item())
                self.param_stats[name]["norm"].append(param.data.norm().item())

    def update_after_clip(self, model: nn.Module):
        # Compute total norm after clipping
        total_norm_after = compute_total_norm(model.parameters())
        self.total_norm_stats["after_clip"].append(total_norm_after)

    def log_stats(self, logger: Logger, global_step: int, log_interval):
        for name in self.grad_stats.keys():
            logger.log(
                {
                    f"grad_{name}_mean": np.mean(
                        self.grad_stats[name]["mean"][-log_interval:]
                    ),
                    f"grad_{name}_std": np.mean(
                        self.grad_stats[name]["std"][-log_interval:]
                    ),
                    f"grad_{name}_norm": np.mean(
                        self.grad_stats[name]["norm"][-log_interval:]
                    ),
                    f"param_{name}_mean": np.mean(
                        self.param_stats[name]["mean"][-log_interval:]
                    ),
                    f"param_{name}_std": np.mean(
                        self.param_stats[name]["std"][-log_interval:]
                    ),
                    f"param_{name}_norm": np.mean(
                        self.param_stats[name]["norm"][-log_interval:]
                    ),
                },
                global_step,
            )
        logger.log(
            {
                "total_grad_norm_before_clip": np.mean(
                    self.total_norm_stats["before_clip"][-log_interval:]
                ),
                "total_grad_norm_after_clip": np.mean(
                    self.total_norm_stats["after_clip"][-log_interval:]
                ),
            },
            global_step,
        )

    def visualize_stats(self):
        self._visualize_stats(self.grad_stats, "gradient")
        self._visualize_stats(self.param_stats, "parameter")
        self._visualize_grad_histograms()

    def visualize_total_norm(self):
        plt.figure(figsize=(12, 6))
        plt.plot(self.total_norm_stats["before_clip"], label="Before Clip")
        plt.plot(self.total_norm_stats["after_clip"], label="After Clip")
        plt.xlabel("Steps")
        plt.ylabel("Total Gradient Norm")
        plt.title("Total Gradient Norm Before and After Clipping")
        plt.legend()
        plt.tight_layout()
        plt.savefig("total_gradient_norm.png")
        plt.close()

    def _visualize_stats(self, stats, stat_type):
        fig, axs = plt.subplots(3, 1, figsize=(12, 18))
        for name, data in stats.items():
            axs[0].plot(data["mean"], label=f"{name}_mean")
            axs[1].plot(data["std"], label=f"{name}_std")
            axs[2].plot(data["norm"], label=f"{name}_norm")

        axs[0].set_title(f"{stat_type.capitalize()} Mean")
        axs[1].set_title(f"{stat_type.capitalize()} Std")
        axs[2].set_title(f"{stat_type.capitalize()} Norm")

        for ax in axs:
            ax.legend()
            ax.set_xlabel("Steps")
            ax.set_ylabel("Value")

        plt.tight_layout()
        plt.savefig(f"{stat_type}_statistics.png")
        plt.close()

    def _visualize_grad_histograms(self):
        num_layers = len(self.grad_stats)
        fig, axs = plt.subplots(num_layers, 1, figsize=(12, 4 * num_layers))

        for idx, (name, data) in enumerate(self.grad_stats.items()):
            histograms = np.array(data["histogram"])
            im = axs[idx].imshow(histograms.T, aspect="auto", cmap="viridis")
            axs[idx].set_title(f"Gradient Histogram: {name}")
            axs[idx].set_xlabel("Steps")
            axs[idx].set_ylabel("Bins")
            plt.colorbar(im, ax=axs[idx])

        plt.tight_layout()
        plt.savefig("gradient_histograms.png")
        plt.close()


def train(epochs=20, batch_size=512, lr=0.01, log_interval=10, val_interval=1):
    data_iter = get_data(batchsize=batch_size, debug=False)
    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr)
    logger = Logger()
    gradient_monitor = GradientMonitor()

    start_time = time.time()
    global_step = 0

    for epoch in range(epochs):
        loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        batch_time = AverageMeter()

        for batch_idx, (x, y) in enumerate(data_iter("train")):
            batch_start = time.time()

            out = net(x)
            loss = criterion(out, y)

            optimizer.zero_grad()
            loss.backward()

            gradient_monitor.update(net)
            # Gradient clipping (optional, but good to show in interviews)
            clip_grad_norm_(net.parameters(), max_norm=1.0)

            gradient_monitor.update_after_clip(net)

            optimizer.step()

            loss_meter.update(loss.item())
            acc_meter.update(accuracy(out, y))
            batch_time.update(time.time() - batch_start)

            global_step += 1

            if (batch_idx + 1) % log_interval == 0:
                logger.log(
                    {
                        "train_loss": loss_meter.avg,
                        "train_acc": acc_meter.avg,
                        "batch_time": batch_time.avg,
                    },
                    global_step,
                )
                # gradient_monitor.log_stats(logger, global_step, log_interval)

        # Validation
        if (epoch + 1) % val_interval == 0:
            val_loss_meter = AverageMeter()
            val_acc_meter = AverageMeter()
            val_batch_time = AverageMeter()

            with torch.no_grad():
                for x, y in data_iter("test"):
                    batch_start = time.time()
                    out = net(x)
                    loss = criterion(out, y)

                    val_loss_meter.update(loss.item())
                    val_acc_meter.update(accuracy(out, y))
                    val_batch_time.update(time.time() - batch_start)

            logger.log(
                {
                    "val_loss": val_loss_meter.avg,
                    "val_acc": val_acc_meter.avg,
                    "val_batch_time": val_batch_time.avg,
                },
                global_step,
            )

    end_time = time.time()
    print(f"Training completed in {end_time - start_time:.2f} seconds")
    gradient_monitor.visualize_stats()
    gradient_monitor.visualize_total_norm()


train(epochs=2)

Gradient accumulation

Gradient Accumulation
import time
from contextlib import contextmanager

import torch
from configs import TrainingConfig
from data_utils import get_data
from layers import Optimizer
from losses import CrossEntropyLoss
from net import Net
from optimizers import get_optimizer
from torch_utils import AverageMeter, Logger, accuracy, clip_grad_norm_


class GradientAccumulator:
    def __init__(
        self,
        model: Net,
        optimizer: Optimizer,
        criterion: CrossEntropyLoss,
        config: TrainingConfig,
    ):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.config = config

        # Track accumulation state
        self.current_step = 0
        self.loss_accumulator = 0.0

        # Scaling factors for loss
        self.accumulation_scale = 1.0 / self.config.accumulation_steps

        # Monitoring metrics
        self.max_gradient_seen = 0.0
        self.min_gradient_seen = float("inf")
        self.gradient_norms = []

    @property
    def accumulation_steps(self):
        return self.config.accumulation_steps

    @contextmanager
    def accumulation_context(self, global_step: int | None = None):
        curr_step = global_step or self.current_step
        yield (curr_step + 1) % self.config.accumulation_steps == 0
        self.current_step += 1

    def backward(self, output: torch.Tensor, target: torch.Tensor):
        """
        Perform backward pass with proper scaling

        The loss is scaled by 1/accumulation_steps to maintain
        equivalent gradients when accumulating over multiple steps.
        """
        # First compute the raw loss
        raw_loss = self.criterion(output, target)

        # Scale the loss
        scaled_loss = raw_loss * self.accumulation_scale
        self.loss_accumulator += raw_loss.item()  # Store unscaled loss for logging

        # Get gradients from criterion with scaling
        dout = self.criterion.backward() * self.accumulation_scale
        self.model.backward(dout)

        return scaled_loss

    def step(self):
        if (self.current_step + 1) % self.accumulation_steps == 0:
            if self.config.max_grad_norm > 0:
                self._clip_gradients()

            self.optimizer.step()
            self.optimizer.zero_grad()
            # Get average loss over accumulation steps
            avg_loss = self.loss_accumulator / self.config.accumulation_steps
            self.loss_accumulator = 0.0
            return avg_loss

    def _clip_gradients(self):
        """Clip gradients and collect statistics"""
        # Calculate gradient norm
        grad_norm = clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)

        # Update statistics
        self.max_gradient_seen = max(self.max_gradient_seen, grad_norm)  # type: ignore
        self.min_gradient_seen = min(self.min_gradient_seen, grad_norm)  # type: ignore
        self.gradient_norms.append(grad_norm)

        return grad_norm


def train(config: TrainingConfig):
    data_iter = get_data(batchsize=config.batch_size, debug=False)
    net = Net()
    criterion = CrossEntropyLoss()
    optimizer = get_optimizer("rmsprop", net.parameters(), lr=config.lr)
    logger = Logger()
    accumulator = GradientAccumulator(net, optimizer, criterion, config)

    start_time = time.time()
    global_step = 0

    for epoch in range(config.epochs):
        loss_meter = AverageMeter()
        acc_meter = AverageMeter()
        batch_time = AverageMeter()

        for batch_idx, (x, y) in enumerate(data_iter("train")):
            batch_start = time.time()
            with accumulator.accumulation_context() as is_last_step:
                out = net(x)
                loss = accumulator.backward(out, y)

                if is_last_step:
                    accumulator.step()

            loss_meter.update(loss.item())
            acc_meter.update(accuracy(out, y))
            batch_time.update(time.time() - batch_start)

            global_step += 1

            if (batch_idx + 1) % config.log_interval == 0:
                logger.log(
                    {
                        "train_loss": loss_meter.avg,
                        "train_acc": acc_meter.avg,
                        "batch_time": batch_time.avg,
                    },
                    global_step,
                )

        # Validation
        if (epoch + 1) % config.val_interval == 0:
            val_loss_meter = AverageMeter()
            val_acc_meter = AverageMeter()
            val_batch_time = AverageMeter()

            with torch.no_grad():
                for x, y in data_iter("test"):
                    batch_start = time.time()
                    out = net(x)
                    loss = criterion(out, y)

                    val_loss_meter.update(loss.item())
                    val_acc_meter.update(accuracy(out, y))
                    val_batch_time.update(time.time() - batch_start)

            logger.log(
                {
                    "val_loss": val_loss_meter.avg,
                    "val_acc": val_acc_meter.avg,
                    "val_batch_time": val_batch_time.avg,
                },
                global_step,
            )

    end_time = time.time()
    print(f"Training completed in {end_time - start_time:.2f} seconds")


def test_gradient_accumulation():
    # Create small test data
    batch_size = 4
    input_dim = 784
    num_classes = 10

    # Create model and accumulator
    net = Net()
    criterion = CrossEntropyLoss()
    optimizer = get_optimizer("rmsprop", net.parameters(), lr=0.01)
    config = TrainingConfig(accumulation_steps=4)
    accumulator = GradientAccumulator(net, optimizer, criterion, config)

    # Run forward/backward with accumulation
    for i in range(8):  # 2 complete accumulation cycles
        x = torch.randn(batch_size, input_dim)
        y = torch.randint(0, num_classes, (batch_size,))

        with accumulator.accumulation_context() as is_last_step:
            out = net(x)
            loss = accumulator.backward(out, y)

            if is_last_step:
                # Check gradients before and after clipping
                params = net.parameters()
                grad_norm_before = sum(p.grad.norm().item() ** 2 for p in params) ** 0.5  # type: ignore
                accumulator._clip_gradients()
                grad_norm_after = sum(p.grad.norm().item() ** 2 for p in params) ** 0.5  # type: ignore

                print(f"Step {i+1}:")
                print(f"  Grad norm before clipping: {grad_norm_before:.4f}")
                print(f"  Grad norm after clipping: {grad_norm_after:.4f}")

                accumulator.step()

    return net


test_gradient_accumulation()

train(TrainingConfig(accumulation_steps=4))

Why Revisit Backpropagation?

Understanding backprop deeply allows us to:

By building our own tools and revisiting foundational concepts, we bridge the gap between theory and practice.


Conclusion

“Back to Backprop” isn’t just a nod to the algorithm’s history—it’s a call to master the fundamentals that underpin our most advanced systems. Whether you’re troubleshooting a vanishing gradient, designing a new architecture, or simply exploring the beauty of gradient-based learning, backpropagation remains a cornerstone of machine learning.

Happy coding, and here’s to gradients that flow smoothly!


Previous Post
AutoDiff Puzzles
Next Post
Estimating Transformer Model Properties: A Deep Dive