fp4_kernel

NVFP4 Fake Quantization Triton Implementation.

This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.

Functions

fp4_fake_quant_block

FP4 fake quantization implementation using block-pointer tiling.

static_blockwise_fp4_fake_quant

Static blockwise FP4 fake quantization using Triton kernel.

fp4_fake_quant_block(x, global_amax, block_size=16, tile_rows=16, tile_cols=64, num_warps=None, num_stages=None)

FP4 fake quantization implementation using block-pointer tiling.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (M, N) or higher.

  • global_amax (torch.Tensor) – Global maximum value tensor for scaling.

  • block_size (int) – Number of elements per FP4 block.

  • tile_rows (int, optional) – Row tile size. Defaults to 64.

  • tile_cols (int, optional) – Column tile size. Defaults to 128. Rounded up to the nearest multiple of block_size internally.

  • num_warps (int | None, optional) – Override for Triton warps. Autotuned when None.

  • num_stages (int | None, optional) – Override for pipeline stages. Autotuned when None.

Returns:

Fake-quantized tensor matching the input shape and dtype.

Return type:

torch.Tensor

static_blockwise_fp4_fake_quant(x, scale, scale_fp8_quant_amax=None, skip_scale_quant=False, out_dtype=None)

Static blockwise FP4 fake quantization using Triton kernel.

Parameters:
  • x (Tensor) – [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.

  • scale (Tensor) – [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.

  • scale_fp8_quant_amax (Tensor | None) – Absolute max range for FP8 quantization of scale. If None, computed from scale.

  • skip_scale_quant (bool) – If True, skip FP8 quantization of scale.

  • out_dtype (dtype | None) – Output dtype. Defaults to x.dtype if None.