implicit_gemm_cuda
Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization.
CUDA kernel source: implicit_gemm_kernel.cu C++ binding: implicit_gemm_binding.cpp
Functions
Conv3D via implicit GEMM with BF16 WMMA tensor cores. |
|
Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel. |
- conv3d_implicit_gemm_cuda(x, w, bias=None, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), act_amax=None, quant_act=False, fp4_block_size=256)
Conv3D via implicit GEMM with BF16 WMMA tensor cores.
- Parameters:
x (Tensor) – Input tensor [N, Cin, D, H, W]
w (Tensor) – Weight tensor [Cout, Cin, kD, kH, kW]
bias (Tensor | None) – Optional bias tensor [Cout]
stride (tuple[int, int, int]) – Convolution stride (D, H, W)
padding (tuple[int, int, int]) – Convolution padding (D, H, W)
dilation (tuple[int, int, int]) – Convolution dilation (D, H, W)
act_amax (Tensor | None) – Activation max value for FP4 quantization
quant_act (bool) – Whether to apply FP4 quantization to activations
fp4_block_size (int) – FP4 quantization block size (16, 32, 64, 128, or 256)
- Returns:
Output tensor [N, Cout, OD, OH, OW]
- Raises:
ValueError – If fp4_block_size is not one of {16, 32, 64, 128, 256}.
- Return type:
Tensor
- fp4_fake_quant(x, global_amax, block_size=16)
Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel.
Applies blockwise FP4 (E2M1) quantize-dequantize with FP8 E4M3 scale quantization.
- Parameters:
x (Tensor) – Input tensor (any shape, numel must be divisible by block_size).
global_amax (Tensor) – Scalar tensor — global abs max for scale computation.
block_size (int) – Number of elements per FP4 quantization block.
- Returns:
Fake-quantized tensor with same shape and dtype as input.
- Return type:
Tensor