mse
Calibrator that returns the MSE amax of all collected tensors.
Classes
Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x. |
|
Per-block FP8 scale sweep calibrator for NVFP4 static quantization. |
- class MseCalibrator
Bases:
_CalibratorPer-tensor and per-channel MSE amax search that minimizes error between x and quantized x.
- __init__(amax, axis=None, step_size=0.1, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)
Initialize MSE calibrator.
- Parameters:
amax (Tensor) – Initial amax value (required).
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
step_size (float) – Step size for amax search. The number of steps is computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier (float) – Starting multiplier for amax search.
stop_multiplier (float) – Ending multiplier for amax search.
quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.
error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).
- collect(x)
Collect input tensor statistics and compute losses for MSE calibration.
- Parameters:
x (Tensor) – Input tensor.
- compute_amax(verbose=False)
Return the amax value that minimizes quantization error.
- Parameters:
verbose (bool) – If True, print the ratio of best_amax to initial_amax.
- reset()
Reset the stored losses and amax value.
- class NVFP4MSECalibrator
Bases:
MseCalibratorPer-block FP8 scale sweep calibrator for NVFP4 static quantization.
collectdispatches to one of two fused Triton fast paths, else the reference 126-step Python sweep. Both fast paths require the input on CUDA in the blocked[n_blocks, block_size]layout with Triton + the kernel package importable:Hessian-weighted (local_hessian): taken when
hessian is not None— minimizesdwᵀ H dw. Wins over the plain path, so it fires even whenerror_funcis also set.plain squared-error: taken when
hessian is None and error_func is None.
Otherwise (CPU, non-blocked layout, Triton unavailable, or an
error_funcwith nohessian) it runs the reference sweep, usingerror_funcas the metric when set. The final amax is cached immediately, so this calibrator is one-shot between resets.- __init__(amax, global_amax, axis=None, quant_func=None, error_func=None, hessian=None)
Initialize NVFP4 MSE calibrator with per-block and global amax.
hessian(per-cin-block[cin // block_size, block_size, block_size]) enables the Hessian-weighted Triton fast path (local_hessian);error_funccarries the same metric for the reference fallback when the fast path is unavailable.- Parameters:
amax (Tensor)
global_amax (Tensor)
axis (int | tuple | list | None)
quant_func (Callable | None)
error_func (Callable | None)
hessian (Tensor | None)
- collect(x)
Collect input statistics and cache the final per-block amax.
- Parameters:
x (Tensor)
- compute_amax(verbose=False)
Return the cached per-block amax.
- Parameters:
verbose (bool)
- reset()
Reset per-cycle state. Keep
_initial_amaxso the calibrator stays reusable.MseCalibrator.reset()intentionally drops_initial_amaxto free memory in the multi-step search, but the NVFP4 per-block amax is shape[num_blocks]— small enough to keep so a follow-upcollect()can run again on the same calibrator instance.