nvfp4_fp8_sweep

Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates and emits the per-block best_amax directly.

The 126 candidates are constructed as valid_fp8_e4m3_value / 448 (see fp8_scale_candidates()). For these specific candidates, the FP8 round-trip on the per-block scale is the identity, so the kernel can use scale = candidate * global_amax / 6.0 without an explicit FP8 cast — making it runnable on any CUDA GPU with Triton (no tl.float8e4nv requirement).

Tile shape (BLOCKS_PER_PROGRAM) and num_warps are autotuned per N_BLOCKS.

Functions

fp8_scale_candidates

Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.

nvfp4_fp8_scale_sweep

Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

fp8_scale_candidates(device='cpu')

Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.

Parameters:

device (device | str)

Return type:

Tensor

nvfp4_fp8_scale_sweep(x, global_amax, block_size=16, candidates=None)

Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in NVFP4MSECalibrator, but fused into a single Triton kernel: every block’s weight elements are loaded once, all 126 candidates are evaluated in registers, and the running argmin is kept inline.

Parameters:
  • x (Tensor) – Weight tensor on CUDA. Total element count must be divisible by block_size; layout is treated as a flat [N_BLOCKS, BLOCK_SIZE].

  • global_amax (Tensor) – Scalar FP32 global amax (= reduce_amax(per_block_amax)).

  • block_size (int) – NVFP4 block size (typically 16).

  • candidates (Tensor | None) – Optional precomputed candidate tensor of shape [126] (must be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.

Returns:

best_amax of shape [N_BLOCKS], fp32, on the same device as x.

Return type:

Tensor