quant_aware_conversion

Quantization-aware reverse weight conversion for unified HF export.

Background

transformers may apply a conversion_mapping when loading a model, so the in-memory parameter names differ from the original model-hub checkpoint (e.g. fused mlp.gate_up_proj, renamed MoE leaves, reordered model/language_model prefix). On save, transformers reverses this via revert_weight_conversion so the on-disk names match the hub checkpoint again.

ModelOpt’s unified export disables that reverse (it raises IndexError on 0-d scalar scale tensors such as weight_scale_2/input_scale), so a quantized export emits the in-memory (post-conversion) names — violating the unified checkpoint contract that names stay aligned with the original hub checkpoint.

This module performs the reverse in a quantization-aware way: it carries each weight’s companion scale tensors (weight_scale, weight_scale_2, input_scale, weight_scale_inv, bias) through the rename and un-fuse operations.

Scope

Two reverse primitives cover the common conversion_mapping cases:

  • Rename — a key-level string substitution. Because a quantized linear stores every tensor under <module>.<leaf>, renaming the module substring rewrites the weight and all its scale siblings together with no tensor manipulation.

  • Split — un-fuse an output-dim concatenation (e.g. gate_up_proj -> gate_proj + up_proj). weight/weight_scale/weight_scale_inv/ bias are chunked along the fused (output) dim; 0-d scalar weight_scale_2/ input_scale are duplicated to each part (they are per-tensor and shared).

The 3-D stacked-expert case (MergeModulelist, where per-expert weights are stacked into experts.gate_up_proj with leading expert dim) is intentionally not handled here: the stacked-scalar-scale layout cannot be validated against a published checkpoint yet. Encountering it raises QuantConversionUnsupportedError so the caller can fall back to the legacy (in-memory-name) behavior rather than emit a silently-wrong checkpoint. See the module TODO.

Classes

RenameRule

Reverse of a WeightRenaming: re.sub(pattern, repl, key) on every key.

SplitRule

Reverse of an output-dim Concatenate: un-fuse one module into parts.

Functions

apply_reverse_rules

Apply quant-aware reverse conversion: splits first, then renames.

revert_weight_conversion_quant_aware

Reverse a transformers conversion_mapping on a quantized state dict.

exception QuantConversionUnsupportedError

Bases: Exception

Raised when a conversion op cannot be reversed quant-aware (caller falls back).

class RenameRule

Bases: object

Reverse of a WeightRenaming: re.sub(pattern, repl, key) on every key.

__init__(pattern, repl)
Parameters:
  • pattern (str)

  • repl (str)

Return type:

None

pattern: str
repl: str
class SplitRule

Bases: object

Reverse of an output-dim Concatenate: un-fuse one module into parts.

Parameters:
  • fused_suffix – module suffix of the fused tensor, e.g. ".gate_up_proj".

  • part_suffixes – ordered replacements, e.g. (".gate_proj", ".up_proj").

  • dim – the fused (output) dim along which weight/weight_scale/bias are chunked. NVFP4 weight is [out, in//2] and weight_scale is [out, in//block] so the output dim is 0 for both.

__init__(fused_suffix, part_suffixes, dim=0)
Parameters:
  • fused_suffix (str)

  • part_suffixes (tuple[str, ...])

  • dim (int)

Return type:

None

dim: int = 0
fused_suffix: str
part_suffixes: tuple[str, ...]
apply_reverse_rules(state_dict, split_rules, rename_rules)

Apply quant-aware reverse conversion: splits first, then renames.

Splits run on the in-memory (post-conversion) names; renames then map the resulting keys back to the original hub names. Renames are applied in order.

Parameters:
Return type:

dict[str, Tensor]

revert_weight_conversion_quant_aware(model, state_dict)

Reverse a transformers conversion_mapping on a quantized state dict.

Builds reverse rules from the model’s conversion mapping and applies them carrying companion scale tensors. Raises QuantConversionUnsupportedError when the mapping uses an op that cannot be reversed quant-aware yet, so the caller can fall back to the legacy behavior.

Parameters:

state_dict (dict[str, Tensor])