quant_utils

Utils for quantization including scaling factors adjustments.

Functions

adjust_attn_amax_values

Adjusts the amax values for the attention layers.

all_items_same

Checks if all elements in the provided list are the same.

from_quantized_weight

Converts the quantized weight to the target torch_dtype format.

fuse_prequant_layernorm

Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.

get_activation_scaling_factor

Returns the activation scaling factor.

get_kv_cache_bias

Returns the kv_cache bias if _bias_value is set.

get_kv_cache_dtype

Returns the kv_cache dtype.

get_kv_cache_scaling_factor

Returns the kv_cache scaling factor if output quantizer is set.

get_prequant_scaling_factor

Returns the prequant scaling factor.

get_quant_config

Generate quantization config for a torch model.

get_quantization_format

Gets the quantization string.

get_scaling_factor

Returns scaling factor from the quantizer as torch.Tensor.

get_scaling_factor_from_weight

Calculate the weight scaling factor for a given group size.

get_weight_block_size

Returns the weight block size.

get_weight_scaling_factor

Returns the weight scaling factor.

get_weight_scaling_factor_2

Returns the secondary weight scaling factor.

maybe_transpose_expert_weight_dimensions

Transpose the last two dimensions of expert weights.

pack_int4_in_uint8

Packs the INT4 weights into uint8 tensor.

postprocess_state_dict

Filters out keys related to weight quantizers and updates KV cache related keys.

preprocess_linear_fusion

Preprocess the quantized linears that we plan to fuse.

process_layer_quant_config

Processes per layer quantization information for TRTLLM export to quant_cfg.json.

resmooth_and_get_scale

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

to_quantized_weight

Converts the weight to the quantized (packed) format.

adjust_attn_amax_values(module)

Adjusts the amax values for the attention layers.

all_items_same(item_list)

Checks if all elements in the provided list are the same.

from_quantized_weight(weight, weights_scaling_factor, quantization, torch_dtype)

Converts the quantized weight to the target torch_dtype format.

Parameters:
  • weight (Tensor)

  • weights_scaling_factor (Tensor)

  • quantization (str)

fuse_prequant_layernorm(layernorm_module, modules)

Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.

Parameters:
  • layernorm_module (Module)

  • modules (list[Tensor])

get_activation_scaling_factor(module, input_quantizer_name='input_quantizer')

Returns the activation scaling factor.

Parameters:
  • module (Module)

  • input_quantizer_name (str)

Return type:

Tensor

get_kv_cache_bias(kv_module)

Returns the kv_cache bias if _bias_value is set. Else returns None.

Parameters:

kv_module (Module)

Return type:

list[Tensor]

get_kv_cache_dtype(modules)

Returns the kv_cache dtype.

If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, otherwise returns None.

Parameters:

modules (list[Module] | Module) – The module or list of modules to inspect.

Returns:

The kv_cache dtype.

Return type:

str | None

get_kv_cache_scaling_factor(kv_module)

Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.

Parameters:

kv_module (Module)

Return type:

list[Tensor]

get_prequant_scaling_factor(module)

Returns the prequant scaling factor.

Parameters:

module (Module)

Return type:

Tensor

get_quant_config(named_modules)

Generate quantization config for a torch model.

Parameters:
  • model – The PyTorch model to analyze

  • named_modules (Module | dict[str, Module])

Returns:

Dictionary containing the quantization configuration

Return type:

dict[str, Any]

get_quantization_format(module)

Gets the quantization string.

Gets the quantization string by iterating through the module and its children. The first non-None quantization string is returned.

Return type:

str | None

get_scaling_factor(quantizer)

Returns scaling factor from the quantizer as torch.Tensor.

Parameters:

quantizer (TensorQuantizer)

Return type:

Tensor

get_scaling_factor_from_weight(weight, group_size)

Calculate the weight scaling factor for a given group size.

Return type:

tensor

get_weight_block_size(module, weight_name='weight')

Returns the weight block size.

Parameters:
  • module (Module)

  • weight_name (str)

Return type:

int

get_weight_scaling_factor(module, weight_name='weight')

Returns the weight scaling factor.

Parameters:
  • module (Module)

  • weight_name (str)

Return type:

Tensor

get_weight_scaling_factor_2(module, weight_name='weight')

Returns the secondary weight scaling factor.

Parameters:
  • module (Module)

  • weight_name (str)

Return type:

Tensor

maybe_transpose_expert_weight_dimensions(weight, weight_scale=None, is_bmm_expert_weight=True)

Transpose the last two dimensions of expert weights.

This function transposes expert weights between the two layouts: - (num_experts, input_dim, output_dim) ↔ (num_experts, output_dim, input_dim)

Since transpose(-2, -1) is self-inverse, this function can be used for both forward and backward transformations. This is needed for quantization functions that expect the last dimension to be the input dimension for block quantization. Specifically used for bmm-style expert weights in models like llama4 and gpt-oss.

Parameters:
  • weight (Tensor) – The weight tensor to transpose. Expected shape for experts: (num_experts, dim1, dim2)

  • weight_scale (Tensor | None) – Optional weight scaling factor tensor to transpose alongside weight

  • is_bmm_expert_weight (bool) – Whether this is an expert weight (3D tensor) that needs transposition

Returns:

Tuple of (transposed_weight, transposed_weight_scale)

Return type:

tuple[Tensor, Tensor | None]

pack_int4_in_uint8(weight, weights_scaling_factor)

Packs the INT4 weights into uint8 tensor.

postprocess_state_dict(state_dict, maxbound, quantization)

Filters out keys related to weight quantizers and updates KV cache related keys.

Parameters:
  • state_dict (dict) – The full model state_dict.

  • maxbound (float) – The maximum bound value for the output quantizer.

  • quantization (str | None) – The KV cache quantization format.

Returns:

The filtered state_dict without unnecessary keys like ‘_amax’ and non KV cache output quantizers.

Return type:

dict

preprocess_linear_fusion(modules, resmooth_only=False)

Preprocess the quantized linears that we plan to fuse.

Use resmooth_only for MOE experts as each individual expert is not fused.

Parameters:

modules (list[Module])

process_layer_quant_config(layer_config_dict)

Processes per layer quantization information for TRTLLM export to quant_cfg.json.

resmooth_and_get_scale(merged_weights, pre_quant_scales, ranks, group_size, new_pre_quant_scale=None, quantization=None)

Resmooths weights from a single or multiple ranks and get scaling factors and amax.

Parameters:
  • merged_weights (Tensor) – Merged weights from ranks.

  • pre_quant_scales (list[Tensor]) – List of pre-quantization scales for each rank.

  • ranks (int) – Number of ranks.

  • group_size (int) – Group size of the quantization block.

  • new_pre_quant_scale (optional) – If not provided, weights will be resmoothed using the average of pre_quant_scales.

  • quantization (str | None)

Returns:

Resmoothed weights. weight_scaling_factors: Resmoothed scaling factors. avg_pre_quant_scale: Calculated average of the quantization scale.

Return type:

weights

to_quantized_weight(weight, weights_scaling_factor, quantization, weights_scaling_factor2=None, block_size=None)

Converts the weight to the quantized (packed) format.

Parameters:
  • weight (Tensor)

  • weights_scaling_factor (Tensor)

  • quantization (str)

  • weights_scaling_factor2 (Tensor | None)

  • block_size (int | None)