graph_surgery
Graph surgery module for post-processing ONNX models.
This module provides utilities for performing graph-level transformations on ONNX models after export. Common use cases include:
Replacing standard attention patterns with GroupQueryAttention (GQA) for LLMs
Adding cross-attention KV cache outputs to encoder models
Converting model precision (e.g., FP16 to BF16)
Transposing DequantizeLinear weights for column-major storage optimization
Graph cleanup and optimization
- Example usage:
>>> from modelopt.onnx.graph_surgery import ( ... replace_attention_with_gqa, ... convert_fp16_to_bf16, ... transpose_dequantize_linear_weights, ... add_cross_kv_to_encoder, ... ) >>> # Replace attention with GQA for LLMs (FP16 model) >>> replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... io_dtype="float16", ... ) >>> # Replace attention with GQA and convert to BF16 in one step >>> replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa_bf16.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... io_dtype="bfloat16", # Automatically converts FP16 to BF16 ... ) >>> # Add cross-attention KV cache outputs to encoder (GenAI compatible) >>> add_cross_kv_to_encoder( ... model_path="encoder_model.onnx", ... output_path="encoder_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... ) >>> # Standalone FP16 to BF16 conversion >>> convert_fp16_to_bf16( ... model_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> >>> # Transpose DequantizeLinear weights for column-major storage >>> transpose_dequantize_linear_weights( ... model_path="model_quantized.onnx", ... output_path="model_quantized_transposed.onnx", ... )
Functions
Add cross-attention KV cache outputs to encoder model. |
|
Convert an ONNX model from FP16 to BF16. |
|
Return a list of all registered graph surgery names. |
|
Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model. |
|
Run a graph surgery by name. |
|
Transpose weights in DequantizeLinear nodes for column-major storage. |
- add_cross_kv_to_encoder(model_path, output_path, hf_model_id, hidden_state_output_name='last_hidden_state', rename_input_features=True, use_external_data=True, external_data_name=None, decoder_filename='decoder_with_past_model.onnx', generate_genai_config=True, provider='cuda', verbose=True, trust_remote_code=False)
Add cross-attention KV cache outputs to encoder model.
This function transforms an Optimum-exported encoder model by adding cross-attention Key/Value projection outputs. This is required for ONNX Runtime GenAI compatibility where the decoder expects pre-computed encoder K/V caches.
The transformation: 1. Renames input_features -> audio_features (optional) 2. Renames last_hidden_state -> encoder_hidden_states 3. Adds K/V projection weights from HuggingFace model 4. Adds MatMul -> Reshape -> Transpose subgraph for each layer 5. Adds outputs: present_key_cross_X, present_value_cross_X 6. Generates genai_config.json and audio_processor_config.json (optional)
- Parameters:
model_path (str) – Path to encoder ONNX model.
output_path (str) – Path to save modified encoder.
hf_model_id (str) – HuggingFace model ID for loading cross-attention weights.
hidden_state_output_name (str) – Name of encoder hidden state output.
rename_input_features (bool) – Whether to rename input_features to audio_features.
use_external_data (bool) – Whether to save weights as external data.
external_data_name (str | None) – Name for external data file.
decoder_filename (str) – Filename for decoder model in genai_config.json. Default is “decoder_with_past_model.onnx”.
generate_genai_config (bool) – Whether to generate genai_config.json.
provider (str) – Execution provider for genai_config.json (“cuda”, “cpu”, “NvTensorRtRtx”).
verbose (bool) – Whether to print progress messages.
trust_remote_code (bool) – Whether to trust remote code in HuggingFace model.
- Returns:
Modified encoder model with cross-attention KV cache outputs.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder >>> model = add_cross_kv_to_encoder( ... model_path="encoder_model.onnx", ... output_path="encoder_model_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... )
- convert_fp16_to_bf16(model_path, output_path, external_data=True, verbose=True)
Convert an ONNX model from FP16 to BF16.
This function converts: 1. All FP16 initializers (weights) to BF16 2. All FP16 value_info (intermediate tensors) to BF16 3. All FP16 graph inputs/outputs to BF16 4. All Cast nodes that target FP16 to target BF16
- Parameters:
model_path (str) – Path to input FP16 ONNX model.
output_path (str) – Path to output BF16 ONNX model.
external_data (bool) – Whether to save weights as external data.
verbose (bool) – Whether to print progress messages.
- Returns:
Dictionary with conversion statistics.
- Return type:
dict[str, int]
Example
>>> stats = convert_fp16_to_bf16( ... model_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> logger.info(f"Converted {stats['initializers_converted']} initializers")
- get_available_surgeries()
Return a list of all registered graph surgery names.
- Return type:
list[str]
- replace_attention_with_gqa(model_path, output_path, hf_model_id, max_seq_len=4096, io_dtype='float16', use_external_data=True, external_data_name=None, ir_version=None, pack_qkv=False, verbose=True, trust_remote_code=False)
Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model.
This function transforms an ONNX model exported from HuggingFace/Optimum to use Microsoft’s GroupQueryAttention operator, which is optimized for inference with ONNX Runtime.
The transformation includes: - Converting weights to target dtype (FP16/BF16) [non-quantized models only] - Adding RoPE cos/sin caches - Replacing attention patterns with GQA for all layers - Fusing Q/K/V projections into single MatMul [non-quantized models only] - Concatenating Q/K/V outputs for GQA [quantized models only] - Adding past/present KV cache inputs/outputs
- Parameters:
model_path (str) – Path to input ONNX model.
output_path (str) – Path to save modified model.
hf_model_id (str) – HuggingFace model ID for config.
max_seq_len (int) – Maximum sequence length for caches.
io_dtype (str) – Data type for I/O tensors (“float16”, “float32”, or “bfloat16”). If the model has FP16 initializers and “bfloat16” is specified, they are automatically converted to BF16.
use_external_data (bool) – Save weights as external data file.
external_data_name (str | None) – Name for external data file (default: model.onnx_data).
ir_version (int | None) – If specified, set the ONNX IR version to this value. Useful for compatibility with older ONNX Runtime versions (e.g., set to 9 for ORT 1.16).
verbose (bool) – Whether to print progress messages.
trust_remote_code (bool) – Whether to trust remote code in HuggingFace model config.
pack_qkv (bool)
- Returns:
Modified ONNX model.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import replace_attention_with_gqa >>> model = replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... max_seq_len=4096, ... io_dtype="float16", ... )
- run_graph_surgery(surgery_name, model_path, output_path, **kwargs)
Run a graph surgery by name.
This is the unified entry point for all graph surgeries. It dispatches to the appropriate surgery function based on the surgery name.
When new surgeries are added to the registry, they are automatically available through this function without any changes to calling code.
- Parameters:
surgery_name (str) – Name of the surgery to run (e.g. ‘replace-gqa’, ‘transpose-dq’). Use get_available_surgeries() to see all available options.
model_path (str) – Path to the input ONNX model.
output_path (str) – Path to save the output ONNX model.
**kwargs – Surgery-specific parameters. Passed directly to the surgery function.
- Returns:
The return value of the surgery function (typically ModelProto or dict).
- Raises:
ValueError – If surgery_name is not registered.
Example
>>> from modelopt.onnx.graph_surgery import run_graph_surgery, get_available_surgeries >>> print(get_available_surgeries()) ['replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq'] >>> run_graph_surgery( ... "replace-gqa", ... model_path="model.onnx", ... output_path="model_gqa.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... )
- transpose_dequantize_linear_weights(model_path, output_path, use_external_data=True, external_data_name=None, verbose=True)
Transpose weights in DequantizeLinear nodes for column-major storage.
This function transforms a quantized ONNX model by: 1. Finding all DequantizeLinear nodes that feed into MatMul/Gemm 2. Transposing the quantized weights, scales, and zero points 3. Updating the axis attribute (0 -> 1) 4. Adding Transpose nodes after DequantizeLinear to recover original shape
This optimization is useful for backends that prefer column-major weight storage, such as NvTensorRtRtx.
- Parameters:
model_path (str) – Path to input quantized ONNX model.
output_path (str) – Path to save modified model.
use_external_data (bool) – Whether to save weights as external data.
external_data_name (str | None) – Name for external data file.
verbose (bool) – Whether to print progress messages.
- Returns:
Modified ONNX model.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import transpose_dequantize_linear_weights >>> model = transpose_dequantize_linear_weights( ... model_path="model_quantized.onnx", ... output_path="model_quantized_transposed.onnx", ... )