fp4_kernel
NVFP4 Fake Quantization Triton Implementation.
This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.
Functions
FP4 fake quantization implementation using block-pointer tiling. |
|
Static blockwise FP4 fake quantization using Triton kernel. |
- fp4_fake_quant_block(x, global_amax, block_size=16, tile_rows=16, tile_cols=64, num_warps=None, num_stages=None)
FP4 fake quantization implementation using block-pointer tiling.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(M, N)or higher.global_amax (torch.Tensor) – Global maximum value tensor for scaling.
block_size (int) – Number of elements per FP4 block.
tile_rows (int, optional) – Row tile size. Defaults to 64.
tile_cols (int, optional) – Column tile size. Defaults to 128. Rounded up to the nearest multiple of
block_sizeinternally.num_warps (int | None, optional) – Override for Triton warps. Autotuned when
None.num_stages (int | None, optional) – Override for pipeline stages. Autotuned when
None.
- Returns:
Fake-quantized tensor matching the input shape and dtype.
- Return type:
torch.Tensor
- static_blockwise_fp4_fake_quant(x, scale, scale_fp8_quant_amax=None, skip_scale_quant=False, out_dtype=None)
Static blockwise FP4 fake quantization using Triton kernel.
- Parameters:
x (Tensor) – [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
scale (Tensor) – [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.
scale_fp8_quant_amax (Tensor | None) – Absolute max range for FP8 quantization of scale. If None, computed from scale.
skip_scale_quant (bool) – If True, skip FP8 quantization of scale.
out_dtype (dtype | None) – Output dtype. Defaults to x.dtype if None.