mse
Calibrator that returns the MSE amax of all collected tensors.
Classes
Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x. |
- class MseCalibrator
Bases:
_CalibratorPer-tensor and per-channel MSE amax search that minimizes error between x and quantized x.
- __init__(amax, axis=None, num_steps=10, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)
Initialize MSE calibrator.
- Parameters:
amax (Tensor) – Initial amax value (required).
axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.
num_steps (int) – Number of amax candidates to try.
start_multiplier (float) – Starting multiplier for amax search.
stop_multiplier (float) – Ending multiplier for amax search.
quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.
error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).
- collect(x)
Collect input tensor statistics and compute losses for MSE calibration.
- Parameters:
x (Tensor) – Input tensor.
- compute_amax(verbose=False)
Return the amax value that minimizes quantization error.
- Parameters:
verbose (bool) – If True, print the ratio of best_amax to initial_amax.
- reset()
Reset the stored losses and amax value.