0. Naive Matmul¶
This tutorial demonstrates a simple implementation of matrix multiplication using Tilus. We use this example to illustrate the basic concepts of writing a kernel in Tilus, including kernel definition, data types, tensors, and kernel invocation.
Tilus Script¶
In Tilus, every kernel is defined by subclassing tilus.Script. A script
has two methods that you must implement:
__init__– initializes the compilation-time hyperparameters of the script (tile sizes, pipeline depths, etc.).__call__– the main entry point that describes the computation logic of the kernel.
The skeleton looks like this:
class TilusScriptKernel(tilus.Script):
def __init__(
self,
# compilation-time known hyperparameters
):
super().__init__()
... # process the hyperparameters
def __call__(
self,
# kernel parameters
): ... # define the computation logic of the kernel
Naive Matmul Implementation¶
With the script concept in mind, let us implement a naive matrix multiplication kernel. This implementation is not optimized for performance, but it serves as a good starting point to understand the basics of Tilus.
We begin with the necessary imports:
from tilus import float16, float32, int32
from tilus.utils import cdiv
The full kernel class is shown below. It tiles the output matrix into blocks of
size block_m x block_n and iterates over the K dimension in chunks of
block_k. Each thread block computes one output tile by accumulating partial
dot products in a register tensor.
class MatmulV0(tilus.Script):
def __init__(self):
super().__init__()
# we define three hyperparameters: ``block_m``, ``block_n``, and ``block_k`` to determine the tile size on
# m, n, and k dimensions for each `thread block` of the kernel.
self.block_m = 64
self.block_n = 64
self.block_k = 16
def __call__(
self,
m_size: int32, # the size of the m dimension of the input matrix A and output matrix C
n_size: int, # the size of the n dimension of the input matrix B and output matrix C
k_size: int, # the size of the k dimension of the input matrix A and B
a_ptr: ~float16, # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size]
b_ptr: ~float16, # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size]
c_ptr: ~float16, # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size]
):
self.attrs.blocks = [
cdiv(m_size, self.block_m), # the x dimension size of the grid
cdiv(n_size, self.block_n), # the y dimension size of the grid
]
self.attrs.warps = 1 # the number of warps per thread block, must be a compile-time known integer
# define two int32 variables to store the offsets of the m and n dimensions for the current thread block.
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y
# create two global tensors `ga` and `gb` to represent the input matrices A and B, respectively.
ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
# create a register tensor `acc` to accumulate the results of the matrix multiplication.
acc = self.register_tensor(
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
)
# iterate over the k dimension in blocks of size `block_k`.
for k in range(cdiv(k_size, self.block_k)):
# calculate the offset for the current block in the k dimension
offset_k = k * self.block_k
# load a block of matrix A and B into register tensors `a` and `b`.
a = self.load_global(
ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k]
)
b = self.load_global(
gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n]
)
# perform the dot product: acc = a @ b + acc
self.dot(a, b, acc, out=acc)
# after the loop, we cast the accumulated result `acc` to float16 type and store it back to the output matrix C.
acc_f16 = self.cast(acc, dtype=float16)
gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(gc, acc_f16, offsets=[offset_m, offset_n])
Type Annotations and Instructions¶
The __call__ signature uses three kinds of type annotations for the kernel
parameters:
int32– a runtime-known 32-bit integer. In the example this is used form_size.int– a compile-time-known integer. Different values trigger Just-In-Time (JIT) re-compilation of the kernel. In the example this is used forn_sizeandk_size.~float16– a pointer to afloat16array (equivalent tofloat16*in C/C++). In the example this is used fora_ptr,b_ptr, andc_ptr.
The kernel body uses the following Tilus instructions:
global_view()– create a global tensor view of the input/output matrices.register_tensor()– allocate a register tensor to accumulate partial results.load_global()– load a tile from a global tensor into a register tensor.dot()– perform a matrix multiply-accumulate on two register tensors.cast()– cast a register tensor to a different data type.store_global()– store a register tensor back to a global tensor.
All of these instructions have block semantics: they are collectively executed by every thread in the thread block.
Launching the Kernel¶
To launch the kernel, create an instance of MatmulV0 and call it with the
appropriate arguments. The code below also verifies correctness and benchmarks
the kernel against PyTorch:
import math
import pandas
import torch
from tilus.utils import benchmark_func
def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [[4096, 4096, 4096]]
rows = []
for m, n, k in workloads:
# create an instance of the kernel we have just defined
matmul = MatmulV0()
a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
c_expect = a @ b
torch.cuda.synchronize()
# launch the kernel by passing required arguments
matmul(m, n, k, a, b, c_actual)
torch.cuda.synchronize()
# check correctness
torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2)
# benchmark
for name, func in [
("torch", lambda: torch.matmul(a, b, out=c_expect)),
("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m * n * k / latency * 1e-9
rows.append([m, n, k, name, latency, tflops])
df = pandas.DataFrame(rows, columns=headers)
print(df)
The kernel is invoked just like a regular Python function – Tilus handles grid
configuration and kernel dispatch behind the scenes. The call
matmul(m, n, k, a, b, c_actual) launches the GPU kernel with the specified
parameters, and the results are checked for correctness using
torch.testing.assert_close.
The output is a pandas.DataFrame that contains the latency and throughput
(TFLOPS) of both the Tilus kernel and the PyTorch (cuBLAS) baseline. This naive
kernel is not yet competitive with vendor libraries, but it establishes the
foundation we build on. In subsequent versions we will introduce optimizations
– shared memory tiling, software pipelining, TMA loads, and more – that bring
performance up to cuBLAS levels.
Full Source¶
The complete example file is located at
examples/matmul/matmul_v0.py.