4. Software Pipelining

This example demonstrates how to overlap computation and memory operations using software pipelining.

The Problem

Without pipelining, a typical matmul main loop looks like this:

for i in range(N):
    async_load(i)
    sync

    compute(i)
    sync

Data loading and computation execute sequentially within the thread block. Although the GPU can schedule other thread blocks on the same SM to hide latency, matrix multiplication kernels typically consume many registers and much shared memory, limiting occupancy. Software pipelining addresses this by overlapping the two phases within a single thread block.

The Pipelined Approach

The core idea is to start loading the next tile while computing the current one:

async_load(0)
for i in range(N):
    if i < N - 1:
        async_load(i + 1)
    compute(i)
    sync

This is generalized to multiple stages: by allocating num_stages copies of the shared memory buffers and using copy_async_commit_group() / copy_async_wait_group() for fine-grained pipeline control, the kernel keeps several loads in flight simultaneously.

For further reading, see ALCOP and Hidet.

Kernel Implementation

@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128)])
@tilus.autotune("block_k", [16, 32])
@tilus.autotune("num_stages", [3, 4, 5])
class MatmulV4(tilus.Script):
    def __init__(self, num_warps, block_m, block_n, block_k, num_stages):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.num_warps = num_warps
        self.num_stages = num_stages

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = self.num_warps

        block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
        offset_m: int32 = block_m * self.blockIdx.x
        offset_n: int32 = block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
        sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k])
        sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n])
        acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)

        for stage in range(self.num_stages - 1):
            offset_k = stage * self.block_k
            self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k])
            self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n])
            self.copy_async_commit_group()

        self.copy_async_wait_group(n=self.num_stages - 2)
        self.sync()

        current_stage: int32 = 0
        preload_stage: int32 = self.num_stages - 1
        for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages):
            # computation for current tile
            a = self.load_shared(sa[current_stage])
            b = self.load_shared(sb[current_stage])
            self.dot(a, b, acc, out=acc)

            # preload the next tile of A and B into shared memory
            preload_offset_k = offset_k + (self.num_stages - 1) * block_k
            self.copy_async(
                src=ga,
                dst=sa[preload_stage],
                offsets=[offset_m, preload_offset_k],
            )
            self.copy_async(
                src=gb,
                dst=sb[preload_stage],
                offsets=[preload_offset_k, offset_n],
            )
            self.copy_async_commit_group()

            # update the stage
            current_stage = (current_stage + 1) % self.num_stages
            preload_stage = (preload_stage + 1) % self.num_stages
            self.copy_async_wait_group(n=self.num_stages - 2)
            self.sync()

        self.free_shared(sa)
        self.free_shared(sb)

        casted_acc = self.cast(acc, dtype=float16)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])

Key points in the code above:

  • ``num_stages`` as a tuned parameter (line 54) – the autotune decorator explores 3, 4, and 5 stages.

  • Multi-stage shared memory (lines 82–83) – the shared tensors have an extra leading dimension of size num_stages, so each pipeline stage has its own buffer.

  • Prologue (lines 86–93) – the first num_stages - 1 tiles are preloaded before the main loop begins, using copy_async_commit_group() to group each pair of copies.

  • Main loop (lines 97–121) – each iteration computes on sa[current_stage] / sb[current_stage] while simultaneously issuing an async copy for the tile num_stages - 1 iterations ahead. copy_async_wait_group() with n=num_stages - 2 ensures only the oldest in-flight group must complete before its data is consumed.

  • Stage rotation (lines 118–119) – current_stage and preload_stage advance modulo num_stages.

Launch the Kernel

def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [
        [4096, 4096, 4096],
        [1024, 1024, 14336],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV4()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)

Note

The full source code for this example can be found at matmul_v4.py.