utils
Utility functions related to onnx.
Functions
Converts a bfloat16 array (as raw data) to a float32 array. |
|
Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead. |
|
Checks if the given model is valid. |
|
Checks if the model uses external data. |
|
Clear stale type/shape metadata that would otherwise trip ORT's type checker. |
|
Duplicate constant tensors if they are shared. |
|
Function to find the lowest common ancestor of two nodes. |
|
Remove Cast(FP32->FP16) nodes after DequantizeLinear by setting DQ output to FP16. |
|
Remove |
|
Remove Cast(FP16->FP32) nodes feeding into Q/DQ scale inputs. |
|
This function generates random inputs for an onnx model. |
|
This function returns the inputs names of the given onnx model. |
|
Returns the value of the specified attribute. |
|
Returns the batch size of the given onnx model. |
|
Returns the batch size of the given onnx model. |
|
Get the target type from a Cast node. |
|
Returns list of output consumer nodes for the given node. |
|
Get all consumer nodes for a given tensor name. |
|
This function returns the dynamic inputs of an ONNX model. |
|
This function returns the external inputs names of the given onnx model. |
|
This function returns the inputs names of the given onnx model in bytes. |
|
This function returns the inputs shapes for the given onnx model. |
|
This function returns the input shapes of the given onnx model in bytes. |
|
Gets the minimum required opset version for a set of Q/DQ precision types. |
|
This function returns all node names from the given onnx model. |
|
This function returns all node names from the given onnx model in bytes. |
|
Returns the opset version of the given model. |
|
This function returns the output names of the given onnx model. |
|
This function returns the output names of the given onnx model in bytes. |
|
This function returns the output shapes for the given onnx model. |
|
Returns list of input producer nodes for the given node. |
|
Get all producer nodes for a given tensor name. |
|
Gets the Q/DQ precision types present in the model. |
|
This function returns a tensor from its name. |
|
Returns the variable inputs of the given Node. |
|
Checks if the given node has the specified attribute. |
|
Infers shapes of the onnx graph, handles large models. |
|
Infers types (and optionally shapes) based on the use_standalone_type_inference flag. |
|
Verify that all reachable tensors have a defined type. |
|
Assigns name to the onnx nodes if not present and return the modified status. |
|
Converts ONNX type in string format to onnx.TensorProto format. |
|
Parse shapes spec and returns them in a dictionary. |
|
Assigns random values to the onnx model weights. |
|
Assigns random values to the onnx model weights. |
|
Reads a float16 or bfloat16 tensor as a float32 numpy ndarray. |
|
Remove training_mode attribute and extra training outputs from nodes of a given op type. |
|
Removes both sequential casts and casts that don't change precision. |
|
Removes raw weight data from the onnx model. |
|
Save an ONNX model to given path. |
|
Saves the onnx bytes to a directory with specified file name. |
|
Updates the domain of all the nodes of the specified op_type to the specified domain. |
|
Returns True if all the model inputs has batch dimension equal to batch_size. |
|
Returns True if the onnx_bytes is valid, else False. |
- bfloat16_to_float32(bf16_array)
Converts a bfloat16 array (as raw data) to a float32 array.
- change_casts_to_fp16(model, target_op_types)
Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead.
- Parameters:
model (ModelProto) – The ONNX model to modify.
target_op_types (list[str]) – List of op types to check for. Cast nodes feeding exclusively into these will be changed from FP32 to FP16.
- Returns:
The modified ONNX model with Cast nodes updated.
- Return type:
ModelProto
- check_model(model)
Checks if the given model is valid.
- Parameters:
model (ModelProto)
- Return type:
None
- check_model_uses_external_data(model)
Checks if the model uses external data. True if any initializer tensor has data_location set to EXTERNAL.
- Parameters:
model (ModelProto)
- Return type:
bool
- clear_stale_value_info(model)
Clear stale type/shape metadata that would otherwise trip ORT’s type checker.
Walks every
Castnode and forces theelem_typeof anygraph.outputentry produced by that Cast to match the Cast’stoattribute (the spec-defined contract for a Cast’s output dtype). Clearsvalue_infowholesale so ORT/shape-inference re-derives intermediate-tensor types from the operator graph during session setup. Finally, reconciles stalegraph.outputshapes (e.g. a leftover rank-0 scalar on a tensor that is really rank-2+) which would otherwise propagate a wrong rank into downstream shape inference.- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
- Returns:
Total number of entries reconciled or cleared.
- Return type:
int
Duplicate constant tensors if they are shared.
- Parameters:
onnx_model (ModelProto)
- Return type:
tuple[ModelProto, bool]
- find_lowest_common_ancestor(node1, node2)
Function to find the lowest common ancestor of two nodes.
- Parameters:
node1 (Node) – First node name.
node2 (Node) – Second node name.
- Returns:
LCA node. Distance from first node. Distance from second node.
- Return type:
tuple[str | None, int, int]
- fold_dq_fp32_to_fp16_casts(onnx_model)
Remove Cast(FP32->FP16) nodes after DequantizeLinear by setting DQ output to FP16.
When convert_float_to_float16 blocks DequantizeLinear, it inserts Cast nodes to bridge the FP32 DQ output to the FP16 graph. This function removes those Cast nodes by: 1. Converting the DQ scale initializer from FP32 to FP16 2. Updating the DQ output type to FP16 in value_info 3. Bypassing and removing the Cast node
NVFP4 uses a nested DQ chain (scale is itself a DQ output). When the outer DQ’s scale is produced by another DQ, recursively retype the inner DQ’s chain so the whole chain produces FP16 tensors under strongly-typed TRT parsing.
- Parameters:
onnx_model (ModelProto) – The ONNX model with DQ -> Cast(FP32->FP16) patterns.
- Returns:
The ONNX model with Cast nodes removed and DQ outputs set to FP16.
- Return type:
ModelProto
- fold_q_fp16_to_fp32_casts(onnx_model)
Remove
Cast(FP16→FP32) → Qpatterns inserted byconvert_float_to_float16.The Q scale is rewritten to FP16 so Q consumes the FP16 graph directly. Skipped for opsets below
BASE_MIN_OPSETsince FP16 Q scales require opset >= 19.- Parameters:
onnx_model (ModelProto)
- Return type:
ModelProto
- fold_qdq_scale_fp16_to_fp32_casts(onnx_model)
Remove Cast(FP16->FP32) nodes feeding into Q/DQ scale inputs.
When convert_float_to_float16 blocks QuantizeLinear/DequantizeLinear, it inserts Cast(FP16->FP32) nodes before every scale input. In opset >=20 Q/DQ natively accept FP16 scales, and leaving the cast in place forces DQ outputs to FP32, breaking downstream FP16 matmul/add operations under strongly-typed TRT parsing.
This function bypasses each such Cast and, when the upstream Constant is FP16, wires the DQ output to FP16 in value_info so shape inference stays consistent.
- Parameters:
onnx_model (ModelProto) – The ONNX model with Cast(FP16->FP32) -> Q/DQ.scale patterns.
- Returns:
The ONNX model with redundant scale-path casts removed.
- Return type:
ModelProto
- gen_random_inputs(model, shapes_spec=None)
This function generates random inputs for an onnx model.
- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
shapes_spec (str | None) – A string representing the shape of each input tensors. The format is
"<tensor1> – <d1>x<d2>,<tensor2>:<d1>,…”. If the shape is not provided for an input tensor, the shape is
directly (inferred from the onnx model)
1. (with all the unknown dims filled with)
- Returns:
Dictionary of numpy tensors.
- Return type:
dict[str, ndarray]
- get_all_input_names(model)
This function returns the inputs names of the given onnx model.
- Parameters:
model (ModelProto)
- Return type:
list[str]
- get_attribute(node, attr_name)
Returns the value of the specified attribute.
- Parameters:
node (NodeProto)
attr_name (str)
- Return type:
Any
- get_batch_size(model)
Returns the batch size of the given onnx model.
Assertion will fail if batch size is not same over all the inputs.
- Parameters:
model (ModelProto)
- Return type:
int
- get_batch_size_from_bytes(onnx_bytes)
Returns the batch size of the given onnx model.
Assertion will fail if batch size is not same over all the inputs.
- Parameters:
onnx_bytes (bytes)
- Return type:
int
- get_cast_to_type(cast_node)
Get the target type from a Cast node.
- Parameters:
cast_node (NodeProto) – The Cast node to extract type from.
- Returns:
The target type value from the Cast node’s ‘to’ attribute.
- Return type:
int
- Raises:
ValueError – If the Cast node does not have a ‘to’ attribute.
- get_child_nodes(node)
Returns list of output consumer nodes for the given node.
- Parameters:
node (Node)
- Return type:
list[Node]
- get_consumer_nodes(model, tensor_name)
Get all consumer nodes for a given tensor name.
- Parameters:
model (ModelProto) – The ONNX model to search.
tensor_name (str) – Name of the tensor to find consumers for.
- Returns:
List of nodes that consume the tensor.
- Return type:
list[onnx.NodeProto]
- get_dynamic_graph_inputs(onnx_model)
This function returns the dynamic inputs of an ONNX model.
- Parameters:
onnx_model (ModelProto) – ONNX model to obtain dynamic inputs from.
- Returns:
List of dynamic inputs.
- get_input_names(model, external_inputs_only=True)
This function returns the external inputs names of the given onnx model.
Note: external_input_names = input_names - initializer_names
- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
external_inputs_only (bool)
- Returns:
List of external input names of the model.
- Return type:
list[str]
- get_input_names_from_bytes(model_bytes, external_inputs_only=True)
This function returns the inputs names of the given onnx model in bytes.
- Parameters:
model_bytes (bytes) – Onnx model in bytes.
external_inputs_only (bool)
- Returns:
List of input names of the model.
- Return type:
list[str]
- get_input_shapes(model, external_inputs_only=True)
This function returns the inputs shapes for the given onnx model.
- Parameters:
model (ModelProto)
external_inputs_only (bool)
- Return type:
dict[str, list[int]]
- get_input_shapes_from_bytes(model_bytes)
This function returns the input shapes of the given onnx model in bytes.
- Parameters:
model_bytes (bytes) – Onnx model in bytes.
- Returns:
Dictionary of inputs names and shapes.
- Return type:
dict[str, list[int]]
- get_min_opset_for_precisions(precisions)
Gets the minimum required opset version for a set of Q/DQ precision types.
- Parameters:
precisions (set) – Set of precision type strings (e.g., ‘float8_e4m3fn’, ‘int4’).
- Returns:
Minimum required opset version for the given precisions.
- Return type:
int
- get_node_names(model)
This function returns all node names from the given onnx model.
- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
- Returns:
List of node names of the model.
- Return type:
list[str]
- get_node_names_from_bytes(model_bytes)
This function returns all node names from the given onnx model in bytes.
- Parameters:
model – onnx model in bytes.
model_bytes (bytes)
- Returns:
List of node names of the model.
- Return type:
list[str]
- get_opset_version(model)
Returns the opset version of the given model.
- Parameters:
model (ModelProto)
- Return type:
int
- get_output_names(model)
This function returns the output names of the given onnx model.
- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
- Returns:
List of output names of the model.
- Return type:
list[str]
- get_output_names_from_bytes(model_bytes)
This function returns the output names of the given onnx model in bytes.
- Parameters:
model_bytes (bytes) – Onnx model in bytes.
- Returns:
List of output names of the model.
- Return type:
list[str]
- get_output_shapes(model)
This function returns the output shapes for the given onnx model.
- Parameters:
model (ModelProto)
- Return type:
dict[str, list[int]]
- get_parent_nodes(node)
Returns list of input producer nodes for the given node.
- Parameters:
node (Node)
- Return type:
list[Node]
- get_producer_nodes(model, tensor_name)
Get all producer nodes for a given tensor name.
- Parameters:
model (ModelProto) – The ONNX model to search.
tensor_name (str) – Name of the tensor to find producers for.
- Returns:
List of nodes that produce the tensor.
- Return type:
list[onnx.NodeProto]
- get_qdq_precisions(model)
Gets the Q/DQ precision types present in the model.
- Parameters:
model (ModelProto) – Loaded in-memory onnx ModelProto.
- Returns:
- Set of Q/DQ precision types present in the model (e.g., ‘float8_e4m3fn’, ‘int8’,
’int4’, ‘float4_e2m1fn’).
- Return type:
set
- get_tensor_by_name(onnx_model, tensor_name)
This function returns a tensor from its name.
This function searches for a tensor in the model’s: 1. Value info (shape/type info, no data) 2. Initializers (TensorProto, contains actual data) 3. Inputs and outputs
- Parameters:
onnx_model (ModelProto) – ONNX model.
tensor_name (str) – tensor name.
- Returns:
tensor
- Return type:
ValueInfoProto | TensorProto | None
- get_variable_inputs(node)
Returns the variable inputs of the given Node.
- Parameters:
node (Node)
- Return type:
list[Variable]
- has_attribute(node, attr_name)
Checks if the given node has the specified attribute.
- Parameters:
node (NodeProto)
attr_name (str)
- Return type:
bool
- infer_shapes(model, **kwargs)
Infers shapes of the onnx graph, handles large models.
- Parameters:
model (ModelProto)
- infer_types(model, use_standalone_type_inference=False, **kwargs)
Infers types (and optionally shapes) based on the use_standalone_type_inference flag.
When use_standalone_type_inference is True, uses a standalone type inference implementation that only infers types. Otherwise, uses ONNX’s infer_shapes which infers both types and shapes.
ONNX’s
infer_shapescan fail on weakly-typed models – withstrict_mode=Trueit raises on an op it cannot resolve (e.g. aTopKwhose axis it resolves to a stale dimension) instead of silently leaving that node’s outputs untyped. On any shape-inference failure this falls back to the standalone type inferencer, which derives types from operator schemas regardless of shapes, so downstream type lookups (e.g. in AutoCast) do not fail. Callers that need a fully typed graph should passstrict_mode=Trueso incomplete inference surfaces as an exception that triggers the fallback.- Parameters:
model (ModelProto) – ONNX model to infer types/shapes for.
use_standalone_type_inference (bool) – If True, use standalone type inference (_infer_types_only). If False, use ONNX’s shape inference (infer_shapes).
**kwargs – Additional arguments passed to infer_shapes when not using standalone type inference (e.g.
strict_mode,check_type,data_prop).
- Returns:
Model with inferred types (and shapes if not using standalone type inference).
- Return type:
onnx.ModelProto
- infer_types_verification(model)
Verify that all reachable tensors have a defined type.
This is necessary because some nodes may be removed during the inference process, leaving unreachable value_info entries.
- Parameters:
model (ModelProto)
- Return type:
ModelProto
- name_onnx_nodes(graph)
Assigns name to the onnx nodes if not present and return the modified status.
- Parameters:
graph (GraphProto)
- Return type:
bool
- onnx_type_str_to_enum(dtype)
Converts ONNX type in string format to onnx.TensorProto format.
Example: ‘tensor(float16)’ becomes onnx.TensorProto.FLOAT16
- Parameters:
dtype (str) – ONNX type in string format.
- Returns:
ONNX type in enum format.
- Return type:
int
- parse_shapes_spec(shapes_spec)
Parse shapes spec and returns them in a dictionary.
Example shapes spec: input0:1x3x256x256,input1:1x3x128x128
- Parameters:
shapes_spec (str)
- Return type:
dict[str, list[int]]
- randomize_weights(onnx_path)
Assigns random values to the onnx model weights.
- Parameters:
onnx_path (str)
- Return type:
None
- randomize_weights_onnx_bytes(onnx_bytes, seed=0)
Assigns random values to the onnx model weights.
- Parameters:
onnx_bytes (bytes)
seed (int)
- Return type:
bytes
- read_f16_tensor_as_fp32(tensor)
Reads a float16 or bfloat16 tensor as a float32 numpy ndarray.
- remove_node_training_mode(onnx_model, node_op_type)
Remove training_mode attribute and extra training outputs from nodes of a given op type.
This also removes the unused outputs from the training_mode nodes.
- Parameters:
onnx_model (ModelProto) – The onnx model.
node_op_type (str) – The node type to remove training_mode attribute from.
- Returns:
The onnx model with the training_mode attribute removed.
- Return type:
ModelProto
- remove_redundant_casts(onnx_model)
Removes both sequential casts and casts that don’t change precision.
This method optimizes the graph by removing unnecessary cast operations that either: 1. Don’t actually change the data type 2. Could be replaced by a single cast operation 3. Can be folded into a preceding Constant node
- Parameters:
onnx_model (ModelProto) – The ONNX model to optimize.
- Returns:
Model with redundant casts removed.
- Return type:
onnx.ModelProto
- remove_weights_data(onnx_bytes)
Removes raw weight data from the onnx model.
- Parameters:
onnx_bytes (bytes)
- Return type:
bytes
- save_onnx(model, onnx_path, save_as_external_data=False)
Save an ONNX model to given path. If a model is larger than 2GB, will save with external data.
- Parameters:
model (ModelProto)
onnx_path (str)
save_as_external_data (bool)
- save_onnx_bytes_to_dir(onnx_bytes, onnx_dir, onnx_name)
Saves the onnx bytes to a directory with specified file name.
- Parameters:
onnx_bytes (bytes)
onnx_dir (str)
onnx_name (str)
- Return type:
None
- update_domain(onnx_model, op_type, domain)
Updates the domain of all the nodes of the specified op_type to the specified domain.
- Parameters:
onnx_model (ModelProto)
op_type (str)
domain (str)
- Return type:
ModelProto
- validate_batch_size(onnx_bytes, batch_size)
Returns True if all the model inputs has batch dimension equal to batch_size.
- Parameters:
onnx_bytes (bytes)
batch_size (int)
- Return type:
bool
- validate_onnx(onnx_bytes)
Returns True if the onnx_bytes is valid, else False.
- Parameters:
onnx_bytes (bytes)
- Return type:
bool