nvfp4_fp8_sweep

Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates and emits the per-block best_amax directly.

The 126 candidates are constructed as valid_fp8_e4m3_value / 448 (see fp8_scale_candidates()). For these specific candidates, the FP8 round-trip on the per-block scale is the identity, so the kernel can use scale = candidate * global_amax / 6.0 without an explicit FP8 cast — making it runnable on any CUDA GPU with Triton (no tl.float8e4nv requirement).

Tile shape (BLOCKS_PER_PROGRAM) and num_warps are autotuned per N_BLOCKS.

Functions

fp8_scale_candidates

Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.

nvfp4_fp8_scale_sweep

Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

nvfp4_fp8_scale_sweep_hessian

Find the per-block FP8 scale minimizing the Hessian-weighted NVFP4 quant error.

fp8_scale_candidates(device='cpu')

Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.

Parameters:

device (device | str)

Return type:

Tensor

nvfp4_fp8_scale_sweep(x, global_amax, block_size=16)

Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in NVFP4MSECalibrator, but fused into a single Triton kernel: every block’s weight elements are loaded once, all 126 candidates are evaluated in registers, and the running argmin is kept inline.

Parameters:
  • x (Tensor) – Weight tensor on CUDA. Total element count must be divisible by block_size; layout is treated as a flat [N_BLOCKS, BLOCK_SIZE].

  • global_amax (Tensor) – Scalar FP32 global amax (= reduce_amax(per_block_amax)).

  • block_size (int) – NVFP4 block size (typically 16).

Returns:

best_amax of shape [N_BLOCKS], fp32, on the same device as x.

Return type:

Tensor

nvfp4_fp8_scale_sweep_hessian(x, global_amax, hessian, block_size=16)

Find the per-block FP8 scale minimizing the Hessian-weighted NVFP4 quant error.

Hessian-weighted counterpart of nvfp4_fp8_scale_sweep(): for each NVFP4 block it minimizes dwᵀ H dw (dw = w - quant(w)) over the 126 FP8 E4M3 candidates, where H is the per-cin-block local Hessian shared across all output rows. Used by NVFP4MSECalibrator for local_hessian calibration.

Parameters:
  • x (Tensor) – Weight tensor on CUDA in the blocked [N_BLOCKS, block_size] layout, row-major over (cout, cin // block_size) so flat block b has cin-block b % (cin // block_size).

  • global_amax (Tensor) – Scalar FP32 global amax (= reduce_amax(per_block_amax)).

  • hessian (Tensor) – Per-cin-block Hessian of shape [cin // block_size, block_size, block_size], fp32 (typically normalized by sample count).

  • block_size (int) – NVFP4 block size (typically 16).

Returns:

best_amax of shape [N_BLOCKS], fp32, on the same device as x.

Return type:

Tensor