fp4_kernel_hopper
NVFP4 Fake Quantization Triton kernels requiring compute capability >= 8.9 (Hopper+).
These kernels use tl.float8e4nv which requires native FP8 hardware support.
Functions
FP4 fake quantization implementation using block-pointer tiling. |
- fp4_fake_quant_block(x, global_amax, block_size=16, tile_rows=16, tile_cols=64, num_warps=None, num_stages=None)
FP4 fake quantization implementation using block-pointer tiling.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(M, N)or higher.global_amax (torch.Tensor) – Global maximum value tensor for scaling.
block_size (int) – Number of elements per FP4 block.
tile_rows (int, optional) – Row tile size. Defaults to 16.
tile_cols (int, optional) – Column tile size. Defaults to 64. Rounded up to the nearest multiple of
block_sizeinternally.num_warps (int | None, optional) – Override for Triton warps. Autotuned when
None.num_stages (int | None, optional) – Override for pipeline stages. Autotuned when
None.
- Returns:
Fake-quantized tensor matching the input shape and dtype.
- Return type:
torch.Tensor