quant_aware_conversion
Quantization-aware reverse weight conversion for unified HF export.
Background
transformers may apply a conversion_mapping when loading a model, so the
in-memory parameter names differ from the original model-hub checkpoint (e.g. fused
mlp.gate_up_proj, renamed MoE leaves, reordered model/language_model
prefix). On save, transformers reverses this via revert_weight_conversion so
the on-disk names match the hub checkpoint again.
ModelOpt’s unified export disables that reverse (it raises IndexError on 0-d
scalar scale tensors such as weight_scale_2/input_scale), so a quantized
export emits the in-memory (post-conversion) names — violating the unified
checkpoint contract that names stay aligned with the original hub checkpoint.
This module performs the reverse in a quantization-aware way: it carries each
weight’s companion scale tensors (weight_scale, weight_scale_2,
input_scale, weight_scale_inv, bias) through the rename and un-fuse
operations.
Scope
Two reverse primitives cover the common conversion_mapping cases:
Rename — a key-level string substitution. Because a quantized linear stores every tensor under
<module>.<leaf>, renaming the module substring rewrites the weight and all its scale siblings together with no tensor manipulation.Split — un-fuse an output-dim concatenation (e.g.
gate_up_proj->gate_proj+up_proj).weight/weight_scale/weight_scale_inv/biasare chunked along the fused (output) dim; 0-d scalarweight_scale_2/input_scaleare duplicated to each part (they are per-tensor and shared).
The 3-D stacked-expert case (MergeModulelist, where per-expert weights are
stacked into experts.gate_up_proj with leading expert dim) is intentionally
not handled here: the stacked-scalar-scale layout cannot be validated against a
published checkpoint yet. Encountering it raises QuantConversionUnsupportedError
so the caller can fall back to the legacy (in-memory-name) behavior rather than
emit a silently-wrong checkpoint. See the module TODO.
Classes
Reverse of a |
|
Reverse of an output-dim |
Functions
Apply quant-aware reverse conversion: splits first, then renames. |
|
Reverse a transformers conversion_mapping on a quantized state dict. |
- exception QuantConversionUnsupportedError
Bases:
ExceptionRaised when a conversion op cannot be reversed quant-aware (caller falls back).
- class RenameRule
Bases:
objectReverse of a
WeightRenaming:re.sub(pattern, repl, key)on every key.- __init__(pattern, repl)
- Parameters:
pattern (str)
repl (str)
- Return type:
None
- pattern: str
- repl: str
- class SplitRule
Bases:
objectReverse of an output-dim
Concatenate: un-fuse one module intoparts.- Parameters:
fused_suffix – module suffix of the fused tensor, e.g.
".gate_up_proj".part_suffixes – ordered replacements, e.g.
(".gate_proj", ".up_proj").dim – the fused (output) dim along which
weight/weight_scale/biasare chunked. NVFP4weightis[out, in//2]andweight_scaleis[out, in//block]so the output dim is0for both.
- __init__(fused_suffix, part_suffixes, dim=0)
- Parameters:
fused_suffix (str)
part_suffixes (tuple[str, ...])
dim (int)
- Return type:
None
- dim: int = 0
- fused_suffix: str
- part_suffixes: tuple[str, ...]
- apply_reverse_rules(state_dict, split_rules, rename_rules)
Apply quant-aware reverse conversion: splits first, then renames.
Splits run on the in-memory (post-conversion) names; renames then map the resulting keys back to the original hub names. Renames are applied in order.
- Parameters:
state_dict (dict[str, Tensor])
split_rules (list[SplitRule])
rename_rules (list[RenameRule])
- Return type:
dict[str, Tensor]
- revert_weight_conversion_quant_aware(model, state_dict)
Reverse a transformers conversion_mapping on a quantized state dict.
Builds reverse rules from the model’s conversion mapping and applies them carrying companion scale tensors. Raises
QuantConversionUnsupportedErrorwhen the mapping uses an op that cannot be reversed quant-aware yet, so the caller can fall back to the legacy behavior.- Parameters:
state_dict (dict[str, Tensor])