3. Warp Specialization¶
In V2, each iteration issues one MMA and waits for it to complete
(via tcgen05.commit + mbarrier.wait) before issuing the next. The tensor
core MMA pipeline sees only one operation at a time — it goes idle between
consecutive MMAs while the warp handles TMA and synchronization.
To keep the MMA pipeline busy, we want multiple MMAs in flight: issue the next MMA as soon as its input data is ready, without waiting for the previous MMA to finish writing its results. This is achieved through warp specialization — assigning TMA and MMA to separate warps, each running its own loop. The MMA warp can issue MMA after MMA without interruption, while the TMA warp independently keeps shared memory filled. The two warps synchronize only through mbarriers, allowing multiple in-flight MMA and TMA operations simultaneously.
Triton also performs warp specialization internally, but as a compiler pass with no user-level control. In Tilus, you explicitly assign different roles to different warps and define how they communicate.
The Full Kernel¶
@tilus.autotune(
"block_m, block_n, e_block_n", [[128, 64, 16], [128, 128, 16], [128, 256, 16]]
)
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmulV3(tilus.Script):
def __init__(
self, block_m: int, block_n: int, block_k: int, stages: int, e_block_n: int
):
super().__init__()
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
self.stages = stages
self.e_block_n = e_block_n
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
self.attrs.warps = 4
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
s_a = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_m, self.block_k]
)
s_b = self.shared_tensor(
dtype=float16, shape=[self.stages, self.block_n, self.block_k]
)
t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])
# full_barriers: signaled when TMA has filled the stage (data ready)
full_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
# empty_barriers: signaled when MMA has consumed the stage (slot free)
empty_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
# TMA warp (producer): loads tiles from global to shared memory
with self.thread_group(thread_begin=0, num_threads=32):
stage: int32 = 0
# phase=1: mbarrier starts at phase 0, so waiting for phase 1
# passes immediately (slot is empty, ready to fill)
producer_phase: uint32 = 1
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
# wait for the MMA warp to free this stage
self.mbarrier.wait(
empty_barriers[stage],
phase=producer_phase,
sem="relaxed",
scope="cta",
)
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
full_barriers[stage],
transaction_bytes=s_a[stage].nbytes + s_b[stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[stage],
offsets=[offset_m, offset_k],
mbarrier=full_barriers[stage],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[stage],
offsets=[offset_n, offset_k],
mbarrier=full_barriers[stage],
)
# advance stage; flip phase when wrapping to stage 0
stage = (stage + 1) % self.stages
producer_phase ^= stage == 0
# MMA warp (consumer): computes on tiles loaded by the TMA warp
with self.thread_group(thread_begin=32, num_threads=32):
# phase=0: mbarrier starts at phase 0, so waiting for phase 0
# blocks until the producer signals (slot is not yet filled)
consumer_phase: uint32 = 0
stage: int32 = 0
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
# wait for the TMA warp to fill this stage
self.mbarrier.wait(
full_barriers[stage],
phase=consumer_phase,
sem="relaxed",
scope="cta",
)
self.tcgen05.mma(
s_a[stage],
s_b[stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
# commit signals empty_barriers: marks this stage as consumed
self.tcgen05.commit(mbarrier=empty_barriers[stage])
# advance stage; flip phase when wrapping to stage 0
stage = (stage + 1) % self.stages
consumer_phase ^= stage == 0
# drain: wait for all in-flight MMA to finish
flush_barrier = self.mbarrier.alloc(1)
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)
self.sync()
# TMA epilogue: tmem -> register -> shared -> global (via TMA)
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
s_c = self.shared_tensor(dtype=float16, shape=[self.block_m, self.e_block_n])
for e_offset_n in range(0, self.block_n, self.e_block_n):
t_acc_slice = self.tcgen05.slice(
t_acc,
offsets=[0, e_offset_n],
shape=[self.block_m, self.e_block_n],
dims=[0, 1],
)
r_acc = self.tcgen05.load(t_acc_slice)
self.tcgen05.wait_load()
self.store_shared(s_c, r_acc.to(float16))
self.fence.proxy_async(space="shared")
self.sync()
with self.single_warp():
self.tma.shared_to_global(
s_c,
g_c,
offsets=[offset_m, offset_n + e_offset_n],
dims=[0, 1],
)
self.tma.commit_group()
self.tma.wait_group(n=0, read=True)
self.sync()
self.sync()
self.tcgen05.dealloc(t_acc)
What Changed from V2¶
V2 |
V3 |
|
|---|---|---|
Warp structure |
Single warp does both TMA and MMA |
TMA warp + MMA warp (warp specialization) |
Barriers |
TMA barriers + 1 MMA barrier |
|
Parallelism |
TMA and MMA overlap across iterations |
TMA and MMA run truly in parallel on separate warps |
In-flight MMA |
1 (wait after each MMA) |
Multiple (MMA warp issues next MMA without waiting for TMA) |
Prefill |
Explicit prefill loop |
Implicit: TMA warp runs ahead of MMA warp |
New instructions |
|
Warp Specialization¶
In V2, a single warp handles both TMA and MMA. Each iteration looks like: issue
TMA → wait for TMA → issue MMA → wait for MMA → repeat.
The tcgen05.commit + mbarrier.wait after each MMA means the tensor core
pipeline sees at most one MMA at a time. Between consecutive MMAs, the pipeline
goes idle while the warp handles synchronization, TMA issuing, and TMA waiting.
With warp specialization, we assign TMA and MMA to separate warps, each running its own loop:
TMA warp (threads 0–31): issues TMA loads back-to-back. Before each load, it only waits on
empty_barriersto ensure the shared memory slot is free.MMA warp (threads 32–63): issues MMA operations back-to-back. Before each MMA, it only waits on
full_barriersto ensure the input data has arrived. Crucially, it does not wait for the previous MMA to complete — it issuestcgen05.commit(mbarrier=empty_barriers[stage])which returns immediately and makes the barrier track the MMA’s completion in the background. When the MMA actually finishes, the hardware will signal the barrier, freeing the stage for the TMA warp. Meanwhile, the MMA warp has already moved on to the next iteration. This allows multiple MMAs to be in flight in the tensor core pipeline.
The result: as long as TMA delivers data fast enough, the MMA warp keeps the tensor core pipeline fully occupied with no idle gaps between operations.
Producer-Consumer Barriers¶
V2 used TMA barriers and a single MMA barrier. V3 replaces them with a producer-consumer barrier pair, as described in the mbarrier guide:
full_barriers: signaled when TMA has filled a stage (data ready). The MMA warp (consumer) waits on these.empty_barriers: signaled when MMA has consumed a stage (slot free). The TMA warp (producer) waits on these.
full_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
empty_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
Each role tracks a single per-role phase variable (as in V2’s XOR-on-wrap pattern). The initial values are critical:
producer_phase = 1: all mbarriers start at hardware phase 0. The producer waits with expected phase 1, which does not match the barrier’s current phase — so the wait passes immediately. This is correct: all stages start empty, so the TMA warp can begin filling right away without blocking.consumer_phase = 0: the consumer waits with expected phase 0, which matches the barrier’s current phase — so the wait blocks until the producer signals. This is correct: no stages are full yet, so the MMA warp must wait for the first TMA load to complete.
As in V2, the phase flips via XOR when the stage index wraps back to 0
(producer_phase ^= (stage == 0)), tracking the alternating barrier phases
across ring buffer cycles.
Flush Barrier¶
In V2, we waited for each MMA individually (tcgen05.commit + mbarrier.wait
after every iteration), so no MMA was ever in flight when the loop ended. In V3,
the MMA warp issues tcgen05.commit without waiting — multiple MMAs can be
in flight simultaneously. After the loop ends, we need to ensure all in-flight
MMAs have completed before reading the accumulator. This is done with a flush
barrier:
# drain: wait for all in-flight MMA to finish
flush_barrier = self.mbarrier.alloc(1)
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)
This extra tcgen05.commit makes the flush_barrier track all prior
uncommitted MMA operations. Once the barrier completes, it is safe to read
from tensor memory.
Walkthrough¶
TMA Warp (Producer)¶
# TMA warp (producer): loads tiles from global to shared memory
with self.thread_group(thread_begin=0, num_threads=32):
stage: int32 = 0
# phase=1: mbarrier starts at phase 0, so waiting for phase 1
# passes immediately (slot is empty, ready to fill)
producer_phase: uint32 = 1
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
# wait for the MMA warp to free this stage
self.mbarrier.wait(
empty_barriers[stage],
phase=producer_phase,
sem="relaxed",
scope="cta",
)
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
full_barriers[stage],
transaction_bytes=s_a[stage].nbytes + s_b[stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[stage],
offsets=[offset_m, offset_k],
mbarrier=full_barriers[stage],
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[stage],
offsets=[offset_n, offset_k],
mbarrier=full_barriers[stage],
)
# advance stage; flip phase when wrapping to stage 0
stage = (stage + 1) % self.stages
producer_phase ^= stage == 0
The TMA warp runs a loop over all K-tiles:
mbarrier.wait()onempty_barriers[stage]blocks until the MMA warp has freed this stage.mbarrier.arrive_and_expect_tx()onfull_barriers[stage]declares the expected TMA bytes.Two
tma.global_to_shared()calls load the A and B tiles.The stage index advances modulo
stages; the phase flips when wrapping.
Note there is no explicit prefill loop as in V2: the TMA warp simply starts
running and naturally gets ahead of the MMA warp. This works because
producer_phase starts at 1, which doesn’t match the barrier’s initial phase
(0), so the producer’s first wait passes immediately for all stages.
MMA Warp (Consumer)¶
# MMA warp (consumer): computes on tiles loaded by the TMA warp
with self.thread_group(thread_begin=32, num_threads=32):
# phase=0: mbarrier starts at phase 0, so waiting for phase 0
# blocks until the producer signals (slot is not yet filled)
consumer_phase: uint32 = 0
stage: int32 = 0
for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
# wait for the TMA warp to fill this stage
self.mbarrier.wait(
full_barriers[stage],
phase=consumer_phase,
sem="relaxed",
scope="cta",
)
self.tcgen05.mma(
s_a[stage],
s_b[stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
# commit signals empty_barriers: marks this stage as consumed
self.tcgen05.commit(mbarrier=empty_barriers[stage])
# advance stage; flip phase when wrapping to stage 0
stage = (stage + 1) % self.stages
consumer_phase ^= stage == 0
# drain: wait for all in-flight MMA to finish
flush_barrier = self.mbarrier.alloc(1)
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)
The MMA warp runs a matching loop:
mbarrier.wait()onfull_barriers[stage]blocks until TMA data has arrived.tcgen05.mma()computes on the loaded tiles.tcgen05.commit()onempty_barriers[stage]signals that this stage is now free for the TMA warp to reuse.The stage index advances modulo
stages; the phase flips when wrapping.After the loop, the flush barrier ensures all MMA writes to tensor memory are complete.
Performance¶
Warp specialization allows multiple MMAs to be in flight simultaneously, pushing tensor core utilization to ~79%. The complete source is at examples/blackwell_matmul/matmul_v3.py.
Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).¶
What’s Next¶
V3 achieves true overlap between TMA and MMA through warp specialization. However, the code still manages the pipeline state (stages, barriers, phases) inline, which becomes increasingly complex as we add more features.
In the next version, we refactor the pipeline logic into a reusable
Pipeline class using tilus.Class, and add tile rasterization for
better L2 cache locality.