3. Warp Specialization

In V2, each iteration issues one MMA and waits for it to complete (via tcgen05.commit + mbarrier.wait) before issuing the next. The tensor core MMA pipeline sees only one operation at a time — it goes idle between consecutive MMAs while the warp handles TMA and synchronization.

To keep the MMA pipeline busy, we want multiple MMAs in flight: issue the next MMA as soon as its input data is ready, without waiting for the previous MMA to finish writing its results. This is achieved through warp specialization — assigning TMA and MMA to separate warps, each running its own loop. The MMA warp can issue MMA after MMA without interruption, while the TMA warp independently keeps shared memory filled. The two warps synchronize only through mbarriers, allowing multiple in-flight MMA and TMA operations simultaneously.

Triton also performs warp specialization internally, but as a compiler pass with no user-level control. In Tilus, you explicitly assign different roles to different warps and define how they communicate.

The Full Kernel

BlackwellMatmulV3 — full kernel
@tilus.autotune(
    "block_m, block_n, e_block_n", [[128, 64, 16], [128, 128, 16], [128, 256, 16]]
)
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmulV3(tilus.Script):
    def __init__(
        self, block_m: int, block_n: int, block_k: int, stages: int, e_block_n: int
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.stages = stages
        self.e_block_n = e_block_n

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        s_a = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_m, self.block_k]
        )
        s_b = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_n, self.block_k]
        )

        t_acc = self.tcgen05.alloc(dtype=float32, shape=[self.block_m, self.block_n])

        # full_barriers: signaled when TMA has filled the stage (data ready)
        full_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
        # empty_barriers: signaled when MMA has consumed the stage (slot free)
        empty_barriers = self.mbarrier.alloc(counts=[1] * self.stages)

        # TMA warp (producer): loads tiles from global to shared memory
        with self.thread_group(thread_begin=0, num_threads=32):
            stage: int32 = 0
            # phase=1: mbarrier starts at phase 0, so waiting for phase 1
            # passes immediately (slot is empty, ready to fill)
            producer_phase: uint32 = 1
            for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
                # wait for the MMA warp to free this stage
                self.mbarrier.wait(
                    empty_barriers[stage],
                    phase=producer_phase,
                    sem="relaxed",
                    scope="cta",
                )
                with self.single_thread():
                    self.mbarrier.arrive_and_expect_tx(
                        full_barriers[stage],
                        transaction_bytes=s_a[stage].nbytes + s_b[stage].nbytes,
                    )
                self.tma.global_to_shared(
                    src=g_a,
                    dst=s_a[stage],
                    offsets=[offset_m, offset_k],
                    mbarrier=full_barriers[stage],
                )
                self.tma.global_to_shared(
                    src=g_b,
                    dst=s_b[stage],
                    offsets=[offset_n, offset_k],
                    mbarrier=full_barriers[stage],
                )
                # advance stage; flip phase when wrapping to stage 0
                stage = (stage + 1) % self.stages
                producer_phase ^= stage == 0

        # MMA warp (consumer): computes on tiles loaded by the TMA warp
        with self.thread_group(thread_begin=32, num_threads=32):
            # phase=0: mbarrier starts at phase 0, so waiting for phase 0
            # blocks until the producer signals (slot is not yet filled)
            consumer_phase: uint32 = 0
            stage: int32 = 0
            for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
                # wait for the TMA warp to fill this stage
                self.mbarrier.wait(
                    full_barriers[stage],
                    phase=consumer_phase,
                    sem="relaxed",
                    scope="cta",
                )
                self.tcgen05.mma(
                    s_a[stage],
                    s_b[stage].transpose(),
                    t_acc,
                    enable_input_d=offset_k != 0,
                )
                # commit signals empty_barriers: marks this stage as consumed
                self.tcgen05.commit(mbarrier=empty_barriers[stage])
                # advance stage; flip phase when wrapping to stage 0
                stage = (stage + 1) % self.stages
                consumer_phase ^= stage == 0

            # drain: wait for all in-flight MMA to finish
            flush_barrier = self.mbarrier.alloc(1)
            self.tcgen05.commit(mbarrier=flush_barrier)
            self.mbarrier.wait(flush_barrier, phase=0)

        self.sync()

        # TMA epilogue: tmem -> register -> shared -> global (via TMA)
        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        s_c = self.shared_tensor(dtype=float16, shape=[self.block_m, self.e_block_n])
        for e_offset_n in range(0, self.block_n, self.e_block_n):
            t_acc_slice = self.tcgen05.slice(
                t_acc,
                offsets=[0, e_offset_n],
                shape=[self.block_m, self.e_block_n],
                dims=[0, 1],
            )
            r_acc = self.tcgen05.load(t_acc_slice)
            self.tcgen05.wait_load()
            self.store_shared(s_c, r_acc.to(float16))
            self.fence.proxy_async(space="shared")
            self.sync()
            with self.single_warp():
                self.tma.shared_to_global(
                    s_c,
                    g_c,
                    offsets=[offset_m, offset_n + e_offset_n],
                    dims=[0, 1],
                )
                self.tma.commit_group()
                self.tma.wait_group(n=0, read=True)
            self.sync()

        self.sync()
        self.tcgen05.dealloc(t_acc)

What Changed from V2

V2

V3

Warp structure

Single warp does both TMA and MMA

TMA warp + MMA warp (warp specialization)

Barriers

TMA barriers + 1 MMA barrier

full_barriers + empty_barriers (producer-consumer)

Parallelism

TMA and MMA overlap across iterations

TMA and MMA run truly in parallel on separate warps

In-flight MMA

1 (wait after each MMA)

Multiple (MMA warp issues next MMA without waiting for TMA)

Prefill

Explicit prefill loop

Implicit: TMA warp runs ahead of MMA warp

New instructions

thread_group(), commit() (with consumer mbarrier)

Warp Specialization

In V2, a single warp handles both TMA and MMA. Each iteration looks like: issue TMA → wait for TMA → issue MMA → wait for MMA → repeat. The tcgen05.commit + mbarrier.wait after each MMA means the tensor core pipeline sees at most one MMA at a time. Between consecutive MMAs, the pipeline goes idle while the warp handles synchronization, TMA issuing, and TMA waiting.

With warp specialization, we assign TMA and MMA to separate warps, each running its own loop:

  • TMA warp (threads 0–31): issues TMA loads back-to-back. Before each load, it only waits on empty_barriers to ensure the shared memory slot is free.

  • MMA warp (threads 32–63): issues MMA operations back-to-back. Before each MMA, it only waits on full_barriers to ensure the input data has arrived. Crucially, it does not wait for the previous MMA to complete — it issues tcgen05.commit(mbarrier=empty_barriers[stage]) which returns immediately and makes the barrier track the MMA’s completion in the background. When the MMA actually finishes, the hardware will signal the barrier, freeing the stage for the TMA warp. Meanwhile, the MMA warp has already moved on to the next iteration. This allows multiple MMAs to be in flight in the tensor core pipeline.

The result: as long as TMA delivers data fast enough, the MMA warp keeps the tensor core pipeline fully occupied with no idle gaps between operations.

Producer-Consumer Barriers

V2 used TMA barriers and a single MMA barrier. V3 replaces them with a producer-consumer barrier pair, as described in the mbarrier guide:

  • full_barriers: signaled when TMA has filled a stage (data ready). The MMA warp (consumer) waits on these.

  • empty_barriers: signaled when MMA has consumed a stage (slot free). The TMA warp (producer) waits on these.

full_barriers = self.mbarrier.alloc(counts=[1] * self.stages)
empty_barriers = self.mbarrier.alloc(counts=[1] * self.stages)

Each role tracks a single per-role phase variable (as in V2’s XOR-on-wrap pattern). The initial values are critical:

  • producer_phase = 1: all mbarriers start at hardware phase 0. The producer waits with expected phase 1, which does not match the barrier’s current phase — so the wait passes immediately. This is correct: all stages start empty, so the TMA warp can begin filling right away without blocking.

  • consumer_phase = 0: the consumer waits with expected phase 0, which matches the barrier’s current phase — so the wait blocks until the producer signals. This is correct: no stages are full yet, so the MMA warp must wait for the first TMA load to complete.

As in V2, the phase flips via XOR when the stage index wraps back to 0 (producer_phase ^= (stage == 0)), tracking the alternating barrier phases across ring buffer cycles.

Flush Barrier

In V2, we waited for each MMA individually (tcgen05.commit + mbarrier.wait after every iteration), so no MMA was ever in flight when the loop ended. In V3, the MMA warp issues tcgen05.commit without waiting — multiple MMAs can be in flight simultaneously. After the loop ends, we need to ensure all in-flight MMAs have completed before reading the accumulator. This is done with a flush barrier:

# drain: wait for all in-flight MMA to finish
flush_barrier = self.mbarrier.alloc(1)
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)

This extra tcgen05.commit makes the flush_barrier track all prior uncommitted MMA operations. Once the barrier completes, it is safe to read from tensor memory.

Walkthrough

TMA Warp (Producer)

TMA warp
# TMA warp (producer): loads tiles from global to shared memory
with self.thread_group(thread_begin=0, num_threads=32):
    stage: int32 = 0
    # phase=1: mbarrier starts at phase 0, so waiting for phase 1
    # passes immediately (slot is empty, ready to fill)
    producer_phase: uint32 = 1
    for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
        # wait for the MMA warp to free this stage
        self.mbarrier.wait(
            empty_barriers[stage],
            phase=producer_phase,
            sem="relaxed",
            scope="cta",
        )
        with self.single_thread():
            self.mbarrier.arrive_and_expect_tx(
                full_barriers[stage],
                transaction_bytes=s_a[stage].nbytes + s_b[stage].nbytes,
            )
        self.tma.global_to_shared(
            src=g_a,
            dst=s_a[stage],
            offsets=[offset_m, offset_k],
            mbarrier=full_barriers[stage],
        )
        self.tma.global_to_shared(
            src=g_b,
            dst=s_b[stage],
            offsets=[offset_n, offset_k],
            mbarrier=full_barriers[stage],
        )
        # advance stage; flip phase when wrapping to stage 0
        stage = (stage + 1) % self.stages
        producer_phase ^= stage == 0

The TMA warp runs a loop over all K-tiles:

Note there is no explicit prefill loop as in V2: the TMA warp simply starts running and naturally gets ahead of the MMA warp. This works because producer_phase starts at 1, which doesn’t match the barrier’s initial phase (0), so the producer’s first wait passes immediately for all stages.

MMA Warp (Consumer)

MMA warp
# MMA warp (consumer): computes on tiles loaded by the TMA warp
with self.thread_group(thread_begin=32, num_threads=32):
    # phase=0: mbarrier starts at phase 0, so waiting for phase 0
    # blocks until the producer signals (slot is not yet filled)
    consumer_phase: uint32 = 0
    stage: int32 = 0
    for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
        # wait for the TMA warp to fill this stage
        self.mbarrier.wait(
            full_barriers[stage],
            phase=consumer_phase,
            sem="relaxed",
            scope="cta",
        )
        self.tcgen05.mma(
            s_a[stage],
            s_b[stage].transpose(),
            t_acc,
            enable_input_d=offset_k != 0,
        )
        # commit signals empty_barriers: marks this stage as consumed
        self.tcgen05.commit(mbarrier=empty_barriers[stage])
        # advance stage; flip phase when wrapping to stage 0
        stage = (stage + 1) % self.stages
        consumer_phase ^= stage == 0

    # drain: wait for all in-flight MMA to finish
    flush_barrier = self.mbarrier.alloc(1)
    self.tcgen05.commit(mbarrier=flush_barrier)
    self.mbarrier.wait(flush_barrier, phase=0)

The MMA warp runs a matching loop:

  • mbarrier.wait() on full_barriers[stage] blocks until TMA data has arrived.

  • tcgen05.mma() computes on the loaded tiles.

  • tcgen05.commit() on empty_barriers[stage] signals that this stage is now free for the TMA warp to reuse.

  • The stage index advances modulo stages; the phase flips when wrapping.

  • After the loop, the flush barrier ensures all MMA writes to tensor memory are complete.

Performance

Warp specialization allows multiple MMAs to be in flight simultaneously, pushing tensor core utilization to ~79%. The complete source is at examples/blackwell_matmul/matmul_v3.py.

../../_images/plot_v3.svg

Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).

What’s Next

V3 achieves true overlap between TMA and MMA through warp specialization. However, the code still manages the pipeline state (stages, barriers, phases) inline, which becomes increasingly complex as we add more features.

In the next version, we refactor the pipeline logic into a reusable Pipeline class using tilus.Class, and add tile rasterization for better L2 cache locality.