graph_surgery

Graph surgery module for post-processing ONNX models.

This module provides utilities for performing graph-level transformations on ONNX models after export. Common use cases include:

  • Replacing standard attention patterns with GroupQueryAttention (GQA) for LLMs

  • Adding cross-attention KV cache outputs to encoder models

  • Converting model precision (e.g., FP16 to BF16)

  • Transposing DequantizeLinear weights for column-major storage optimization

  • Graph cleanup and optimization

Example usage:
>>> from modelopt.onnx.graph_surgery import (
...     replace_attention_with_gqa,
...     convert_fp16_to_bf16,
...     transpose_dequantize_linear_weights,
...     add_cross_kv_to_encoder,
... )
>>> # Replace attention with GQA for LLMs (FP16 model)
>>> replace_attention_with_gqa(
...     model_path="model_fp16.onnx",
...     output_path="model_gqa.onnx",
...     hf_model_id="meta-llama/Llama-2-7b-hf",
...     io_dtype="float16",
... )
>>> # Replace attention with GQA and convert to BF16 in one step
>>> replace_attention_with_gqa(
...     model_path="model_fp16.onnx",
...     output_path="model_gqa_bf16.onnx",
...     hf_model_id="meta-llama/Llama-2-7b-hf",
...     io_dtype="bfloat16",  # Automatically converts FP16 to BF16
... )
>>> # Add cross-attention KV cache outputs to encoder (GenAI compatible)
>>> add_cross_kv_to_encoder(
...     encoder_path="encoder_model.onnx",
...     output_path="encoder_with_kv.onnx",
...     hf_model_id="openai/whisper-large-v3-turbo",
... )
>>> # Standalone FP16 to BF16 conversion
>>> convert_fp16_to_bf16(
...     input_path="model_fp16.onnx",
...     output_path="model_bf16.onnx",
... )
>>>
>>> # Transpose DequantizeLinear weights for column-major storage
>>> transpose_dequantize_linear_weights(
...     model_path="model_quantized.onnx",
...     output_path="model_quantized_transposed.onnx",
... )

Functions

add_cross_kv_to_encoder

Add cross-attention KV cache outputs to encoder model.

convert_fp16_to_bf16

Convert an ONNX model from FP16 to BF16.

replace_attention_with_gqa

Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model.

transpose_dequantize_linear_weights

Transpose weights in DequantizeLinear nodes for column-major storage.

add_cross_kv_to_encoder(encoder_path, output_path, hf_model_id, hidden_state_output_name='last_hidden_state', rename_input_features=True, use_external_data=True, external_data_name=None, decoder_filename='decoder_with_past_model.onnx', generate_genai_config=True, provider='cuda', verbose=True, trust_remote_code=False)

Add cross-attention KV cache outputs to encoder model.

This function transforms an Optimum-exported encoder model by adding cross-attention Key/Value projection outputs. This is required for ONNX Runtime GenAI compatibility where the decoder expects pre-computed encoder K/V caches.

The transformation: 1. Renames input_features -> audio_features (optional) 2. Renames last_hidden_state -> encoder_hidden_states 3. Adds K/V projection weights from HuggingFace model 4. Adds MatMul -> Reshape -> Transpose subgraph for each layer 5. Adds outputs: present_key_cross_X, present_value_cross_X 6. Generates genai_config.json and audio_processor_config.json (optional)

Parameters:
  • encoder_path (str) – Path to encoder ONNX model.

  • output_path (str) – Path to save modified encoder.

  • hf_model_id (str) – HuggingFace model ID for loading cross-attention weights.

  • hidden_state_output_name (str) – Name of encoder hidden state output.

  • rename_input_features (bool) – Whether to rename input_features to audio_features.

  • use_external_data (bool) – Whether to save weights as external data.

  • external_data_name (str | None) – Name for external data file.

  • decoder_filename (str) – Filename for decoder model in genai_config.json. Default is “decoder_with_past_model.onnx”.

  • generate_genai_config (bool) – Whether to generate genai_config.json.

  • provider (str) – Execution provider for genai_config.json (“cuda”, “cpu”, “NvTensorRtRtx”).

  • verbose (bool) – Whether to print progress messages.

  • trust_remote_code (bool) – Whether to trust remote code in HuggingFace model.

Returns:

Modified encoder model with cross-attention KV cache outputs.

Return type:

ModelProto

Example

>>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder
>>> model = add_cross_kv_to_encoder(
...     encoder_path="encoder_model.onnx",
...     output_path="encoder_model_with_kv.onnx",
...     hf_model_id="openai/whisper-large-v3-turbo",
... )
convert_fp16_to_bf16(input_path, output_path, external_data=True, verbose=True)

Convert an ONNX model from FP16 to BF16.

This function converts: 1. All FP16 initializers (weights) to BF16 2. All FP16 value_info (intermediate tensors) to BF16 3. All FP16 graph inputs/outputs to BF16 4. All Cast nodes that target FP16 to target BF16

Parameters:
  • input_path (str) – Path to input FP16 ONNX model.

  • output_path (str) – Path to output BF16 ONNX model.

  • external_data (bool) – Whether to save weights as external data.

  • verbose (bool) – Whether to print progress messages.

Returns:

Dictionary with conversion statistics.

Return type:

dict[str, int]

Example

>>> stats = convert_fp16_to_bf16(
...     input_path="model_fp16.onnx",
...     output_path="model_bf16.onnx",
... )
>>> logger.info(f"Converted {stats['initializers_converted']} initializers")
replace_attention_with_gqa(model_path, output_path, hf_model_id, max_seq_len=4096, io_dtype='float16', use_external_data=True, external_data_name=None, ir_version=None, pack_qkv=False, verbose=True, trust_remote_code=False)

Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model.

This function transforms an ONNX model exported from HuggingFace/Optimum to use Microsoft’s GroupQueryAttention operator, which is optimized for inference with ONNX Runtime.

The transformation includes: - Converting weights to target dtype (FP16/BF16) [non-quantized models only] - Adding RoPE cos/sin caches - Replacing attention patterns with GQA for all layers - Fusing Q/K/V projections into single MatMul [non-quantized models only] - Concatenating Q/K/V outputs for GQA [quantized models only] - Adding past/present KV cache inputs/outputs

Parameters:
  • model_path (str) – Path to input ONNX model.

  • output_path (str) – Path to save modified model.

  • hf_model_id (str) – HuggingFace model ID for config.

  • max_seq_len (int) – Maximum sequence length for caches.

  • io_dtype (str) – Data type for I/O tensors (“float16”, “float32”, or “bfloat16”). If the model has FP16 initializers and “bfloat16” is specified, they are automatically converted to BF16.

  • use_external_data (bool) – Save weights as external data file.

  • external_data_name (str | None) – Name for external data file (default: model.onnx_data).

  • ir_version (int | None) – If specified, set the ONNX IR version to this value. Useful for compatibility with older ONNX Runtime versions (e.g., set to 9 for ORT 1.16).

  • verbose (bool) – Whether to print progress messages.

  • trust_remote_code (bool) – Whether to trust remote code in HuggingFace model config.

  • pack_qkv (bool)

Returns:

Modified ONNX model.

Return type:

ModelProto

Example

>>> from modelopt.onnx.graph_surgery import replace_attention_with_gqa
>>> model = replace_attention_with_gqa(
...     model_path="model_fp16.onnx",
...     output_path="model_gqa.onnx",
...     hf_model_id="meta-llama/Llama-2-7b-hf",
...     max_seq_len=4096,
...     io_dtype="float16",
... )
transpose_dequantize_linear_weights(model_path, output_path, use_external_data=True, external_data_name=None, verbose=True)

Transpose weights in DequantizeLinear nodes for column-major storage.

This function transforms a quantized ONNX model by: 1. Finding all DequantizeLinear nodes that feed into MatMul/Gemm 2. Transposing the quantized weights, scales, and zero points 3. Updating the axis attribute (0 -> 1) 4. Adding Transpose nodes after DequantizeLinear to recover original shape

This optimization is useful for backends that prefer column-major weight storage, such as NvTensorRtRtx.

Parameters:
  • model_path (str) – Path to input quantized ONNX model.

  • output_path (str) – Path to save modified model.

  • use_external_data (bool) – Whether to save weights as external data.

  • external_data_name (str | None) – Name for external data file.

  • verbose (bool) – Whether to print progress messages.

Returns:

Modified ONNX model.

Return type:

ModelProto

Example

>>> from modelopt.onnx.graph_surgery import transpose_dequantize_linear_weights
>>> model = transpose_dequantize_linear_weights(
...     model_path="model_quantized.onnx",
...     output_path="model_quantized_transposed.onnx",
... )