3. Async Copy

3. Async Copy

On NVIDIA Ampere and newer architectures, hardware support for asynchronous copy from global memory to shared memory was introduced. The key advantage is that data moves directly from global to shared memory without using registers as an intermediate buffer, freeing register resources and reducing latency.

Tilus exposes this feature through two block-level instructions:

  • copy_async() – issues an asynchronous copy from a global tensor to a shared tensor. The call returns immediately; the data transfer happens in the background.

  • copy_async_wait_all() – blocks until all previously issued asynchronous copies have completed, guaranteeing that the data is available in shared memory.

Because copy_async_wait_all() does not synchronize threads within the block, a subsequent sync() call is still necessary before reading the shared memory data.

Kernel Implementation

The kernel below is structurally similar to V2 but replaces the load_global / store_shared pair with copy_async(), eliminating the register intermediate buffer:

@tilus.autotune("num_warps", [4, 8])
@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128)])
@tilus.autotune("block_k", [16, 32])
class MatmulV3(tilus.Script):
    def __init__(
        self,
        num_warps,
        block_m,
        block_n,
        block_k,
    ):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.num_warps = num_warps

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [
            cdiv(m_size, self.block_m),
            cdiv(n_size, self.block_n),
        ]
        self.attrs.warps = self.num_warps

        block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
        offset_m: int32 = block_m * self.blockIdx.x
        offset_n: int32 = block_n * self.blockIdx.y

        ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size])
        sa = self.shared_tensor(dtype=float16, shape=[block_m, block_k])
        sb = self.shared_tensor(dtype=float16, shape=[block_k, block_n])
        acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)

        for offset_k in range(0, k_size, block_k):
            # issue asynchronous copy instructions to load tiles of A and B
            self.copy_async(src=ga, dst=sa, offsets=[offset_m, offset_k])
            self.copy_async(src=gb, dst=sb, offsets=[offset_k, offset_n])

            # wait for all asynchronous copy operations to complete
            self.copy_async_wait_all()

            # synchronize threads in the block to ensure data is available in shared memory
            self.sync()

            a = self.load_shared(sa)
            b = self.load_shared(sb)
            self.dot(a, b, acc, out=acc)
            self.sync()

        self.free_shared(sa)
        self.free_shared(sb)

        casted_acc = self.cast(acc, dtype=float16)
        gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])

Lines 86–90 (within the loop) are the core change compared to V2:

  1. Two copy_async() calls issue background copies for tiles of A and B.

  2. copy_async_wait_all() ensures both copies have landed in shared memory.

  3. sync() synchronizes all threads so every thread sees the updated shared data before the computation begins.

Launch the Kernel

def main():
    headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
    workloads = [
        [4096, 4096, 4096],
    ]

    rows = []
    for m, n, k in workloads:
        matmul = MatmulV3()

        a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
        c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
        c_expect = a @ b
        matmul(m, n, k, a, b, c_actual)

        # check correctness
        torch.testing.assert_close(c_expect, c_actual)

        # benchmark
        for name, func in [
            ("torch", lambda: torch.matmul(a, b, out=c_expect)),
            ("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
        ]:
            latency = benchmark_func(func, warmup=5, repeat=20)
            tflops = 2 * m * n * k / latency * 1e-9
            rows.append([m, n, k, name, latency, tflops])

    df = pandas.DataFrame(rows, columns=headers)
    print(df)

Note

The full source code for this example can be found at matmul_v3.py.