Cute Dsl GEMM Runner#
-
class CuteDslGemmRunner#
Runner for CuTe DSL compiled GEMM kernels, replacing cuBLAS for Talker MLP.
Provides FP16 GEMM with FP32 accumulation: C = A @ B^T where A is [M, K], B is [N, K] (row-major), C is [M, N].
Multiple architecture variants are compiled AOT and selected at runtime based on GPU SM version:
Ampere (SM 80/86/87): cp.async + MmaF16BF16Op
Blackwell datacenter (SM 100/101/103/110): tcgen05 + TMA
Blackwell GeForce (SM 120/121): WGMMA + TMA
Public Static Functions
-
static bool canImplement(int32_t smVersion)#
Check if CuTe DSL GEMM can run on this GPU.
- Parameters:
smVersion – GPU SM version (e.g. 87, 100, 121)
- Returns:
true if a compiled GEMM variant exists for this SM
-
static bool loadKernelModule()#
Load the kernel module (thread-safe, idempotent).
-
static void unloadKernelModule()#
Unload the kernel module.
- static bool run(
- void const *aPtr,
- void const *bPtr,
- void *cPtr,
- int32_t M,
- int32_t N,
- int32_t K,
- cudaStream_t stream
Execute GEMM: C[M,N] = A[M,K] @ B[N,K]^T.
All tensors are FP16, row-major. Accumulation is FP32.
- Parameters:
aPtr – Input A [M, K] (FP16, K contiguous)
bPtr – Weight B [N, K] (FP16, K contiguous)
cPtr – Output C [M, N] (FP16, N contiguous)
M – Number of rows in A / C
N – Number of rows in B / columns in C
K – Inner dimension
stream – CUDA stream
- Returns:
true on success, false if kernel module not loaded or variant unavailable.
- static bool runBiasSiLU(
- void const *aPtr,
- void const *bPtr,
- void *cPtr,
- void const *biasPtr,
- int32_t M,
- int32_t N,
- int32_t K,
- cudaStream_t stream
Execute fused GEMM + bias + SiLU: C = SiLU(A @ B^T + bias)
Uses AOT-compiled fused epilogue kernel on all architectures (Ampere, Blackwell DC, BW GeForce). Falls back to plain GEMM + separate CUDA kernel if the fused variant is not compiled for the current arch.
- Parameters:
biasPtr – Bias vector [N] (FP16)
- Returns:
true on success
- static bool runBias(
- void const *aPtr,
- void const *bPtr,
- void *cPtr,
- void const *biasPtr,
- int32_t M,
- int32_t N,
- int32_t K,
- cudaStream_t stream
Execute fused GEMM + bias: C = A @ B^T + bias.
Uses AOT-compiled fused epilogue kernel on all architectures. Falls back to plain GEMM + separate CUDA kernel if the fused variant is not compiled.
- Parameters:
biasPtr – Bias vector [N] (FP16)
- Returns:
true on success