Skip to content

Tensor Puzzles

Published: at 05:32 AM in 13 min readSuggest Changes

My solutions to srush’s tensor puzzles. This is useful as a quick refresher for things you can do with vanilla tensor operations.

These puzzles are about broadcasting. Know this rule:

There is a rule you should learn at last.
combination of tensors the task.
Dims right-aligned,
extra lefts 1s aassigned,
match paired dimensions: Broadcast!

Example:

9 x 1 x 3                     9 x 1 x 3
    8 x 1 ->  extra left 1 -> 1 x 8 x 1
---------                     ---------
                              9 x 8 x 3

More seriously, here are the full set of broadcasting rules from the cs231n numpy tutorial:

  1. If the arrays do not have the same rank, prepend the shape of the lower rank array with 1s until both shapes have the same length (as we saw above).
  2. The two arrays are said to be compatible in a dimension if they have the same size in the dimension, or if one of the arrays has size 1 in that dimension.
  3. The arrays can be broadcast together if they are compatible in all dimensions.
  4. After broadcasting, each array behaves as if it had shape equal to the elementwise maximum of shapes of the two input arrays.
  5. In any dimension where one array had size 1 and the other array had size greater than 1, the first array behaves as if it were copied along that dimension.

More seriously, here are the full set of broadcasting rules from the numpy documentation.

The rules for these puzzles are the following:

  1. Each puzzle can be solved in 1-line (<80 columns) of code.
  2. We are allowed @, arithmetic, comparison, shape, any indexing (e.g. a[:j], a[:, None], a[arange(10)]), and previous puzzle functions.
  3. We are not allowed anything else. No view, sum, take, squeeze, tensor.

Additionally, we are provided with two primitives: arange and where:

def arange(i: int):
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))


def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

Puzzle 1 - ones

Compute ones - the vector of all ones.

Solution: Create a tensor of shape (i,) using arange. Since all values are greater than or equal to 0, we can use where to set all values to 1.

def ones(i: int) -> TT["i"]:
  return where(arange(i) >= 0, 1, 0)

Puzzle 2 - sum

Compute sum - the sum of a vector.

Solution: Compute the dot product of the vector with a vector of ones. The shapes are [1 x i] @ [i x 1] = [1 x 1]. [i] @ [i, 1] would also work.

def sum(a: TT["i"]) -> TT[1]:
  return ones(a.shape[0]) @ a[:, None]

Puzzle 3 - outer

Compute outer - the outer product of two vectors.

Solution: Compute the matrix-matrix product @ by adding singleton dimensions to the vectors. The shapes are [i x 1] @ [1 x j] = [i x j].

def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    return a[:, None] @ b[None]

Puzzle 4 - diag

Compute diag - the diagonal vector of a square matrix.

Solution: Use arange to create a vector of indices and use it to index into the matrix along the diagonal.

def diag(a: TT["i", "i"]) -> TT["i"]:
    ind = arange(a.shape[0])
    return a[ind, ind]

Puzzle 5 - eye

Compute eye - the identity matrix.

Solution: Use broadcasting to check each arange value with itself, so that the diagonal has all ones.

def eye(j: int) -> TT["j", "j"]:
    return where(arange(j)[:, None] == arange(j), 1, 0)

Puzzle 6 - triu

Compute triu - the upper triangular matrix.

Solution: Use broadcasting to check if the row index is less than or equal to the column index.

def triu(j: int) -> TT["j", "j"]:
    return where(arange(j)[:, None] <= arange(j), 1, 0)

Puzzle 7 - cumsum

Compute cumsum - the cumulative sum.

Solution: For each value i in the array, we want the sum from 0 to i. We can multiply by the triu pattern to achieve this.

def cumsum(a: TT["i"]) -> TT["i"]:
    return a @ triu(a.shape[0])

Puzzle 8 - diff

Compute diff - the running difference.

Solution: Simply subtract the array from itself shifted by one, and set the first value to the original value.

 def diff(a: TT["i"], i: int) -> TT["i"]:
    return where(arange(i) > 0, a[arange(i)] - a[arange(i) - 1], a[arange(i)])

Puzzle 9 - vstack

Compute vstack - the matrix of two vectors

Solution: Create 2x1 matrix [[0], [1]] and broadcast compare it with ones. (2x1) broadcasting with (4,) will result in (2, 4) matrix. Le

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
    return where(arange(2)[:, None] != ones(a.shape[0]), a, b)

Puzzle 10 - roll

Compute roll - the vector shifted 1 circular position.

Note: I believe, but have not confirmed, that the solution accepted by the test_roll and roll_spec functions does not faithfully recreate the same behavior as numpy.roll; namely, no matter what shift value is provided, the accepted solution expected only a shift of 1. I’m not sure why this is the desired behavior, but I have provided a more faithful recreation of numpy.roll below.

Solution: Use the modulo operator to shift the array by position.

def roll(a: TT["i"], i: int) -> TT["i"]:
    return a[(arange(a.shape[0]) - i) % a.shape[0]]

Puzzle 11 - flip

Compute flip - the reversed vector

Solution: Create a reversed index array by subtracting the current index from the total length. Use this array to index into the original array.

def flip(a: TT["i"], i: int) -> TT["i"]:
    return a[i - arange(i) - 1]

Puzzle 12 - compress

Compute compress - keep only masked entries (left-aligned).

Note: Again the accepted solution does not faithfully recreate the same behavior as numpy.compress. The accepted appears to expected the output to have the same shape as the input, with the masked entries set to 0. As of numpy 2.1, np.compress returns an output with number of elements equal to the number of True values in the mask.

# accepted solution
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    return v @ where(g[:, None], arange(i) == cumsum(1 * g)[:, None] - 1, 0)

# np.compress faithful
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    return v[arange(i)[g]]

Puzzle 13 - pad_to

Compute pad_to - eliminate or add 0s to change size of vector.

Solution: Create a non-squre identity matrix of size i x j and matmul with the vector of shape i. i @ [i x j] = [j] (because i gets broadcasted to [1 x i]).

Note: here is a utility function to create a non-square identity matrix:

def eye(i: int, j: int) -> TT["i", "j"]:
    return arange(i)[: None] == arange(j)
def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
    return a @ (1 * (arange(i)[:, None] == arange(j)))

Note: the 1 * is necessary to convert the boolean array to an integer array.

Puzzle 14 - sequence_mask

Compute sequence_mask - pad out to length per batch.

Solution: The sequence_mask operation is used to mask a batch of sequences based on their lengths. The values input is a 2D tensor of shape (batch_size, max_length) representing a batch of sequences, and length is a 1D tensor of shape (batch_size,) representing the length of each sequence in the batch. The operation masks each sequence in the batch based on its length, zero-padding or truncating each sequence to the specified length.

def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    """
    max_seq_len = values.shape[1]
    arange(max_seq_len) < length[:, None]
    (j, )                 (i, 1)          -> (i, j)
    """
    return where(arange(values.shape[1]) < length[:, None], values, 0)

Puzzle 15 - bincount

Compute bincount - count number of times an entry was seen.

Solution: The trick here lies with eyej, which will have shape (i, j). That part just says that for value in a, get the corresponding row from the identity matrix. We have a guarantee that j = max(a) + 1, so eye(j) will contain all of the relevant rows. Now, we can just do a simple matmul with a ones array (1 x i) @ (i x j) -> (1 x j) -> (j,)to vertically sum the rows, which counts the occurrences of each value.

def bincount(a: TT["i"], j: int) -> TT["j"]:
    return ones(a.shape[0]) @ eye(j)[a]

Puzzle 16 - scatter_add

Compute scatter_add - add together values that link to the same location.

Solution: We use a similar technique from bincount except now we multiply by values since we’re trying to sum to a specific position in a new array. The link tells us where to go and using the eye matrix lets us know which values to pull out for each position.

def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    return values @ eye(j)[link]

Puzzle 17 - flatten

Compute flatten

Solution: This problem can be solved via indexing with multidimensional indexing arrays. According to the numpy docs, in this case, when the index arrays have a matching shape, and there is an index array for each dimension of the array being indexed, the resultant array has the same shape as the index arrays, and the values correspond to the index set for each position in the index arrays. In this case, we will use two 1D index arrays to flatten the 2D array. The first index array will be the row indices, and the second index array will be the column indices.

arange(i) // j -> [0, 0, 0, ...0, 1, 1, 1, ...,1, ..., i-1]
arange(i) % j -> [0, 1, 2, ...j-1, 0, 1, 2, ...j-1, ..., 0, 1, 2, ...j-1]

So we have:

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
    return a[arange(i) // j, arange(i) % j]

Puzzle 18 - linspace

Compute linspace

Solution: The trick here is to use the arange function to create a tensor of evenly spaced values. We can then scale and shift these values to fit the desired range.

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
    return (j - i) * arange(n) / max(1, n - 1) + i

Puzzle 19 - heaviside

Compute heaviside

Solution: The heaviside function H(x1,x2)H(x_1, x_2) is defined as 0 for negative values, x2x_2 for zero, and 1 for positive values. We can use the where function to implement this logic.

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
    return where(a == 0, b, a > 0)

Puzzle 20 - repeat (1d)

Compute repeat

Solution: The trick here is to use the arange function to create a tensor of indices that will allow us to repeat the elements of the input tensor according to the specified number of repetitions.

def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
    return ones(d)[:, None] @ a[None]

Puzzle 21 - bucketize

Compute bucketize

Solution: Use broadcasting to compare each input value with all of the boundaries. This gives us a boolean array indicating which boundaries each value is greater than. We can then sum across the columns to get the index of the bucket for each value. We can sum across the columns by computing the outer product of the boolean array with a vector of ones.

def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
   return (1 * (v[:, None] >= boundaries)) @ ones(boundaries.shape[0])

General Tips and Tricks

These tensor puzzles demonstrate a variety of powerful techniques for manipulating tensors using only basic operations. Here are some key takeaways and general techniques explored:

  1. Broadcasting: Many puzzles leverage broadcasting to perform operations between tensors of different shapes. Understanding how to use [:, None] or [None, :] to add dimensions is crucial (e.g., outer, eye, triu).

  2. Indexing with arange: The arange function is frequently used to create index arrays, enabling complex slicing and rearrangement of tensor elements (e.g., diag, roll, flip).

  3. Conditional operations with where: The where function is a powerful tool for implementing conditional logic without explicit loops or if-statements (e.g., ones, compress, sequence_mask).

  4. Matrix multiplication for aggregation: Matrix multiplication (@) is used in creative ways to sum or aggregate values (e.g., sum, cumsum, bincount).

  5. Creating and using identity matrices: Several puzzles use the concept of identity matrices (or variations) for selection or mapping operations (e.g., eye, pad_to, scatter_add).

  6. Modular arithmetic: Operations like modulo (%) and integer division (//) are useful for creating repeating patterns or mapping between different shapes (e.g., roll, flatten).

  7. Comparison operations for masking: Many solutions use comparison operations to create boolean masks, which are then used for selection or filtering (e.g., triu, compress, bucketize).

  8. Clever use of cumsum: The cumulative sum operation, implemented using matrix multiplication, proves useful in problems involving running totals or inclusive scans (e.g., cumsum, compress).

  9. Dimension manipulation: Adding or removing dimensions strategically can enable powerful broadcasting operations (e.g., vstack, repeat).

  10. Linear algebra concepts: Some solutions leverage linear algebra concepts like outer products and matrix transformations (e.g., outer, linspace).

Matrix Multiplication Tricks

1. Multiplication with Identity Matrix (eye)

The identity matrix (eye) is a powerful tool in tensor manipulations. When you multiply a matrix by an identity matrix, it essentially selects or rearranges elements.

def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    return values @ eye(j)[link]

2. Multiplication with Ones Vector/Matrix

Multiplying with a vector or matrix of ones is a common technique for summing or aggregating values along certain dimensions.

def sum(a: TT["i"]) -> TT[1]:
    return ones(a.shape[0]) @ a[:, None]
def bincount(a: TT["i"], j: int) -> TT["j"]:
    return ones(a.shape[0]) @ eye(j)[a]

3. Multiplication with Upper Triangular Matrix

An upper triangular matrix can be used for cumulative operations.

def cumsum(a: TT["i"]) -> TT["i"]:
    return a @ triu(a.shape[0])

4. Outer Product

The outer product, which can be seen as a special case of matrix multiplication, is useful for creating 2D patterns from 1D inputs.

def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    return a[:, None] @ b[None]

5. Boolean Matrices

Matrix multiplication with boolean matrices (often created through comparison operations) can be used for filtering or masking operations.

def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    return where(arange(values.shape[1]) < length[:, None], values, 0)

These matrix multiplication tricks demonstrate how fundamental linear algebra operations can be leveraged to perform a wide range of tensor manipulations efficiently. By choosing the right matrix to multiply with (identity, ones, triangular, etc.), we can achieve various goals such as selection, aggregation, cumulative operations, and filtering, all without explicit loops or complex indexing schemes.

Conclusion

Understanding these patterns allows us to write more efficient and vectorized code, which is crucial for performance in numerical computing and machine learning tasks.

These puzzles demonstrate that a wide range of complex tensor operations can be implemented using a small set of fundamental operations. Mastering these techniques can lead to more efficient and vectorized code, avoiding explicit loops and improving performance in numerical computing tasks.

Moreover, these exercises highlight the importance of thinking in terms of whole-array operations rather than element-wise manipulations. This mindset is crucial for writing efficient code in libraries like NumPy, PyTorch, and TensorFlow.

By solving these puzzles, we’ve gained a deeper understanding of tensor operations and improved our ability to write concise, vectorized code for complex numerical tasks.


Previous Post
Automating App Deployment on a VPS with GitHub Actions