scaling_factor_utils

Utils for scaling factors adjustments.

Functions

adjust_attn_amax_values

Adjusts the amax values for the attention layers.

convert_state_dict_amax_to_scales

Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly.

get_weights_scaling_factor_and_amax

Calculate the weight scaling facotrs for a given group size.

resmooth_and_get_scale_and_amax

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

adjust_attn_amax_values(module)

Adjusts the amax values for the attention layers.

convert_state_dict_amax_to_scales(quantized_state_dict, maxbound, layers_quant_config)

Convert _amax keys in a quantized state dictionary to scale values and update the state dictionary accordingly.

Parameters:
  • quantized_state_dict (dict) – The input state dictionary with quantized values.

  • maxbound (float) – The maximum bound value for the given quantization format.

  • layers_quant_config (dict/str) – Dictionary containing per layer quantization format information for

  • quantization. (mixed_precision and str containing quantization format for regular) –

Returns:

The updated state dictionary with converted scale values.

Return type:

dict

get_weights_scaling_factor_and_amax(weight, group_size)

Calculate the weight scaling facotrs for a given group size.

resmooth_and_get_scale_and_amax(merged_weights, pre_quant_scales, ranks, group_size, avg_pre_quant_scale=None)

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

Parameters:
  • merged_weights (Tensor) – Merged weights from ranks.

  • pre_quant_scales (List[Tensor]) – List of pre-quantization scales for each rank.

  • ranks (int) – Number of ranks.

  • group_size (int) – Group size of the quantization block.

  • avg_pre_quant_scale (optional) – If not provided, weights will be resmoothed using the average of pre_quant_scales.

Returns:

Resmoothed weights. weight_scaling_factors: Resmoothed scaling factors. avg_pre_quant_scale: Calculated average of the quantization scale. amaxes: Amax values for the weights.

Return type:

weights