4. Tile Rasterization and Pipeline Abstraction¶
V3 introduced warp specialization with separate TMA and MMA warps. This version adds two improvements:
Tile rasterization — reorder tile assignments so that consecutive thread blocks process tiles that share B columns, improving L2 cache reuse.
Pipeline abstraction — encapsulate the producer-consumer barrier logic from V3 into a reusable
Pipelineclass usingtilus.Class, making it easy to add more pipelines as the kernel grows (e.g., separate TMA and MMA pipelines in later versions).
The Full Kernel¶
class Pipeline(tilus.Class):
def __init__(
self,
num_stages: int,
producer_arrive_count: int = 1,
consumer_arrive_count: int = 1,
):
self.num_stages: int = num_stages
self.empty_barriers = self.mbarrier.alloc(
[consumer_arrive_count for _ in range(num_stages)]
)
self.full_barriers = self.mbarrier.alloc(
[producer_arrive_count for _ in range(num_stages)]
)
self.producer_stage: int32 = 0
self.consumer_stage: int32 = 0
self.producer_phase: uint32 = self.mbarrier.producer_initial_phase
self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase
def producer_acquire(self):
# wait until the current stage is free (consumer has finished with it)
self.mbarrier.wait(
barrier=self.empty_barriers[self.producer_stage],
phase=self.producer_phase,
sem="relaxed",
scope="cta",
)
def producer_barrier(self) -> RegisterTensor:
# return the barrier to signal when the producer has filled this stage
return self.full_barriers[self.producer_stage]
def producer_advance(self):
# advance to the next stage; flip phase when wrapping around
self.producer_stage = (self.producer_stage + 1) % self.num_stages
self.producer_phase = self.producer_phase ^ (self.producer_stage == 0)
def consumer_acquire(self):
# wait until the current stage is filled (producer has loaded data)
self.mbarrier.wait(
barrier=self.full_barriers[self.consumer_stage],
phase=self.consumer_phase,
sem="relaxed",
scope="cta",
)
def consumer_barrier(self) -> RegisterTensor:
# return the barrier to signal when the consumer has consumed this stage
return self.empty_barriers[self.consumer_stage]
def consumer_advance(self):
# advance to the next stage; flip phase when wrapping around
self.consumer_stage = (self.consumer_stage + 1) % self.num_stages
self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0)
@tilus.autotune(
"block_m, block_n, e_block_n", [[128, 64, 16], [128, 128, 16], [128, 256, 16]]
)
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
@tilus.autotune("swizzle_size", [1, 4, 8])
class BlackwellMatmulV4(tilus.Script):
def __init__(
self,
block_m: int,
block_n: int,
block_k: int,
stages: int,
e_block_n: int,
swizzle_size: int,
):
super().__init__()
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
self.stages = stages
self.e_block_n = e_block_n
self.swizzle_size = swizzle_size
def compute_block_coord(
self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int
):
"""Map a 1D linear block index to 2D (m_block, n_block) with swizzle grouping.
Tiles within a swizzle group share N-columns, improving L2 cache reuse
for the B matrix.
"""
swizzle_size = self.swizzle_size
tiles_per_group = num_m_blocks * swizzle_size
group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group)
first_n = group_idx * swizzle_size
m_block: int32 = 0
n_block: int32 = 0
# When num_n_blocks is divisible by swizzle_size, all groups are full and
# last_group_width is never used. Use swizzle_size as a safe fallback to
# avoid division-by-zero in the precompute.
remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size
last_group_width = remainder if remainder > 0 else swizzle_size
if first_n + swizzle_size <= num_n_blocks:
# Full group: swizzle_size is a compile-time constant
m_block, r = self.fast_divmod(in_group_idx, swizzle_size)
n_block = first_n + r
else:
# Last group: divisor is num_n_blocks % swizzle_size, which is grid-constant
m_block, r = self.fast_divmod(in_group_idx, last_group_width)
n_block = first_n + r
return m_block, n_block
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
block_m = self.block_m
block_n = self.block_n
block_k = self.block_k
stages = self.stages
e_block_n = self.e_block_n
num_m_blocks = cdiv(m_size, block_m)
num_n_blocks = cdiv(n_size, block_n)
# 1D grid: tile rasterization maps linear index to 2D coordinates
self.attrs.blocks = num_m_blocks * num_n_blocks
self.attrs.warps = 4
# tile rasterization: swizzle for better L2 cache reuse of B columns
m_block, n_block = self.compute_block_coord(
self.blockIdx.x, num_m_blocks, num_n_blocks
)
offset_m: int32 = m_block * block_m
offset_n: int32 = n_block * block_n
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
s_a = self.shared_tensor(dtype=float16, shape=[stages, block_m, block_k])
s_b = self.shared_tensor(dtype=float16, shape=[stages, block_n, block_k])
t_acc = self.tcgen05.alloc(dtype=float32, shape=[block_m, block_n])
# Pipeline class encapsulates barrier/phase/stage management from V3
tma_pipe = Pipeline(stages)
flush_barrier = self.mbarrier.alloc(1)
with self.thread_group(thread_begin=0, num_threads=32):
for offset_k in self.range(0, k_size, block_k, unroll=stages):
tma_pipe.producer_acquire()
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_pipe.producer_barrier(),
transaction_bytes=s_a[tma_pipe.producer_stage].nbytes
+ s_b[tma_pipe.producer_stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[tma_pipe.producer_stage],
offsets=[offset_m, offset_k],
mbarrier=tma_pipe.producer_barrier(),
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[tma_pipe.producer_stage],
offsets=[offset_n, offset_k],
mbarrier=tma_pipe.producer_barrier(),
)
tma_pipe.producer_advance()
with self.thread_group(thread_begin=32, num_threads=32):
for offset_k in self.range(0, k_size, block_k, unroll=stages):
tma_pipe.consumer_acquire()
self.tcgen05.mma(
s_a[tma_pipe.consumer_stage],
s_b[tma_pipe.consumer_stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
self.tcgen05.commit(mbarrier=tma_pipe.consumer_barrier())
tma_pipe.consumer_advance()
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)
self.sync()
# TMA epilogue: tmem -> register -> shared -> global (via TMA)
s_c = self.shared_tensor(dtype=float16, shape=[block_m, e_block_n])
for e_offset_n in range(0, block_n, e_block_n):
# slice a e_block_n-wide column from the accumulator
t_acc_slice = self.tcgen05.slice(
t_acc,
offsets=[0, e_offset_n],
shape=[block_m, e_block_n],
dims=[0, 1],
)
r_acc = self.tcgen05.load(t_acc_slice)
self.tcgen05.wait_load()
self.store_shared(s_c, r_acc.to(float16))
# fence: make generic-proxy writes visible to async-proxy (TMA)
self.fence.proxy_async(space="shared")
self.sync()
with self.single_warp():
# TMA bulk store from shared to global
self.tma.shared_to_global(
s_c,
g_c,
offsets=[offset_m, offset_n + e_offset_n],
dims=[0, 1],
)
self.tma.commit_group()
self.tma.wait_group(n=0, read=True)
self.sync()
self.sync()
self.tcgen05.dealloc(t_acc)
What Changed from V3¶
V3 |
V4 |
|
|---|---|---|
Grid layout |
2D grid (column-major) |
1D grid with swizzled tile rasterization for L2 locality |
Barrier management |
Manual barriers, phases, and stage indices |
|
Integer division |
Standard |
|
New instructions |
|
Tile Rasterization¶
In matmul, each output tile (m, n) needs a row-strip of A (row m) and a
column-strip of B (column n). A rows are unique per tile, but B columns are
shared across all tiles in the same N-column. This means B data loaded by
one tile can be reused by other tiles that share the same N-column — if the
data is still in L2 cache.
The question is: how do we order the tiles so that the GPU’s L2 cache is used effectively? The key idea is to minimize the working set — the number of distinct A rows and B columns that must be in L2 simultaneously for the currently active thread blocks.
An 8 × 8 tile grid with a wave of 16 active blocks (blue cells). Orange bars on the left mark active A rows; blue bars on top mark active B columns. Swizzle achieves a smaller L2 working set (8 vs 10 tiles) for the same number of active blocks.¶
A concrete example. Consider an 8 × 8 grid of tiles with
block_m = block_n = 128 and K = 8192. Each A row-strip and each B
column-strip is 128 × 8192 in fp16 = 2 MB.
Suppose the GPU can run 16 thread blocks concurrently (one wave). With
column-major ordering (the naive 2D grid from V3), blockIdx.x walks
down M first, so the 16 active blocks fill columns 0 and 1 completely (8 rows
each). As shown in the figure:
Active A rows: 8 (all rows touched). Working set: 8 × 2 MB = 16 MB.
Active B columns: 2 (cols 0–1). Working set: 2 × 2 MB = 4 MB.
Total L2 working set: 20 MB (10 unique tiles).
With swizzled ordering (swizzle_size = 4), the same 16 blocks are
mapped to rows 0–3 of columns 0–3 — a compact 4 × 4 square in the
top-left of the grid:
Active A rows: 4 (rows 0–3 only). Working set: 4 × 2 MB = 8 MB.
Active B columns: 4 (cols 0–3). Working set: 4 × 2 MB = 8 MB.
Total L2 working set: 16 MB (8 unique tiles) — 20% smaller.
The reduction comes from balancing A and B: column-major packs all 8 rows into 2 columns (8 + 2 = 10 tiles), while swizzle distributes the same 16 blocks across 4 rows and 4 columns (4 + 4 = 8 tiles). Fewer unique tiles in L2 means higher hit rates and less off-chip memory traffic.
Formulation¶
Given a grid of num_m_blocks × num_n_blocks tiles and a
swizzle_size, the swizzled mapping works as follows:
Divide the N-columns into groups of
swizzle_size: groupgroup_idxcovers columns[group_idx * swizzle_size, group_idx * swizzle_size + swizzle_size).Within each group, the tile shape is
[num_m_blocks, swizzle_size], and tiles are assigned in row-major order:(m_block=0, n_block=0), (m_block=0, n_block=1), ..., (m_block=1, n_block=0), ...
This makes consecutive thread blocks touch nearby rows and columns, producing a
more compact active region in the grid. The swizzle_size controls the
trade-off between A and B working sets: a larger value keeps more B columns
active (increasing B working set) but fewer A rows active (decreasing A working
set). The optimal value depends on the problem size and L2 cache capacity, so
it is autotuned.
Implementation¶
The grid is launched as 1D, and compute_block_coord remaps each
blockIdx.x to the swizzled (m, n) coordinates:
num_m_blocks = cdiv(m_size, block_m)
num_n_blocks = cdiv(n_size, block_n)
# 1D grid: tile rasterization maps linear index to 2D coordinates
self.attrs.blocks = num_m_blocks * num_n_blocks
self.attrs.warps = 4
# tile rasterization: swizzle for better L2 cache reuse of B columns
m_block, n_block = self.compute_block_coord(
self.blockIdx.x, num_m_blocks, num_n_blocks
)
offset_m: int32 = m_block * block_m
offset_n: int32 = n_block * block_n
The mapping logic:
def compute_block_coord(
self, linear_idx: int32, num_m_blocks: int32, num_n_blocks: int
):
"""Map a 1D linear block index to 2D (m_block, n_block) with swizzle grouping.
Tiles within a swizzle group share N-columns, improving L2 cache reuse
for the B matrix.
"""
swizzle_size = self.swizzle_size
tiles_per_group = num_m_blocks * swizzle_size
group_idx, in_group_idx = self.fast_divmod(linear_idx, tiles_per_group)
first_n = group_idx * swizzle_size
m_block: int32 = 0
n_block: int32 = 0
# When num_n_blocks is divisible by swizzle_size, all groups are full and
# last_group_width is never used. Use swizzle_size as a safe fallback to
# avoid division-by-zero in the precompute.
remainder = num_n_blocks - num_n_blocks // swizzle_size * swizzle_size
last_group_width = remainder if remainder > 0 else swizzle_size
if first_n + swizzle_size <= num_n_blocks:
# Full group: swizzle_size is a compile-time constant
m_block, r = self.fast_divmod(in_group_idx, swizzle_size)
n_block = first_n + r
else:
# Last group: divisor is num_n_blocks % swizzle_size, which is grid-constant
m_block, r = self.fast_divmod(in_group_idx, last_group_width)
n_block = first_n + r
return m_block, n_block
When num_n_blocks is not divisible by swizzle_size, the last group has
fewer than swizzle_size columns. The code handles this by computing
last_group_width (the remainder) and using it as the divisor for tiles in
the last group, ensuring correct (m_block, n_block) mapping.
Hint
Integer division and modulo are expensive on GPUs. When the divisor is a
compile-time constant (like swizzle_size), the compiler converts division
into a mul + shift automatically, so normal // and % are fine. For
non-constant divisors, the NVIDIA compiler falls back to floating-point
arithmetic with int-to-float and float-to-int conversions, which is slow.
For grid-constant divisors (like tiles_per_group, which is the same
for all thread blocks but not known at compile time),
fast_divmod() implements the fast divmod algorithm using
integer mul + shift, precomputing the magic number once per grid launch.
Pipeline Abstraction¶
On Blackwell, mbarriers are the fundamental mechanism for tracking completion of asynchronous work, and shared memory (or tensor memory) serves as the buffer for data in transit. When the producer and consumer operate at different speeds — which is always the case in practice — we need a pipeline to decouple them.
A pipeline has three components:
Producer — generates data and writes it into a buffer slot when one is available.
Consumer — reads data from a buffer slot when one is filled.
Ring buffer — a fixed number of slots (
stages) that the producer and consumer cycle through independently.
Each slot has two mbarriers:
full_barrier — signaled when the producer has filled the slot. The consumer waits on this.
empty_barrier — signaled when the consumer has consumed the slot. The producer waits on this.
The producer and consumer each maintain a stage pointer (which slot they are currently working on) and a phase variable (for the mbarrier of their current slot). Both advance through the ring buffer independently, synchronized only by the mbarrier signals.
The Pipeline with 5 stages. The producer is filling slot 3; the consumer is consuming slot 1. Slots 2 is full (waiting to be consumed); slots 0 and 4 are empty (waiting to be filled). The ✓/✗ marks indicate whether each slot’s full/empty mbarrier has completed.¶
In V3, we managed all this state manually (barriers, phases, stage indices). The
Pipeline class below encapsulates the bookkeeping into a clean API. Note that
this is not a built-in part of Tilus; it is constructed from existing
instructions (mbarrier.alloc, mbarrier.wait, etc.) as a user-level
helper. You can always manage the mbarriers manually as in V3 if you prefer.
class Pipeline(tilus.Class):
def __init__(
self,
num_stages: int,
producer_arrive_count: int = 1,
consumer_arrive_count: int = 1,
):
self.num_stages: int = num_stages
self.empty_barriers = self.mbarrier.alloc(
[consumer_arrive_count for _ in range(num_stages)]
)
self.full_barriers = self.mbarrier.alloc(
[producer_arrive_count for _ in range(num_stages)]
)
self.producer_stage: int32 = 0
self.consumer_stage: int32 = 0
self.producer_phase: uint32 = self.mbarrier.producer_initial_phase
self.consumer_phase: uint32 = self.mbarrier.consumer_initial_phase
def producer_acquire(self):
# wait until the current stage is free (consumer has finished with it)
self.mbarrier.wait(
barrier=self.empty_barriers[self.producer_stage],
phase=self.producer_phase,
sem="relaxed",
scope="cta",
)
def producer_barrier(self) -> RegisterTensor:
# return the barrier to signal when the producer has filled this stage
return self.full_barriers[self.producer_stage]
def producer_advance(self):
# advance to the next stage; flip phase when wrapping around
self.producer_stage = (self.producer_stage + 1) % self.num_stages
self.producer_phase = self.producer_phase ^ (self.producer_stage == 0)
def consumer_acquire(self):
# wait until the current stage is filled (producer has loaded data)
self.mbarrier.wait(
barrier=self.full_barriers[self.consumer_stage],
phase=self.consumer_phase,
sem="relaxed",
scope="cta",
)
def consumer_barrier(self) -> RegisterTensor:
# return the barrier to signal when the consumer has consumed this stage
return self.empty_barriers[self.consumer_stage]
def consumer_advance(self):
# advance to the next stage; flip phase when wrapping around
self.consumer_stage = (self.consumer_stage + 1) % self.num_stages
self.consumer_phase = self.consumer_phase ^ (self.consumer_stage == 0)
The Pipeline class inherits from tilus.Class, which works like
Script but for helper objects that are not kernels themselves.
It can allocate barriers, shared tensors, and use all Tilus instructions.
The usage in the kernel becomes straightforward:
tma_pipe = Pipeline(stages)
# TMA warp (producer)
tma_pipe.producer_acquire() # wait for empty slot
# ... issue TMA loads with tma_pipe.producer_barrier() ...
tma_pipe.producer_advance() # move to next stage
# MMA warp (consumer)
tma_pipe.consumer_acquire() # wait for filled slot
# ... issue MMA ...
self.tcgen05.commit(mbarrier=tma_pipe.consumer_barrier()) # signal slot consumed
tma_pipe.consumer_advance() # move to next stage
This pattern is reusable: later versions add a second pipeline (mma_pipe)
between the MMA and epilogue stages, using the same Pipeline class.
Walkthrough¶
The kernel structure is the same as V3 (TMA warp + MMA warp), but now using
the Pipeline class. The epilogue is unchanged from V1.
Setup¶
num_m_blocks = cdiv(m_size, block_m)
num_n_blocks = cdiv(n_size, block_n)
# 1D grid: tile rasterization maps linear index to 2D coordinates
self.attrs.blocks = num_m_blocks * num_n_blocks
self.attrs.warps = 4
# tile rasterization: swizzle for better L2 cache reuse of B columns
m_block, n_block = self.compute_block_coord(
self.blockIdx.x, num_m_blocks, num_n_blocks
)
offset_m: int32 = m_block * block_m
offset_n: int32 = n_block * block_n
g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
s_a = self.shared_tensor(dtype=float16, shape=[stages, block_m, block_k])
s_b = self.shared_tensor(dtype=float16, shape=[stages, block_n, block_k])
t_acc = self.tcgen05.alloc(dtype=float32, shape=[block_m, block_n])
# Pipeline class encapsulates barrier/phase/stage management from V3
tma_pipe = Pipeline(stages)
flush_barrier = self.mbarrier.alloc(1)
Key differences from V3:
The grid is 1D:
blocks = num_m_blocks * num_n_blocks.compute_block_coordmaps the linear index to swizzled (m, n) coordinates.Pipeline(stages)replaces manual barrier/phase/stage management.
TMA and MMA Warps¶
with self.thread_group(thread_begin=0, num_threads=32):
for offset_k in self.range(0, k_size, block_k, unroll=stages):
tma_pipe.producer_acquire()
with self.single_thread():
self.mbarrier.arrive_and_expect_tx(
tma_pipe.producer_barrier(),
transaction_bytes=s_a[tma_pipe.producer_stage].nbytes
+ s_b[tma_pipe.producer_stage].nbytes,
)
self.tma.global_to_shared(
src=g_a,
dst=s_a[tma_pipe.producer_stage],
offsets=[offset_m, offset_k],
mbarrier=tma_pipe.producer_barrier(),
)
self.tma.global_to_shared(
src=g_b,
dst=s_b[tma_pipe.producer_stage],
offsets=[offset_n, offset_k],
mbarrier=tma_pipe.producer_barrier(),
)
tma_pipe.producer_advance()
with self.thread_group(thread_begin=32, num_threads=32):
for offset_k in self.range(0, k_size, block_k, unroll=stages):
tma_pipe.consumer_acquire()
self.tcgen05.mma(
s_a[tma_pipe.consumer_stage],
s_b[tma_pipe.consumer_stage].transpose(),
t_acc,
enable_input_d=offset_k != 0,
)
self.tcgen05.commit(mbarrier=tma_pipe.consumer_barrier())
tma_pipe.consumer_advance()
self.tcgen05.commit(mbarrier=flush_barrier)
self.mbarrier.wait(flush_barrier, phase=0)
The logic is identical to V3, but expressed through the Pipeline API:
producer_acquire / producer_barrier / producer_advance for the TMA
warp, and consumer_acquire / consumer_barrier / consumer_advance
for the MMA warp.
Performance¶
Tile rasterization improves L2 cache reuse for B tiles, and the Pipeline abstraction simplifies the code without sacrificing performance. The complete source is at examples/blackwell_matmul/matmul_v4.py.
Blackwell matmul performance on B200 (M=N=K=8192, fp16). TFLOPS derived from NCU profiling. Peak TFLOPS estimated from cuBLAS tensor core utilization (96.6%).¶
What’s Next¶
V4 is a well-optimized single-CTA kernel. To push performance further, we need to look beyond the single thread block.
In the next version, we introduce CLC (Cluster Launch Control) for persistent kernels — each CTA processes multiple output tiles dynamically via hardware scheduling, avoiding kernel launch overhead. We also add a pipelined epilogue and expand to 4 warp roles (TMA, MMA, scheduler, epilogue).