Skip to content

Softmax to the Max

Published: at 10:02 AM in 10 min readSuggest Changes

Softmax, LogSoftmax and LogSumExp

These functions commonly arise in the context of neural networks, especially in the output layer. They occur when dealing with categorical and multinomial probability distributions, converting raw logits into a probably distrubution over classes, as well as in attention mechanisms in transformers These functions are all closely related, involving sums of exponentials, and have nice relationships with each other.

Definitions

Given a vector xRn\mathbf{x} \in \mathbb{R}^n, we define the following functions:

s=softmax(x):=exp(x)iexp(xi)\mathbf{s} = \text{softmax}(\mathbf{x}) := \frac{\exp(\mathbf{x})}{\sum_{i} \exp(x_i)}
sixj=si(δijsj) or Js=diag(s)ssT\frac{\partial s_i}{\partial x_j} = s_i(\delta_{ij} - s_j) \text{ or } \mathbf{J}_{\mathbf{s}} = \text{diag}({\mathbf{s}}) - \mathbf{s} \mathbf{s}^T
=logsumexp(x):=log(iexp(xi))\ell = \text{logsumexp}(\mathbf{x}) := \log \left( \sum_{i} \exp(x_i) \right)
xj=sj or J=diag(softmax(x))\frac{\partial \ell}{\partial x_j} = s_j \text{ or } \mathbf{J}_{\ell} = \text{diag}(\text{softmax}(\mathbf{x}))
s~=logsoftmax(x):=xlogsumexp(x)\mathbf{\tilde{s}} = \text{logsoftmax}(\mathbf{x}) := \mathbf{x} - \text{logsumexp}(\mathbf{x})
s~ixj=δijsj or Js~=I1softmax(x)T\frac{\partial \tilde{s}_i}{\partial x_j} = \delta_{ij} - s_j \text{ or } \mathbf{J}_{\mathbf{\tilde{s}}} = \mathbf{I} - \mathbf{1}\cdot \text{softmax}(\mathbf{x})^T

So we see that the softmax and logsoftmax are both functions that take in a vector and return a vector, while the logsumexp is a scalar valued function. The logsoftmax is just the log of the softmax, which also happens to be the input minus its logsumexp.

How logsumexp connects softmax and logsoftmax

By the definition of logsoftmax, we have

logsoftmax(x)=log(softmax(x))=logexp(x)iexp(xi)=logexp(xi)logiexp(xi)(by log rules)=xilogiexp(xi)(log is the inverse of exp)=xlogsumexp(x)(by definition)\begin{align*} \text{logsoftmax}(\mathbf{x}) &= \log(\text{softmax}(\mathbf{x})) \\ &= \log \frac{\exp(\mathbf{x})}{\sum_{i} \exp(x_i)} \\ &= \log \exp(x_i) - \log \sum_{i} \exp(x_i) && \text{(by log rules)} \\ &= x_i - \log \sum_{i} \exp(x_i) && \text{(log is the inverse of exp)} \\ &= \mathbf{x} - \text{logsumexp}(\mathbf{x}) && \text{(by definition)} \end{align*}

Numerical Stability

In practice, the softmax function can be numerically unstable, especially when the input values have large magnitudes, where very negative numbers underflow to zero and very positive numbers can overflow to infinity. To mitigate this, we can use a nice property of the logsumexp function, which allows us to shift the values in the exponent by an arbitrary constant without changing the result. That is,

logsumexp(x)=c+logsumexp(xc)\text{logsumexp}(\mathbf{x}) = c + \text{logsumexp}(\mathbf{x} - c)

Consider the following:

Let m=max(x)m = \max(\mathbf{x}), then we can factor out exp(m)\exp(m) from the sum:

logsumexp(x)=log(iexp(xi))=log(exp(m)iexp(xim))\text{logsumexp}(\mathbf{x}) = \log \left( \sum_{i} \exp(x_i) \right) = \log \left( \exp(m) \sum_{i} \exp(x_i - m) \right)

because exp(xi)=exp(xim+m)=exp(xim)exp(m)\exp(x_i) = \exp(x_i - m + m) = \exp(x_i - m) \exp(m). Then using the log rule log(ab)=log(a)+log(b)\log(ab) = \log(a) + \log(b), we get:

logsumexp(x)=log(exp(m)iexp(xim))=logexp(m)+log(iexp(xim))=m+log(iexp(xim))\begin{align*} \text{logsumexp}(\mathbf{x}) &= \log \left( \exp(m) \sum_{i} \exp(x_i - m) \right) \\ &= \log \exp(m) + \log \left( \sum_{i} \exp(x_i - m) \right) \\ &= m + \log \left( \sum_{i} \exp(x_i - m) \right) \end{align*}

If we choose c=max(x)c = \max(x), then the largest value in the exponent is zero, which is numerically stable. This is the trick used in the implementation of the softmax and logsumexp functions in PyTorch and TensorFlow.

Gradients

The gradients of these functions similarly have nice relationships to each other, so let’s build them up piece by piece.

Gradient of LogSumExp

Recall, logsumexp is a scalar valued function RnR\mathbb{R}^n \rightarrow \mathbb{R}, so the gradient is a vector. Set =logsumexp(x)\ell = \text{logsumexp}(\mathbf{x}) and s=softmax(x)\mathbf{s} = \text{softmax}(\mathbf{x}). Then we have

xi=xilog(jexp(xj))=1jexp(xj)xijexp(xj)(ddxlog(f(x))=f(x)f(x) by chain rule)=exp(xi)jexp(xj)(only nonzero term is exi when i=j)=si=softmax(x)i\begin{align*} \frac{\partial \ell}{\partial x_i} &= \frac{\partial}{\partial x_i} \log \left( \sum_{j} \exp(x_j) \right) \\ &= \frac{1}{\sum_{j} \exp(x_j)} \frac{\partial}{\partial x_i} \sum_{j} \exp(x_j) && \text{($\frac{d}{dx} \log(f(x)) = \frac{f'(x)}{f(x)}$ by chain rule)} \\ &= \frac{\exp(x_i)}{\sum_{j} \exp(x_j)} && (\text{only nonzero term is $e^{x_i}$ when $i=j$}) \\ &= s_i \\ &= \text{softmax}(\mathbf{x})_i \end{align*}

So the gradient of logsumexp is the softmax function! This will be useful when we compute the gradient of the logsoftmax function.

Gradient of LogSoftmax

The logsoftmax function is a vector valued function RnRn\mathbb{R}^n \rightarrow \mathbb{R}^n, so the gradient is a matrix. Set s~=logsoftmax(x)\mathbf{\tilde{s}} = \text{logsoftmax}(\mathbf{x}) and s=softmax(x)\mathbf{s} = \text{softmax}(\mathbf{x}). Then we have

s~ixj=xj(xilogkexp(xk))=δijxjlogkexp(xk)(only xi term depends on xj)=δijsj(by the gradient of logsumexp)=δijsoftmax(x)j\begin{align*} \frac{\partial \tilde{s}_i}{\partial x_j} &= \frac{\partial}{\partial x_j} \left( x_i - \log \sum_{k} \exp(x_k) \right) \\ &= \delta_{ij} - \frac{\partial}{\partial x_j} \log \sum_{k} \exp(x_k) && \text{(only $x_i$ term depends on $x_j$)} \\ &= \delta_{ij} - s_j && \text{(by the gradient of logsumexp)} \\ &= \delta_{ij} - \text{softmax}(\mathbf{x})_j \\ \end{align*}

In matrix form, this is

s~x=I1softmax(x)T=(1s1s2s3s11s2s3s1s21s3).\begin{align*} \frac{\partial \mathbf{\tilde{s}}}{\partial \mathbf{x}} &= \mathbf{I} - \mathbf{1} \cdot \text{softmax}(\mathbf{x})^T = \begin{pmatrix} 1 - s_1 & -s_2 & -s_3 \\ -s_1 & 1 - s_2 & -s_3 \\ -s_1 & -s_2 & 1 - s_3 \end{pmatrix}. \end{align*}

where 1\mathbf{1} is the vector of nn ones, which broadcasts the softmax vector to a matrix so that the jj-th column contains the jj-th element of the softmax vector (vstacks the softmax vector).

Gradient of Softmax

The softmax function is a vector valued function RnRn\mathbb{R}^n \rightarrow \mathbb{R}^n, so the gradient is a matrix. Set s=softmax(x)\mathbf{s} = \text{softmax}(\mathbf{x}). There are several ways to compute the gradient of the softmax function:

Using logsumexp-based definition:

We know that xj=sj\frac{\partial \ell}{\partial x_j} = s_j. We can use this to compute the gradient of the softmax function:

sixj=xjexp(xi)=exp(xi)xj(xi)=si(xixjxj)=si(δijsj)\begin{align*} \frac{\partial s_i}{\partial x_j} &= \frac{\partial}{\partial x_j} \exp(x_i - \ell) \\ &= \exp(x_i - \ell) \cdot \frac{\partial}{\partial x_j} (x_i - \ell) \\ &= s_i (\frac{\partial x_i}{\partial x_j} - \frac{\partial}{\partial x_j} \ell) \\ &= s_i (\delta_{ij} - s_j) \end{align*}

In matrix form, these all give:

sx=diag(s)ssT=(s1(1s1)s1s2s1s3s2s1s2(1s2)s2s3s3s1s3s2s3(1s3)).\begin{align*} \frac{\partial \mathbf{s}}{\partial \mathbf{x}} &= \text{diag}(\mathbf{s}) - \mathbf{s} \mathbf{s}^T = \begin{pmatrix} s_1(1 - s_1) & -s_1 s_2 & -s_1 s_3 \\ -s_2 s_1 & s_2(1 - s_2) & -s_2 s_3 \\ -s_3 s_1 & -s_3 s_2 & s_3(1 - s_3) \end{pmatrix}. \end{align*}

where diag(s)\text{diag}(\mathbf{s}) is the diagonal matrix with the softmax vector on the diagonal.

Softmax and Backpropagation

Backpropagation is fundamentally about repeated application of the chain rule. For an in-depth look into backprop, see the Back to Backprop post.

Let: s=softmax(x)RD,\mathbf{s} = \text{softmax}(\mathbf{x}) \in \mathbb{R}^D, sL=[Ls1,,LsD]RD\nabla_{\mathbf{s}} L = \bigl[\frac{\partial L}{\partial s_1}, \dots, \frac{\partial L}{\partial s_D}\bigr] \in \mathbb{R}^D .

We want xL=[Lx1,,LxD]\nabla_{\mathbf{x}} L = \bigl[\frac{\partial L}{\partial x_1}, \dots, \frac{\partial L}{\partial x_D}\bigr].

Lxi=k=1D(Lskskxi)=k=1D(Lsksk(δiksi))=k=1Dsk(LskδikLsksi).\frac{\partial L}{\partial x_i} = \sum_{k=1}^D \left( \frac{\partial L}{\partial s_k} \frac{\partial s_k}{\partial x_i} \right) = \sum_{k=1}^D \left( \frac{\partial L}{\partial s_k} s_k (\delta_{ik} - s_i) \right) = \sum_{k=1}^D s_k \left( \frac{\partial L}{\partial s_k} \delta_{ik} - \frac{\partial L}{\partial s_k} s_i \right).

Let’s separate this into two terms to make it clearer:

Lxi=k=1DskLskδikk=1DskLsksi\frac{\partial L}{\partial x_i} = \sum_{k=1}^D s_k \frac{\partial L}{\partial s_k} \delta_{ik} - \sum_{k=1}^D s_k \frac{\partial L}{\partial s_k} s_i

Now, let’s analyze each term:

In the first term, because of the Kronecker delta δik\delta_{ik}, the sum collapses to just the ii-th term:

k=1DskLskδik=siLsi\sum_{k=1}^D s_k \frac{\partial L}{\partial s_k} \delta_{ik} = s_i \frac{\partial L}{\partial s_i}

In the second term, notice that sis_i can come out of the sum since it doesn’t depend on kk:

k=1DskLsksi=sik=1DskLsk\sum_{k=1}^D s_k \frac{\partial L}{\partial s_k} s_i = s_i \sum_{k=1}^D s_k \frac{\partial L}{\partial s_k}

Putting it back together:

Lxi=siLsisik=1DskLsk\frac{\partial L}{\partial x_i} = s_i \frac{\partial L}{\partial s_i} - s_i \sum_{k=1}^D s_k \frac{\partial L}{\partial s_k}

This can be factored as:

Lxi=si(Lsik=1DskLsk)\frac{\partial L}{\partial x_i} = s_i \left(\frac{\partial L}{\partial s_i} - \sum_{k=1}^D s_k \frac{\partial L}{\partial s_k}\right)

Now, here’s the key insight that leads to the vectorized form: this equation gives us the ii-th component of the gradient, and we can write it in vector form! The gradient is:

xL=s(sL(ssL)1)\nabla_{\mathbf{x}} L = \mathbf{s} \odot (\nabla_{\mathbf{s}} L - (\mathbf{s} \cdot \nabla_{\mathbf{s}} L)\mathbf{1})

where:

Let’s implement this in PyTorch:


import torch

class SoftmaxLayer:
    def forward(self, x):
        # Save input for backward pass
        self.input = x

        # For numerical stability, subtract max value before exponential
        # Shape: (batch_size, num_classes)
        shifted_x = x - torch.max(x, dim=1, keepdim=True)[0]

        # Compute exponentials
        # Shape: (batch_size, num_classes)
        exp_x = torch.exp(shifted_x)

        # Compute sum of exponentials for normalization
        # Shape: (batch_size, 1)
        sum_exp = torch.sum(exp_x, dim=1, keepdim=True)

        # Compute softmax output
        # Shape: (batch_size, num_classes)
        self.output = exp_x / sum_exp

        return self.output

    def backward(self, grad_output):
        # grad_output shape: (batch_size, num_classes)
        # output shape: (batch_size, num_classes)

        # The Jacobian of softmax for each sample is:
        # J[i,j] = output[i] * (δ[i,j] - output[j])
        # where δ[i,j] is 1 if i=j and 0 otherwise

        # Compute using the compact form:
        # grad = output * (grad_output - (output · grad_output))

        # Shape: (batch_size, 1)
        sum_grad = torch.sum(grad_output * self.output, dim=1, keepdim=True) # s · ∇_s L

        # Shape: (batch_size, num_classes)
        return self.output * (grad_output - sum_grad) # s ⊙ (∇_s L - (s · ∇_s L))

    def __call__(self, x):
        return self.forward(x)

Chain Rule with Jacobians (shape analysis)

Now, for the chain rule with Jacobians, we have:

Lx=Lssx\frac{\partial L}{\partial \mathbf{x}} = \frac{\partial L}{\partial \mathbf{s}} \frac{\partial \mathbf{s}}{\partial \mathbf{x}}

The dimensions help us determine the multiplication order:

Now, let’s connect this to our earlier vectorized form. If we multiply it out:

Lx=Ls(diag(s)ssT)=(Lsdiag(s))(LsssT)\begin{align*} \frac{\partial L}{\partial \mathbf{x}} &= \frac{\partial L}{\partial \mathbf{s}}(\text{diag}(\mathbf{s}) - \mathbf{s}\mathbf{s}^T) \\ &= \left(\frac{\partial L}{\partial \mathbf{s}}\text{diag}(\mathbf{s})\right) - \left(\frac{\partial L}{\partial \mathbf{s}}\mathbf{s}\mathbf{s}^T\right) \end{align*}

The first term becomes element-wise multiplication of Ls\frac{\partial L}{\partial \mathbf{s}} and s\mathbf{s}. The second term becomes (Lss)sT(\frac{\partial L}{\partial \mathbf{s}} \cdot \mathbf{s})\mathbf{s}^T. When we transpose everything (since we typically work with column vectors in practice), this gives us exactly our previous formula: Lx=s(Ls(sLs)1)\frac{\partial L}{\partial \mathbf{x}} = \mathbf{s} \odot \left(\frac{\partial L}{\partial \mathbf{s}} - (\mathbf{s} \cdot \frac{\partial L}{\partial \mathbf{s}})\mathbf{1}\right)

Negative Log Likelihood Loss

First, let’s recall that the negative log likelihood loss for a single example is:

L=log(sy)L = -\log(s_{y})

where yy is the target class index and sys_y is the softmax probability for that class.

For a single example, the gradient of NLL with respect to its input (the softmax outputs) is:

Lsi=1siδiy\frac{\partial L}{\partial s_i} = -\frac{1}{s_i}\delta_{iy}

We can write this as a row vector:

sL=[1s1δ1y,1s2δ2y,...,1snδny]\nabla_{\mathbf{s}}L = \left[-\frac{1}{s_1}\delta_{1y}, -\frac{1}{s_2}\delta_{2y}, ..., -\frac{1}{s_n}\delta_{ny}\right]

Now, let’s chain this with our softmax Jacobian. Remember, we want:

Lx=sLsx\frac{\partial L}{\partial \mathbf{x}} = \nabla_{\mathbf{s}}L \cdot \frac{\partial \mathbf{s}}{\partial \mathbf{x}}

Using our softmax Jacobian:

sx=diag(s)ssT\frac{\partial \mathbf{s}}{\partial \mathbf{x}} = \text{diag}(\mathbf{s}) - \mathbf{s}\mathbf{s}^T

When we multiply these out, something remarkable happens. Let’s write it component by component:

Lxi=kLskskxi\frac{\partial L}{\partial x_i} = \sum_k \frac{\partial L}{\partial s_k} \frac{\partial s_k}{\partial x_i}

Substituting our expressions:

Lxi=k(1skδky)(sk(δkisi))\frac{\partial L}{\partial x_i} = \sum_k \left(-\frac{1}{s_k}\delta_{ky}\right) \left(s_k(\delta_{ki} - s_i)\right)

Most terms in this sum are zero because of the δky\delta_{ky} term. The only non-zero term is when k=yk = y:

Lxi=1sysy(δyisi)=(δyisi)\frac{\partial L}{\partial x_i} = -\frac{1}{s_y}s_y(\delta_{yi} - s_i) = -(\delta_{yi} - s_i)

Therefore:

Lxi=siδyi\frac{\partial L}{\partial x_i} = s_i - \delta_{yi}

In vector form, this gives us:

Lx=sy\frac{\partial L}{\partial \mathbf{x}} = \mathbf{s} - \mathbf{y}

where y\mathbf{y} is the one-hot encoded target vector. This is why we get the simple sy\mathbf{s} - \mathbf{y} gradient! The division by sys_y in the NLL gradient exactly cancels with the multiplication by sys_y in the softmax Jacobian.

LogSoftmax and NLL

If we had instead used log-softmax followed by NLL, we would have arrived at the same result through a different path:

This elegant result explains why cross-entropy loss is numerically stable and computationally efficient - all the complex terms in the chain rule cancel out to give us this simple gradient!


Previous Post
Back to the Basics
Next Post
Flash Attention in a Flash