utils

Quantization utilities.

Classes

SharedQuantState

Base class for shared quantization state owned by a group parent.

SharedWeightGlobalAmaxState

Canonical shared weight global_amax for one fusible sibling group.

Functions

convert_quantization_axis_to_reduce_axis

Convert the quantization axis to the reduce axis.

export_torch_mode

Context manager enabling the export mode.

find_shared_input_groups

Find fusible sibling groups by regex over module FQNs; capture groups define the key.

is_quantized

Check if a module is quantized.

is_quantized_column_parallel_linear

Check if a module is a quantized column parallel linear module.

is_quantized_linear

Check if a module is a quantized linear module.

is_quantized_row_parallel_linear

Check if a module is a quantized row parallel linear module.

iter_shared_quant_states

Yield shared quant states owned within model.

reduce_amax

Compute the absolute maximum value of a tensor.

reduce_sum

Compute the sum of a tensor along specified axes.

replace_function

Replace a function with a new one within a context.

representative_weight_quantizer

Return the representative weight quantizer for weight_name on module.

update_quant_cfg_with_kv_cache_quant

Update the quant_cfg with the kv cache quant_cfg.

weight_attr_names

Get the weight param attribute names in a converted module, non-recursive.

class SharedQuantState

Bases: Module, ABC

Base class for shared quantization state owned by a group parent.

Subclasses define when and how their canonical state is initialized. Runtime states can override install_hooks() to compute/cache group-level values at the parent instead of doing the same work in every member.

__init__()

Initialize an empty shared-state owner.

Return type:

None

classmethod attach(model, patterns=None)

Create this state on each discovered group’s parent.

Parameters:
  • model (Module)

  • patterns (Sequence[str] | None)

Return type:

int

default_patterns: ClassVar[tuple[str, ...]] = ()
finalize()

Whether the managed buffer(s) are populated; the finalize hook-produced states inherit.

A state whose value is produced during the forward (e.g. the shared input-amax state’s parent hook) needs only this readiness gate — the value already exists by now. States that produce on demand override it: weight aggregates member _amax, SVDQuant runs an SVD. populate() skips a state whose finalize returns False (uncalibrated / meta / forward never ran).

Return type:

bool

install_hooks()

Install parent/member hooks for runtime shared computation.

Return type:

None

managed_attrs: ClassVar[tuple[str, ...]] = ()
property members: tuple[Module, ...]

Return linked member modules.

classmethod metadata(model)

Return restore metadata for this state when present in model.

Parameters:

model (Module)

Return type:

dict[str, bool]

name: ClassVar[str]
classmethod populate(model)

Finalize and sync every state of this type in model; return the count populated.

Parameters:

model (Module)

Return type:

int

remove_hooks()

Remove hooks installed by install_hooks().

Return type:

None

classmethod resolve_patterns(shared_states=None)

Resolve the max-calibration config into grouping patterns for this state.

Parameters:

shared_states (Mapping[str, Mapping[str, Sequence[str]]] | None)

Return type:

list[str]

classmethod restore(model, patterns=None)

Re-attach states and rebuild member aliases from members’ restored buffers.

Parameters:
  • model (Module)

  • patterns (Sequence[str] | None)

Return type:

None

restore_from_members()

Rebuild the canonical buffer from members’ restored buffers and re-tie.

Used only on checkpoint restore: the state is non-persistent, so it is absent until rebuilt here from the members’ (persistent, just-loaded) buffers.

Return type:

bool

set_members(parent, members)

Set the owning parent and linked member modules.

Parameters:
  • parent (Module)

  • members (Sequence[Module])

Return type:

None

abstract sync(parallel_state=None)

Synchronize canonical state across distributed process groups.

Parameters:

parallel_state (ParallelState | None)

Return type:

None

target_quantizer_kind: ClassVar[str] = 'weight'
tie_member_quantizer(quantizer)

Alias a member quantizer’s managed buffers to this state’s canonical buffers.

For each managed attr, point quantizer._buffers[attr] at the same tensor object as self.<attr> (register it if absent, else replace) so the member and the state share one storage, not a copy. Records the attr in the quantizer’s _shared_quant_tied_attrs so TensorQuantizer.__setattr__ rejects a later rebind. Returns whether anything was tied.

Parameters:

quantizer (Module)

Return type:

bool

tie_member_quantizers()

Tie all eligible member quantizers to the canonical state buffers.

Return type:

None

class SharedWeightGlobalAmaxState

Bases: SharedQuantState

Canonical shared weight global_amax for one fusible sibling group.

__init__()

Initialize with an unset canonical global_amax buffer.

Return type:

None

default_patterns: ClassVar[tuple[str, ...]] = ('(?:(.*)\\.)?(?:q_proj|k_proj|v_proj)', '(?:(.*)\\.)?(?:gate_proj|up_proj)', '(?:(.*)\\.)?(?:w1|w3)')
finalize()

Set global_amax to the max over members’ calibrated _amax.

Return type:

bool

property global_amax

Return the canonical shared global amax.

managed_attrs: ClassVar[tuple[str, ...]] = ('_global_amax',)
name: ClassVar[str] = 'weight_global_amax'
sync(parallel_state=None)

All-reduce (MAX) global_amax across EP, plus TP defensively.

Parameters:

parallel_state (ParallelState | None)

Return type:

None

target_quantizer_kind: ClassVar[str] = 'weight'
tie_member_quantizer(quantizer)

Tie one member quantizer to the shared _global_amax buffer when eligible.

Parameters:

quantizer (Module)

Return type:

bool

convert_quantization_axis_to_reduce_axis(input, axis)

Convert the quantization axis to the reduce axis.

Parameters:
  • input (torch.Tensor) – The input tensor.

  • axis (int, tuple, list of None) – The quantization axis. None means per-tensor quantization.

Returns:

The axis to reduce. None suggests all dimensions should be reduced.

Return type:

list

export_torch_mode()

Context manager enabling the export mode.

find_shared_input_groups(model, patterns=None, target_quantizer_kind='weight')

Find fusible sibling groups by regex over module FQNs; capture groups define the key.

Each pattern is re.fullmatch-ed against every quantized module’s fully-qualified name; modules whose match yields the same capture-group tuple form one group, parented at their LCA. Granularity is set by what you capture:

  • Capture the immediate parent -> per-parent grouping: q/k/v per attention block, and per-expert w1/w3 (each expert is the immediate parent), e.g. r"(.*)\.(?:w1|w3)$".

  • Capture only a level above the expert index, leaving the index uncaptured -> one cross-expert group, e.g. r"(.*)\.experts\.\d+\.(?:w1|w3)$".

Roles to fuse together go in a non-capturing alternation (?:w1|w3) so they don’t split the key; what you wrap in (...) is the group boundary. Pass SHARED_PATTERNS for the standard q/k/v + gate/up groups, or override via MaxCalibConfig.shared_states. The caller selects which quantizer these groups apply to. Returns (parent, members) tuples; empty when no patterns are given.

Parameters:
  • model (Module)

  • patterns (Sequence[str] | None)

  • target_quantizer_kind (str)

Return type:

list[tuple[Module, list[Module]]]

is_quantized(module)

Check if a module is quantized.

is_quantized_column_parallel_linear(module)

Check if a module is a quantized column parallel linear module.

is_quantized_linear(module)

Check if a module is a quantized linear module.

is_quantized_row_parallel_linear(module)

Check if a module is a quantized row parallel linear module.

iter_shared_quant_states(model, state_cls=<class 'modelopt.torch.quantization.utils.shared_input.SharedQuantState'>)

Yield shared quant states owned within model.

Parameters:
reduce_amax(input, axis=None, keepdims=True, squeeze_scalar=True)

Compute the absolute maximum value of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

Note

Gradient computation is disabled as this function is never meant learning reduces amax

Parameters:
  • input – Input tensor

  • axis – The dimensions to reduce. None or int or tuple of ints. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).

  • keepdims – A boolean. If true, retains reduced dimensions with length 1. Default True

Returns:

The reduced tensor.

reduce_sum(input, axis=None, keepdims=True)

Compute the sum of a tensor along specified axes.

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

Note

Gradient computation is disabled as this function is never meant for learning.

Parameters:
  • input – Input tensor

  • axis – The dimensions to reduce. None or int or tuple of ints. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).

  • keepdims – A boolean. If true, retains reduced dimensions with length 1. Default True

Returns:

The reduced tensor.

replace_function(package, name, new_func, og_func_cache_name=None)

Replace a function with a new one within a context.

representative_weight_quantizer(module, weight_name='weight')

Return the representative weight quantizer for weight_name on module.

Handles two layouts:

  • singular <name>_weight_quantizer — standard nn.Linear / _QuantLinear.

  • plural <name>_weight_quantizers (nn.ModuleList) — fused-experts modules (_QuantFusedExperts) hold one TensorQuantizer per expert. Per-expert formats are identical, so the first element is representative.

Returns None if no matching quantizer is found.

Parameters:
  • module (Module)

  • weight_name (str)

update_quant_cfg_with_kv_cache_quant(quant_cfg, kv_cache_quant_cfg)

Update the quant_cfg with the kv cache quant_cfg.

Parameters:
  • quant_cfg (dict[str, Any]) – The outer quantization config dict (with "quant_cfg" and "algorithm" keys).

  • kv_cache_quant_cfg (list[QuantizerCfgEntry]) – A list of QuantizerCfgEntry dicts for KV cache quantization, typically some_kv_cfg["quant_cfg"].

Returns:

A deep copy of quant_cfg with the KV cache entries appended to quant_cfg["quant_cfg"].

Return type:

dict[str, Any]

weight_attr_names(module)

Get the weight param attribute names in a converted module, non-recursive.

Covers three layouts:

  • standard nn.Linear: weight + weight_quantizer.

  • custom per-weight quantizer (e.g. Llama4TextExperts with gate_up_proj + gate_up_proj_weight_quantizer).

  • fused-experts nn.ModuleList quantizers (_QuantFusedExperts with gate_up_proj + gate_up_proj_weight_quantizers plural list).

Parameters:

module (Module)

Return type:

Generator[str, None, None]