graph_surgery
Graph surgery module for post-processing ONNX models.
This module provides utilities for performing graph-level transformations on ONNX models after export. Common use cases include:
Replacing standard attention patterns with GroupQueryAttention (GQA) for LLMs
Adding cross-attention KV cache outputs to encoder models
Converting model precision (e.g., FP16 to BF16)
Transposing DequantizeLinear weights for column-major storage optimization
Graph cleanup and optimization
- Example usage:
>>> from modelopt.onnx.graph_surgery import ( ... replace_attention_with_gqa, ... convert_fp16_to_bf16, ... transpose_dequantize_linear_weights, ... add_cross_kv_to_encoder, ... ) >>> # Replace attention with GQA for LLMs (FP16 model) >>> replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... io_dtype="float16", ... ) >>> # Replace attention with GQA and convert to BF16 in one step >>> replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa_bf16.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... io_dtype="bfloat16", # Automatically converts FP16 to BF16 ... ) >>> # Add cross-attention KV cache outputs to encoder (GenAI compatible) >>> add_cross_kv_to_encoder( ... encoder_path="encoder_model.onnx", ... output_path="encoder_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... ) >>> # Standalone FP16 to BF16 conversion >>> convert_fp16_to_bf16( ... input_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> >>> # Transpose DequantizeLinear weights for column-major storage >>> transpose_dequantize_linear_weights( ... model_path="model_quantized.onnx", ... output_path="model_quantized_transposed.onnx", ... )
Functions
Add cross-attention KV cache outputs to encoder model. |
|
Convert an ONNX model from FP16 to BF16. |
|
Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model. |
|
Transpose weights in DequantizeLinear nodes for column-major storage. |
- add_cross_kv_to_encoder(encoder_path, output_path, hf_model_id, hidden_state_output_name='last_hidden_state', rename_input_features=True, use_external_data=True, external_data_name=None, decoder_filename='decoder_with_past_model.onnx', generate_genai_config=True, provider='cuda', verbose=True, trust_remote_code=False)
Add cross-attention KV cache outputs to encoder model.
This function transforms an Optimum-exported encoder model by adding cross-attention Key/Value projection outputs. This is required for ONNX Runtime GenAI compatibility where the decoder expects pre-computed encoder K/V caches.
The transformation: 1. Renames input_features -> audio_features (optional) 2. Renames last_hidden_state -> encoder_hidden_states 3. Adds K/V projection weights from HuggingFace model 4. Adds MatMul -> Reshape -> Transpose subgraph for each layer 5. Adds outputs: present_key_cross_X, present_value_cross_X 6. Generates genai_config.json and audio_processor_config.json (optional)
- Parameters:
encoder_path (str) – Path to encoder ONNX model.
output_path (str) – Path to save modified encoder.
hf_model_id (str) – HuggingFace model ID for loading cross-attention weights.
hidden_state_output_name (str) – Name of encoder hidden state output.
rename_input_features (bool) – Whether to rename input_features to audio_features.
use_external_data (bool) – Whether to save weights as external data.
external_data_name (str | None) – Name for external data file.
decoder_filename (str) – Filename for decoder model in genai_config.json. Default is “decoder_with_past_model.onnx”.
generate_genai_config (bool) – Whether to generate genai_config.json.
provider (str) – Execution provider for genai_config.json (“cuda”, “cpu”, “NvTensorRtRtx”).
verbose (bool) – Whether to print progress messages.
trust_remote_code (bool) – Whether to trust remote code in HuggingFace model.
- Returns:
Modified encoder model with cross-attention KV cache outputs.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder >>> model = add_cross_kv_to_encoder( ... encoder_path="encoder_model.onnx", ... output_path="encoder_model_with_kv.onnx", ... hf_model_id="openai/whisper-large-v3-turbo", ... )
- convert_fp16_to_bf16(input_path, output_path, external_data=True, verbose=True)
Convert an ONNX model from FP16 to BF16.
This function converts: 1. All FP16 initializers (weights) to BF16 2. All FP16 value_info (intermediate tensors) to BF16 3. All FP16 graph inputs/outputs to BF16 4. All Cast nodes that target FP16 to target BF16
- Parameters:
input_path (str) – Path to input FP16 ONNX model.
output_path (str) – Path to output BF16 ONNX model.
external_data (bool) – Whether to save weights as external data.
verbose (bool) – Whether to print progress messages.
- Returns:
Dictionary with conversion statistics.
- Return type:
dict[str, int]
Example
>>> stats = convert_fp16_to_bf16( ... input_path="model_fp16.onnx", ... output_path="model_bf16.onnx", ... ) >>> logger.info(f"Converted {stats['initializers_converted']} initializers")
- replace_attention_with_gqa(model_path, output_path, hf_model_id, max_seq_len=4096, io_dtype='float16', use_external_data=True, external_data_name=None, ir_version=None, pack_qkv=False, verbose=True, trust_remote_code=False)
Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model.
This function transforms an ONNX model exported from HuggingFace/Optimum to use Microsoft’s GroupQueryAttention operator, which is optimized for inference with ONNX Runtime.
The transformation includes: - Converting weights to target dtype (FP16/BF16) [non-quantized models only] - Adding RoPE cos/sin caches - Replacing attention patterns with GQA for all layers - Fusing Q/K/V projections into single MatMul [non-quantized models only] - Concatenating Q/K/V outputs for GQA [quantized models only] - Adding past/present KV cache inputs/outputs
- Parameters:
model_path (str) – Path to input ONNX model.
output_path (str) – Path to save modified model.
hf_model_id (str) – HuggingFace model ID for config.
max_seq_len (int) – Maximum sequence length for caches.
io_dtype (str) – Data type for I/O tensors (“float16”, “float32”, or “bfloat16”). If the model has FP16 initializers and “bfloat16” is specified, they are automatically converted to BF16.
use_external_data (bool) – Save weights as external data file.
external_data_name (str | None) – Name for external data file (default: model.onnx_data).
ir_version (int | None) – If specified, set the ONNX IR version to this value. Useful for compatibility with older ONNX Runtime versions (e.g., set to 9 for ORT 1.16).
verbose (bool) – Whether to print progress messages.
trust_remote_code (bool) – Whether to trust remote code in HuggingFace model config.
pack_qkv (bool)
- Returns:
Modified ONNX model.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import replace_attention_with_gqa >>> model = replace_attention_with_gqa( ... model_path="model_fp16.onnx", ... output_path="model_gqa.onnx", ... hf_model_id="meta-llama/Llama-2-7b-hf", ... max_seq_len=4096, ... io_dtype="float16", ... )
- transpose_dequantize_linear_weights(model_path, output_path, use_external_data=True, external_data_name=None, verbose=True)
Transpose weights in DequantizeLinear nodes for column-major storage.
This function transforms a quantized ONNX model by: 1. Finding all DequantizeLinear nodes that feed into MatMul/Gemm 2. Transposing the quantized weights, scales, and zero points 3. Updating the axis attribute (0 -> 1) 4. Adding Transpose nodes after DequantizeLinear to recover original shape
This optimization is useful for backends that prefer column-major weight storage, such as NvTensorRtRtx.
- Parameters:
model_path (str) – Path to input quantized ONNX model.
output_path (str) – Path to save modified model.
use_external_data (bool) – Whether to save weights as external data.
external_data_name (str | None) – Name for external data file.
verbose (bool) – Whether to print progress messages.
- Returns:
Modified ONNX model.
- Return type:
ModelProto
Example
>>> from modelopt.onnx.graph_surgery import transpose_dequantize_linear_weights >>> model = transpose_dequantize_linear_weights( ... model_path="model_quantized.onnx", ... output_path="model_quantized_transposed.onnx", ... )