fp4_kernel
NVFP4 Fake Quantization Triton Implementation.
This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.
Functions
FP4 fake quantization implementation using block-pointer tiling. |
- fp4_fake_quant_block(x, global_amax, block_size=16, tile_rows=16, tile_cols=64, num_warps=None, num_stages=None)
FP4 fake quantization implementation using block-pointer tiling.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(M, N)or higher.global_amax (torch.Tensor) – Global maximum value tensor for scaling.
block_size (int) – Number of elements per FP4 block.
tile_rows (int, optional) – Row tile size. Defaults to 64.
tile_cols (int, optional) – Column tile size. Defaults to 128. Rounded up to the nearest multiple of
block_sizeinternally.num_warps (int | None, optional) – Override for Triton warps. Autotuned when
None.num_stages (int | None, optional) – Override for pipeline stages. Autotuned when
None.
- Returns:
Fake-quantized tensor matching the input shape and dtype.
- Return type:
torch.Tensor