1. Use Shared Memory

In the previous tutorial, every thread block loaded its tiles of A and B directly from global memory into registers. Global memory has high latency, so a natural next step is to stage the data through shared memory – a small, fast on-chip scratchpad that is visible to all threads in the same block.

Why shared memory helps

Shared memory is much faster than global memory (on the order of 100x lower latency). By first copying a tile from global memory into shared memory, all threads in the block can then read from the shared copy at high bandwidth. This is especially beneficial when the same data is read multiple times by different threads, as is the case in matrix multiplication where every element of a tile participates in multiple multiply-accumulate operations.

The data flow for each iteration of the inner loop becomes:

  1. Load tiles from global memory into register tensors.

  2. Store those register tensors into shared memory (store_shared).

  3. Synchronize all threads so the shared data is fully written.

  4. Load from shared memory back into registers (load_shared).

  5. Compute the dot product.

  6. Synchronize again before the next iteration overwrites shared memory.

Kernel implementation

MatmulV1 – matrix multiplication with shared memory
class MatmulV1(tilus.Script):
    def __init__(self, num_warps=4, block_m=64, block_n=64, block_k=16):
        super().__init__()
        self.num_warps = num_warps
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),
            cdiv(n_size, self.block_n),
        ]
        self.attrs.warps = self.num_warps

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])

        # allocate shared memory for the tiles for A and B
        sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
        sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])

        acc = self.register_tensor(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        for offset_k in range(0, k_size, self.block_k):
            # load a tile of A matrix from global memory to shared memory
            lda = self.load_global(
                ga,
                offsets=[offset_m, offset_k],
                shape=[self.block_m, self.block_k],
            )

            # store the loaded tile in shared memory
            self.store_shared(sa, lda)

            # load a tile of B matrix from global memory to shared memory
            ldb = self.load_global(
                gb,
                offsets=[offset_k, offset_n],
                shape=[self.block_k, self.block_n],
            )

            # store the loaded tile in shared memory
            self.store_shared(sb, ldb)

            # synchronize threads to ensure all have stored their data in shared memory
            self.sync()

            # load the tiles from shared memory to registers
            a = self.load_shared(sa)
            b = self.load_shared(sb)

            acc = self.dot(a, b, acc)
            self.sync()

        self.free_shared(sa)
        self.free_shared(sb)

        casted_acc = self.cast(acc, dtype=float16)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])

The kernel follows the same overall structure as MatmulV0, with two key additions: shared memory tiles and explicit synchronization.

Shared memory tiles

At the top of the __call__ method we allocate two shared tensors, sa and sb, to hold the current tiles of A and B respectively:

sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])

Inside the loop, data flows through shared memory before reaching the accumulator:

# global -> registers -> shared memory
lda = self.load_global(ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k])
self.store_shared(sa, lda)

# ... same for B ...

self.sync()  # ensure all stores to shared memory are visible

# shared memory -> registers
a = self.load_shared(sa)
b = self.load_shared(sb)

acc = self.dot(a, b, acc)
self.sync()  # ensure dot is done before next iteration overwrites shared

At the end, both shared tensors are freed so their memory can be reused:

self.free_shared(sa)
self.free_shared(sb)

Why two synchronizations?

Both self.sync() calls are necessary because sa and sb are reused across loop iterations:

  • The first sync (after store_shared) guarantees that every thread has finished writing its portion of the tile into shared memory before any thread tries to read from it via load_shared.

  • The second sync (after dot) guarantees that the dot product has finished reading from shared memory before the next iteration overwrites it with new data.

Omitting either synchronization leads to a data race.

New instructions

This example introduces five new Script methods compared to MatmulV0:

  • shared_tensor() – allocate a shared-memory tensor with a given dtype and shape.

  • store_shared() – copy a register tensor into a shared tensor.

  • load_shared() – copy a shared tensor into a new register tensor.

  • free_shared() – release the shared memory so it can be reused. Every allocation must be freed before the kernel ends.

  • sync() – synchronize all threads in the thread block (equivalent to __syncthreads() in CUDA C).

Launching the kernel

The launch code is identical in structure to the previous version – create an instance, prepare tensors, and call the script:

Benchmarking MatmulV1
def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [
        [4096, 4096, 4096],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV1()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        torch.cuda.synchronize()

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)

Full source

The complete example is available at examples/matmul/matmul_v1.py.