2. Multi-Stage Software Pipelining¶
In V1, each loop iteration first waits for TMA to finish loading data, then issues the MMA. Load and compute are fully serialized — the TMA engine sits idle during MMA, and the tensor cores sit idle during TMA.
This version introduces multi-stage software pipelining: shared memory is divided into multiple stages (a ring buffer), and the kernel prefills several stages before entering the main loop. In each iteration of the main loop, the TMA loads data for a future iteration while the MMA processes data from a previously loaded stage. This overlaps load and compute, significantly improving utilization.
If you have used Triton, this is similar to Triton’s num_stages parameter —
but here you control the pipelining explicitly: allocating per-stage buffers,
issuing prefill loads, and managing phase tracking yourself.
The Full Kernel¶
@tilus.autotune(
"block_m, block_n, e_block_n", [[128, 64, 16], [128, 128, 16], [128, 256, 16]]
)
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmulV2(tilus.Script):
def __init__(
self, block_m: int, block_n: int, block_k: int, stages: int, e_block_n: int
):
super().__init__()
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
self.stages = stages
self.e_block_n = e_block_n
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
self.attrs.warps = 4
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
# multi-stage shared memory: leading dimension indexes the pipeline stage
s_a = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_m, self.block_k]
)
s_b = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_n, self.block_k]
)
t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])
# one TMA barrier per stage
tma_barriers = self.mbarrier.alloc(counts=[1 for _ in range(self.stages)])
mma_barrier = self.mbarrier.alloc(counts=1)
# per-role phase: tracks the expected phase for the next wait
tma_phase: uint32 = 0
mma_phase: uint32 = 0
# prefill: issue TMA loads for the first (stages - 1) tiles without waiting
for i in range(self.stages - 1):
offset_k = i * self.block_k
with self.single_warp():
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_barriers[i], transaction_bytes=s_a[i].nbytes + s_b[i].nbytes
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[i],
offsets=[offset_m, offset_k],
mbarrier=tma_barriers[i],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[i],
offsets=[offset_n, offset_k],
mbarrier=tma_barriers[i],
)
self.sync()
current_stage: int32 = 0
preload_stage: int32 = self.stages - 1
# unroll by stages so the compiler can resolve stage indices to constants
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
with self.single_warp():
# preload: issue TMA for a future tile into the next free stage
preload_offset_k = offset_k + (self.stages - 1) * self.block_k
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_barriers[preload_stage],
transaction_bytes=s_a[preload_stage].nbytes
+ s_b[preload_stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[preload_stage],
offsets=[offset_m, preload_offset_k],
mbarrier=tma_barriers[preload_stage],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[preload_stage],
offsets=[offset_n, preload_offset_k],
mbarrier=tma_barriers[preload_stage],
)
# wait for the current stage's TMA data to arrive
self.mbarrier.wait(
tma_barriers[current_stage],
phase=tma_phase,
sem="relaxed",
scope="cta",
)
# compute on the current stage
self.tcgen05.mma(
s_a[current_stage],
s_b[current_stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
self.tcgen05.commit(mbarrier=mma_barrier)
self.mbarrier.wait(
mma_barrier, phase=mma_phase, sem="relaxed", scope="cta"
)
# advance stage indices (ring buffer); flip phase when wrapping to stage 0
preload_stage = (preload_stage + 1) % self.stages
current_stage = (current_stage + 1) % self.stages
tma_phase ^= current_stage == 0
mma_phase ^= 1
self.sync()
# TMA epilogue: tmem -> register -> shared -> global (via TMA)
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
s_c = self.shared_tensor(dtype=float16, shape=[self.block_m, self.e_block_n])
for e_offset_n in range(0, self.block_n, self.e_block_n):
t_acc_slice = self.tcgen05.slice(
t_acc,
offsets=[0, e_offset_n],
shape=[self.block_m, self.e_block_n],
dims=[0, 1],
)
r_acc = self.tcgen05.load(t_acc_slice)
self.tcgen05.wait_load()
self.store_shared(s_c, r_acc.to(float16))
self.fence.proxy_async(space="shared")
self.sync()
with self.single_warp():
self.tma.shared_to_global(
s_c,
g_c,
offsets=[offset_m, offset_n + e_offset_n],
dims=[0, 1],
)
self.tma.commit_group()
self.tma.wait_group(n=0, read=True)
self.sync()
self.sync()
self.tcgen05.dealloc(t_acc)
What Changed from V1¶
V1 |
V2 |
|
|---|---|---|
Shared memory |
Single stage: |
Multi-stage ring buffer: |
TMA barriers |
1 barrier |
1 barrier per stage |
Phase tracking |
Single |
Per-role |
Loop structure |
Load then compute, serial |
Prefill stages, then overlap load and compute |
New parameter |
— |
|
New instructions |
|
Software Pipelining¶
Top: V1 serializes load and compute. Bottom: V2 overlaps them using a multi-stage pipeline.¶
The idea is simple: if we have S stages of shared memory, we can have up to
S - 1 TMA loads in flight while one stage is being consumed by MMA. The
kernel proceeds in two phases:
Prefill — Before the main loop, load the first
S - 1stages via TMA. These loads run asynchronously; we do not wait for them yet.Main loop — Each iteration does three things:
Preload: issue a TMA load into the next free stage (
preload_stage).Wait: wait for the current stage’s TMA to complete (
tma_barriers[current_stage]).Compute: run MMA on the current stage’s data.
After each iteration, both
current_stageandpreload_stageadvance modulostages, cycling through the ring buffer.
Per-Stage Barriers and Phase Tracking¶
Each stage has its own mbarrier to track TMA completion independently:
tma_barriers = self.mbarrier.alloc(counts=[1 for _ in range(self.stages)])
tma_phase: uint32 = 0
mma_phase: uint32 = 0
Instead of tracking a separate phase for each stage’s barrier, we use a single per-role phase variable. After each iteration, the stage index advances through the ring buffer. When the stage index wraps back to 0 (completing a full cycle through all stages), the phase flips via XOR:
current_stage = (current_stage + 1) % self.stages
tma_phase ^= (current_stage == 0) # flip when wrapping to stage 0
This works because each barrier sees one wait per cycle through the ring
buffer. After a full cycle (stages iterations), every barrier has been waited
on once, and the next wait on the same barrier needs the opposite phase. The XOR
flips the phase exactly at that point.
The MMA barrier remains a single barrier (as in V1) since MMA operations are
still serialized. Its phase flips every iteration (mma_phase ^= 1).
Loop Unrolling¶
The main loop uses self.range() with
unroll=self.stages instead of Python’s range():
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
Both Python’s range() and range() are lowered to the
same loop statement internally. The difference is that self.range provides
additional control — here the unroll hint instructs the compiler to unroll
the loop body by the number of stages. Loop unrolling is important for pipelined
kernels because the stage index
(current_stage, preload_stage) cycles modulo stages — with
unrolling, the compiler can resolve these indices to constants, eliminating
modular arithmetic and enabling more efficient code generation.
Walkthrough¶
Prefill¶
for i in range(self.stages - 1):
offset_k = i * self.block_k
with self.single_warp():
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_barriers[i], transaction_bytes=s_a[i].nbytes + s_b[i].nbytes
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[i],
offsets=[offset_m, offset_k],
mbarrier=tma_barriers[i],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[i],
offsets=[offset_n, offset_k],
mbarrier=tma_barriers[i],
)
Before the main loop, the first stages - 1 TMA loads are issued without
waiting. Each iteration loads into stage i and signals tma_barriers[i].
After the prefill, self.sync() ensures all threads have issued their TMA
requests before entering the main loop.
Main Loop¶
current_stage: int32 = 0
preload_stage: int32 = self.stages - 1
# unroll by stages so the compiler can resolve stage indices to constants
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
with self.single_warp():
# preload: issue TMA for a future tile into the next free stage
preload_offset_k = offset_k + (self.stages - 1) * self.block_k
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_barriers[preload_stage],
transaction_bytes=s_a[preload_stage].nbytes
+ s_b[preload_stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[preload_stage],
offsets=[offset_m, preload_offset_k],
mbarrier=tma_barriers[preload_stage],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[preload_stage],
offsets=[offset_n, preload_offset_k],
mbarrier=tma_barriers[preload_stage],
)
# wait for the current stage's TMA data to arrive
self.mbarrier.wait(
tma_barriers[current_stage],
phase=tma_phase,
sem="relaxed",
scope="cta",
)
# compute on the current stage
self.tcgen05.mma(
s_a[current_stage],
s_b[current_stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
self.tcgen05.commit(mbarrier=mma_barrier)
self.mbarrier.wait(
mma_barrier, phase=mma_phase, sem="relaxed", scope="cta"
)
# advance stage indices (ring buffer); flip phase when wrapping to stage 0
preload_stage = (preload_stage + 1) % self.stages
current_stage = (current_stage + 1) % self.stages
tma_phase ^= current_stage == 0
mma_phase ^= 1
self.sync()
In each iteration:
Preload (into
preload_stage):mbarrier.arrive_and_expect_tx()andtma.global_to_shared()issue TMA loads for a future K-tile into the next free stage. This runs asynchronously — the TMA engine works in the background.Wait (on
current_stage):mbarrier.wait()blocks until the current stage’s TMA data has arrived. The phase comes from the per-roletma_phasevariable.Compute (from
current_stage):tcgen05.mma()readss_a[current_stage]ands_b[current_stage], accumulating into the tensor memory accumulator.
After the MMA,
tcgen05.commit()
and mbarrier.wait()
on the MMA barrier ensure the MMA completes before the next iteration.
Finally, the stage indices advance modulo stages, and the per-role phase
flips when the stage wraps back to 0:
preload_stage = (preload_stage + 1) % self.stages # next free stage
current_stage = (current_stage + 1) % self.stages # next stage to consume
tma_phase ^= (current_stage == 0) # flip phase on wrap
mma_phase ^= 1
Performance¶
Multi-stage pipelining overlaps TMA loads with MMA compute, more than doubling throughput compared to V1. The complete source is at examples/blackwell_matmul/matmul_v2.py.
Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).¶
What’s Next¶
V2 overlaps TMA loads with MMA compute across iterations, but there is still a
limitation: each iteration runs one MMA and must wait for it to complete
(via tcgen05.commit + mbarrier.wait) before issuing the next. The MMA
pipeline depth is effectively 1 — no in-flight MMAs overlap.
In the next version, we separate the load and compute into different thread groups: a dedicated TMA warp and a dedicated MMA warp. With separate warps, the MMA warp can issue its next MMA immediately after the previous one without waiting for the TMA warp, and vice versa. This enables true parallelism between TMA and MMA, with the two warps progressing independently and synchronizing only through mbarriers.