5. Split-K¶
In previous examples each output tile of C is computed by a single thread block that iterates over the entire K dimension. This works well when M and N are large enough to saturate the GPU. However, for workloads with small M and N but large K, there are not enough output tiles to keep all SMs busy.
Split-K addresses this by partitioning the K dimension into
split_k_factor segments, assigning each segment to a separate thread block.
The partial results are then aggregated in-place using semaphore-based
synchronization.
New Instructions¶
This example introduces three new tilus instructions:
global_tensor()– allocates a global tensor managed by tilus. Here it stores one semaphore per output tile. Therequires_clean=Trueflag guarantees the tensor is zero-initialized before each kernel launch.lock_semaphore()– spins until the semaphore reaches the expectedvalue, then proceeds. This ensures blocks aggregate in the correct order.release_semaphore()– sets the semaphore to a new value, unblocking the next waiting block.
Aggregation Protocol¶
Suppose split_k_factor=4, producing blocks 0, 1, 2, 3 for the same output
tile:
Block 0 stores its partial result directly to C (no lock needed). It then releases the semaphore with value 1.
Block 1 spins on
lock_semaphore()until the semaphore equals 1. It loads the partial C, adds its own contribution, stores the sum back, and releases with value 2.Block 2 and Block 3 follow the same pattern.
The last block releases the semaphore with value 0, resetting it for
requires_clean=True.
Kernel Implementation¶
@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)])
@tilus.autotune("block_k", [16, 32])
@tilus.autotune("num_stages", [3, 4, 5])
@tilus.autotune("split_k_factor", [1, 4, 12, 16])
class MatmulV5(tilus.Script):
def __init__(
self,
num_warps,
block_m,
block_n,
block_k,
num_stages,
split_k_factor,
):
super().__init__()
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
self.num_warps = num_warps
self.num_stages = num_stages
self.split_k_factor = split_k_factor
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [
cdiv(m_size, self.block_m),
cdiv(n_size, self.block_n),
self.split_k_factor,
]
self.attrs.warps = self.num_warps
# the k_size for each thread block
block_k_size = (
cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k
)
start_offset_k = self.blockIdx.z * block_k_size
end_offset_k = min(start_offset_k + block_k_size, k_size)
block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
offset_m: int32 = block_m * self.blockIdx.x
offset_n: int32 = block_n * self.blockIdx.y
ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n])
acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)
for stage in range(self.num_stages - 1):
offset_k = start_offset_k + stage * self.block_k
self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k])
self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n])
self.copy_async_commit_group()
self.copy_async_wait_group(n=self.num_stages - 2)
self.sync()
current_stage: int32 = 0
preload_stage: int32 = self.num_stages - 1
for offset_k in self.range(
start_offset_k, end_offset_k, block_k, unroll=self.num_stages
):
# computation for current tile
a = self.load_shared(sa[current_stage])
b = self.load_shared(sb[current_stage])
self.dot(a, b, acc, out=acc)
# preload the next tile of A and B into shared memory
preload_offset_k = offset_k + (self.num_stages - 1) * block_k
if preload_offset_k < end_offset_k:
self.copy_async(
src=ga,
dst=sa[preload_stage],
offsets=[offset_m, preload_offset_k],
)
self.copy_async(
src=gb,
dst=sb[preload_stage],
offsets=[preload_offset_k, offset_n],
)
self.copy_async_commit_group()
# update the stage
current_stage = (current_stage + 1) % self.num_stages
preload_stage = (preload_stage + 1) % self.num_stages
self.copy_async_wait_group(n=self.num_stages - 2)
self.sync()
# free the shared memory tensors for A and B
self.free_shared(sa)
self.free_shared(sb)
# cast the accumulator to float16 and change the register tensor's layout
sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n])
casted_acc = self.cast(acc, dtype=float16)
self.store_shared(sc, casted_acc)
self.sync()
rc = self.load_shared(sc)
self.free_shared(sc)
m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n)
gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
if self.split_k_factor == 0:
self.store_global(gc, rc, offsets=[offset_m, offset_n])
else:
semaphores = self.global_tensor(
dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True
)
semaphore: ~int32 = semaphores[self.blockIdx.x, self.blockIdx.y].item_ptr()
# load and accumulate the partial result in global memory
if self.blockIdx.z > 0:
self.lock_semaphore(semaphore, value=self.blockIdx.z)
partial_rc = self.load_global(
gc, offsets=[offset_m, offset_n], shape=[block_m, block_n]
)
self.add(rc, partial_rc, out=rc)
# store the result to global memory and release the semaphore
self.store_global(gc, rc, offsets=[offset_m, offset_n])
# release the semaphore
self.sync() # we need to make sure the previous store_global is finished
self.release_semaphore(
semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor
)
Key points in the code above:
Three-dimensional grid (lines 85–89) – the third grid dimension is
split_k_factor, andself.blockIdx.zidentifies which K-segment a block processes.Per-block K range (lines 93–97) – each block computes over
[start_offset_k, end_offset_k), a contiguous slice of the K dimension rounded to multiples ofblock_k.Pipelined main loop (lines 109–147) – identical in structure to V4, with the loop bounds narrowed to the block’s K-segment.
Layout change via shared memory (lines 154–159) – after accumulation, the result is cast to float16, written to shared memory, and reloaded. This changes the register tensor layout to one suitable for the global store and subsequent aggregation.
Semaphore-guarded aggregation (lines 166–186):
When
split_k_factor > 1, aglobal_tensor()allocates one int32 semaphore per output tile.Block 0 stores directly; blocks 1+ call
lock_semaphore(), load the partial result, accumulate, and store.Each block calls
release_semaphore()with the next expected value (wrapping to 0 for the last block).
Launch the Kernel¶
def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [
[4096, 4096, 4096],
[4096, 4096, 14336],
]
rows = []
for m, n, k in workloads:
matmul = MatmulV5()
a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
c_expect = a @ b
matmul(m, n, k, a, b, c_actual)
# check correctness
torch.testing.assert_close(c_expect, c_actual)
# benchmark
for name, func in [
("torch", lambda: torch.matmul(a, b, out=c_expect)),
("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m * n * k / latency * 1e-9
rows.append([m, n, k, name, latency, tflops])
df = pandas.DataFrame(rows, columns=headers)
print(df)
Note
The full source code for this example can be found at
matmul_v5.py.