quantization_utils
Utils for quantization including scaling factors adjustments.
Functions
Adjusts the amax values for the attention layers. |
|
Checks if all elements in the provided list are the same. |
|
Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly. |
|
Filters out all output quantizers in the state_dict except for the ones related to the kv_cache. |
|
Converts the quantized weight to the target torch_dtype format. |
|
Returns the activation scaling factor. |
|
Returns the kv_cache dtype. |
|
Returns the kv_cache scaling factor if output quantizer is set. |
|
Returns the prequant scaling factor. |
|
Get the qkv and average prequant scaling factor for the module. |
|
Gets the quantization string. |
|
Returns scaling factor from the quantizer as torch.Tensor. |
|
Returns the weight block size. |
|
Returns the weight scaling factor. |
|
Returns the secondary weight scaling factor. |
|
Calculate the weight scaling facotrs for a given group size. |
|
Processes per layer quantization information for TRTLLM export to quant_cfg.json. |
|
Resmooths weights from a single or multiple ranks and get scaling factors and amax. |
|
Converts the weight to the quantized (packed) format. |
- adjust_attn_amax_values(module)
Adjusts the amax values for the attention layers.
- all_items_same(item_list)
Checks if all elements in the provided list are the same.
- convert_state_dict_amax_to_scales(quantized_state_dict, maxbound, layers_quant_config)
Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly.
- Parameters:
quantized_state_dict (dict) – The input state dictionary with quantized values.
maxbound (float) – The maximum bound value for the given quantization format.
layers_quant_config (dict/str) – Dictionary containing per layer quantization format information for
quantization. (mixed_precision and str containing quantization format for regular) –
- Returns:
The updated state dictionary with converted scale values.
- Return type:
dict
- filter_output_quantizer(state_dict)
Filters out all output quantizers in the state_dict except for the ones related to the kv_cache.
- Parameters:
state_dict (dict) – The full model state_dict.
- Returns:
Filtered state_dict with only kv_cache output quantizers.
- Return type:
dict
- from_quantized_weight(weight, weights_scaling_factor, quantization, torch_dtype)
Converts the quantized weight to the target torch_dtype format.
- Parameters:
weight (Tensor) –
weights_scaling_factor (Tensor) –
quantization (str) –
- get_activation_scaling_factor(module)
Returns the activation scaling factor.
- Parameters:
module (Module) –
- Return type:
Tensor
- get_kv_cache_dtype(modules)
Returns the kv_cache dtype.
If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, otherwise returns None.
- Parameters:
modules (Union[List[nn.Module], nn.Module]) – The module or list of modules to inspect.
- Returns:
The kv_cache dtype.
- Return type:
str
- get_kv_cache_scaling_factor(qkv_modules)
Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.
- Parameters:
qkv_modules (List[Module]) –
- Return type:
Tensor
- get_prequant_scaling_factor(module, dtype)
Returns the prequant scaling factor.
- Parameters:
module (Module) –
dtype (dtype) –
- Return type:
Tensor
- get_qkv_and_avg_prequant_scale(module, dtype)
Get the qkv and average prequant scaling factor for the module.
- Parameters:
module – The module containing q, k, and v submodules.
dtype – The data type for the scaling factors.
- Returns:
- A tuple containing the average prequant scaling factor and individual
scaling factors for q, k, and v.
- Return type:
tuple
- get_quantization_format(module)
Gets the quantization string.
Gets the quantization string by iterating through the module and its children. The first non-None quantization string is returned.
- Return type:
str | None
- get_scaling_factor(quantizer)
Returns scaling factor from the quantizer as torch.Tensor.
- Parameters:
quantizer (TensorQuantizer) –
- Return type:
Tensor
- get_weight_block_size(module)
Returns the weight block size.
- Parameters:
module (Module) –
- Return type:
int
- get_weight_scaling_factor(module)
Returns the weight scaling factor.
- Parameters:
module (Module) –
- Return type:
Tensor
- get_weight_scaling_factor_2(module)
Returns the secondary weight scaling factor.
- Parameters:
module (Module) –
- Return type:
Tensor
- get_weights_scaling_factor_and_amax(weight, group_size)
Calculate the weight scaling facotrs for a given group size.
- process_layer_quant_config(layer_config_dict)
Processes per layer quantization information for TRTLLM export to quant_cfg.json.
- resmooth_and_get_scale_and_amax(merged_weights, pre_quant_scales, ranks, group_size, avg_pre_quant_scale=None, quantization=None)
Resmooths weights from a single or multiple ranks and get scaling factors and amax.
- Parameters:
merged_weights (Tensor) – Merged weights from ranks.
pre_quant_scales (List[Tensor]) – List of pre-quantization scales for each rank.
ranks (int) – Number of ranks.
group_size (int) – Group size of the quantization block.
avg_pre_quant_scale (optional) – If not provided, weights will be resmoothed using the average of pre_quant_scales.
quantization (str | None) –
- Returns:
Resmoothed weights. weight_scaling_factors: Resmoothed scaling factors. avg_pre_quant_scale: Calculated average of the quantization scale. amaxes: Amax values for the weights.
- Return type:
weights
- to_quantized_weight(weight, weights_scaling_factor, quantization, weights_scaling_factor2=None, block_size=None)
Converts the weight to the quantized (packed) format.
- Parameters:
weight (Tensor) –
weights_scaling_factor (Tensor) –
quantization (str) –
weights_scaling_factor2 (Tensor | None) –
block_size (int | None) –