implicit_gemm_cuda

Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization.

CUDA kernel source: implicit_gemm_kernel.cu C++ binding: implicit_gemm_binding.cpp

Functions

conv3d_implicit_gemm_cuda

Conv3D via implicit GEMM with BF16 WMMA tensor cores.

fp4_fake_quant

Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel.

conv3d_implicit_gemm_cuda(x, w, bias=None, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), act_amax=None, quant_act=False, fp4_block_size=256)

Conv3D via implicit GEMM with BF16 WMMA tensor cores.

Parameters:
  • x (Tensor) – Input tensor [N, Cin, D, H, W]

  • w (Tensor) – Weight tensor [Cout, Cin, kD, kH, kW]

  • bias (Tensor | None) – Optional bias tensor [Cout]

  • stride (tuple[int, int, int]) – Convolution stride (D, H, W)

  • padding (tuple[int, int, int]) – Convolution padding (D, H, W)

  • dilation (tuple[int, int, int]) – Convolution dilation (D, H, W)

  • act_amax (Tensor | None) – Activation max value for FP4 quantization

  • quant_act (bool) – Whether to apply FP4 quantization to activations

  • fp4_block_size (int) – FP4 quantization block size (16, 32, 64, 128, or 256)

Returns:

Output tensor [N, Cout, OD, OH, OW]

Raises:

ValueError – If fp4_block_size is not one of {16, 32, 64, 128, 256}.

Return type:

Tensor

fp4_fake_quant(x, global_amax, block_size=16)

Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel.

Applies blockwise FP4 (E2M1) quantize-dequantize with FP8 E4M3 scale quantization.

Parameters:
  • x (Tensor) – Input tensor (any shape, numel must be divisible by block_size).

  • global_amax (Tensor) – Scalar tensor — global abs max for scale computation.

  • block_size (int) – Number of elements per FP4 quantization block.

Returns:

Fake-quantized tensor with same shape and dtype as input.

Return type:

Tensor