mse
Calibrator that returns the MSE amax of all collected tensors.
Classes
MSE calibrator that sweeps 126 valid FP8 E4M3 candidates of |
|
Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x. |
|
FP8 scale sweep calibrator for NVFP4 per-block static quantization. |
- class FP8ScaleSweepCalibrator
Bases:
MseCalibratorMSE calibrator that sweeps 126 valid FP8 E4M3 candidates of
initial_amax.Candidate amax values are
initial_amax * candidate
- class MseCalibrator
Bases:
_CalibratorPer-tensor and per-channel MSE amax search that minimizes error between x and quantized x.
- __init__(amax, axis=None, step_size=0.1, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)
Initialize MSE calibrator.
- Parameters:
amax (Tensor) – Initial amax value (required).
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
step_size (float) – Step size for amax search. The number of steps is computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier (float) – Starting multiplier for amax search.
stop_multiplier (float) – Ending multiplier for amax search.
quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.
error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).
- collect(x)
Collect input tensor statistics and compute losses for MSE calibration.
- Parameters:
x (Tensor) – Input tensor.
- compute_amax(verbose=False)
Return the amax value that minimizes quantization error.
- Parameters:
verbose (bool) – If True, print the ratio of best_amax to initial_amax.
- reset()
Reset the stored losses and amax value.
- class NVFP4MSECalibrator
Bases:
FP8ScaleSweepCalibratorFP8 scale sweep calibrator for NVFP4 per-block static quantization.
Extends
FP8ScaleSweepCalibratorwith aglobal_amaxthat drives the candidate amax computation: each candidate scalesglobal_amaxuniformly across all blocks.- __init__(amax, global_amax, axis=None, quant_func=None, error_func=None)
Initialize NVFP4 calibrator.
- Parameters:
amax (Tensor) – Per-block amax tensor (shape
[num_blocks]).global_amax (Tensor) – Scalar global amax used to scale all FP8 candidates.
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
quant_func (Callable | None) – Function that quantizes input tensor given an amax value.
error_func (Callable | None) – Function to compute error between x and xq.