My solutions to srush’s AutoDiff Puzzles. This is useful as a quick refresher for computing gradients.
Introduction
From the puzzle’s intro
Deep learning libraries like Torch utilize autodifferentiation of tensors to compute the parameter updates necessary to learn complex models from data. This technique is central to understanding why deep learning has become so widely used and effective. The autodifferentiation process is a neat trick that builds up a computational graph and then uses that graph to provide derivatives for user-constructed mathematical functions. At heart, this process is just an instantiation of the chain-rule based on implementations of every function and its derivative.
However, a library still needs to have efficient implementations of derivatives for its key building blocks. This sounds trivial -> just implement high school calculus. However, this is a bit more tricky than it sounds. Tensor-to-tensor functions are pretty complex and require keeping careful track of indexing on both the input and the output side.
Your goal in these puzzles is to implement the Jacobian, i.e. the derivative of each output cell with respect to each input cell, for each function below. In each case the function takes in a tensor and returns a tensor , so your job is to compute for all indices and . If you get discouraged, just remember, you did this in high school (it just had way less indexing).
Rules and Tips
-
Every answer is 1 line of 80-column code.
-
Everything in these puzzles should be done with standard Python numbers. (There is no need for Torch or tensors.)
-
Recall the basic multivariate calculus identities, most importantly:
-
Hint: Python booleans are numbers. So you can use them as indicator functions, i.e.
For each of the problems, a tensor function is provided. There are two solution formats:
- Given a lambda called
f
, your job is to fill in the functiondx(i, j)
which provides . These puzzles are about correctness, not efficiency. - Given some input and output tensors
Is and Os
and a function with Shaped Array types, fill out a functionjac
which returns the full Jacobian at a point (identical tojax.jacfwd(f)(x)
).
Problem 1: Id
Warmup:
Solution: This warmup asks us to compute . The derivative is 1 since it’s just the identity function. More explicitly, we know that is 1 for since .
def fb_id(x):
f = lambda o: x[0]
dx = lambda i, o: 1 # Fill in this line
return f, dx
Is = np.arange(1)
def f(x: Shaped[Array, "1"]) -> Shaped[Array, "1"]:
return 2 * x
def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "1 1"]:
return 2 * (x==x)[None]
Problem 2: Cosine
Warmup:
Solution: The derivative of is . So we need to fill in the line dx = lambda i, o: -math.sin(x[0])
.
def fb_cos(x):
f = lambda o: math.cos(x[0])
dx = lambda i, o: -math.sin(x) # Fill in this line
return f, dx
import math
def f(x: Shaped[Array, "1"]) -> Shaped[Array, "1"]:
return np.cos(x)
def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "1 1"]:
return -np.sin(x)[None]
Problem 3: Mean
Solution: The Jacobian has shape since the output only has one element. for all .
def fb_mean(x):
I = x.shape[0]
f = lambda o: sum(x[i] for i in range(I)) / I
dx = lambda i, o: 1 / I # Fill in this line
return f, dx
I = 10
Is = np.arange(I)
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "1"]:
return np.mean(x, axis=0, keepdims=True)
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "1 I"]:
return 1 / x.shape[0] * (x == x)[None]
Problem 4: Product
Solution: The Jacobian has shape since the output only has one element. For a given , the derivative is the product of all the other elements. So we can write this as .
def fb_prod(x):
pr = torch.prod(x)
f = lambda o: pr
dx = lambda i, o: pr / x[i] if x[i] != 0 else 0
return f, dx
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "1"]:
return np.prod(x, keepdims=True)
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "1 I"]:
pr = f(x)
return (pr / x)[None]
Problem 5: Repeat
Hint: The function dx
should return a scalar. It is the
derivative of , i.e. the o’th output.
Solution: The Jacobian has shape since the output has elements. The derivative is 1 for all .
def fb_repeat(x):
f = lambda o: x[0]
dx = lambda i, o: 1
return f, dx
Is = np.arange(1)
O = 10
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "1"]) -> Shaped[Array, "O"]:
return (x + (Os * 0 + 1))[:, 0]
def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "O 1"]:
return (x == x)[None]
Problem 6: Repeat and Scale
Solution: The scalar in front of the repeated version of the input depends on its index. In particular, the -th element is . So the derivative is .
def fb_repeat_scale(x):
I = 50
f = lambda o: x[0] * (o / I)
dx = lambda i, o: (o / I)
return f, dx
Is = np.arange(1)
O = 10
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "1"]) -> Shaped[Array, "O"]:
return x * (Os / O)[:, 0]
def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "O 1"]:
return Os / O
Problem 7: Negation
(Hint: remember the indicator trick, i.e.
(a == b) * 27 # 27 if a == b else 0
Solution: Here, the Jacobian has shape with since the input/output has elements. The Jacobian is a diagonal matrix with -1 on the diagonal.
def fb_neg(x):
f = lambda o: -x[o]
dx = lambda i, o: -(i == o)
return f, dx
I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return -x
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
return (0 - (Os ==Is[None])).astype(float)
Problem 8: ReLU
Recall
(Note: you can ignore the not of non-differentiability at 0.)
Solution: ReLU is an element-wise function, so we know the Jacobian is a diagonal matrix. The derivative is 0 if and 1 otherwise.
def fb_relu(x):
f = lambda o: (x[o] > 0) * x[o]
dx = lambda i, o: (i == o) * (x[o] > 0)
return f, dx
I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return x * (x > 0)
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
# x.shape (I, )
# (Os == Is).shape (O, I)
# Broadcasting (1, I) * (O, I)
return (x > 0) * (Os == Is)
Problem 8.5/9: Index
Solution: The Jacobian is a 15x25 matrix. are not used in the output, so the derivative is 0 for those indices. The outputs are just the inputs shifted by 10, so the derivative is 1 for and 0 otherwise.
# i o dx
# 0 0 0
# ...
# 10 0 1
# 11 1 1
# ...
def fb_index(x):
f = lambda o: x[o + 10]
dx = lambda i, o: 1 if i == (o + 10) else 0
return f, dx
I = 25
O = 15
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return x[10:]
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
return Is == (Os + 10)s
Problem 10: Cumsum
Solution: Each element of the output is the cumulative sum of the input up to that point, in other words, is present in the sum for output element if . So the derivative is 1 if and 0 otherwise. There is a scaling factor of 1/20 at the end.
def fb_cumsum(x):
out = torch.cumsum(x, 0)
f = lambda o: out[o] / 20
dx = lambda i, o: (i <= o) * 1/20
return f, dx
I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return np.cumsum(x) / 20
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
return 1 / 20 * (Is <= Os)
Problem 11: Running Mean
# i o dx
# 0 0 1/W
# 0 1 1/W
# 1 0 0
# 1 1 1/W
def fb_running(x):
W = 10
f = lambda o: sum([
x[o - do] for do in range(W)
if o - do >= 0
]) / W
dx = lambda i, o: ((o - i) < 10) * (i <= o) * (1/W)
return f, dx
I = 10
O = 8
Is = np.arange(I)
Os = np.arange(O)[:, None]
W = 3
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return np.convolve(x, np.ones(W) / W, mode="valid")
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
return ((Is - Os) < W) * (1 / W * (Is >= Os))
Problem 12: Sort
(This one is a bit counterintuitive! Note that we are not asking you to differentiate the sorting function it self.)
Solution: The derivative with respect to a given depends on its position in the sorted array. If is the -th element in the sorted array, then the derivative is 1. Otherwise, it’s 0.
Below, the argsort
array gives the indices of the sorted array.
def fb_sort(x):
sort, argsort = torch.sort(x, 0)
f = lambda o: sort[o]
dx = lambda i, o: i == argsort[o]
return f, dx
I = 10
O = 8
Is = np.arange(I)
Os = np.arange(O)[:, None]
W = 3
def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
return np.convolve(x, np.ones(W) / W, mode="valid")
def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
return ((Is - Os) < W) * (1 / W * (Is >= Os))
Next we move on to functions of two arguments. For these you will produce two derivatives: . Everything else is the same.
Problem 13: Elementwise mean
Solution: Since this is an element-wise operation, the Jacobian must be diagonal in the sense that it is nonzero only when or . The derivative is 1/2 for both and .
def fb_emean(x, y):
f = lambda o: (x[o] + y[o]) / 2
dx = lambda i, o: (i == o) * 1/2
dy = lambda j, o: (j == o) * 1/2
return f, dx, dy
I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(
x: Shaped[Array, "I"],
y: Shaped[Array, "I"],
) -> Shaped[Array, "O"]:
return (x + y) / 2
def jac(
x: Shaped[Array, "I"], y: Shaped[Array, "I"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O I"]]:
return (Is == Os) / 2, (Is == Os) / 2
Problem 14: Elementwise mul
Solution: Similar to the previous problem, the Jacobian is diagonal. The derivative is for and for .
def fb_mul(x, y):
f = lambda o: x[o] * y[o]
dx = lambda i, o: (i == o) * y[o]
dy = lambda j, o: (j == o) * x[o]
return f, dx, dy
I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]
def f(
x: Shaped[Array, "I"],
y: Shaped[Array, "I"],
) -> Shaped[Array, "O"]:
return x * y
def jac(
x: Shaped[Array, "I"], y: Shaped[Array, "I"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O I"]]:
return (Is == Os) * y, (Is == Os) * x # fill in
Problem 15: 1D Convolution
This is the standard 1D convolution used in deep learning. There is no wrap-around.
Note: This is probably the hardest one. The answer is short but tricky.
Solution: Let’s spend a bit more time understanding this one.
(Note: in the text of the puzzle, the summation might be shown starting from j=1, but in the code snippet we see it runs from j=0 to j=W-1. Both are equivalent if we shift indices accordingly.)
Derivative w.r.t.
Fix an output index . Then
Inside that sum, appears only when . In other words:
- If for any , then the derivative is 0.
- If for some , then the derivative is .
One way to write this is:
dx = lambda i, o: sum((i == (o + j)) * y[j] for j in range(W)) / W
which uses the fact that Python booleans are numbers.
We could also take advantage of the direct relationship that implies . However, we need to be more careful about the bounds and indexing errors, namely that is only valid if , assuming that is the length of the filter:
# Returns 0 if i - o is out of bounds
# Otherwise, returns y[i - o] / W
dx = lambda i, o: (y[i - o] / W if 0 <= i - o < W else 0)
Derivative w.r.t.
Again, fix the output index and use a different index name for . Now look at the same sum:
For a given , only the single term where matters. That term is . So the derivative is .
import math
def fb_conv(x, y):
W = 5
f = lambda o: sum((x[o + j] * y[j]) / W for j in range(W))
dx = lambda i, o: (
(0 <= i - o < W)) * (y[i - o] / W
if 0 <= i - o < W
else 0
)
dy = lambda j, o: x[o + j] / W
return f, dx, dy
I = 10
O = 6
W = 5
Is = np.arange(I)
Os = np.arange(O)[:, None]
Ws = np.arange(W)
def f(
x: Shaped[Array, "I"],
y: Shaped[Array, "W"],
) -> Shaped[Array, "O"]:
return np.convolve(x, y, mode="valid") / W
def jac(
x: Shaped[Array, "I"], y: Shaped[Array, "W"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O W"]]:
dx = (Is >= Os) * (Is - Os < W) * (y[Os - Is - 1] / W)
inds = Os - Ws + W - 1
dy = x[inds] / W
return dx, dy
For these next problems, the input is a matrix and an optional vector, and the output is a matrix.
Problem 16: View
Compute the identity function for all . is ignored.
Solution: This is similar to the first problem. Since , each element of the output depends only on one element in the input: the one at the same row-column index . So the derivative is
def fb_view(x, y):
f = lambda o, p: x[o, p]
dx = lambda i, j, o, p: 1 * (i == o) & (j == p)
dy = lambda j, o, p: 0
return f, dx, dy
I = 4
J = 4
O = 4
P = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
Ps = np.arange(P)[:, None, None, None]
def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P"]:
return x
def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P I J"]:
return (Os == Js) * (Ps == Is) # fill in
Problem 16.5 View + Reshape
Same as above, but the output matrix has been flattened to a 1D array.
Solution: Think of how you would iterate over a 2d matrix row-wise, considering the relationship between i, j and it’s flattened coordinate.
I = 4
J = 4
O = 16
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O"]:
return x.reshape((O,))
def jac(
x: Shaped[Array, "I J"],
) -> Shaped[Array, "O I J"]:
dx = (Is * J + Js)
return (Is * J + Js) == Os
Problem 16: Transpose
Transpose row and columns
Solution: The output is the transpose of the input, so the derivative is 1 if and .
def fb_trans(x, y):
f = lambda o, p: x[p, o]
dx = lambda i, j, o, p: (i == p) & (j == o) # Fill in this line
dy = lambda j, o, p: 0 # Fill in this line
return f, dx, dy
I = 4
J = 4
O = 4
P = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
Ps = np.arange(P)[:, None, None, None]
def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P"]:
return x
def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P I J"]:
return (Os == Js) * (Ps == Is) # fill in
Problem 17: Broadcast
Broadcast a matrix with a vector
Solution: Notice that when , does not appear in the output . So the derivative is 0 in those cases. When , the derivative is .
For the derivative w.r.t. , the derivative is only when .
def fb_broad(x, y):
f = lambda o, p: x[o, p] * y[p]
dx = lambda i, j, o, p: (j==p) * (i==o) * y[p] # Fill in this line
dy = lambda j, o, p: (j == p) * x[o, p]
return f, dx, dy
Problem 18: Mean Reduce
Compute the mean over rows
Note: Based on the formula above, I think the phrasing “compute the mean over rows” is slightly misleading in that one might think we are computing row-wise means (this is how I initially interpreted the operation before examining the formula). However, the formula is computing column-wise means, which I suppose is technically over all rows for a given column. The diagram itself is also misleading.
Solution: If we fix a column , the output is the mean of all the elements in that column. So the derivative is for all and . The derivative w.r.t. is 0 since it is not used in the output.
def fb_mean(x, y):
R = x.shape[0]
f = lambda o, p: sum(x[di, p] for di in range(R)) / R
dx = lambda i, j, o, p: (j == p) * 1/R # Fill in this line
dy = lambda j, o, p: 0 # Fill in this line
return f, dx, dy
I = 4
J = 4
O = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O"]:
return np.mean(x, axis=0)
def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O I J"]:
dx = (Os == Js) * (Is == Is) * (1 / I)
return dx
Problem 19: Matmul
Standard matrix multiplication
Solution: For now, I just write out the Jacobian by hand. I will try to come up with a more general solution later.
I = 4
J = 4
O = 4
Is = np.arange(I)
Js = np.arange(J)
Os = np.arange(O)[:, None]
def f(x: Shaped[Array, "I"], y: Shaped[Array, "J"]) -> Shaped[Array, "O "]:
return (x.reshape(2, 2) @ y.reshape(2, 2)).reshape(O)
def jac(
x: Shaped[Array, "I"], y: Shaped[Array, "J"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O J"]]:
# Build df/dx
# Row 0 => partial of f_0 wrt x_0..x_3 => [y[0], y[2], 0, 0]
# Row 1 => partial of f_1 wrt x_0..x_3 => [y[1], y[3], 0, 0]
# Row 2 => partial of f_2 wrt x_0..x_3 => [0, 0, y[0], y[2]]
# Row 3 => partial of f_3 wrt x_0..x_3 => [0, 0, y[1], y[3]]
dx = np.array([
[y[0], y[2], 0 , 0 ],
[y[1], y[3], 0 , 0 ],
[ 0 , 0 , y[0], y[2]],
[ 0 , 0 , y[1], y[3]],
], dtype=x.dtype)
# Build df/dy
# Row 0 => partial of f_0 wrt y_0..y_3 => [x[0], 0 , x[1], 0 ]
# Row 1 => partial of f_1 wrt y_0..y_3 => [ 0 , x[0], 0 , x[1]]
# Row 2 => partial of f_2 wrt y_0..y_3 => [x[2], 0 , x[3], 0 ]
# Row 3 => partial of f_3 wrt y_0..y_3 => [ 0 , x[2], 0 , x[3]]
dy = np.array([
[x[0], 0 , x[1], 0 ],
[ 0 , x[0], 0 , x[1]],
[x[2], 0 , x[3], 0 ],
[ 0 , x[2], 0 , x[3]],
], dtype=y.dtype)
return dx, dy