2. Multi-Stage Software Pipelining

In V1, each loop iteration first waits for TMA to finish loading data, then issues the MMA. Load and compute are fully serialized — the TMA engine sits idle during MMA, and the tensor cores sit idle during TMA.

This version introduces multi-stage software pipelining: shared memory is divided into multiple stages (a ring buffer), and the kernel prefills several stages before entering the main loop. In each iteration of the main loop, the TMA loads data for a future iteration while the MMA processes data from a previously loaded stage. This overlaps load and compute, significantly improving utilization.

If you have used Triton, this is similar to Triton’s num_stages parameter — but here you control the pipelining explicitly: allocating per-stage buffers, issuing prefill loads, and managing phase tracking yourself.

The Full Kernel

BlackwellMatmulV2 — full kernel
@tilus.autotune(
    "block_m, block_n, e_block_n", [[128, 64, 16], [128, 128, 16], [128, 256, 16]]
)
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmulV2(tilus.Script):
    def __init__(
        self, block_m: int, block_n: int, block_k: int, stages: int, e_block_n: int
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.stages = stages
        self.e_block_n = e_block_n

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        # multi-stage shared memory: leading dimension indexes the pipeline stage
        s_a = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_m, self.block_k]
        )
        s_b = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_n, self.block_k]
        )

        t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])

        # one TMA barrier per stage
        tma_barriers = self.mbarrier.alloc(counts=[1 for _ in range(self.stages)])
        mma_barrier = self.mbarrier.alloc(counts=1)
        # per-role phase: tracks the expected phase for the next wait
        tma_phase: uint32 = 0
        mma_phase: uint32 = 0

        # prefill: issue TMA loads for the first (stages - 1) tiles without waiting
        for i in range(self.stages - 1):
            offset_k = i * self.block_k
            with self.single_warp():
                with self.single_thread():
                    self.mbarrier.arrive_and_expect_tx(
                        tma_barriers[i], transaction_bytes=s_a[i].nbytes + s_b[i].nbytes
                    )
                self.tma.global_to_shared(
                    src=g_a,
                    dst=s_a[i],
                    offsets=[offset_m, offset_k],
                    mbarrier=tma_barriers[i],
                )
                self.tma.global_to_shared(
                    src=g_b,
                    dst=s_b[i],
                    offsets=[offset_n, offset_k],
                    mbarrier=tma_barriers[i],
                )

        self.sync()

        current_stage: int32 = 0
        preload_stage: int32 = self.stages - 1

        # unroll by stages so the compiler can resolve stage indices to constants
        for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
            with self.single_warp():
                # preload: issue TMA for a future tile into the next free stage
                preload_offset_k = offset_k + (self.stages - 1) * self.block_k
                with self.single_thread():
                    self.mbarrier.arrive_and_expect_tx(
                        tma_barriers[preload_stage],
                        transaction_bytes=s_a[preload_stage].nbytes
                        + s_b[preload_stage].nbytes,
                    )
                self.tma.global_to_shared(
                    src=g_a,
                    dst=s_a[preload_stage],
                    offsets=[offset_m, preload_offset_k],
                    mbarrier=tma_barriers[preload_stage],
                )
                self.tma.global_to_shared(
                    src=g_b,
                    dst=s_b[preload_stage],
                    offsets=[offset_n, preload_offset_k],
                    mbarrier=tma_barriers[preload_stage],
                )
                # wait for the current stage's TMA data to arrive
                self.mbarrier.wait(
                    tma_barriers[current_stage],
                    phase=tma_phase,
                    sem="relaxed",
                    scope="cta",
                )
                # compute on the current stage
                self.tcgen05.mma(
                    s_a[current_stage],
                    s_b[current_stage].transpose(),
                    t_acc,
                    enable_input_d=offset_k != 0,
                )
                self.tcgen05.commit(mbarrier=mma_barrier)
                self.mbarrier.wait(
                    mma_barrier, phase=mma_phase, sem="relaxed", scope="cta"
                )

            # advance stage indices (ring buffer); flip phase when wrapping to stage 0
            preload_stage = (preload_stage + 1) % self.stages
            current_stage = (current_stage + 1) % self.stages
            tma_phase ^= current_stage == 0
            mma_phase ^= 1
            self.sync()

        # TMA epilogue: tmem -> register -> shared -> global (via TMA)
        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        s_c = self.shared_tensor(dtype=float16, shape=[self.block_m, self.e_block_n])
        for e_offset_n in range(0, self.block_n, self.e_block_n):
            t_acc_slice = self.tcgen05.slice(
                t_acc,
                offsets=[0, e_offset_n],
                shape=[self.block_m, self.e_block_n],
                dims=[0, 1],
            )
            r_acc = self.tcgen05.load(t_acc_slice)
            self.tcgen05.wait_load()
            self.store_shared(s_c, r_acc.to(float16))
            self.fence.proxy_async(space="shared")
            self.sync()
            with self.single_warp():
                self.tma.shared_to_global(
                    s_c,
                    g_c,
                    offsets=[offset_m, offset_n + e_offset_n],
                    dims=[0, 1],
                )
                self.tma.commit_group()
                self.tma.wait_group(n=0, read=True)
            self.sync()

        self.sync()
        self.tcgen05.dealloc(t_acc)

What Changed from V1

V1

V2

Shared memory

Single stage: [block_m, block_k]

Multi-stage ring buffer: [stages, block_m, block_k]

TMA barriers

1 barrier

1 barrier per stage

Phase tracking

Single phase variable

Per-role tma_phase with XOR-on-wrap

Loop structure

Load then compute, serial

Prefill stages, then overlap load and compute

New parameter

stages (autotuned: 2, 3, or 4)

New instructions

range() (loop with unroll hint)

Software Pipelining

../../_images/v2_pipeline.svg

Top: V1 serializes load and compute. Bottom: V2 overlaps them using a multi-stage pipeline.

The idea is simple: if we have S stages of shared memory, we can have up to S - 1 TMA loads in flight while one stage is being consumed by MMA. The kernel proceeds in two phases:

  1. Prefill — Before the main loop, load the first S - 1 stages via TMA. These loads run asynchronously; we do not wait for them yet.

  2. Main loop — Each iteration does three things:

    • Preload: issue a TMA load into the next free stage (preload_stage).

    • Wait: wait for the current stage’s TMA to complete (tma_barriers[current_stage]).

    • Compute: run MMA on the current stage’s data.

    After each iteration, both current_stage and preload_stage advance modulo stages, cycling through the ring buffer.

Multi-Stage Shared Memory

In V1, shared tensors had shape [block_m, block_k] — a single buffer that was overwritten every iteration. In V2, shared tensors gain a leading stage dimension:

s_a = self.shared_tensor(dtype=float16, shape=[self.stages, self.block_m, self.block_k])
s_b = self.shared_tensor(dtype=float16, shape=[self.stages, self.block_n, self.block_k])

Each stage s_a[i] / s_b[i] is an independent buffer. TMA writes to one stage while MMA reads from another, without conflicts.

Per-Stage Barriers and Phase Tracking

Each stage has its own mbarrier to track TMA completion independently:

tma_barriers = self.mbarrier.alloc(counts=[1 for _ in range(self.stages)])
tma_phase: uint32 = 0
mma_phase: uint32 = 0

Instead of tracking a separate phase for each stage’s barrier, we use a single per-role phase variable. After each iteration, the stage index advances through the ring buffer. When the stage index wraps back to 0 (completing a full cycle through all stages), the phase flips via XOR:

current_stage = (current_stage + 1) % self.stages
tma_phase ^= (current_stage == 0)   # flip when wrapping to stage 0

This works because each barrier sees one wait per cycle through the ring buffer. After a full cycle (stages iterations), every barrier has been waited on once, and the next wait on the same barrier needs the opposite phase. The XOR flips the phase exactly at that point.

The MMA barrier remains a single barrier (as in V1) since MMA operations are still serialized. Its phase flips every iteration (mma_phase ^= 1).

Loop Unrolling

The main loop uses self.range() with unroll=self.stages instead of Python’s range():

for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):

Both Python’s range() and range() are lowered to the same loop statement internally. The difference is that self.range provides additional control — here the unroll hint instructs the compiler to unroll the loop body by the number of stages. Loop unrolling is important for pipelined kernels because the stage index (current_stage, preload_stage) cycles modulo stages — with unrolling, the compiler can resolve these indices to constants, eliminating modular arithmetic and enabling more efficient code generation.

Walkthrough

Prefill

Prefill: load the first stages - 1 tiles
for i in range(self.stages - 1):
    offset_k = i * self.block_k
    with self.single_warp():
        with self.single_thread():
            self.mbarrier.arrive_and_expect_tx(
                tma_barriers[i], transaction_bytes=s_a[i].nbytes + s_b[i].nbytes
            )
        self.tma.global_to_shared(
            src=g_a,
            dst=s_a[i],
            offsets=[offset_m, offset_k],
            mbarrier=tma_barriers[i],
        )
        self.tma.global_to_shared(
            src=g_b,
            dst=s_b[i],
            offsets=[offset_n, offset_k],
            mbarrier=tma_barriers[i],
        )

Before the main loop, the first stages - 1 TMA loads are issued without waiting. Each iteration loads into stage i and signals tma_barriers[i]. After the prefill, self.sync() ensures all threads have issued their TMA requests before entering the main loop.

Main Loop

Main loop: overlap preload and compute
current_stage: int32 = 0
preload_stage: int32 = self.stages - 1

# unroll by stages so the compiler can resolve stage indices to constants
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
    with self.single_warp():
        # preload: issue TMA for a future tile into the next free stage
        preload_offset_k = offset_k + (self.stages - 1) * self.block_k
        with self.single_thread():
            self.mbarrier.arrive_and_expect_tx(
                tma_barriers[preload_stage],
                transaction_bytes=s_a[preload_stage].nbytes
                + s_b[preload_stage].nbytes,
            )
        self.tma.global_to_shared(
            src=g_a,
            dst=s_a[preload_stage],
            offsets=[offset_m, preload_offset_k],
            mbarrier=tma_barriers[preload_stage],
        )
        self.tma.global_to_shared(
            src=g_b,
            dst=s_b[preload_stage],
            offsets=[offset_n, preload_offset_k],
            mbarrier=tma_barriers[preload_stage],
        )
        # wait for the current stage's TMA data to arrive
        self.mbarrier.wait(
            tma_barriers[current_stage],
            phase=tma_phase,
            sem="relaxed",
            scope="cta",
        )
        # compute on the current stage
        self.tcgen05.mma(
            s_a[current_stage],
            s_b[current_stage].transpose(),
            t_acc,
            enable_input_d=offset_k != 0,
        )
        self.tcgen05.commit(mbarrier=mma_barrier)
        self.mbarrier.wait(
            mma_barrier, phase=mma_phase, sem="relaxed", scope="cta"
        )

    # advance stage indices (ring buffer); flip phase when wrapping to stage 0
    preload_stage = (preload_stage + 1) % self.stages
    current_stage = (current_stage + 1) % self.stages
    tma_phase ^= current_stage == 0
    mma_phase ^= 1
    self.sync()

In each iteration:

  • Preload (into preload_stage): mbarrier.arrive_and_expect_tx() and tma.global_to_shared() issue TMA loads for a future K-tile into the next free stage. This runs asynchronously — the TMA engine works in the background.

  • Wait (on current_stage): mbarrier.wait() blocks until the current stage’s TMA data has arrived. The phase comes from the per-role tma_phase variable.

  • Compute (from current_stage): tcgen05.mma() reads s_a[current_stage] and s_b[current_stage], accumulating into the tensor memory accumulator.

After the MMA, tcgen05.commit() and mbarrier.wait() on the MMA barrier ensure the MMA completes before the next iteration.

Finally, the stage indices advance modulo stages, and the per-role phase flips when the stage wraps back to 0:

preload_stage = (preload_stage + 1) % self.stages   # next free stage
current_stage = (current_stage + 1) % self.stages   # next stage to consume
tma_phase ^= (current_stage == 0)                   # flip phase on wrap
mma_phase ^= 1

Performance

Multi-stage pipelining overlaps TMA loads with MMA compute, more than doubling throughput compared to V1. The complete source is at examples/blackwell_matmul/matmul_v2.py.

../../_images/plot_v2.svg

Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).

What’s Next

V2 overlaps TMA loads with MMA compute across iterations, but there is still a limitation: each iteration runs one MMA and must wait for it to complete (via tcgen05.commit + mbarrier.wait) before issuing the next. The MMA pipeline depth is effectively 1 — no in-flight MMAs overlap.

In the next version, we separate the load and compute into different thread groups: a dedicated TMA warp and a dedicated MMA warp. With separate warps, the MMA warp can issue its next MMA immediately after the previous one without waiting for the TMA warp, and vice versa. This enables true parallelism between TMA and MMA, with the two warps progressing independently and synchronizing only through mbarriers.