fp4_kernel

NVFP4 Fake Quantization Triton Implementation.

This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.

Functions

fp4_dequantize

Dequantizes FP4 packed tensor using per-block scaling factors.

static_blockwise_fp4_fake_quant

Static blockwise FP4 fake quantization using Triton kernel.

fp4_dequantize(packed_tensor, scale_tensor, global_scale, block_size=16, tile_size=128, dtype=torch.float32)

Dequantizes FP4 packed tensor using per-block scaling factors.

Parameters:
  • packed_tensor (torch.Tensor) – Packed uint8 tensor of shape (M, N//2)

  • scale_tensor (torch.Tensor) – Per-block scale tensor of shape (M, N//block_size)

  • global_scale (torch.Tensor) – Global scaling factor tensor

  • block_size (int) – Size of FP4 quantization blocks

  • tile_size (int) – Size of processing tiles

  • dtype (dtype)

Returns:

Dequantized tensor of shape (M, N)

Return type:

torch.Tensor

static_blockwise_fp4_fake_quant(x, amax, global_amax=None, quantize_block_scales=True, out_dtype=None)

Static blockwise FP4 fake quantization using Triton kernel.

Parameters:
  • x (Tensor) – [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.

  • amax (Tensor) – [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values.

  • global_amax (Tensor | None) – FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax.

  • quantize_block_scales (bool) – If True, quantize block scales to FP8.

  • out_dtype (dtype | None) – Output dtype. Defaults to x.dtype if None.