Split-K

2.1.6. Split-K

This example demonstrates how to implement a matrix multiplication kernel using split-K optimization in tilus.

In previous examples, we use a single thread block to compute a tile of the output matrix C. This approach works well for workloads with large m and n dimensions since there are enough C tiles to saturate the GPU. However, for workloads with small m and n dimensions and large k dimension, it’s more efficient to split the k dimension into multiple segments and assign each segment to a separate thread block. After that, we can aggregate the results from these thread blocks that compute the same C tile to get the final result.

There are mainly two ways to implement the split-K optimization: 1) using a separate kernel to perform the aggregation, or 2) implementing the aggregation logic in the same kernel that computes the C tile with semaphores. In this example, we will implement the second approach.

We will use several new tilus instructions:

global_tensor(dtype, shape, *[, layout])

Allocate a global tensor.

lock_semaphore(semaphore, value)

Lock semaphore with a specified value.

release_semaphore(semaphore, value)

Release semaphore with a specified value.

We use global_tensor() to create a global tensor that will be used to store the semaphores for each C tile. Its shape is [cdiv(m_size, block_m), cdiv(n_size, block_n)], where block_m and block_n are the dimensions of the C tile. All thread blocks that compute the same C tile will use the same semaphore to synchronize the aggregation of the results.

After we launched the kernel, all thread blocks will compute their accumulated result for the C tile. If k_size equals to 1024 and we split it into 4 segments (i.e., split_k_factor=4), then each thread block will compute over a k segment of size 256. We name the 4 blocks as 0, 1, 2, 3. The first thread block directly stores its result to the C matrix, while the other thread blocks will wait until the semaphore becomes to their block index. After the first thread block stores its result, it releases the semaphore with the value of 1, allowing the second thread block to aggregate its result with the first one. The second thread block will then store the aggregated result to the C matrix and release the semaphore with the value of 2, allowing the third thread block to aggregate its result with the first two. This process continues until all thread blocks have aggregated their results and stored the final result to the C matrix. The last thread block will release the semaphore with the value of 0, to satisfy the requirement of global_tensor() with requires_clean=True.

The following example implements above logic in the matrix multiplication kernel:

import tilus
from tilus import float16, float32, int32
from tilus.utils import cdiv


@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)])
@tilus.autotune("block_k", [16, 32])
@tilus.autotune("num_stages", [3, 4, 5])
@tilus.autotune("split_k_factor", [1, 4, 12, 16])
class MatmulV5(tilus.Script):
    def __init__(
        self,
        num_warps,
        block_m,
        block_n,
        block_k,
        num_stages,
        split_k_factor,
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.num_warps = num_warps
        self.num_stages = num_stages
        self.split_k_factor = split_k_factor

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),
            cdiv(n_size, self.block_n),
            self.split_k_factor,
        ]
        self.attrs.warps = self.num_warps

        # the k_size for each thread block
        block_k_size = (
            cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k
        )
        start_offset_k = self.blockIdx.z * block_k_size
        end_offset_k = min(start_offset_k + block_k_size, k_size)

        block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
        offset_m: int32 = block_m * self.blockIdx.x
        offset_n: int32 = block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
        sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k])
        sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n])
        acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)

        for stage in range(self.num_stages - 1):
            offset_k = start_offset_k + stage * self.block_k
            self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k])
            self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n])
            self.copy_async_commit_group()

        self.copy_async_wait_group(n=self.num_stages - 2)
        self.sync()

        current_stage: int32 = 0
        preload_stage: int32 = self.num_stages - 1
        for offset_k in self.range(
            start_offset_k, end_offset_k, block_k, unroll=self.num_stages
        ):
            # computation for current tile
            a = self.load_shared(sa[current_stage])
            b = self.load_shared(sb[current_stage])
            self.dot(a, b, acc, out=acc)

            # preload the next tile of A and B into shared memory
            preload_offset_k = offset_k + (self.num_stages - 1) * block_k
            if preload_offset_k < end_offset_k:
                self.copy_async(
                    src=ga,
                    dst=sa[preload_stage],
                    offsets=[offset_m, preload_offset_k],
                )
                self.copy_async(
                    src=gb,
                    dst=sb[preload_stage],
                    offsets=[preload_offset_k, offset_n],
                )
            self.copy_async_commit_group()

            # update the stage
            current_stage = (current_stage + 1) % self.num_stages
            preload_stage = (preload_stage + 1) % self.num_stages
            self.copy_async_wait_group(n=self.num_stages - 2)
            self.sync()

        # free the shared memory tensors for A and B
        self.free_shared(sa)
        self.free_shared(sb)

        # cast the accumulator to float16 and change the register tensor's layout
        sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n])
        casted_acc = self.cast(acc, dtype=float16)
        self.store_shared(sc, casted_acc)
        self.sync()
        rc = self.load_shared(sc)
        self.free_shared(sc)

        m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        if self.split_k_factor == 0:
            self.store_global(gc, rc, offsets=[offset_m, offset_n])
        else:
            semaphores = self.global_tensor(
                dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True
            )
            semaphore: ~int32 = ~semaphores[self.blockIdx.x, self.blockIdx.y]

            # load and accumulate the partial result in global memory
            if self.blockIdx.z > 0:
                self.lock_semaphore(semaphore, value=self.blockIdx.z)
                partial_rc = self.load_global(
                    gc, offsets=[offset_m, offset_n], shape=[block_m, block_n]
                )
                self.add(rc, partial_rc, out=rc)

            # store the result to global memory and release the semaphore
            self.store_global(gc, rc, offsets=[offset_m, offset_n])

            # release the semaphore
            self.sync()  # we need to make sure the previous store_global is finished
            self.release_semaphore(
                semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor
            )
import math

import pandas
import torch
from tilus.utils import benchmark_func


def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "gflops"]
    workloads = [
        [4096, 4096, 4096],
        [4096, 4096, 14336],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV5()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            flops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, flops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)
if __name__ == "__main__":
    main()
[Building] matmul_v5-4096-4096-d1:  26%|██████▉                    | 49/192 [00:01<00:03, 45.48it/s]
[Building] matmul_v5-4096-4096-d1:  29%|███████▉                   | 56/192 [00:01<00:02, 48.24it/s]
[Building] matmul_v5-4096-4096-d1:  30%|████████                   | 57/192 [00:01<00:03, 42.02it/s]
[Building] matmul_v5-4096-4096-d1:  34%|█████████▎                 | 66/192 [00:01<00:02, 48.87it/s]
[Building] matmul_v5-4096-4096-d1:  39%|██████████▌                | 75/192 [00:01<00:02, 40.83it/s]
[Building] matmul_v5-4096-4096-d1:  43%|███████████▌               | 82/192 [00:01<00:02, 45.07it/s]
[Building] matmul_v5-4096-4096-d1:  46%|████████████▌              | 89/192 [00:02<00:02, 42.72it/s]
[Building] matmul_v5-4096-4096-d1:  50%|█████████████▌             | 96/192 [00:02<00:02, 46.81it/s]
[Building] matmul_v5-4096-4096-d1:  55%|██████████████▏           | 105/192 [00:02<00:01, 53.63it/s]
[Building] matmul_v5-4096-4096-d1:  59%|███████████████▎          | 113/192 [00:02<00:01, 58.30it/s]
[Building] matmul_v5-4096-4096-d1:  62%|████████████████          | 119/192 [00:02<00:01, 51.78it/s]
[Building] matmul_v5-4096-4096-d1:  63%|████████████████▍         | 121/192 [00:02<00:02, 32.61it/s]
[Building] matmul_v5-4096-4096-d1:  71%|██████████████████▌       | 137/192 [00:02<00:01, 46.39it/s]
[Building] matmul_v5-4096-4096-d1:  80%|████████████████████▋     | 153/192 [00:03<00:00, 45.72it/s]
[Building] matmul_v5-4096-4096-d1:  81%|████████████████████▉     | 155/192 [00:03<00:00, 40.15it/s]
[Building] matmul_v5-4096-4096-d1:  88%|██████████████████████▉   | 169/192 [00:03<00:00, 39.12it/s]
[Building] matmul_v5-4096-4096-d1: 100%|██████████████████████████| 192/192 [00:03<00:00, 48.72it/s]

[Building] matmul_v5-4096-14336-d1:  26%|██████▊                   | 50/192 [00:01<00:02, 49.00it/s]
[Building] matmul_v5-4096-14336-d1:  29%|███████▍                  | 55/192 [00:01<00:02, 47.24it/s]
[Building] matmul_v5-4096-14336-d1:  30%|███████▋                  | 57/192 [00:01<00:04, 28.56it/s]
[Building] matmul_v5-4096-14336-d1:  34%|████████▊                 | 65/192 [00:01<00:03, 34.59it/s]
[Building] matmul_v5-4096-14336-d1:  39%|██████████                | 74/192 [00:01<00:02, 42.06it/s]
[Building] matmul_v5-4096-14336-d1:  43%|███████████               | 82/192 [00:01<00:02, 48.17it/s]
[Building] matmul_v5-4096-14336-d1:  47%|████████████▎             | 91/192 [00:02<00:01, 55.53it/s]
[Building] matmul_v5-4096-14336-d1:  52%|█████████████▍            | 99/192 [00:02<00:01, 60.40it/s]
[Building] matmul_v5-4096-14336-d1:  56%|█████████████▉           | 107/192 [00:02<00:01, 64.51it/s]
[Building] matmul_v5-4096-14336-d1:  60%|███████████████          | 116/192 [00:02<00:01, 69.00it/s]
[Building] matmul_v5-4096-14336-d1:  63%|███████████████▊         | 121/192 [00:02<00:01, 41.12it/s]
[Building] matmul_v5-4096-14336-d1:  64%|███████████████▉         | 122/192 [00:02<00:02, 32.92it/s]
[Building] matmul_v5-4096-14336-d1:  68%|█████████████████        | 131/192 [00:02<00:01, 44.36it/s]
[Building] matmul_v5-4096-14336-d1:  71%|█████████████████▊       | 137/192 [00:03<00:01, 45.08it/s]
[Building] matmul_v5-4096-14336-d1:  76%|██████████████████▉      | 145/192 [00:03<00:00, 52.80it/s]
[Building] matmul_v5-4096-14336-d1:  80%|███████████████████▉     | 153/192 [00:03<00:00, 50.47it/s]
[Building] matmul_v5-4096-14336-d1:  81%|████████████████████▏    | 155/192 [00:03<00:01, 34.84it/s]
[Building] matmul_v5-4096-14336-d1:  88%|██████████████████████   | 169/192 [00:03<00:00, 56.58it/s]
[Building] matmul_v5-4096-14336-d1:  92%|██████████████████████▉  | 176/192 [00:03<00:00, 59.54it/s]
[Building] matmul_v5-4096-14336-d1:  96%|████████████████████████ | 185/192 [00:03<00:00, 57.79it/s]
[Building] matmul_v5-4096-14336-d1: 100%|█████████████████████████| 192/192 [00:03<00:00, 48.48it/s]
      m     n      k   name  latency (ms)      gflops
0  4096  4096   4096  torch      0.863824  159.105271
1  4096  4096   4096  tilus      0.867232  158.480031
2  4096  4096  14336  torch      2.987632  161.009231
3  4096  4096  14336  tilus      2.857008  168.370666

Total running time of the script: (0 minutes 10.040 seconds)

Gallery generated by Sphinx-Gallery