fp4_kernel
NVFP4 Fake Quantization Triton Implementation.
This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.
Functions
Dequantizes FP4 packed tensor using per-block scaling factors. |
|
Static blockwise FP4 fake quantization using Triton kernel. |
- fp4_dequantize(packed_tensor, scale_tensor, global_scale, block_size=16, tile_size=128, dtype=torch.float32)
Dequantizes FP4 packed tensor using per-block scaling factors.
- Parameters:
packed_tensor (torch.Tensor) – Packed uint8 tensor of shape (M, N//2)
scale_tensor (torch.Tensor) – Per-block scale tensor of shape (M, N//block_size)
global_scale (torch.Tensor) – Global scaling factor tensor
block_size (int) – Size of FP4 quantization blocks
tile_size (int) – Size of processing tiles
dtype (dtype)
- Returns:
Dequantized tensor of shape (M, N)
- Return type:
torch.Tensor
- static_blockwise_fp4_fake_quant(x, amax, global_amax=None, quantize_block_scales=True, out_dtype=None)
Static blockwise FP4 fake quantization using Triton kernel.
- Parameters:
x (Tensor) – [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
amax (Tensor) – [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values.
global_amax (Tensor | None) – FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax.
quantize_block_scales (bool) – If True, quantize block scales to FP8.
out_dtype (dtype | None) – Output dtype. Defaults to x.dtype if None.