2. Auto-tuning¶
In previous versions of the matmul kernel, we manually set the hyperparameters
such as block_m, block_n, and block_k. However, these
hyperparameters can significantly affect the performance of the kernel, and
finding the optimal values for them can be a tedious and time-consuming process.
Tilus provides the tilus.autotune() decorator to annotate the search space
of the hyperparameters and let tilus automatically search for the best
configuration.
The decorator accepts parameter names and a list of values. When multiple
@tilus.autotune decorators are stacked, tilus forms the Cartesian product of
all value lists and tries every combination. At the first invocation the kernel
is compiled for each configuration, benchmarked on the actual arguments, and the
fastest configuration is selected automatically. Subsequent calls reuse the
winner.
@tilus.autotune("arg_name1", [v11, v12, v13])
@tilus.autotune("arg_name2, arg_name3", [(v21, v31), (v22, v32)])
class AwesomeKernel(tilus.Script):
def __init__(self, user_arg, arg_name1, arg_name2, arg_name3):
super().__init__()
...
When instantiating the class, only the non-tuned arguments are provided – the tuned parameters are filled in automatically by the autotuning engine.
Imports¶
import tilus
from tilus import float16, float32, int32
from tilus.utils import cdiv
Annotate the Search Space¶
Reusing the same kernel implementation as in V1, we add
tilus.autotune() decorators for num_warps, block_m/block_n,
and block_k:
@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128)])
@tilus.autotune("block_k", [16, 32])
class MatmulV2(tilus.Script):
def __init__(
self,
num_warps,
block_m,
block_n,
block_k,
):
super().__init__()
self.num_warps = num_warps
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [
cdiv(m_size, self.block_m),
cdiv(n_size, self.block_n),
]
self.attrs.warps = self.num_warps
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y
ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])
acc = self.register_tensor(
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
)
for offset_k in range(0, k_size, self.block_k):
lda = self.load_global(
ga,
offsets=[offset_m, offset_k],
shape=[self.block_m, self.block_k],
)
self.store_shared(sa, lda)
ldb = self.load_global(
gb,
offsets=[offset_k, offset_n],
shape=[self.block_k, self.block_n],
)
self.store_shared(sb, ldb)
self.sync()
a = self.load_shared(sa)
b = self.load_shared(sb)
acc = self.dot(a, b, acc)
self.sync()
self.free_shared(sa)
self.free_shared(sb)
casted_acc = self.cast(acc, dtype=float16)
gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])
The three decorators create a space of \(2 \times 3 \times 2 = 12\) configurations. Tilus compiles all twelve, benchmarks them, and keeps the fastest.
Launch the Kernel¶
Notice that MatmulV2() is instantiated with no arguments – the
tuned parameters are determined automatically.
def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [
[4096, 4096, 4096],
]
rows = []
for m, n, k in workloads:
matmul = MatmulV2()
a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
c_expect = a @ b
matmul(m, n, k, a, b, c_actual)
torch.cuda.synchronize()
# check correctness
torch.testing.assert_close(c_expect, c_actual)
# benchmark
for name, func in [
("torch", lambda: torch.matmul(a, b, out=c_expect)),
("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m * n * k / latency * 1e-9
rows.append([m, n, k, name, latency, tflops])
df = pandas.DataFrame(rows, columns=headers)
print(df)
The first call to matmul(m, n, k, a, b, c_actual) triggers the
autotuning process: every configuration is compiled and benchmarked on the
given arguments. The best configuration is then cached, so subsequent
invocations skip tuning entirely.
Note
The full source code for this example can be found at
matmul_v2.py.