5. Split-K

In previous examples each output tile of C is computed by a single thread block that iterates over the entire K dimension. This works well when M and N are large enough to saturate the GPU. However, for workloads with small M and N but large K, there are not enough output tiles to keep all SMs busy.

Split-K addresses this by partitioning the K dimension into split_k_factor segments, assigning each segment to a separate thread block. The partial results are then aggregated in-place using semaphore-based synchronization.

New Instructions

This example introduces three new tilus instructions:

  • global_tensor() – allocates a global tensor managed by tilus. Here it stores one semaphore per output tile. The requires_clean=True flag guarantees the tensor is zero-initialized before each kernel launch.

  • lock_semaphore() – spins until the semaphore reaches the expected value, then proceeds. This ensures blocks aggregate in the correct order.

  • release_semaphore() – sets the semaphore to a new value, unblocking the next waiting block.

Aggregation Protocol

Suppose split_k_factor=4, producing blocks 0, 1, 2, 3 for the same output tile:

  1. Block 0 stores its partial result directly to C (no lock needed). It then releases the semaphore with value 1.

  2. Block 1 spins on lock_semaphore() until the semaphore equals 1. It loads the partial C, adds its own contribution, stores the sum back, and releases with value 2.

  3. Block 2 and Block 3 follow the same pattern.

  4. The last block releases the semaphore with value 0, resetting it for requires_clean=True.

Kernel Implementation

@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)])
@tilus.autotune("block_k", [16, 32])
@tilus.autotune("num_stages", [3, 4, 5])
@tilus.autotune("split_k_factor", [1, 4, 12, 16])
class MatmulV5(tilus.Script):
    def __init__(
        self,
        num_warps,
        block_m,
        block_n,
        block_k,
        num_stages,
        split_k_factor,
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.num_warps = num_warps
        self.num_stages = num_stages
        self.split_k_factor = split_k_factor

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),
            cdiv(n_size, self.block_n),
            self.split_k_factor,
        ]
        self.attrs.warps = self.num_warps

        # the k_size for each thread block
        block_k_size = (
            cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k
        )
        start_offset_k = self.blockIdx.z * block_k_size
        end_offset_k = min(start_offset_k + block_k_size, k_size)

        block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
        offset_m: int32 = block_m * self.blockIdx.x
        offset_n: int32 = block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
        sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k])
        sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n])
        acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)

        for stage in range(self.num_stages - 1):
            offset_k = start_offset_k + stage * self.block_k
            self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k])
            self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n])
            self.copy_async_commit_group()

        self.copy_async_wait_group(n=self.num_stages - 2)
        self.sync()

        current_stage: int32 = 0
        preload_stage: int32 = self.num_stages - 1
        for offset_k in self.range(
            start_offset_k, end_offset_k, block_k, unroll=self.num_stages
        ):
            # computation for current tile
            a = self.load_shared(sa[current_stage])
            b = self.load_shared(sb[current_stage])
            self.dot(a, b, acc, out=acc)

            # preload the next tile of A and B into shared memory
            preload_offset_k = offset_k + (self.num_stages - 1) * block_k
            if preload_offset_k < end_offset_k:
                self.copy_async(
                    src=ga,
                    dst=sa[preload_stage],
                    offsets=[offset_m, preload_offset_k],
                )
                self.copy_async(
                    src=gb,
                    dst=sb[preload_stage],
                    offsets=[preload_offset_k, offset_n],
                )
            self.copy_async_commit_group()

            # update the stage
            current_stage = (current_stage + 1) % self.num_stages
            preload_stage = (preload_stage + 1) % self.num_stages
            self.copy_async_wait_group(n=self.num_stages - 2)
            self.sync()

        # free the shared memory tensors for A and B
        self.free_shared(sa)
        self.free_shared(sb)

        # cast the accumulator to float16 and change the register tensor's layout
        sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n])
        casted_acc = self.cast(acc, dtype=float16)
        self.store_shared(sc, casted_acc)
        self.sync()
        rc = self.load_shared(sc)
        self.free_shared(sc)

        m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        if self.split_k_factor == 0:
            self.store_global(gc, rc, offsets=[offset_m, offset_n])
        else:
            semaphores = self.global_tensor(
                dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True
            )
            semaphore: ~int32 = semaphores[self.blockIdx.x, self.blockIdx.y].item_ptr()

            # load and accumulate the partial result in global memory
            if self.blockIdx.z > 0:
                self.lock_semaphore(semaphore, value=self.blockIdx.z)
                partial_rc = self.load_global(
                    gc, offsets=[offset_m, offset_n], shape=[block_m, block_n]
                )
                self.add(rc, partial_rc, out=rc)

            # store the result to global memory and release the semaphore
            self.store_global(gc, rc, offsets=[offset_m, offset_n])

            # release the semaphore
            self.sync()  # we need to make sure the previous store_global is finished
            self.release_semaphore(
                semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor
            )

Key points in the code above:

  • Three-dimensional grid (lines 85–89) – the third grid dimension is split_k_factor, and self.blockIdx.z identifies which K-segment a block processes.

  • Per-block K range (lines 93–97) – each block computes over [start_offset_k, end_offset_k), a contiguous slice of the K dimension rounded to multiples of block_k.

  • Pipelined main loop (lines 109–147) – identical in structure to V4, with the loop bounds narrowed to the block’s K-segment.

  • Layout change via shared memory (lines 154–159) – after accumulation, the result is cast to float16, written to shared memory, and reloaded. This changes the register tensor layout to one suitable for the global store and subsequent aggregation.

  • Semaphore-guarded aggregation (lines 166–186):

    • When split_k_factor > 1, a global_tensor() allocates one int32 semaphore per output tile.

    • Block 0 stores directly; blocks 1+ call lock_semaphore(), load the partial result, accumulate, and store.

    • Each block calls release_semaphore() with the next expected value (wrapping to 0 for the last block).

Launch the Kernel

def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [
        [4096, 4096, 4096],
        [4096, 4096, 14336],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV5()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)

Note

The full source code for this example can be found at matmul_v5.py.