Skip to content

GPU Puzzles

Published: at 05:32 AM in 10 min readSuggest Changes

My solutions to srush’s GPU Puzzles. The puzzles are a series of exercises that test your understanding of tensor operations and derivatives. The puzzles are written in Python and use the numpy library.

Introduction

GPU Programming for Beginners: A Hands-On Approach with NUMBA GPU architectures are essential to modern machine learning, but it’s possible to be an expert in the field without ever touching GPU code directly. This can make it difficult to develop intuition about how these architectures work.

This blog post introduces a hands-on approach to learning beginner GPU programming. Instead of focusing on theoretical concepts, we’ll dive straight into coding and building GPU kernels using NUMBA, a Python library that directly maps Python code to CUDA kernels.

By the end of this post, you’ll have a basic understanding of the real algorithms that power most deep learning today.

Getting Started

Before we begin, make sure you have access to a GPU. If you’re using Google Colab, you can enable GPU mode by navigating to Runtime / Change runtime type and setting Hardware accelerator to GPU.

Next, install the necessary libraries:

!pip install -qqq git+https://github.com/chalk-diagrams/planar git+https://github.com/danoneata/chalk@srush-patch-1
!wget -q https://github.com/srush/GPU-Puzzles/raw/main/robot.png https://github.com/srush/GPU-Puzzles/raw/main/lib.py

Import the necessary libraries:

import numba
import numpy as np
import warnings
from lib import CudaProblem, Coord

The Puzzles

Puzzle 1: Map

Problem: Implement a kernel that adds 10 to each position of a vector a and stores it in vector out. You have one thread per position.

Solution:

def map_test(cuda):
    def call(out, a) -> None:
        local_i = cuda.threadIdx.x
        out[local_i] = a[local_i] + 10

    return call

Explanation:

Puzzle 2: Zip

Problem: Implement a kernel that adds together each position of a and b and stores it in out. You have one thread per position.

Solution:

def zip_test(cuda):
    def call(out, a, b) -> None:
        local_i = cuda.threadIdx.x
        out[local_i] = a[local_i] + b[local_i]

    return call

Explanation: Similar to the previous puzzle, we use the thread index to access corresponding elements of the input arrays and perform the addition.

Puzzle 3: Guards

Problem: Implement a kernel that adds 10 to each position of a and stores it in out. You have more threads than positions.

Solution:

def map_guard_test(cuda):
    def call(out, a, size) -> None:
        local_i = cuda.threadIdx.x
        if local_i < size:
            out[local_i] = a[local_i] + 10

    return call

Explanation: We introduce a guard condition (if local_i < size) to ensure that threads with indices exceeding the size of the input array do not perform any operations.

Puzzle 4: Map 2D

Problem: Implement a kernel that adds 10 to each position of a and stores it in out. Input a is 2D and square. You have more threads than positions.

Solution:

def map_2D_test(cuda):
    def call(out, a, size) -> None:
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y
        if local_i < size and local_j < size:
            out[local_i, local_j] = a[local_i, local_j] + 10

    return call

Explanation: We extend the previous puzzle to handle 2D arrays by using both cuda.threadIdx.x and cuda.threadIdx.y to access elements.

Puzzle 5: Broadcast

Problem: Implement a kernel that adds a and b and stores it in out. Inputs a and b are vectors. You have more threads than positions.

Solution:

def broadcast_test(cuda):
    def call(out, a, b, size) -> None:
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y
        if local_i < size and local_j < size:
            out[local_i, local_j] = a[local_i, 0] + b[0, local_j]

    return call

Explanation: This puzzle involves broadcasting, where we add a column vector to a row vector to produce a matrix. We use the thread indices to access the appropriate elements for the addition.

Puzzle 6: Blocks

Problem: Implement a kernel that adds 10 to each position of a and stores it in out. You have fewer threads per block than the size of a.

Solution:

def map_block_test(cuda):
    def call(out, a, size) -> None:
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        if i < size:
            out[i] = a[i] + 10

    return call

Explanation:

Puzzle 7: Blocks 2D

Problem: Implement the same kernel in 2D. You have fewer threads per block than the size of a in both directions.

Solution:

def map_block2D_test(cuda):
    def call(out, a, size) -> None:
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        j = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
        if i < size and j < size:
            out[i, j] = a[i, j] + 10

    return call

Explanation: Similar to the previous puzzle, but we now use cuda.blockIdx.y and cuda.blockDim.y to handle the y-dimension.

Puzzle 8: Shared

Problem: Implement a kernel that adds 10 to each position of a and stores it in out. You have fewer threads per block than the size of a.

Solution:

TPB = 4
def shared_test(cuda):
    def call(out, a, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            shared[local_i] = a[i]
            cuda.syncthreads()

        if i < size:
            out[i] = shared[local_i] + 10

    return call

Explanation:

Puzzle 9: Pooling

Problem: Implement a kernel that sums together the last 3 positions of a and stores it in out. You have one thread per position. You only need one global read and one global write per thread.

Solution:

TPB = 8
def pool_test(cuda):
    def call(out, a, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            shared[local_i] = a[i]
            cuda.syncthreads()
            o1 = shared[local_i]
            o2 = shared[local_i - 1] if local_i > 0 else 0
            o3 = shared[local_i - 2] if local_i > 1 else 0
            out[i] = o1 + o2 + o3

    return call

Explanation: We use shared memory to store a small portion of the input array, allowing each thread to access the necessary elements for the pooling operation with only one global read.

Puzzle 10: Dot Product

Problem: Implement a kernel that computes the dot product of a and b and stores it in out. You have one thread per position. You only need two global reads and one global write per thread.

Solution:

TPB = 8
def dot_test(cuda):
    def call(out, a, b, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)

        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            shared[local_i] = a[i] * b[i]
        else:
            shared[local_i] = 0

        cuda.syncthreads()

        # Naive sequential reduction
        # if local_i == 0:
        #     total_sum = 0
        #     for j in range(TPB):
        #         total_sum += shared[j]
        #     out[0] = total_su

        # Prefix-sum Reduction in shared memory (covered again in Puzzle 12)
        s = TPB // 2
        while s > 0:
            if local_i < s:
                shared[local_i] += shared[local_i + s]
            cuda.syncthreads()
            s //= 2

        if local_i == 0:
            out[0] = shared[0]  # Store the final result

    return call

Explanation:

Puzzle 11: 1D Convolution

Problem: Implement a kernel that computes a 1D convolution between a and b and stores it in out. You need to handle the general case. You only need two global reads and one global write per thread.

Solution:

MAX_CONV = 4
TPB = 8
TPB_MAX_CONV = TPB + MAX_CONV
def conv_test(cuda):
    def call(out, a, b, a_size, b_size) -> None:
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        # Shared memory
        shared_a = cuda.shared.array(TPB_MAX_CONV, numba.float32)
        shared_b = cuda.shared.array(MAX_CONV, numba.float32)  # Shared memory for b

        # Load into shared memory
        if i < a_size:
            shared_a[local_i] = a[i]

        if local_i < b_size:
            shared_b[local_i] = b[local_i]
        else:
            # If we are not in the last block, we need to consider
            # values that extend into the next block for computing the full
            # conv
            local_i2 = local_i - b_size
            i2 = i - b_size
            if i2 + TPB < a_size and local_i2 < b_size:
                shared_a[TPB + local_i2] = a[i2 + TPB]

        cuda.syncthreads()

        conv_sum = 0
        for k in range(b_size):
            if i + k < a_size:  # Check if within bounds
                conv_sum += shared_a[local_i + k] * shared_b[k]

        if i < a_size:
            out[i] = conv_sum

    return call

Explanation:

Puzzle 12: Prefix Sum

Problem: Implement a kernel that computes a sum over a and stores it in out. If the size of a is greater than the block size, only store the sum of each block.

Solution:

TPB = 8
def sum_test(cuda):
    def call(out, a, size: int) -> None:
        cache = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            cache[local_i] = a[i]
            cuda.syncthreads()

            for k in range(3):
                p = 2 ** k
                if local_i % (p * 2) == 0:
                    if i + p < size:
                        cache[local_i] = cache[local_i] + cache[local_i + p]
                cuda.syncthreads()
            if local_i == 0:
                out[cuda.blockIdx.x] = cache[local_i]

    return call

Explanation:

Puzzle 13: Axis Sum

Problem: Implement a kernel that computes a sum over each column of a and stores it in out.

Solution:

TPB = 8
def axis_sum_test(cuda):
    def call(out, a, size: int) -> None:
        cache = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x
        batch = cuda.blockIdx.y
        if i < size:
            cache[local_i] = a[batch, i]
            cuda.syncthreads()
            for k in range(3):
                p = 2 ** k
                if local_i % (p * 2) == 0:
                    if i + p < size:
                        cache[local_i] = cache[local_i] + cache[local_i + p]
                cuda.syncthreads()
            if local_i == 0:
                out[batch, 0] = cache[local_i]

    return call

Explanation:

Puzzle 14: Matrix Multiply!

Problem: Implement a kernel that multiplies square matrices a and b and stores the result in out.

Solution:

TPB = 3
def mm_oneblock_test(cuda):
    def call(out, a, b, size: int) -> None:
        a_shared = cuda.shared.array((TPB, TPB), numba.float32)
        b_shared = cuda.shared.array((TPB, TPB), numba.float32)

        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        j = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y

        acc = 0
        for k in range(0, size, TPB):
            # Load a and b into shared memory
            if i < size and k + local_j < size:
                a_shared[local_i, local_j] = a[i, k + local_j]
            if j < size and k + local_i < size:
                b_shared[local_i, local_j] = b[k + local_i, j]
            cuda.syncthreads()

            for local_k in range(min(TPB, size - k)):
                acc += a_shared[local_i, local_k] * b_shared[local_k, local_j]
        if i < size and j < size:
            out[i, j] = acc

    return call

Explanation:

This puzzle challenges you to implement matrix multiplication, a core operation in linear algebra and deep learning. The solution utilizes shared memory to store blocks of the input matrices, allowing each thread to efficiently compute partial results and ultimately contribute to the final product.

Conclusion

Through these puzzles, we’ve explored fundamental concepts in GPU programming using NUMBA. By directly coding GPU kernels, we’ve gained valuable intuition about how these architectures work and how to optimize code for performance.

NUMBA provides a user-friendly interface for writing CUDA kernels, making it an excellent tool for beginners to learn GPU programming. With further exploration and practice, you can leverage these concepts to accelerate your deep learning models and tackle complex computational tasks.


Previous Post
Flash Attention in a Flash
Next Post
Computing the Jacobian of a Matrix Product