Skip to content

AutoDiff Puzzles

Published: at 05:32 AM in 19 min readSuggest Changes

My solutions to srush’s AutoDiff Puzzles. This is useful as a quick refresher for computing gradients.

Introduction

From the puzzle’s intro

Deep learning libraries like Torch utilize autodifferentiation of tensors to compute the parameter updates necessary to learn complex models from data. This technique is central to understanding why deep learning has become so widely used and effective. The autodifferentiation process is a neat trick that builds up a computational graph and then uses that graph to provide derivatives for user-constructed mathematical functions. At heart, this process is just an instantiation of the chain-rule based on implementations of every function and its derivative.

However, a library still needs to have efficient implementations of derivatives for its key building blocks. This sounds trivial -> just implement high school calculus. However, this is a bit more tricky than it sounds. Tensor-to-tensor functions are pretty complex and require keeping careful track of indexing on both the input and the output side.

Your goal in these puzzles is to implement the Jacobian, i.e. the derivative of each output cell with respect to each input cell, for each function below. In each case the function takes in a tensor xRIx\in \mathbb{R}^{I} and returns a tensor f(x)ROf(x) \in \mathbb{R}^{O}, so your job is to compute df(x)odxi\frac{d f(x)_o}{dx_i} for all indices o{0O1}o \in \{0\ldots O-1\} and i{0I1}i\in \{0\ldots I-1\}. If you get discouraged, just remember, you did this in high school (it just had way less indexing).

Rules and Tips

For each of the problems, a tensor function ff is provided. There are two solution formats:

  1. Given a lambda called f, your job is to fill in the function dx(i, j) which provides df(x)odxi\frac{d f(x)_o}{dx_i}. These puzzles are about correctness, not efficiency.
  2. Given some input and output tensors Is and Os and a function ff with Shaped Array types, fill out a function jac which returns the full Jacobian at a point xx (identical to jax.jacfwd(f)(x)).

Problem 1: Id

Warmup: f(x0)=[x0]f(x_0) = [x_0]

Solution: This warmup asks us to compute dxdx=1\frac{d x}{dx} = 1. The derivative is 1 since it’s just the identity function. More explicitly, we know that df(x)odxi\frac{df(x)_o}{dx_i} is 1 for o=i=0o = i = 0 since f(x)0=x0f(x)_0 = x_0.

def fb_id(x):
  f = lambda o: x[0]
  dx = lambda i, o: 1 # Fill in this line
  return f, dx

Is = np.arange(1)


def f(x: Shaped[Array, "1"]) -> Shaped[Array, "1"]:
    return 2 * x


def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "1 1"]:
    return 2 * (x==x)[None]

Problem 2: Cosine

Warmup: f(x0)=[cos(x0)]f(x_0) = [\cos(x_0)]

Solution: The derivative of cos(x)\cos(x) is sin(x)-\sin(x). So we need to fill in the line dx = lambda i, o: -math.sin(x[0]).

def fb_cos(x):
    f = lambda o: math.cos(x[0])
    dx = lambda i, o: -math.sin(x)  # Fill in this line
    return f, dx
    import math

def f(x: Shaped[Array, "1"]) -> Shaped[Array, "1"]:
    return np.cos(x)

def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "1 1"]:
    return -np.sin(x)[None]

Problem 3: Mean

f(x0,x1,,xI1)=[(x0+x1++xI)/I]f(x_0, x_1, \ldots, x_{I-1}) = [(x_0 + x_1 + \ldots + x_I) / I]

Solution: The Jacobian has shape 1×I1 \times I since the output only has one element. f(x)oxi=1I\frac{\partial f(x)o}{\partial x_i} = \frac{1}{I} for all ii.

def fb_mean(x):
    I = x.shape[0]
    f = lambda o: sum(x[i] for i in range(I)) / I
    dx = lambda i, o: 1 / I # Fill in this line
    return f, dx
    I = 10

Is = np.arange(I)


def f(x: Shaped[Array, "I"]) -> Shaped[Array, "1"]:
    return np.mean(x, axis=0, keepdims=True)


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "1 I"]:
    return 1 / x.shape[0] * (x == x)[None]

Problem 4: Product

f(x0,x2,,xI1)=x1x2xI1f(x_0, x_2, \ldots, x_{I-1}) = x_1 x_2 \ldots x_{I-1}

Solution: The Jacobian has shape 1×I1 \times I since the output only has one element. For a given xix_i, the derivative is the product of all the other elements. So we can write this as f(x)oxi=jixj\frac{\partial f(x)_o}{\partial x_i} = \prod_{j \neq i} x_j.

def fb_prod(x):
    pr = torch.prod(x)
    f = lambda o: pr
    dx = lambda i, o: pr / x[i] if x[i] != 0 else 0
    return f, dx

def f(x: Shaped[Array, "I"]) -> Shaped[Array, "1"]:
    return np.prod(x, keepdims=True)


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "1 I"]:
    pr = f(x)
    return (pr / x)[None]

Problem 5: Repeat

f(x0)=[x0,x0,x0,x0]f(x_0) = [x_0, x_0, x_0, \ldots x_0]

Hint: The function dx should return a scalar. It is the derivative of f(x0)of(x_0)_o, i.e. the o’th output.

Solution: The Jacobian has shape O×1O \times 1 since the output has OO elements. The derivative is 1 for all ii.

def fb_repeat(x):
    f = lambda o: x[0]
    dx = lambda i, o: 1
    return f, dx
    Is = np.arange(1)

O = 10
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "1"]) -> Shaped[Array, "O"]:
    return (x + (Os * 0 + 1))[:, 0]


def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "O 1"]:
    return (x == x)[None]

Problem 6: Repeat and Scale

f(x0)=[x0×0/I,x0×2/I,x0×3/I,,x0×(I1)/I]f(x_0) = [x_0 \times 0/I, x_0 \times 2/I, x_0 \times 3/I, \ldots, x_{0} \times (I-1)/I]

Solution: The scalar in front of the repeated version of the input depends on its index. In particular, the oo-th element is x0×o/Ix_0 \times o/I. So the derivative is o/Io/I.

def fb_repeat_scale(x):
    I = 50
    f = lambda o: x[0] * (o / I)
    dx = lambda i, o: (o / I)
    return f, dx

Is = np.arange(1)
O = 10
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "1"]) -> Shaped[Array, "O"]:
    return x * (Os / O)[:, 0]


def jac(x: Shaped[Array, "1"]) -> Shaped[Array, "O 1"]:
    return Os / O

Problem 7: Negation

f(x0,x1,)=[x0,x1,]f(x_0, x_1, \ldots) = [-x_0, -x_1, \ldots]

(Hint: remember the indicator trick, i.e.

(a == b) * 27 # 27 if a == b else 0

Solution: Here, the Jacobian has shape I×II \times I with I=OI=O since the input/output has II elements. The Jacobian is a diagonal matrix with -1 on the diagonal.

def fb_neg(x):
    f = lambda o: -x[o]
    dx = lambda i, o: -(i == o)
    return f, dx

I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return -x


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    return (0 - (Os ==Is[None])).astype(float)

Problem 8: ReLU

f(x0,x1,)=[relu(x0),relu(x1),]f(x_0, x_1, \ldots) = [\text{relu}(x_0), \text{relu}(x_1), \ldots]

Recall

relu(x)={0x<0xx>=0\text{relu}(x) = \begin{cases} 0 & x < 0 \\ x & x >= 0 \end{cases}

(Note: you can ignore the not of non-differentiability at 0.)

Solution: ReLU is an element-wise function, so we know the Jacobian is a diagonal matrix. The derivative is 0 if xi<0x_i < 0 and 1 otherwise.

def fb_relu(x):
    f = lambda o: (x[o] > 0) * x[o]
    dx = lambda i, o: (i == o) * (x[o] > 0)
    return f, dx
    I = 10

O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return x * (x > 0)


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    # x.shape (I, )
    # (Os == Is).shape (O, I)
    # Broadcasting (1, I) * (O, I)
    return (x > 0) * (Os == Is)

Problem 8.5/9: Index

f(x0,x1,,x24)=[x10,x11,,x24]f(x_0, x_1, \ldots, x_{24}) = [x_{10}, x_{11}, \ldots, x_{24}]

Solution: The Jacobian is a 15x25 matrix. x0x9x_0 \dots x_9 are not used in the output, so the derivative is 0 for those indices. The outputs are just the inputs shifted by 10, so the derivative is 1 for i=o+10i = o + 10 and 0 otherwise.

# i o dx
# 0 0 0
# ...
# 10 0 1
# 11 1 1
# ...

def fb_index(x):
    f = lambda o: x[o + 10]
    dx = lambda i, o: 1 if i == (o + 10) else 0
    return f, dx

I = 25
O = 15
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return x[10:]


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    return Is == (Os + 10)s

Problem 10: Cumsum

f(x0,x1,)=[i=00xi,i=01xi,i=02xi,,]/20f(x_0, x_1, \ldots) = [\sum^0_{i=0} x_{i}, \sum^1_{i=0} x_{i}, \sum^2_{i=0} x_{i}, \ldots, ] / 20

Solution: Each element of the output is the cumulative sum of the input up to that point, in other words, xix_i is present in the sum for output element oo if ioi \leq o. So the derivative is 1 if ioi \leq o and 0 otherwise. There is a scaling factor of 1/20 at the end.

def fb_cumsum(x):
    out = torch.cumsum(x, 0)
    f = lambda o: out[o] / 20
    dx = lambda i, o:  (i <= o) * 1/20
    return f, dx

I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return np.cumsum(x) / 20


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    return 1 / 20 * (Is <= Os)

Problem 11: Running Mean

f(x0,x1,)o=i=oWoxiWf(x_0, x_1, \ldots)_o = \frac{\displaystyle \sum^o_{i=o-W} x_i}{W}

# i o dx
# 0 0 1/W
# 0 1 1/W
# 1 0 0
# 1 1 1/W
def fb_running(x):
    W = 10
    f = lambda o: sum([
        x[o - do] for do in range(W)
        if o - do >= 0
    ]) / W
    dx = lambda i, o: ((o - i) < 10) * (i <= o) * (1/W)
    return f, dx

I = 10
O = 8
Is = np.arange(I)
Os = np.arange(O)[:, None]
W = 3

def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return np.convolve(x, np.ones(W) / W, mode="valid")


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    return ((Is - Os) < W) * (1 / W * (Is >= Os))

Problem 12: Sort

f(x0,x1,)=x’s in sorted orderf(x_0, x_1, \ldots) = \text{x's in sorted order}

(This one is a bit counterintuitive! Note that we are not asking you to differentiate the sorting function it self.)

Solution: The derivative with respect to a given xix_i depends on its position in the sorted array. If xix_i is the oo-th element in the sorted array, then the derivative is 1. Otherwise, it’s 0. Below, the argsort array gives the indices of the sorted array.

def fb_sort(x):
    sort, argsort = torch.sort(x, 0)
    f = lambda o: sort[o]
    dx = lambda i, o: i == argsort[o]
    return f, dx

I = 10
O = 8
Is = np.arange(I)
Os = np.arange(O)[:, None]
W = 3

def f(x: Shaped[Array, "I"]) -> Shaped[Array, "O"]:
    return np.convolve(x, np.ones(W) / W, mode="valid")


def jac(x: Shaped[Array, "I"]) -> Shaped[Array, "O I"]:
    return ((Is - Os) < W) * (1 / W * (Is >= Os))

Next we move on to functions of two arguments. For these you will produce two derivatives: df(x,y)oxi,df(x,y)oyj\frac{df(x, y)_o}{x_i}, \frac{df(x, y)_o}{y_j}. Everything else is the same.

Problem 13: Elementwise mean

f(x,y)o=(xo+yo)/2f(x, y)_o = (x_o + y_o) /2

Solution: Since this is an element-wise operation, the Jacobian must be diagonal in the sense that it is nonzero only when i=oi = o or j=oj = o. The derivative is 1/2 for both xx and yy.

def fb_emean(x, y):
    f = lambda o: (x[o] + y[o]) / 2
    dx = lambda i, o: (i == o) * 1/2
    dy = lambda j, o: (j == o) * 1/2
    return f, dx, dy

I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(
    x: Shaped[Array, "I"],
    y: Shaped[Array, "I"],
) -> Shaped[Array, "O"]:
    return (x + y) / 2


def jac(
    x: Shaped[Array, "I"], y: Shaped[Array, "I"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O I"]]:
    return (Is == Os) / 2, (Is == Os) / 2

Problem 14: Elementwise mul

f(x,y)o=xoyof(x, y)_o = x_o * y_o

Solution: Similar to the previous problem, the Jacobian is diagonal. The derivative is yoy_o for xx and xox_o for yy.

def fb_mul(x, y):
    f = lambda o: x[o] * y[o]
    dx = lambda i, o: (i == o) * y[o]
    dy = lambda j, o: (j == o) * x[o]
    return f, dx, dy

I = 10
O = 10
Is = np.arange(I)
Os = np.arange(O)[:, None]


def f(
    x: Shaped[Array, "I"],
    y: Shaped[Array, "I"],
) -> Shaped[Array, "O"]:
    return x * y


def jac(
    x: Shaped[Array, "I"], y: Shaped[Array, "I"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O I"]]:
    return (Is == Os) * y, (Is == Os) * x  # fill in

Problem 15: 1D Convolution

This is the standard 1D convolution used in deep learning. There is no wrap-around.

f(x,y)o=j=1Kxo+jyj/Kf(x, y)_o = \sum_{j=1}^K x_{o+j} * y_{j} / K

Note: This is probably the hardest one. The answer is short but tricky.

Solution: Let’s spend a bit more time understanding this one.

(Note: in the text of the puzzle, the summation might be shown starting from j=1, but in the code snippet we see it runs from j=0 to j=W-1. Both are equivalent if we shift indices accordingly.)

Derivative w.r.t. xix_i

Fix an output index oo. Then

f(x,y)o=1Wj=0W1(xo+j×yj).f(x, y)_o = \frac{1}{W} \sum_{j=0}^{W-1} \bigl(x_{o + j} \times y_j\bigr).

Inside that sum, xix_i appears only when i=o+ji = o + j. In other words:

One way to write this is:

dx = lambda i, o: sum((i == (o + j)) * y[j] for j in range(W)) / W

which uses the fact that Python booleans are numbers.

We could also take advantage of the direct relationship that i=o+ji = o + j implies j=ioj = i - o. However, we need to be more careful about the bounds and indexing errors, namely that j=ioj = i - o is only valid if 0j<W0 \leq j < W, assuming that WW is the length of the filter:

# Returns 0 if i - o is out of bounds
# Otherwise, returns y[i - o] / W
dx = lambda i, o: (y[i - o] / W if 0 <= i - o < W else 0)

Derivative w.r.t. yjy_j

Again, fix the output index oo and use a different index name for jj. Now look at the same sum:

f(x,y)o=1Wk=0W1(xo+k×yk).f(x, y)_o = \frac{1}{W} \sum_{k=0}^{W-1} \bigl(x_{o + k} \times y_k\bigr).

For a given jj, only the single term where k=jk = j matters. That term is xo+jyj/Wx_{o + j} * y_j / W. So the derivative is xo+j/Wx_{o + j} / W.

import math
def fb_conv(x, y):
    W = 5
    f = lambda o: sum((x[o + j] * y[j]) / W for j in range(W))
    dx = lambda i, o: (
        (0 <= i - o < W)) * (y[i - o] / W
        if 0 <= i - o < W
        else 0
    )

    dy = lambda j, o: x[o + j] / W
    return f, dx, dy

I = 10
O = 6
W = 5
Is = np.arange(I)
Os = np.arange(O)[:, None]
Ws = np.arange(W)


def f(
    x: Shaped[Array, "I"],
    y: Shaped[Array, "W"],
) -> Shaped[Array, "O"]:
    return np.convolve(x, y, mode="valid") / W


def jac(
    x: Shaped[Array, "I"], y: Shaped[Array, "W"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O W"]]:
    dx = (Is >= Os) * (Is - Os < W) * (y[Os - Is - 1] / W)
    inds = Os - Ws + W - 1
    dy = x[inds] / W
    return dx, dy

For these next problems, the input is a matrix and an optional vector, and the output is a matrix.

df(x,y)o,pxi,j,df(x,y)o,pyj\frac{df(x, y)_{o, p}}{x_{i, j}}, \frac{df(x, y)_{o,p}}{y_j}

Problem 16: View

Compute the identity function for all o,po,p. yy is ignored.

f(X,y)o,p=Xo,pf(X, y)_{o, p} = X_{o, p}

Solution: This is similar to the first problem. Since f(X,y)o,p=Xo,pf(X,y)_{o,p} = X_{o,p}, each element of the output depends only on one element in the input: the one at the same row-column index (o,p)(o,p). So the derivative is 1[(i,j)=(o,p)]\mathbb{1}[(i, j) = (o, p)]

def fb_view(x, y):
    f = lambda o, p: x[o, p]
    dx = lambda i, j, o, p: 1 * (i == o) & (j == p)
    dy = lambda j, o, p: 0
    return f, dx, dy

I = 4
J = 4
O = 4
P = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
Ps = np.arange(P)[:, None, None, None]

def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P"]:
    return x


def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P I J"]:
    return (Os == Js) *  (Ps == Is)  # fill in

Problem 16.5 View + Reshape

Same as above, but the output matrix has been flattened to a 1D array.

Solution: Think of how you would iterate over a 2d matrix row-wise, considering the relationship between i, j and it’s flattened coordinate.

I = 4
J = 4
O = 16
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]

def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O"]:
    return x.reshape((O,))


def jac(
    x: Shaped[Array, "I J"],
) -> Shaped[Array, "O I J"]:
    dx = (Is * J + Js)
    return (Is * J + Js) == Os

Problem 16: Transpose

Transpose row and columns

f(X,y)o,p=Xp,of(X, y)_{o, p} = X_{p, o}

Solution: The output is the transpose of the input, so the derivative is 1 if i=pi = p and j=oj = o.

def fb_trans(x, y):
    f = lambda o, p: x[p, o]
    dx = lambda i, j, o, p: (i == p) & (j == o)  # Fill in this line
    dy = lambda j, o, p: 0  # Fill in this line
    return f, dx, dy

I = 4
J = 4
O = 4
P = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]
Ps = np.arange(P)[:, None, None, None]

def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P"]:
    return x


def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O P I J"]:
    return (Os == Js) *  (Ps == Is)  # fill in

Problem 17: Broadcast

Broadcast a matrix with a vector

f(X,y)o,p=Xo,pypf(X, y)_{o, p} = X_{o, p} \cdot y_p

Solution: Notice that when (i,j)(o,p)(i, j) \neq (o, p), i,j_{i, j} does not appear in the output f(X,y)o,p=Xo,pypf(X, y)_{o, p} = X_{o, p} \cdot y_p. So the derivative is 0 in those cases. When (i,j)=(o,p)(i, j) = (o, p), the derivative is ypy_p.

For the derivative w.r.t. yy, the derivative is Xo,pX_{o, p} only when j=pj = p.

def fb_broad(x, y):
    f = lambda o, p: x[o, p] * y[p]
    dx = lambda i, j, o, p:  (j==p) * (i==o) * y[p]  # Fill in this line
    dy = lambda j, o, p: (j == p) * x[o, p]
    return f, dx, dy

Problem 18: Mean Reduce

Compute the mean over rows

f(X,y)o,p=iXi,p/Rf(X, y)_{o, p} = \sum_{i} X_{i, p} / R

Note: Based on the formula above, I think the phrasing “compute the mean over rows” is slightly misleading in that one might think we are computing row-wise means (this is how I initially interpreted the operation before examining the formula). However, the formula is computing column-wise means, which I suppose is technically over all rows for a given column. The diagram itself is also misleading.

Solution: If we fix a column pp, the output is the mean of all the elements in that column. So the derivative is 1/R1/R for all ii and j=pj = p. The derivative w.r.t. yy is 0 since it is not used in the output.

def fb_mean(x, y):
    R = x.shape[0]
    f = lambda o, p: sum(x[di, p] for di in range(R)) / R
    dx = lambda i, j, o, p: (j == p) * 1/R  # Fill in this line
    dy = lambda j, o, p: 0  # Fill in this line
    return f, dx, dy

I = 4
J = 4
O = 4
Is = np.arange(I)[:, None]
Js = np.arange(J)
Os = np.arange(O)[:, None, None]

def f(x: Shaped[Array, "I J"]) -> Shaped[Array, "O"]:
    return np.mean(x, axis=0)


def jac(x: Shaped[Array, "I J"]) -> Shaped[Array, "O I J"]:
    dx = (Os == Js) * (Is == Is) * (1 / I)
    return dx

Problem 19: Matmul

Standard matrix multiplication

f(X,Y)o,p=jXo,jYj,pf(X, Y)_{o,p} = \sum_j X_{o, j} Y_{j,p}

Solution: For now, I just write out the Jacobian by hand. I will try to come up with a more general solution later.

I = 4
J = 4
O = 4
Is = np.arange(I)
Js = np.arange(J)
Os = np.arange(O)[:, None]


def f(x: Shaped[Array, "I"], y: Shaped[Array, "J"]) -> Shaped[Array, "O "]:
    return (x.reshape(2, 2) @ y.reshape(2, 2)).reshape(O)


def jac(
    x: Shaped[Array, "I"], y: Shaped[Array, "J"]
) -> tuple[Shaped[Array, "O I"], Shaped[Array, "O J"]]:

    # Build df/dx
    # Row 0 => partial of f_0 wrt x_0..x_3 => [y[0], y[2], 0, 0]
    # Row 1 => partial of f_1 wrt x_0..x_3 => [y[1], y[3], 0, 0]
    # Row 2 => partial of f_2 wrt x_0..x_3 => [0, 0, y[0], y[2]]
    # Row 3 => partial of f_3 wrt x_0..x_3 => [0, 0, y[1], y[3]]
    dx = np.array([
        [y[0], y[2],    0 ,    0 ],
        [y[1], y[3],    0 ,    0 ],
        [   0 ,    0 , y[0], y[2]],
        [   0 ,    0 , y[1], y[3]],
    ], dtype=x.dtype)

    # Build df/dy
    # Row 0 => partial of f_0 wrt y_0..y_3 => [x[0],  0  , x[1],  0  ]
    # Row 1 => partial of f_1 wrt y_0..y_3 => [ 0  , x[0],  0  , x[1]]
    # Row 2 => partial of f_2 wrt y_0..y_3 => [x[2],  0  , x[3],  0  ]
    # Row 3 => partial of f_3 wrt y_0..y_3 => [ 0  , x[2],  0  , x[3]]
    dy = np.array([
        [x[0],   0 , x[1],   0 ],
        [  0 , x[0],   0 , x[1]],
        [x[2],   0 , x[3],   0 ],
        [  0 , x[2],   0 , x[3]],
    ], dtype=y.dtype)

    return dx, dy

Previous Post
Computing the Jacobian of a Matrix Product
Next Post
Back to Backprop