mse

Calibrator that returns the MSE amax of all collected tensors.

Classes

MseCalibrator

Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x.

class MseCalibrator

Bases: _Calibrator

Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x.

__init__(amax, axis=None, num_steps=10, start_multiplier=0.25, stop_multiplier=4.0, quant_func=None, error_func=None)

Initialize MSE calibrator.

Parameters:
  • amax (Tensor) – Initial amax value (required).

  • axis (int | tuple | list | None) – Quantization axis. None means per-tensor quantization.

  • num_steps (int) – Number of amax candidates to try.

  • start_multiplier (float) – Starting multiplier for amax search.

  • stop_multiplier (float) – Ending multiplier for amax search.

  • quant_func (Callable[[Tensor, Tensor], Tensor] | None) – Function that quantizes input tensor given an amax value. Should have signature: quant_func(x, amax) -> quantized_x.

  • error_func (Callable[[Tensor, Tensor], Tensor] | None) – Function to compute error between x and xq. Default is F.mse_loss(x, xq, reduction=’none’).

collect(x)

Collect input tensor statistics and compute losses for MSE calibration.

Parameters:

x (Tensor) – Input tensor.

compute_amax(verbose=False)

Return the amax value that minimizes quantization error.

Parameters:

verbose (bool) – If True, print the ratio of best_amax to initial_amax.

reset()

Reset the stored losses and amax value.