fp4_kernel

NVFP4 Fake Quantization Triton Implementation.

This module provides high-performance GPU implementations of NVFP4 fake quantization operations using Triton kernels.

Functions

fp4_fake_quant_block

FP4 fake quantization implementation using block-pointer tiling.

fp4_fake_quant_block(x, global_amax, block_size=16, tile_rows=16, tile_cols=64, num_warps=None, num_stages=None)

FP4 fake quantization implementation using block-pointer tiling.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (M, N) or higher.

  • global_amax (torch.Tensor) – Global maximum value tensor for scaling.

  • block_size (int) – Number of elements per FP4 block.

  • tile_rows (int, optional) – Row tile size. Defaults to 64.

  • tile_cols (int, optional) – Column tile size. Defaults to 128. Rounded up to the nearest multiple of block_size internally.

  • num_warps (int | None, optional) – Override for Triton warps. Autotuned when None.

  • num_stages (int | None, optional) – Override for pipeline stages. Autotuned when None.

Returns:

Fake-quantized tensor matching the input shape and dtype.

Return type:

torch.Tensor