nvfp4_fp8_sweep
Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block best_amax directly.
The 126 candidates are constructed as valid_fp8_e4m3_value / 448 (see
fp8_scale_candidates()). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
scale = candidate * global_amax / 6.0 without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no tl.float8e4nv requirement).
Tile shape (BLOCKS_PER_PROGRAM) and num_warps are autotuned per N_BLOCKS.
Functions
Return the 126 valid finite positive FP8 E4M3 scale candidates / 448. |
|
Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. |
- fp8_scale_candidates(device='cpu')
Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.
- Parameters:
device (device | str)
- Return type:
Tensor
- nvfp4_fp8_scale_sweep(x, global_amax, block_size=16, candidates=None)
Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.
Equivalent to the 126-step sweep in
NVFP4MSECalibrator, but fused into a single Triton kernel: every block’s weight elements are loaded once, all 126 candidates are evaluated in registers, and the running argmin is kept inline.- Parameters:
x (Tensor) – Weight tensor on CUDA. Total element count must be divisible by
block_size; layout is treated as a flat[N_BLOCKS, BLOCK_SIZE].global_amax (Tensor) – Scalar FP32 global amax (
= reduce_amax(per_block_amax)).block_size (int) – NVFP4 block size (typically 16).
candidates (Tensor | None) – Optional precomputed candidate tensor of shape
[126](must be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.
- Returns:
best_amaxof shape[N_BLOCKS], fp32, on the same device asx.- Return type:
Tensor