mse
Calibrator that returns the MSE amax of all collected tensors.
Classes
Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x. |
|
Per-block FP8 scale sweep calibrator for NVFP4 static quantization. |
- class MseCalibrator
Bases:
_CalibratorPer-tensor and per-channel MSE amax search that minimizes error between x and quantized x.
- __init__(amax, axis=None, step_size=0.1, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)
Initialize MSE calibrator.
- Parameters:
amax (Tensor) – Initial amax value (required).
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
step_size (float) – Step size for amax search. The number of steps is computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier (float) – Starting multiplier for amax search.
stop_multiplier (float) – Ending multiplier for amax search.
quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.
error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).
- collect(x)
Collect input tensor statistics and compute losses for MSE calibration.
- Parameters:
x (Tensor) – Input tensor.
- compute_amax(verbose=False)
Return the amax value that minimizes quantization error.
- Parameters:
verbose (bool) – If True, print the ratio of best_amax to initial_amax.
- reset()
Reset the stored losses and amax value.
- class NVFP4MSECalibrator
Bases:
MseCalibratorPer-block FP8 scale sweep calibrator for NVFP4 static quantization.
Uses a fused Triton kernel as an internal fast path on the first
collectcall when (a)error_func is None, (b) the input tensor is on CUDA in the standard blocked[n_blocks, block_size]layout, and (c) Triton + the kernel package are importable. Falls back to the reference 126-step Python sweep otherwise (customerror_funcusers, multi-collectactivation flows, CPU inputs, or when the fast path is disabled viaMODELOPT_NVFP4_TRITON_SWEEP=0).- __init__(amax, global_amax, axis=None, quant_func=None, error_func=None)
Initialize NVFP4 MSE calibrator with per-block and global amax.
- Parameters:
amax (Tensor)
global_amax (Tensor)
axis (int | tuple | list | None)
quant_func (Callable | None)
error_func (Callable | None)
- collect(x)
Collect input statistics. Uses the Triton fast path when eligible.
- Parameters:
x (Tensor)
- compute_amax(verbose=False)
Return the per-block amax — from the fast path if it ran, else from the reference sweep.
- Parameters:
verbose (bool)
- reset()
Reset per-cycle state. Keep
_initial_amaxso the calibrator stays reusable.MseCalibrator.reset()intentionally drops_initial_amaxto free memory in the multi-step search, but the NVFP4 per-block amax is shape[num_blocks]— small enough to keep so a follow-upcollect()can run again on the same calibrator instance.