model_quant

User-facing quantization API.

Functions

quantize

Quantizes and calibrates the model in-place.

auto_quantize

API for AutoQuantize which quantizes a model by searching for the best quantization formats per-layer.

disable_quantizer

Disable quantizer by wildcard or filter function.

enable_quantizer

Enable quantizer by wildcard or filter function.

print_quant_summary

Print summary of all quantizer modules in the model.

fold_weight

Fold weight quantizer for fast evaluation.

auto_quantize(model, constraints={'effective_bits': 4.8}, quantization_formats=['W4A8_AWQ_BETA_CFG', 'FP8_DEFAULT_CFG', None], data_loader=None, forward_step=None, loss_func=None, forward_backward_step=None, num_calib_steps=512, num_score_steps=128, verbose=False)

API for AutoQuantize which quantizes a model by searching for the best quantization formats per-layer.

auto_quantize uses a gradient based sensitivity score to rank the per-layer quantization formats and search for the best quantization formats per-layer.

Parameters:
  • model (Module) – A pytorch model with quantizer modules.

  • constraints (Dict[str, float | str]) –

    Constraints for the search. Currently we support only effective_bits. effective_bits specifies the effective number of bits for the quantized model.

    Here is an example for valid effective_bits argument:

    # For an effective quantization bits of 4.8
    constraints = {"effective_bits": 4.8}
    

  • quantization_formats (List[str | None]) –

    A list of the string names of the quantization formats to search for. The supported quantization formats are as listed by modelopt.torch.quantization.config.choices.

    In addition, the quantization format can also be None which implies skipping quantization for the layer.

    Note

    The quantization formats will be applied on a per-layer match basis. The global model level name based quantizer attribute setting will be ignored. For example, in FP8_DEFAULT_CFG quantizer configuration the key "*lm_head*": {"enable": False} disables quantization for the lm_head layer. However in auto_quantize, the quantization format for the lm_head layer will be searched. This is because the key "*lm_head*" sets the quantizer attributes based on the global model level name, not per-layer basis. The keys "*input_quantizer", "*weight_quantizer" etc. in FP8_DEFAULT_CFG match on a per-layer basis - hence the corresponding quantizers will be set as specified.

    Here is an example quantization_formats argument:

    # A valid `quantization_formats` argument
    # This will search for the best per-layer quantization from FP8, W4A8_AWQ or No quantization
    quantization_formats = ["FP8_DEFAULT_CFG", "W4A8_AWQ", None]
    

  • data_loader (Iterable) – An iterator that yields data that is to be used for calibrating quantized layers and estimating auto_quantize scores.

  • forward_step (Callable[[Module, Any], Any | Tensor]) –

    A callable that takes the model and a batch of data from data_loader as input, forwards the data through the model and returns the model output. This is a required argument.

    Here is an example for a valid forward_step:

    # Takes the model and a batch of data as input and returns the model output
    def forward_step(model, batch) -> torch.Tensor:
        output = model(batch)
        return output
    

  • loss_func (Callable[[Any, Any], Tensor]) –

    (Optional) A callable that takes the model output and the batch of data as input and computes the loss. The model output is the output given by forward_step. .backward() will be called on the loss.

    Here is an example for a valid loss_func:

    # Takes the model output and a batch of data as input and returns the loss
    def loss_func(output, batch) -> torch.Tensor:
        ...
        return loss
    
    
    # loss should be a scalar tensor such that loss.backward() can be called
    loss = loss_func(output, batch)
    loss.backward()
    

    If this argument is not provided, forward_backward_step should be provided.

  • forward_backward_step (Callable[[Module, Any], Any] | None) –

    (Optional) A callable that takes batch of data from data_loader, forwards it through the model, computes the loss and runs backward on the loss.

    Here is an example for a valid forward_backward_step argument:

    # Takes the model and a batch of data as input and runs forward and backward pass
    def forward_backward_step(model, batch) -> None:
        output = model(batch)
        loss = my_loss_func(output, batch)
        run_custom_backward(loss)
    

    If this argument is not provided, loss_func should be provided.

  • num_calib_steps (int) – Number of batches to use for calibrating the quantized model. Suggested value is 512.

  • num_score_steps (int) – Number of batches to use for estimating auto_quantize scores. Suggested value is 128. A higher value could increase the time taken for performing auto_quantize.

  • verbose (bool) – If True, prints the search progress/intermediate results.

Returns: A tuple (model, state_dict) where model is the searched and quantized model and

state_dict contains the history and detailed stats of the search procedure.

Note

auto_quantize groups certain layers and restricts the quantization formats for them to be same. For example, Q, K, V linear layers belonging to the same transformer layer will have the same quantization format. This is to ensure compatibility with TensorRT-LLM which fuses these three linear layers into a single linear layer.

A list of regex pattern rules as defined in rules are used to specify the group of layers. The first captured group in the regex pattern (i.e, pattern.match(name).group(1)) is used to group the layers. All the layers that share the same first captured group will have the same quantization format..

For example, the rule r"^(.*?)\.(q_proj|k_proj|v_proj)$" groups the q_proj, k_proj, v_proj linear layers belonging to the same transformer layer.

You may modify the rules to group the layers as per your requirement.

from modelopt.torch.quantization.algorithms import AutoQuantizeSearcher

# To additionally group the layers belonging to same `mlp` layer,
# add the following rule
AutoQuantizeSearcher.rules.append(r"^(.*?)\.mlp")

# Perform `auto_quantize`
model, state_dict = auto_quantize(model, ...)

Note

The auto_quantize API and algorithm is experimental and subject to change. auto_quantize searched models might not be readily deployable to TensorRT-LLM yet.

disable_quantizer(model, wildcard_or_filter_func)

Disable quantizer by wildcard or filter function.

Parameters:
  • model (Module) –

  • wildcard_or_filter_func (str | Callable) –

enable_quantizer(model, wildcard_or_filter_func)

Enable quantizer by wildcard or filter function.

Parameters:
  • model (Module) –

  • wildcard_or_filter_func (str | Callable) –

fold_weight(model)

Fold weight quantizer for fast evaluation.

Parameters:

model (Module) –

print_quant_summary(model)

Print summary of all quantizer modules in the model.

Parameters:

model (Module) –

quantize(model, config, forward_loop=None)

Quantizes and calibrates the model in-place.

This method performs replacement of modules with their quantized counterparts and performs calibration as specified by quant_cfg. forward_loop is used to forward data through the model and gather statistics for calibration.

Parameters:
  • model (Module) – A pytorch model

  • config (Dict[str, Any]) –

    A dictionary or an instance of QuantizeConfig specifying the values for keys "quant_cfg" and "algorithm". It is basically a dictionary specifying the values for keys "quant_cfg" and "algorithm". The "quant_cfg" key specifies the quantization configurations. The "algorithm" key specifies the algorithm argument to calibrate.

    Quantization configurations is a dictionary mapping wildcards or filter functions to its quantizer attributes. The wildcards or filter functions are matched against the quantizer module names. The quantizer modules have names ending with weight_quantizer and input_quantizer and they perform weight quantization and input quantization (or activation quantization) respectively. The quantizer modules are instances of TensorQuantizer. The quantizer attributes are defined by QuantizerAttributeConfig. See QuantizerAttributeConfig for details on the quantizer attributes and their values.

    An example config dictionary is given below:

    See Quantization Formats to learn more about the supported quantization formats. See Quantization Configs for more details on config dictionary.

  • forward_loop (Callable[[Module], None] | None) –

    A callable that forwards all calibration data through the model. This is used to gather statistics for calibration. It should take model as the argument. It does not need to return anything.

    This argument is not required for weight-only quantization with the "max" algorithm.

    Here are a few examples for correct forward_loop definitions: Example 1:

    def forward_loop(model) -> None:
        # iterate over the data loader and forward data through the model
        for batch in data_loader:
            model(batch)
    

    Example 2:

    def forward_loop(model) -> float:
        # evaluate the model on the task
        return evaluate(model, task, ....)
    

    Example 3:

    def forward_loop(model) -> None:
        # run evaluation pipeline
        evaluator.model = model
        evaluator.evaluate()
    

    Note

    Calibration does not require forwarding the entire dataset through the model. Please subsample the dataset or reduce the number of batches if needed.

Return type:

Module

Returns: A pytorch model which has been quantized and calibrated.