config

This document lists the quantization formats supported by Model Optimizer and example quantization configs.

Quantization Formats

The following table lists the quantization formats supported by Model Optimizer and the corresponding quantization config. See Quantization Configs for the specific quantization config definitions.

Please see choosing the right quantization formats to learn more about the formats and their use-cases.

Note

The recommended configs given below are for LLM models. For CNN models, only INT8 quantization is supported. Please use quantization config INT8_DEFAULT_CFG for CNN models.

Quantization Format

Model Optimizer config

INT8

INT8_SMOOTHQUANT_CFG

FP8

FP8_DEFAULT_CFG

INT4 Weights only AWQ (W4A16)

INT4_AWQ_CFG

INT4-FP8 AWQ (W4A8)

W4A8_AWQ_BETA_CFG

Quantization Configs

Quantization config is dictionary specifying the values for keys "quant_cfg" and "algorithm". The "quant_cfg" key specifies the quantization configurations. The "algorithm" key specifies the algorithm argument to calibrate. Please see QuantizeConfig for the quantization config definition.

‘Quantization configurations’ is a dictionary mapping wildcards or filter functions to its ‘quantizer attributes’. The wildcards or filter functions are matched against the quantizer module names. The quantizer modules have names ending with weight_quantizer and input_quantizer and they perform weight quantization and input quantization (or activation quantization) respectively. The quantizer modules are generally instances of TensorQuantizer. The quantizer attributes are defined by QuantizerAttributeConfig. See QuantizerAttributeConfig for details on the quantizer attributes and their values.

The key “default” from the quantization configuration dictionary is applied if no other wildcard or filter functions match the quantizer module name.

The quantizer attributes are applied in the order they are specified. For the missing attributes, the default attributes as defined by QuantizerAttributeConfig are used.

Quantizer attributes can also be a list of dictionaries. In this case, the matched quantizer module is replaced with a SequentialQuantizer module which is used to quantize a tensor in multiple formats sequentially. Each quantizer attribute dictionary in the list specifies the quantization formats for each quantization step of the sequential quantizer. For example, SequentialQuantizer is used in ‘INT4 Weights, FP8 Activations’ quantization in which the weights are quantized in INT4 followed by FP8.

In addition, the dictionary entries could also be pytorch module class names mapping the class specific quantization configurations. The pytorch modules should have a quantized equivalent.

To get the string representation of a module class, do:

from modelopt.torch.quantization.nn import QuantModuleRegistry

# Get the class name for nn.Conv2d
class_name = QuantModuleRegistry.get_key(nn.Conv2d)

Here is an example of a quantization config:

MY_QUANT_CFG = {
    "quant_cfg": {
        # Quantizer wildcard strings mapping to quantizer attributes
        "*weight_quantizer": {"num_bits": 8, "axis": 0},
        "*input_quantizer": {"num_bits": 8, "axis": None},

        # Module class names mapping to quantizer configurations
        "nn.LeakyReLU": {"*input_quantizer": {"enable": False}},

    }
}

Example Quantization Configurations

Here are the recommended quantization configs from Model Optimizer for quantization formats such as FP8, INT8, INT4, etc.:

INT8_DEFAULT_CFG = {
    "quant_cfg": {
    "*weight_quantizer": {"num_bits": 8, "axis": 0},
    "*input_quantizer": {"num_bits": 8, "axis": None},
    "*lm_head*": {"enable": False},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "default": {"enable": False},
    },
    "algorithm": "max",
}

INT8_SMOOTHQUANT_CFG = {
    "quant_cfg": {
    "*weight_quantizer": {"num_bits": 8, "axis": 0},
    "*input_quantizer": {"num_bits": 8, "axis": -1},
    "*lm_head*": {"enable": False},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "nn.Conv2d": {
        "*weight_quantizer": {"num_bits": 8, "axis": 0},
        "*input_quantizer": {"num_bits": 8, "axis": None},
    },
    "default": {"enable": False},
    },
    "algorithm": "smoothquant",
}

FP8_DEFAULT_CFG = {
    "quant_cfg": {
    "*weight_quantizer": {"num_bits": (4, 3), "axis": None},
    "*input_quantizer": {"num_bits": (4, 3), "axis": None},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "default": {"enable": False},
    },
    "algorithm": "max",
}

INT4_BLOCKWISE_WEIGHT_ONLY_CFG = {
    "quant_cfg": {
    "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
    "*input_quantizer": {"enable": False},
    "*lm_head*": {"enable": False},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "default": {"enable": False},
    },
    "algorithm": "max",
}

INT4_AWQ_CFG = {
    "quant_cfg": {
    "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
    "*input_quantizer": {"enable": False},
    "*lm_head*": {"enable": False},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "default": {"enable": False},
    },
    "algorithm": {"method": "awq_lite", "alpha_step": 0.1},
    # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024},
    # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048},
}

W4A8_AWQ_BETA_CFG = {
"quant_cfg": {
    "*weight_quantizer": [
        {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True},
        {"num_bits": (4, 3), "axis": None, "enable": True},
    ],
    "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
    "*lm_head*": {"enable": False},
    "*block_sparse_moe.gate*": {"enable": False},  # Skip the MOE router
    "*router*": {"enable": False},  # Skip the MOE router
    "default": {"enable": False},
},
"algorithm": "awq_lite",
}

These config can be accessed as attributes of modelopt.torch.quantization and can be given as input to mtq.quantize(). For example:

import modelopt.torch.quantization as mtq
model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)

You can also create your own config by following these examples. For instance, if you want to quantize a model with int4 AWQ algorithm, but need to skip quantizing the layer named lm_head, you can create a custom config and quantize your model as following:

# Create custom config
CUSTOM_INT4_AWQ_CFG = copy.deepcopy(mtq.INT4_AWQ_CFG)
CUSTOM_INT4_AWQ_CFG["quant_cfg"]["*lm_head*"] = {"enable": False}

# quantize model
model = mtq.quantize(model, CUSTOM_INT4_AWQ_CFG, forward_loop)
ModeloptConfig AWQClipCalibConfig

Bases: QuantizeAlgorithmConfig

The config for awq_clip (AWQ clip) algorithm.

AWQ clip searches clipped amax for per-group quantization, This search requires much more compute compared to AWQ lite. To avoid any OOM, the linear layer weights are batched along the out_features dimension of batch size max_co_batch_size. AWQ clip calibration also takes longer than AWQ lite.

Show default config as JSON
Default config (JSON):

{
   "method": "max",
   "max_co_batch_size": 1024,
   "max_tokens_per_batch": 64,
   "min_clip_ratio": 0.5,
   "shrink_step": 0.05,
   "debug": false
}

field debug: bool | None

Show details

If True, module’s search metadata will be kept as a module attribute named awq_clip.

field max_co_batch_size: int | None

Show details

Reduce this number if CUDA Out of Memory error occurs.

field max_tokens_per_batch: int | None

Show details

The total tokens used for clip search would be max_tokens_per_batch * number of batches. Original AWQ uses a total of 512 tokens to search for clip values.

field min_clip_ratio: float | None

Show details

It should be in (0, 1.0). Clip will search for the optimal clipping value in the range [original block amax * min_clip_ratio, original block amax].

Constraints:
  • gt = 0.0

  • lt = 1.0

field shrink_step: float | None

Show details

It should be in range (0, 1.0]. The clip ratio will be searched from min_clip_ratio to 1 with the step size specified.

Constraints:
  • gt = 0.0

  • le = 1.0

ModeloptConfig AWQFullCalibConfig

Bases: AWQLiteCalibConfig, AWQClipCalibConfig

The config for awq or awq_full algorithm (AWQ full).

AWQ full performs awq_lite followed by awq_clip.

Show default config as JSON
Default config (JSON):

{
   "method": "max",
   "max_co_batch_size": 1024,
   "max_tokens_per_batch": 64,
   "min_clip_ratio": 0.5,
   "shrink_step": 0.05,
   "debug": false,
   "alpha_step": 0.1
}

field debug: bool | None

Show details

If True, module’s search metadata will be kept as module attributes named awq_lite and awq_clip.

ModeloptConfig AWQLiteCalibConfig

Bases: QuantizeAlgorithmConfig

The config for awq_lite (AWQ lite) algorithm.

AWQ lite applies a channel-wise scaling factor which minimizes the output difference after quantization. See AWQ paper for more details.

Show default config as JSON
Default config (JSON):

{
   "method": "max",
   "alpha_step": 0.1,
   "debug": false
}

field alpha_step: float | None

Show details

The alpha will be searched from 0 to 1 with the step size specified.

Constraints:
  • gt = 0.0

  • le = 1.0

field debug: bool | None

Show details

If True, module’s search metadata will be kept as a module attribute named awq_lite.

ModeloptConfig MaxCalibConfig

Bases: QuantizeAlgorithmConfig

The config for max calibration algorithm.

Max calibration estimates max values of activations or weights and use this max values to set the quantization scaling factor. See Integer Quantization for the concepts.

Show default config as JSON
Default config (JSON):

{
   "method": "max"
}

ModeloptConfig QuantizeAlgorithmConfig

Bases: ModeloptBaseConfig

Calibration algorithm config base.

Show default config as JSON
Default config (JSON):

{
   "method": "max"
}

field method: str

Show details

The algorithm used for calibration. Supported algorithms include "max", "smoothquant", "awq_lite", "awq_full", and "awq_clip".

ModeloptConfig QuantizeConfig

Bases: ModeloptBaseConfig

Default configuration for quantize mode.

Show default config as JSON
Default config (JSON):

{
   "quant_cfg": {
      "default": {
         "num_bits": 8,
         "axis": null
      }
   },
   "algorithm": "max"
}

field algorithm: None | str | MaxCalibConfig | SmoothQuantCalibConfig | AWQLiteCalibConfig | AWQClipCalibConfig | AWQFullCalibConfig | RealQuantizeConfig
field quant_cfg: Dict[str | Callable, QuantizerAttributeConfig | List[QuantizerAttributeConfig] | Dict[str | Callable, QuantizerAttributeConfig | List[QuantizerAttributeConfig]]]
ModeloptConfig QuantizerAttributeConfig

Bases: ModeloptBaseConfig

Quantizer attribute type.

Show default config as JSON
Default config (JSON):

{
   "enable": true,
   "num_bits": 8,
   "axis": null,
   "fake_quant": true,
   "unsigned": false,
   "narrow_range": false,
   "learn_amax": false,
   "type": "static",
   "block_sizes": null,
   "trt_high_precision_dtype": "Float",
   "calibrator": "max"
}

field axis: int | Tuple[int, ...] | None

Show details

The specified axis/axes will have its own amax for computing scaling factor. If None (the default), use per tensor scale. Must be in the range [-rank(input_tensor), rank(input_tensor)). E.g. For a KCRS weight tensor, quant_axis=(0) will yield per channel scaling.

field block_sizes: Dict[int | str, int | Tuple[int, int] | str | Dict[int, int]] | None

Show details

The keys are the axes for block quantization and the values are block sizes for quantization along the respective axes. Keys must be in the range [-tensor.dim(), tensor.dim()). Values, which are the block sizes for quantization must be positive integers.

In addition, there can be special string keys "type", "scale_bits" and "scale_block_sizes".

Key "type" should map to "dynamic" or "static" where "dynamic" indicates dynamic block quantization and “static” indicates static calibrated block quantization. By default, the type is "static".

Key "scale_bits" specify the quantization bits for the per-block quantization scale factor (i.e a double quantization scheme).

Key "scale_block_sizes" specify the block size for double quantization. By default per-block quantization scale is not quantized.

For example, block_sizes = {-1: 32} will quantize the last axis of the input tensor in blocks of size 32 with static calibration and block_sizes = {-1: 32, "type": "dynamic"} will perform dynamic block quantization. If None, block quantization is not performed. axis must be None when block_sizes is not None.

field calibrator: str | Callable | Tuple

Show details

The calibrator can be a string from ["max", "histogram"] or a constructor to create a calibrator which subclasses _Calibrator. See standardize_constructor_args for more information on how to specify the constructor.

field enable: bool

Show details

If True, enables the quantizer. If False, by-pass the quantizer and returns the input tensor.

field fake_quant: bool

Show details

If True, enable fake quantization.

field learn_amax: bool

Show details

If True, enable learning amax.

field narrow_range: bool

Show details

If True, enable narrow range quantization. Used only for integer quantization.

field num_bits: int | Tuple[int, int]

Show details

num_bits can be:

  1. A positive integer argument for integer quantization. num_bits specify

    the number of bits used for integer quantization.

  2. Constant integer tuple (E,M) for floating point quantization emulating

    Nvidia’s FPx quantization. E is the number of exponent bits and M is the number of mantissa bits. Supported FPx quantization formats: FP8 (E4M3, E5M2).

field trt_high_precision_dtype: str

Show details

The value is a string from ["Float", "Half", "BFloat16"]. The QDQs will be assigned the appropriate data type, and this variable will only be used when the user is exporting the quantized ONNX model.

Constraints:
  • pattern = ^Float$|^Half$|^BFloat16$

field type: str

Show details

The value is a string from ["static", "dynamic"]. If "dynamic", dynamic quantization will be enabled which does not collect any statistics during calibration.

Constraints:
  • pattern = ^static$|^dynamic$

field unsigned: bool

Show details

If True, enable unsigned quantization. Used only for integer quantization.

ModeloptConfig RealQuantizeConfig

Bases: QuantizeAlgorithmConfig

The config for real quantization config.

The additional_algorithm will be used for calibration before quantizing weights into low precision.

Show default config as JSON
Default config (JSON):

{
   "method": "max",
   "additional_algorithm": ""
}

field additional_algorithm: AWQLiteCalibConfig | AWQClipCalibConfig | AWQFullCalibConfig | None

Show details

The algorithm used for calibration. Supported algorithms include "awq_lite", "awq_full", and "awq_clip".

ModeloptConfig SmoothQuantCalibConfig

Bases: QuantizeAlgorithmConfig

The config for smoothquant algorithm (SmoothQuant).

SmoothQuant applies a smoothing factor which balances the scale of outliers in weights and activations. See SmoothQuant paper for more details.

Show default config as JSON
Default config (JSON):

{
   "method": "max",
   "alpha": 1.0
}

field alpha: float | None

Show details

This hyper-parameter controls the migration strength.The migration strength is within [0, 1], a larger value migrates more quantization difficulty to weights.

Constraints:
  • ge = 0.0

  • le = 1.0