2. Auto-tuning

In previous versions of the matmul kernel, we manually set the hyperparameters such as block_m, block_n, and block_k. However, these hyperparameters can significantly affect the performance of the kernel, and finding the optimal values for them can be a tedious and time-consuming process.

Tilus provides the tilus.autotune() decorator to annotate the search space of the hyperparameters and let tilus automatically search for the best configuration.

The decorator accepts parameter names and a list of values. When multiple @tilus.autotune decorators are stacked, tilus forms the Cartesian product of all value lists and tries every combination. At the first invocation the kernel is compiled for each configuration, benchmarked on the actual arguments, and the fastest configuration is selected automatically. Subsequent calls reuse the winner.

@tilus.autotune("arg_name1", [v11, v12, v13])
@tilus.autotune("arg_name2, arg_name3", [(v21, v31), (v22, v32)])
class AwesomeKernel(tilus.Script):
    def __init__(self, user_arg, arg_name1, arg_name2, arg_name3):
        super().__init__()
        ...

When instantiating the class, only the non-tuned arguments are provided – the tuned parameters are filled in automatically by the autotuning engine.

Imports

import tilus
from tilus import float16, float32, int32
from tilus.utils import cdiv

Annotate the Search Space

Reusing the same kernel implementation as in V1, we add tilus.autotune() decorators for num_warps, block_m/block_n, and block_k:

@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128)])
@tilus.autotune("block_k", [16, 32])
class MatmulV2(tilus.Script):
    def __init__(
        self,
        num_warps,
        block_m,
        block_n,
        block_k,
    ):
        super().__init__()
        self.num_warps = num_warps
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),
            cdiv(n_size, self.block_n),
        ]
        self.attrs.warps = self.num_warps

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
        sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
        sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])
        acc = self.register_tensor(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        for offset_k in range(0, k_size, self.block_k):
            lda = self.load_global(
                ga,
                offsets=[offset_m, offset_k],
                shape=[self.block_m, self.block_k],
            )
            self.store_shared(sa, lda)
            ldb = self.load_global(
                gb,
                offsets=[offset_k, offset_n],
                shape=[self.block_k, self.block_n],
            )
            self.store_shared(sb, ldb)
            self.sync()

            a = self.load_shared(sa)
            b = self.load_shared(sb)
            acc = self.dot(a, b, acc)
            self.sync()

        self.free_shared(sa)
        self.free_shared(sb)

        casted_acc = self.cast(acc, dtype=float16)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])

The three decorators create a space of \(2 \times 3 \times 2 = 12\) configurations. Tilus compiles all twelve, benchmarks them, and keeps the fastest.

Launch the Kernel

Notice that MatmulV2() is instantiated with no arguments – the tuned parameters are determined automatically.

def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [
        [4096, 4096, 4096],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV2()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        torch.cuda.synchronize()

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)

The first call to matmul(m, n, k, a, b, c_actual) triggers the autotuning process: every configuration is compiled and benchmarked on the given arguments. The best configuration is then cached, so subsequent invocations skip tuning entirely.

Note

The full source code for this example can be found at matmul_v2.py.