0. A Minimal Blackwell Matmul

This first version implements a minimal but correct matrix multiplication kernel on Blackwell GPUs. It introduces two key Blackwell features: Tensor Memory and 5th-generation Tensor Cores (tcgen05) and asynchronous barriers (mbarrier).

The kernel is not yet fast (we will optimize it step by step in later versions), but it establishes the foundation for everything that follows.

The Full Kernel

Before diving into the details, here is the complete kernel so you can see the big picture. We will explain each part in the sections that follow.

Hint

To view the generated CUDA source code, check the cache directory. See Cache for details.

BlackwellMatmulV0 — full kernel
@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
@tilus.autotune("block_k", [16, 32, 64])
class BlackwellMatmulV0(tilus.Script):
    def __init__(self, block_m: int, block_n: int, block_k: int):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        # set the number of blocks and warps for the kernel launch
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        # compute the tile offset from the block index
        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        # create global tensor views from raw pointers
        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])

        # allocate shared memory tiles for A and B
        s_a = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
        s_b = self.shared_tensor(dtype=float16, shape=[self.block_n, self.block_k])

        # allocate a tensor in tensor memory (tmem) as the MMA accumulator
        t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])

        # allocate one mbarrier to track MMA completion
        mbarriers = self.mbarrier.alloc(counts=[1])

        # mbarrier phase flips between 0 and 1 after each wait
        phase: uint32 = 0

        # synchronize all threads before entering the main loop
        self.sync()

        for offset_k in range(0, k_size, self.block_k):
            # async copy tiles from global to shared memory (legacy, non-TMA)
            self.copy_async(src=g_a, dst=s_a, offsets=[offset_m, offset_k])
            self.copy_async(src=g_b, dst=s_b, offsets=[offset_n, offset_k])
            self.copy_async_wait_all()
            self.sync()

            # tcgen05 instructions are warp-cooperative (issued by a single warp)
            with self.single_warp():
                # D = A @ B (first iter) or D = A @ B + D (subsequent iters)
                self.tcgen05.mma(
                    s_a, s_b.transpose(), t_acc, enable_input_d=offset_k != 0
                )
                # make the mbarrier track completion of prior async tcgen05 ops
                self.tcgen05.commit(mbarrier=mbarriers[0])
                # wait until the MMA writes to tmem are complete
                self.mbarrier.wait(mbarriers[0], phase=phase)
            self.sync()

            phase ^= 1

        # load the result from tensor memory to registers
        r_acc = self.tcgen05.load(t_acc)

        # cast to float16 and store to global memory
        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

        # all allocated tensor memory must be deallocated before kernel exits
        self.sync()
        self.tcgen05.dealloc(t_acc)

Block Tiling

We compute \(C = A \times B^T\) where A is (M, K) and B is (N, K). The output matrix C is (M, N).

Each thread block is responsible for computing one block_m x block_n tile of C. The K dimension is iterated in chunks of block_k.

../../_images/v0_block_tiling.svg

Block tiling of the matmul. Each thread block computes one output tile. The hatched regions show the full slices of A and BT that participate in computing the highlighted C tile.

Note

Data layout: K-major. Blackwell tensor cores expect operands in one of two shared memory layouts: MN-major or K-major. This tutorial uses K-major (K is the contiguous dimension), so A is [M, K] and B is [N, K]. tcgen05.mma expects logical shapes [M, K] and [K, N], which is why we call s_b.transpose(), a view operation that reinterprets the layout without copying data.

Data Flow

Triton also uses Blackwell hardware features like TMA and tcgen05, but manages them automatically through compiler passes. Tilus opens the black box: you control memory placement, data movement, and synchronization directly, which is necessary for achieving peak performance. The kernel moves data through four memory levels:

../../_images/v0_data_flow.svg

Data flow in the kernel: Global Memory → Shared Memory → Tensor Memory → Registers → Global Memory.

  1. GlobalShared: copy_async() loads tiles of A and B from global memory into shared memory asynchronously. This uses the legacy async copy mechanism (Ampere-era); we will replace it with the hardware TMA engine in the next version.

  2. SharedTensor Memory: tcgen05.mma() reads operands from shared memory and accumulates the result in tensor memory.

  3. Tensor MemoryRegister: tcgen05.load() moves the accumulated result from tensor memory to registers.

  4. RegisterGlobal: store_global() writes the final result back to global memory.

Tensor Memory (TMEM)

On pre-Blackwell architectures, MMA results are accumulated in the register file, which creates two problems: the accumulator tiles consume a large number of registers per thread (e.g., 128 registers just for one Hopper GMMA N=256 result), and the MMA traffic monopolizes register file bandwidth, stalling other work. Blackwell introduces Tensor Memory, a dedicated on-chip memory private to the SM’s tensor cores. By moving the accumulator into TMEM, the register file is freed for epilogue and load instructions to run concurrently with the tensor cores. If you have used Triton, tensor memory replaces Triton’s implicit register-based accumulation; here you manage the allocation and lifecycle explicitly.

Tensor Memory is organized as a 2D structure of 128 lanes (rows) and 512 columns per CTA, with each cell being 32 bits. Memory is allocated in units of 32 columns.

../../_images/tmem_layout.svg

Tensor Memory layout: 128 lanes x 512 columns, each cell 32 bits.

In Tilus, the lifecycle of a tensor memory allocation is:

  1. tcgen05.alloc() — allocate a TMemoryTensor in tensor memory.

  2. tcgen05.mma() — use it as the accumulator in MMA operations.

  3. tcgen05.load() — read the result out to a RegisterTensor.

  4. tcgen05.dealloc() — free the allocation (required before the kernel exits).

For more details, see Script.tcgen05.

Asynchronous Barriers (mbarrier)

In Triton, synchronization is handled implicitly. On Blackwell, many operations are asynchronous: the instruction returns immediately and the work completes in the background. This enables overlapping data movement with computation, but requires explicit tracking of when operations finish. This is the role of the mbarrier (memory barrier, see Script.mbarrier).

../../_images/mbarrier_state.svg

An mbarrier tracks pending arrivals and a phase bit.

An mbarrier is a 64-bit synchronization object in shared memory that tracks two things:

  • Pending arrivals: how many threads still need to signal they are done. Each mbarrier.arrive() call decrements this count.

  • Phase (1 bit): flips between 0 and 1 each time a phase completes.

Note

mbarriers can also track asynchronous byte transactions (tx-count) for TMA transfers. We will introduce this in the next version.

A phase completes when all pending arrivals reach zero. At that point, the hardware automatically flips the phase bit and resets the counters for the next phase.

Wait checks the phase: mbarrier.wait(barrier, phase=p) blocks until the barrier’s current phase differs from p. When the phase has flipped, the tracked operations are guaranteed to have completed.

Why flip the phase? The same barrier is reused across loop iterations. The phase bit distinguishes “this iteration completed” from “the previous iteration completed.” After each wait, the caller flips its local phase variable (phase ^= 1) so the next wait targets the new phase:

phase: uint32 = 0               # start expecting phase 0
for ...:
    ...                          # do async work
    # arrive on the barrier (e.g., via tcgen05.commit)
    ...
    self.mbarrier.wait(barrier, phase=phase)  # wait for current phase
    phase ^= 1                  # next iteration waits for the other phase

In this kernel, we use an mbarrier to track when tcgen05.mma() has finished writing to tensor memory. tcgen05.commit() groups all prior async tcgen05 operations issued by the current warp (such as tcgen05.mma() and tcgen05.copy()) and signals one arrival (arrival count = 1) on the given mbarrier when those operations complete.

Thread Groups

By default, every instruction in a Tilus kernel operates on the entire thread block: the __call__ body defines the behavior of all threads in the block collectively. However, efficient matrix multiplication kernels on Hopper and Blackwell architectures require different warps to perform different jobs and collaborate with each other asynchronously. To narrow the execution scope to a subset of threads, Tilus provides thread groups.

A thread group selects a subset of threads within the block using thread_group(). For example:

with self.thread_group(thread_begin=0, num_threads=32):
    # Only threads 0-31 (one warp) execute this
    ...

with self.thread_group(thread_begin=32, num_threads=32):
    # Only threads 32-63 execute this
    ...

Tilus also provides shortcuts for common patterns: single_thread() for one thread, single_warp() for one warp (32 threads), and warp_group() for multiple warps.

Note that Tilus does not expose threadIdx to the user. There is no way to write if threadIdx.x < 32 in a Tilus program. Instead, use thread_group() and its shortcuts to narrow the execution scope.

Every Tilus instruction has a requirement on the thread group it can execute in. Some instructions work in any thread group, while others require a single thread, a single warp, or a warp group. In this kernel, tcgen05.mma() and tcgen05.commit() require the thread group to be a single warp. We use single_warp() to narrow the execution scope:

with self.single_warp():
    # D = A @ B (first iter) or D = A @ B + D (subsequent iters)
    self.tcgen05.mma(
        s_a, s_b.transpose(), t_acc, enable_input_d=offset_k != 0
    )
    # make the mbarrier track completion of prior async tcgen05 ops
    self.tcgen05.commit(mbarrier=mbarriers[0])
    # wait until the MMA writes to tmem are complete
    self.mbarrier.wait(mbarriers[0], phase=phase)

For more details, see Thread Group.

Walkthrough

With the key Blackwell features covered above (tensor memory, asynchronous barriers, and thread groups), let us now walk through the kernel code in detail.

A Tilus kernel is defined as a subclass of Script. The __init__ method stores compile-time hyperparameters (tile sizes), and __call__ describes the kernel logic. For more on the script structure, see Tilus Script.

The @tilus.autotune decorators define a search space for compile-time hyperparameters. Tilus benchmarks all combinations and picks the fastest configuration automatically. For more on autotuning, see Autotuning.

Kernel Setup

Kernel setup
# set the number of blocks and warps for the kernel launch
self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
self.attrs.warps = 4

# compute the tile offset from the block index
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y

# create global tensor views from raw pointers
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])

# allocate shared memory tiles for A and B
s_a = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
s_b = self.shared_tensor(dtype=float16, shape=[self.block_n, self.block_k])

# allocate a tensor in tensor memory (tmem) as the MMA accumulator
t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])

# allocate one mbarrier to track MMA completion
mbarriers = self.mbarrier.alloc(counts=[1])

# mbarrier phase flips between 0 and 1 after each wait
phase: uint32 = 0

# synchronize all threads before entering the main loop
self.sync()
  • self.attrs.blocks sets the grid dimensions: ceil(M / block_m) x ceil(N / block_n) thread blocks.

  • self.attrs.warps sets the number of warps per block. Here we use 4 warps (128 threads). Tensor memory has 128 lanes, and tcgen05.load() maps each thread to one lane: warp 0 covers lanes 0–31, warp 1 covers lanes 32–63, and so on. Four warps are needed to cover all 128 lanes and load the complete accumulator tile. Later versions will use more warps to overlap loading and computing.

  • offset_m and offset_n are the output tile offsets, computed from the block index (blockIdx).

  • global_view() interprets the raw pointers as 2D global memory tensors with the given dtype and shape.

  • shared_tensor() allocates shared memory tiles for staging A and B data.

  • tcgen05.alloc() allocates the tensor memory accumulator.

  • mbarrier.alloc() allocates one mbarrier with an expected arrival count of 1, because tcgen05.commit() will perform exactly one arrival on the barrier when the MMA completes.

  • sync() ensures all preceding instructions (such as barrier and shared memory allocations) have completed before execution proceeds. Without it, subsequent instructions could execute concurrently with the allocations, leading to use of uninitialized resources. This is lowered to __syncthreads() in CUDA.

Main Loop

Main loop
for offset_k in range(0, k_size, self.block_k):
    # async copy tiles from global to shared memory (legacy, non-TMA)
    self.copy_async(src=g_a, dst=s_a, offsets=[offset_m, offset_k])
    self.copy_async(src=g_b, dst=s_b, offsets=[offset_n, offset_k])
    self.copy_async_wait_all()
    self.sync()

    # tcgen05 instructions are warp-cooperative (issued by a single warp)
    with self.single_warp():
        # D = A @ B (first iter) or D = A @ B + D (subsequent iters)
        self.tcgen05.mma(
            s_a, s_b.transpose(), t_acc, enable_input_d=offset_k != 0
        )
        # make the mbarrier track completion of prior async tcgen05 ops
        self.tcgen05.commit(mbarrier=mbarriers[0])
        # wait until the MMA writes to tmem are complete
        self.mbarrier.wait(mbarriers[0], phase=phase)
    self.sync()

    phase ^= 1

In each iteration:

  • copy_async() loads a block_m x block_k tile of A and a block_n x block_k tile of B from global to shared memory. copy_async_wait_all() waits for all outstanding copies, and sync() ensures all threads see the shared memory writes before the MMA warp reads them.

  • single_warp() narrows the execution scope to one warp. tcgen05.mma() multiplies the two tiles and accumulates into tensor memory. enable_input_d=offset_k != 0 controls whether the existing accumulator value is used as an addend. The MMA computes \(D = A \times B + D\) when enabled, or \(D = A \times B\) when disabled. On the first K-iteration (offset_k == 0), tensor memory contains uninitialized data, so we disable the addend. On subsequent iterations, the accumulator holds the running sum from prior tiles, so we enable it to accumulate.

  • tcgen05.commit() groups all prior async tcgen05 operations issued by the current warp and signals one arrival on the mbarrier when they complete. mbarrier.wait() blocks until the MMA writes to tensor memory are complete.

  • sync() after the single_warp block ensures all threads reconverge before the next iteration.

  • phase ^= 1 flips the local phase so the next mbarrier.wait targets the new phase of the reused barrier.

Epilogue

Epilogue
# load the result from tensor memory to registers
r_acc = self.tcgen05.load(t_acc)

# cast to float16 and store to global memory
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

# all allocated tensor memory must be deallocated before kernel exits
self.sync()
self.tcgen05.dealloc(t_acc)

After the loop, tcgen05.load() moves the accumulator from tensor memory to registers, .to(float16) casts it, and store_global() writes the result to global memory. Finally, tcgen05.dealloc() frees the tensor memory. All TMEM allocations must be explicitly deallocated before the kernel exits.

Running the Kernel

BlackwellMatmulV0() creates a kernel template. Compilation happens on the first call.

Note the two different integer annotations in the function signature:

  • int32 (e.g., m_size: int32): a runtime parameter. The value is passed to the GPU kernel as an argument and can change between calls without recompilation.

  • int (e.g., n_size: int, k_size: int): a compile-time constant. The value is baked into the generated CUDA code, so a new value triggers JIT recompilation and autotuning.

Making n_size and k_size compile-time constants allows the compiler to specialize the kernel (e.g., unroll loops, compute constant addresses). For more details, see Tilus Script.

Once compiled, subsequent calls with the same compile-time values dispatch directly to the GPU.

Launch, verify, and benchmark
def main(bench=True):
    matmul = BlackwellMatmulV0()

    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    rows = []

    for m_size, n_size, k_size in [
        [8192, 8192, 8192],
    ]:
        print(f"Running with m_size={m_size}, n_size={n_size}, k_size={k_size}")
        a = torch.randn(m_size, k_size, dtype=torch.float16, device="cuda")
        b = torch.randn(n_size, k_size, dtype=torch.float16, device="cuda")
        c = torch.empty(m_size, n_size, dtype=torch.float16, device="cuda")

        c_ref = a @ b.T
        torch.cuda.synchronize()

        matmul(m_size, n_size, k_size, a, b, c)
        torch.cuda.synchronize()

        torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2)

        # benchmark
        if bench:
            for name, func in [
                ("torch", lambda: a @ b.T),
                ("tilus", lambda: matmul(m_size, n_size, k_size, a, b, c)),
            ]:
                latency = benchmark_func(func, warmup=5, repeat=100)
                tflops = 2 * m_size * n_size * k_size / latency * 1e-9
                rows.append([m_size, n_size, k_size, name, latency, tflops])
                time.sleep(3)  # sleep 3s to cool down the GPU between runs

    if bench:
        df = pandas.DataFrame(rows, columns=headers)
        print(df)


if __name__ == "__main__":
    main(bench=True)
    # tilus.utils.ncu_run(main, bench=False, kernel_regex="tilus|nvjet")

Performance

This minimal kernel achieves ~491 TFLOPS. The main bottleneck is the legacy cp.async load path, which generates high instruction overhead. The complete source is at examples/blackwell_matmul/matmul_v0.py.

../../_images/plot_v0.svg

Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).

What’s Next

This kernel works but is far from optimal. The main bottleneck is the load path: copy_async() uses the legacy async copy mechanism (cp.async in PTX), where every thread issues its own small copy instruction. This has high instruction overhead and does not leverage dedicated hardware for bulk data movement.

In the next version, we replace copy_async with TMA (Tensor Memory Access) loads. TMA offloads address generation and data movement to a dedicated hardware unit. A single TMA instruction copies an entire tile with minimal instruction overhead, and is issued by just one warp.