quantization_utils

Utils for quantization including scaling factors adjustments.

Functions

adjust_attn_amax_values

Adjusts the amax values for the attention layers.

all_items_same

Checks if all elements in the provided list are the same.

convert_state_dict_amax_to_scales

Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly.

filter_output_quantizer

Filters out all output quantizers in the state_dict except for the ones related to the kv_cache.

from_quantized_weight

Converts the quantized weight to the target torch_dtype format.

get_activation_scaling_factor

Returns the activation scaling factor.

get_kv_cache_dtype

Returns the kv_cache dtype.

get_kv_cache_scaling_factor

Returns the kv_cache scaling factor if output quantizer is set.

get_prequant_scaling_factor

Returns the prequant scaling factor.

get_qkv_and_avg_prequant_scale

Get the qkv and average prequant scaling factor for the module.

get_quantization_format

Gets the quantization string.

get_scaling_factor

Returns scaling factor from the quantizer as torch.Tensor.

get_weight_block_size

Returns the weight block size.

get_weight_scaling_factor

Returns the weight scaling factor.

get_weight_scaling_factor_2

Returns the secondary weight scaling factor.

get_weights_scaling_factor_and_amax

Calculate the weight scaling facotrs for a given group size.

process_layer_quant_config

Processes per layer quantization information for TRTLLM export to quant_cfg.json.

resmooth_and_get_scale_and_amax

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

to_quantized_weight

Converts the weight to the quantized (packed) format.

adjust_attn_amax_values(module)

Adjusts the amax values for the attention layers.

all_items_same(item_list)

Checks if all elements in the provided list are the same.

convert_state_dict_amax_to_scales(quantized_state_dict, maxbound, layers_quant_config)

Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly.

Parameters:
  • quantized_state_dict (dict) – The input state dictionary with quantized values.

  • maxbound (float) – The maximum bound value for the given quantization format.

  • layers_quant_config (dict/str) – Dictionary containing per layer quantization format information for

  • quantization. (mixed_precision and str containing quantization format for regular) –

Returns:

The updated state dictionary with converted scale values.

Return type:

dict

filter_output_quantizer(state_dict)

Filters out all output quantizers in the state_dict except for the ones related to the kv_cache.

Parameters:

state_dict (dict) – The full model state_dict.

Returns:

Filtered state_dict with only kv_cache output quantizers.

Return type:

dict

from_quantized_weight(weight, weights_scaling_factor, quantization, torch_dtype)

Converts the quantized weight to the target torch_dtype format.

Parameters:
  • weight (Tensor) –

  • weights_scaling_factor (Tensor) –

  • quantization (str) –

get_activation_scaling_factor(module)

Returns the activation scaling factor.

Parameters:

module (Module) –

Return type:

Tensor

get_kv_cache_dtype(modules)

Returns the kv_cache dtype.

If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, otherwise returns None.

Parameters:

modules (Union[List[nn.Module], nn.Module]) – The module or list of modules to inspect.

Returns:

The kv_cache dtype.

Return type:

str

get_kv_cache_scaling_factor(qkv_modules)

Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.

Parameters:

qkv_modules (List[Module]) –

Return type:

Tensor

get_prequant_scaling_factor(module, dtype)

Returns the prequant scaling factor.

Parameters:
  • module (Module) –

  • dtype (dtype) –

Return type:

Tensor

get_qkv_and_avg_prequant_scale(module, dtype)

Get the qkv and average prequant scaling factor for the module.

Parameters:
  • module – The module containing q, k, and v submodules.

  • dtype – The data type for the scaling factors.

Returns:

A tuple containing the average prequant scaling factor and individual

scaling factors for q, k, and v.

Return type:

tuple

get_quantization_format(module)

Gets the quantization string.

Gets the quantization string by iterating through the module and its children. The first non-None quantization string is returned.

Return type:

str | None

get_scaling_factor(quantizer)

Returns scaling factor from the quantizer as torch.Tensor.

Parameters:

quantizer (TensorQuantizer) –

Return type:

Tensor

get_weight_block_size(module)

Returns the weight block size.

Parameters:

module (Module) –

Return type:

int

get_weight_scaling_factor(module)

Returns the weight scaling factor.

Parameters:

module (Module) –

Return type:

Tensor

get_weight_scaling_factor_2(module)

Returns the secondary weight scaling factor.

Parameters:

module (Module) –

Return type:

Tensor

get_weights_scaling_factor_and_amax(weight, group_size)

Calculate the weight scaling facotrs for a given group size.

process_layer_quant_config(layer_config_dict)

Processes per layer quantization information for TRTLLM export to quant_cfg.json.

resmooth_and_get_scale_and_amax(merged_weights, pre_quant_scales, ranks, group_size, avg_pre_quant_scale=None, quantization=None)

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

Parameters:
  • merged_weights (Tensor) – Merged weights from ranks.

  • pre_quant_scales (List[Tensor]) – List of pre-quantization scales for each rank.

  • ranks (int) – Number of ranks.

  • group_size (int) – Group size of the quantization block.

  • avg_pre_quant_scale (optional) – If not provided, weights will be resmoothed using the average of pre_quant_scales.

  • quantization (str | None) –

Returns:

Resmoothed weights. weight_scaling_factors: Resmoothed scaling factors. avg_pre_quant_scale: Calculated average of the quantization scale. amaxes: Amax values for the weights.

Return type:

weights

to_quantized_weight(weight, weights_scaling_factor, quantization, weights_scaling_factor2=None, block_size=None)

Converts the weight to the quantized (packed) format.

Parameters:
  • weight (Tensor) –

  • weights_scaling_factor (Tensor) –

  • quantization (str) –

  • weights_scaling_factor2 (Tensor | None) –

  • block_size (int | None) –