fp4_kernel

NVFP4 Fake Quantization Triton Implementation.

This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.

Functions

compute_fp4_scales

Compute per-block FP4 scales from amax values.

fp4_dequantize

Dequantizes FP4 packed tensor using per-block scaling factors.

static_blockwise_fp4_fake_quant

Static blockwise FP4 fake quantization using Triton kernel.

compute_fp4_scales(amax, global_amax=None, quantize_block_scales=True)

Compute per-block FP4 scales from amax values.

scale = amax / 6.0, optionally quantized to FP8 E4M3.

Parameters:
  • amax (Tensor) – Per-block amax values (any shape).

  • global_amax (Tensor | None) – Global amax for FP8 two-level scaling. Computed from amax if None.

  • quantize_block_scales (bool) – If True, quantize scales to FP8 E4M3.

Returns:

Per-block scales (same shape as amax), float32.

Return type:

Tensor

fp4_dequantize(packed_tensor, scale_tensor, global_scale, block_size=16, tile_size=128, dtype=torch.float32)

Dequantizes FP4 packed tensor using per-block scaling factors.

Parameters:
  • packed_tensor (torch.Tensor) – Packed uint8 tensor of shape (M, N//2)

  • scale_tensor (torch.Tensor) – Per-block scale tensor of shape (M, N//block_size)

  • global_scale (torch.Tensor) – Global scaling factor tensor

  • block_size (int) – Size of FP4 quantization blocks

  • tile_size (int) – Size of processing tiles

  • dtype (dtype)

Returns:

Dequantized tensor of shape (M, N)

Return type:

torch.Tensor

static_blockwise_fp4_fake_quant(x, amax, global_amax=None, quantize_block_scales=True, out_dtype=None)

Static blockwise FP4 fake quantization using Triton kernel.

Parameters:
  • x (Tensor) – Input tensor on CUDA. The last dim must be the block dim (each consecutive BLOCK_SIZE elements form one FP4 block). Any number of leading dims is supported — they’re flattened internally and the shape is restored on output (so MoE expert weights (E, F, K) work the same as plain linear weights (N, K)).

  • amax (Tensor) – Per-block amax values. amax.numel() must equal x.numel() // BLOCK_SIZE. Shape is otherwise free; the kernel consumes it as a flat 1-D buffer of length NUM_FP4_BLOCKS.

  • global_amax (Tensor | None) – FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax.

  • quantize_block_scales (bool) – If True, quantize block scales to FP8.

  • out_dtype (dtype | None) – Output dtype. Defaults to x.dtype if None.