mse
Calibrator that returns the MSE amax of all collected tensors.
Classes
Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x. |
- class MseCalibrator
Bases:
_CalibratorPer-tensor and per-channel MSE amax search that minimizes error between x and quantized x.
- __init__(amax, axis=None, step_size=0.1, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)
Initialize MSE calibrator.
- Parameters:
amax (Tensor) – Initial amax value (required).
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
step_size (float) – Step size for amax search. The number of steps is computed as ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier (float) – Starting multiplier for amax search.
stop_multiplier (float) – Ending multiplier for amax search.
quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.
error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).
- clear()
Clear all cached data to free GPU memory.
Call this after compute_amax() and load_calib_amax() are done.
- collect(x)
Collect input tensor statistics and compute losses for MSE calibration.
- Parameters:
x (Tensor) – Input tensor.
- compute_amax(verbose=False)
Return the amax value that minimizes quantization error.
- Parameters:
verbose (bool) – If True, print the ratio of best_amax to initial_amax.
- reset()
Reset the stored losses and amax value.