nvfp4_fp8_sweep
Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block best_amax directly.
The 126 candidates are constructed as valid_fp8_e4m3_value / 448 (see
fp8_scale_candidates()). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
scale = candidate * global_amax / 6.0 without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no tl.float8e4nv requirement).
Tile shape (BLOCKS_PER_PROGRAM) and num_warps are autotuned per N_BLOCKS.
Functions
Return the 126 valid finite positive FP8 E4M3 scale candidates / 448. |
|
Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. |
|
Find the per-block FP8 scale minimizing the Hessian-weighted NVFP4 quant error. |
- fp8_scale_candidates(device='cpu')
Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.
- Parameters:
device (device | str)
- Return type:
Tensor
- nvfp4_fp8_scale_sweep(x, global_amax, block_size=16)
Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.
Equivalent to the 126-step sweep in
NVFP4MSECalibrator, but fused into a single Triton kernel: every block’s weight elements are loaded once, all 126 candidates are evaluated in registers, and the running argmin is kept inline.- Parameters:
x (Tensor) – Weight tensor on CUDA. Total element count must be divisible by
block_size; layout is treated as a flat[N_BLOCKS, BLOCK_SIZE].global_amax (Tensor) – Scalar FP32 global amax (
= reduce_amax(per_block_amax)).block_size (int) – NVFP4 block size (typically 16).
- Returns:
best_amaxof shape[N_BLOCKS], fp32, on the same device asx.- Return type:
Tensor
- nvfp4_fp8_scale_sweep_hessian(x, global_amax, hessian, block_size=16)
Find the per-block FP8 scale minimizing the Hessian-weighted NVFP4 quant error.
Hessian-weighted counterpart of
nvfp4_fp8_scale_sweep(): for each NVFP4 block it minimizesdwᵀ H dw(dw = w - quant(w)) over the 126 FP8 E4M3 candidates, whereHis the per-cin-block local Hessian shared across all output rows. Used byNVFP4MSECalibratorforlocal_hessiancalibration.- Parameters:
x (Tensor) – Weight tensor on CUDA in the blocked
[N_BLOCKS, block_size]layout, row-major over(cout, cin // block_size)so flat blockbhas cin-blockb % (cin // block_size).global_amax (Tensor) – Scalar FP32 global amax (
= reduce_amax(per_block_amax)).hessian (Tensor) – Per-cin-block Hessian of shape
[cin // block_size, block_size, block_size], fp32 (typically normalized by sample count).block_size (int) – NVFP4 block size (typically 16).
- Returns:
best_amaxof shape[N_BLOCKS], fp32, on the same device asx.- Return type:
Tensor