These functions commonly arise in the context of neural networks, especially in the output layer.
They occur when dealing with categorical and multinomial probability distributions, converting raw logits into a probably distrubution over classes, as well as in attention mechanisms in transformers
These functions are all closely related, involving sums of exponentials, and have nice relationships with each other.
Definitions
Given a vector x∈Rn, we define the following functions:
s=softmax(x):=∑iexp(xi)exp(x)
∂xj∂si=si(δij−sj) or Js=diag(s)−ssT
ℓ=logsumexp(x):=log(i∑exp(xi))
∂xj∂ℓ=sj or Jℓ=diag(softmax(x))
s~=logsoftmax(x):=x−logsumexp(x)
∂xj∂s~i=δij−sj or Js~=I−1⋅softmax(x)T
So we see that the softmax and logsoftmax are both functions that take in a vector and return a vector, while the logsumexp is a scalar valued function. The logsoftmax is just the log of the softmax, which also happens to be the input minus its logsumexp.
How logsumexp connects softmax and logsoftmax
By the definition of logsoftmax, we have
logsoftmax(x)=log(softmax(x))=log∑iexp(xi)exp(x)=logexp(xi)−logi∑exp(xi)=xi−logi∑exp(xi)=x−logsumexp(x)(by log rules)(log is the inverse of exp)(by definition)
Numerical Stability
In practice, the softmax function can be numerically unstable, especially when the input values have large magnitudes, where very negative numbers underflow to zero and very positive numbers can overflow to infinity. To mitigate this, we can use a nice property of the logsumexp function, which allows us to shift the values in the exponent by an arbitrary constant without changing the result. That is,
yeyeyy=log(i∑exp(xi))=i∑exp(xi)=eci∑exp(xi−c)=c+log(i∑exp(xi−c))exponentiate both sidesex=e(x−c)∗ec=ex−c+ctake the log of both sides
If we choose c=max(x), then the largest value in the exponent is zero, which is numerically stable. This is the trick used in the implementation of the softmax and logsumexp functions in PyTorch and TensorFlow.
Gradients
The gradients of these functions similarly have nice relationships to each other, so let’s build them up piece by piece.
Gradient of LogSumExp
Recall, logsumexp is a scalar valued function Rn→R, so the gradient is a vector. Set ℓ=logsumexp(x) and s=softmax(x). Then we have
∂xi∂ℓ=∂xi∂log(j∑exp(xj))=∑jexp(xj)1∂xi∂j∑exp(xj)=∑jexp(xj)exp(xi)=si=softmax(x)i(dxdlog(f(x))=f(x)f′(x) by chain rule)(only nonzero term is exi when i=j)
So the gradient of logsumexp is the softmax function! This will be useful when we compute the gradient of the logsoftmax function.
Gradient of LogSoftmax
The logsoftmax function is a vector valued function Rn→Rn, so the gradient is a matrix. Set s~=logsoftmax(x) and s=softmax(x). Then we have
∂xj∂s~i=∂xj∂(xi−logk∑exp(xk))=δij−∂xj∂logk∑exp(xk)=δij−sj=δij−softmax(x)j(only xi term depends on xj)(by the gradient of logsumexp)
where 1 is the vector of n ones, which broadcasts the softmax vector to a matrix so that the j-th column contains the j-th element of the softmax vector (vstacks the softmax vector).
Gradient of Softmax
The softmax function is a vector valued function Rn→Rn, so the gradient is a matrix. Set s=softmax(x). There are several ways to compute the gradient of the softmax function:
Instead of ∂xj∂si, we will compute ∂xj∂log(si).
Notice that
∂xj∂log(si)=si1⋅∂xj∂si
If we rearrange the terms, we get:
∂xj∂si=si⋅∂xj∂log(si)
which tells us that the gradient of the softmax is the softmax itself times the gradient of the log-softmax. We showed that the gradient of the log-softmax is
∂xj∂log(si)=δij−softmax(x)j
So the gradient of the softmax is:
∂xj∂si=si(δij−sj)
First, recall that the quotient rule states that the derivative of a function of the form h(x)=g(x)f(x) is:
dxdh(x)=g(x)2f′(x)g(x)−f(x)g′(x)
In the case of the softmax, we have:
si=∑jexjexi
so we can write the softmax as a function of the form h(x)=g(x)f(x) where fi=exi and gi=∑iexi.
Now, we have ∂xj∂fi=δijexi and ∂xj∂gi=exj, so we can apply the quotient. For convenience, let ∑=∑jexj:
Let’s separate this into two terms to make it clearer:
∂xi∂L=k=1∑Dsk∂sk∂Lδik−k=1∑Dsk∂sk∂Lsi
Now, let’s analyze each term:
In the first term, because of the Kronecker delta δik, the sum collapses to just the i-th term:
k=1∑Dsk∂sk∂Lδik=si∂si∂L
In the second term, notice that si can come out of the sum since it doesn’t depend on k:
k=1∑Dsk∂sk∂Lsi=sik=1∑Dsk∂sk∂L
Putting it back together:
∂xi∂L=si∂si∂L−sik=1∑Dsk∂sk∂L
This can be factored as:
∂xi∂L=si(∂si∂L−k=1∑Dsk∂sk∂L)
Now, here’s the key insight that leads to the vectorized form: this equation gives us the i-th component of the gradient, and we can write it in vector form! The gradient is:
∇xL=s⊙(∇sL−(s⋅∇sL)1)
where:
⊙ represents element-wise multiplication
⋅ represents dot product
1 is a vector of ones
Let’s implement this in PyTorch:
import torchclass SoftmaxLayer: def forward(self, x): # Save input for backward pass self.input = x # For numerical stability, subtract max value before exponential # Shape: (batch_size, num_classes) shifted_x = x - torch.max(x, dim=1, keepdim=True)[0] # Compute exponentials # Shape: (batch_size, num_classes) exp_x = torch.exp(shifted_x) # Compute sum of exponentials for normalization # Shape: (batch_size, 1) sum_exp = torch.sum(exp_x, dim=1, keepdim=True) # Compute softmax output # Shape: (batch_size, num_classes) self.output = exp_x / sum_exp return self.output def backward(self, grad_output): # grad_output shape: (batch_size, num_classes) # output shape: (batch_size, num_classes) # The Jacobian of softmax for each sample is: # J[i,j] = output[i] * (δ[i,j] - output[j]) # where δ[i,j] is 1 if i=j and 0 otherwise # Compute using the compact form: # grad = output * (grad_output - (output · grad_output)) # Shape: (batch_size, 1) sum_grad = torch.sum(grad_output * self.output, dim=1, keepdim=True) # s · ∇_s L # Shape: (batch_size, num_classes) return self.output * (grad_output - sum_grad) # s ⊙ (∇_s L - (s · ∇_s L)) def __call__(self, x): return self.forward(x)
Chain Rule with Jacobians (shape analysis)
Now, for the chain rule with Jacobians, we have:
∂x∂L=∂s∂L∂x∂s
The dimensions help us determine the multiplication order:
∂s∂L is 1×D (a row vector)
∂x∂s is D×D (your Jacobian matrix)
∂x∂L should be 1×D (a row vector)
Now, let’s connect this to our earlier vectorized form. If we multiply it out:
The first term becomes element-wise multiplication of ∂s∂L and s. The second term becomes (∂s∂L⋅s)sT.
When we transpose everything (since we typically work with column vectors in practice), this gives us exactly our previous formula:
∂x∂L=s⊙(∂s∂L−(s⋅∂s∂L)1)
Negative Log Likelihood Loss
First, let’s recall that the negative log likelihood loss for a single example is:
L=−log(sy)
where y is the target class index and sy is the softmax probability for that class.
For a single example, the gradient of NLL with respect to its input (the softmax outputs) is:
∂si∂L=−si1δiy
We can write this as a row vector:
∇sL=[−s11δ1y,−s21δ2y,...,−sn1δny]
Now, let’s chain this with our softmax Jacobian. Remember, we want:
∂x∂L=∇sL⋅∂x∂s
Using our softmax Jacobian:
∂x∂s=diag(s)−ssT
When we multiply these out, something remarkable happens. Let’s write it component by component:
∂xi∂L=k∑∂sk∂L∂xi∂sk
Substituting our expressions:
∂xi∂L=k∑(−sk1δky)(sk(δki−si))
Most terms in this sum are zero because of the δky term. The only non-zero term is when k=y:
∂xi∂L=−sy1sy(δyi−si)=−(δyi−si)
Therefore:
∂xi∂L=si−δyi
In vector form, this gives us:
∂x∂L=s−y
where y is the one-hot encoded target vector.
This is why we get the simple s−y gradient! The division by sy in the NLL gradient exactly cancels with the multiplication by sy in the softmax Jacobian.
LogSoftmax and NLL
If we had instead used log-softmax followed by NLL, we would have arrived at the same result through a different path:
NLL gradient with respect to log-softmax outputs is just −y
Log-softmax Jacobian I−1sT
Multiplying these together also gives us s−y
This elegant result explains why cross-entropy loss is numerically stable and computationally efficient - all the complex terms in the chain rule cancel out to give us this simple gradient!