1. Use Shared Memory¶
In the previous tutorial, every thread block loaded its tiles of A and B directly from global memory into registers. Global memory has high latency, so a natural next step is to stage the data through shared memory – a small, fast on-chip scratchpad that is visible to all threads in the same block.
Why shared memory helps¶
Shared memory is much faster than global memory (on the order of 100x lower latency). By first copying a tile from global memory into shared memory, all threads in the block can then read from the shared copy at high bandwidth. This is especially beneficial when the same data is read multiple times by different threads, as is the case in matrix multiplication where every element of a tile participates in multiple multiply-accumulate operations.
The data flow for each iteration of the inner loop becomes:
Load tiles from global memory into register tensors.
Store those register tensors into shared memory (
store_shared).Synchronize all threads so the shared data is fully written.
Load from shared memory back into registers (
load_shared).Compute the dot product.
Synchronize again before the next iteration overwrites shared memory.
Kernel implementation¶
class MatmulV1(tilus.Script):
def __init__(self, num_warps=4, block_m=64, block_n=64, block_k=16):
super().__init__()
self.num_warps = num_warps
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k
def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [
cdiv(m_size, self.block_m),
cdiv(n_size, self.block_n),
]
self.attrs.warps = self.num_warps
offset_m: int32 = self.block_m * self.blockIdx.x
offset_n: int32 = self.block_n * self.blockIdx.y
ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
# allocate shared memory for the tiles for A and B
sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])
acc = self.register_tensor(
dtype=float32, shape=[self.block_m, self.block_n], init=0.0
)
for offset_k in range(0, k_size, self.block_k):
# load a tile of A matrix from global memory to shared memory
lda = self.load_global(
ga,
offsets=[offset_m, offset_k],
shape=[self.block_m, self.block_k],
)
# store the loaded tile in shared memory
self.store_shared(sa, lda)
# load a tile of B matrix from global memory to shared memory
ldb = self.load_global(
gb,
offsets=[offset_k, offset_n],
shape=[self.block_k, self.block_n],
)
# store the loaded tile in shared memory
self.store_shared(sb, ldb)
# synchronize threads to ensure all have stored their data in shared memory
self.sync()
# load the tiles from shared memory to registers
a = self.load_shared(sa)
b = self.load_shared(sb)
acc = self.dot(a, b, acc)
self.sync()
self.free_shared(sa)
self.free_shared(sb)
casted_acc = self.cast(acc, dtype=float16)
gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])
The kernel follows the same overall structure as MatmulV0, with two key
additions: shared memory tiles and explicit synchronization.
Shared memory tiles¶
At the top of the __call__ method we allocate two shared tensors, sa and
sb, to hold the current tiles of A and B respectively:
sa = self.shared_tensor(dtype=float16, shape=[self.block_m, self.block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.block_k, self.block_n])
Inside the loop, data flows through shared memory before reaching the accumulator:
# global -> registers -> shared memory
lda = self.load_global(ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k])
self.store_shared(sa, lda)
# ... same for B ...
self.sync() # ensure all stores to shared memory are visible
# shared memory -> registers
a = self.load_shared(sa)
b = self.load_shared(sb)
acc = self.dot(a, b, acc)
self.sync() # ensure dot is done before next iteration overwrites shared
At the end, both shared tensors are freed so their memory can be reused:
self.free_shared(sa)
self.free_shared(sb)
Why two synchronizations?¶
Both self.sync() calls are necessary because sa and sb are reused
across loop iterations:
The first sync (after
store_shared) guarantees that every thread has finished writing its portion of the tile into shared memory before any thread tries to read from it viaload_shared.The second sync (after
dot) guarantees that the dot product has finished reading from shared memory before the next iteration overwrites it with new data.
Omitting either synchronization leads to a data race.
New instructions¶
This example introduces five new Script methods compared to
MatmulV0:
shared_tensor()– allocate a shared-memory tensor with a given dtype and shape.store_shared()– copy a register tensor into a shared tensor.load_shared()– copy a shared tensor into a new register tensor.free_shared()– release the shared memory so it can be reused. Every allocation must be freed before the kernel ends.sync()– synchronize all threads in the thread block (equivalent to__syncthreads()in CUDA C).
Launching the kernel¶
The launch code is identical in structure to the previous version – create an instance, prepare tensors, and call the script:
def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [
[4096, 4096, 4096],
]
rows = []
for m, n, k in workloads:
matmul = MatmulV1()
a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
c_expect = a @ b
matmul(m, n, k, a, b, c_actual)
torch.cuda.synchronize()
# check correctness
torch.testing.assert_close(c_expect, c_actual)
# benchmark
for name, func in [
("torch", lambda: torch.matmul(a, b, out=c_expect)),
("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m * n * k / latency * 1e-9
rows.append([m, n, k, name, latency, tflops])
df = pandas.DataFrame(rows, columns=headers)
print(df)
Full source¶
The complete example is available at
examples/matmul/matmul_v1.py.