utils

Utility functions related to onnx.

Functions

bfloat16_to_float32

Converts a bfloat16 array (as raw data) to a float32 array.

change_casts_to_fp16

Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead.

check_model

Checks if the given model is valid.

check_model_uses_external_data

Checks if the model uses external data.

duplicate_shared_constants

Duplicate constant tensors if they are shared.

find_lowest_common_ancestor

Function to find the lowest common ancestor of two nodes.

fix_fp16_fp32_mismatches

Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion.

gen_random_inputs

This function generates random inputs for an onnx model.

get_all_input_names

This function returns the inputs names of the given onnx model.

get_attribute

Returns the value of the specified attribute.

get_batch_size

Returns the batch size of the given onnx model.

get_batch_size_from_bytes

Returns the batch size of the given onnx model.

get_cast_to_type

Get the target type from a Cast node.

get_child_nodes

Returns list of output consumer nodes for the given node.

get_consumer_nodes

Get all consumer nodes for a given tensor name.

get_dynamic_graph_inputs

This function returns the dynamic inputs of an ONNX model.

get_input_names

This function returns the external inputs names of the given onnx model.

get_input_names_from_bytes

This function returns the inputs names of the given onnx model in bytes.

get_input_shapes

This function returns the inputs shapes for the given onnx model.

get_input_shapes_from_bytes

This function returns the input shapes of the given onnx model in bytes.

get_min_opset_for_precisions

Gets the minimum required opset version for a set of Q/DQ precision types.

get_node_names

This function returns all node names from the given onnx model.

get_node_names_from_bytes

This function returns all node names from the given onnx model in bytes.

get_opset_version

Returns the opset version of the given model.

get_output_names

This function returns the output names of the given onnx model.

get_output_names_from_bytes

This function returns the output names of the given onnx model in bytes.

get_output_shapes

This function returns the output shapes for the given onnx model.

get_parent_nodes

Returns list of input producer nodes for the given node.

get_producer_nodes

Get all producer nodes for a given tensor name.

get_qdq_precisions

Gets the Q/DQ precision types present in the model.

get_tensor_by_name

This function returns a tensor from its name.

get_variable_inputs

Returns the variable inputs of the given Node.

has_attribute

Checks if the given node has the specified attribute.

infer_shapes

Infers shapes of the onnx graph, handles large models.

infer_types

Infers types (and optionally shapes) based on the use_standalone_type_inference flag.

infer_types_verification

Verify that all reachable tensors have a defined type.

name_onnx_nodes

Assigns name to the onnx nodes if not present and return the modified status.

onnx_type_str_to_enum

Converts ONNX type in string format to onnx.TensorProto format.

parse_shapes_spec

Parse shapes spec and returns them in a dictionary.

randomize_weights

Assigns random values to the onnx model weights.

randomize_weights_onnx_bytes

Assigns random values to the onnx model weights.

read_f16_tensor_as_fp32

Reads a float16 or bfloat16 tensor as a float32 numpy ndarray.

remove_node_training_mode

Remove training_mode attribute and extra training outputs from nodes of a given op type.

remove_redundant_casts

Removes both sequential casts and casts that don't change precision.

remove_weights_data

Removes raw weight data from the onnx model.

save_onnx

Save an ONNX model to given path.

save_onnx_bytes_to_dir

Saves the onnx bytes to a directory with specified file name.

update_domain

Updates the domain of all the nodes of the specified op_type to the specified domain.

validate_batch_size

Returns True if all the model inputs has batch dimension equal to batch_size.

validate_onnx

Returns True if the onnx_bytes is valid, else False.

bfloat16_to_float32(bf16_array)

Converts a bfloat16 array (as raw data) to a float32 array.

change_casts_to_fp16(model, target_op_types)

Change FP16-to-FP32 Cast nodes whose entire fanout feeds target ops to cast to FP16 instead.

Parameters:
  • model (ModelProto) – The ONNX model to modify.

  • target_op_types (list[str]) – List of op types to check for. Cast nodes feeding exclusively into these will be changed from FP32 to FP16.

Returns:

The modified ONNX model with Cast nodes updated.

Return type:

ModelProto

check_model(model)

Checks if the given model is valid.

Parameters:

model (ModelProto)

Return type:

None

check_model_uses_external_data(model)

Checks if the model uses external data. True if any initializer tensor has data_location set to EXTERNAL.

Parameters:

model (ModelProto)

Return type:

bool

duplicate_shared_constants(onnx_model)

Duplicate constant tensors if they are shared.

Parameters:

onnx_model (ModelProto)

Return type:

tuple[ModelProto, bool]

find_lowest_common_ancestor(node1, node2)

Function to find the lowest common ancestor of two nodes.

Parameters:
  • node1 (Node) – First node name.

  • node2 (Node) – Second node name.

Returns:

LCA node. Distance from first node. Distance from second node.

Return type:

tuple[str | None, int, int]

fix_fp16_fp32_mismatches(model)

Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion.

After convert_float_to_float16 with an op_block_list, FP32 data from blocked ops (e.g., QDQ paths) can flow into nodes whose other inputs are FP16. TensorRT –stronglyTyped rejects such mismatches. This function propagates “real” types through the graph and inserts FP32->FP16 Cast nodes where needed.

Note: value_info types are unreliable after convert_float_to_float16 with blocked ops (metadata may say FP16 even when actual data is FP32), so this function re-derives types by following op semantics.

Parameters:

model (ModelProto) – The ONNX model to fix.

Returns:

The modified ONNX model with Cast nodes inserted to resolve mismatches.

Return type:

ModelProto

gen_random_inputs(model, shapes_spec=None)

This function generates random inputs for an onnx model.

Parameters:
  • model (ModelProto) – Loaded in-memory onnx ModelProto.

  • shapes_spec (str | None) – A string representing the shape of each input tensors. The format is

  • "<tensor1> – <d1>x<d2>,<tensor2>:<d1>,…”. If the shape is not provided for an input tensor, the shape is

  • directly (inferred from the onnx model)

  • 1. (with all the unknown dims filled with)

Returns:

Dictionary of numpy tensors.

Return type:

dict[str, ndarray]

get_all_input_names(model)

This function returns the inputs names of the given onnx model.

Parameters:

model (ModelProto)

Return type:

list[str]

get_attribute(node, attr_name)

Returns the value of the specified attribute.

Parameters:
  • node (NodeProto)

  • attr_name (str)

Return type:

Any

get_batch_size(model)

Returns the batch size of the given onnx model.

Assertion will fail if batch size is not same over all the inputs.

Parameters:

model (ModelProto)

Return type:

int

get_batch_size_from_bytes(onnx_bytes)

Returns the batch size of the given onnx model.

Assertion will fail if batch size is not same over all the inputs.

Parameters:

onnx_bytes (bytes)

Return type:

int

get_cast_to_type(cast_node)

Get the target type from a Cast node.

Parameters:

cast_node (NodeProto) – The Cast node to extract type from.

Returns:

The target type value from the Cast node’s ‘to’ attribute.

Return type:

int

Raises:

ValueError – If the Cast node does not have a ‘to’ attribute.

get_child_nodes(node)

Returns list of output consumer nodes for the given node.

Parameters:

node (Node)

Return type:

list[Node]

get_consumer_nodes(model, tensor_name)

Get all consumer nodes for a given tensor name.

Parameters:
  • model (ModelProto) – The ONNX model to search.

  • tensor_name (str) – Name of the tensor to find consumers for.

Returns:

List of nodes that consume the tensor.

Return type:

list[onnx.NodeProto]

get_dynamic_graph_inputs(onnx_model)

This function returns the dynamic inputs of an ONNX model.

Parameters:

onnx_model (ModelProto) – ONNX model to obtain dynamic inputs from.

Returns:

List of dynamic inputs.

get_input_names(model, external_inputs_only=True)

This function returns the external inputs names of the given onnx model.

Note: external_input_names = input_names - initializer_names

Parameters:
  • model (ModelProto) – Loaded in-memory onnx ModelProto.

  • external_inputs_only (bool)

Returns:

List of external input names of the model.

Return type:

list[str]

get_input_names_from_bytes(model_bytes, external_inputs_only=True)

This function returns the inputs names of the given onnx model in bytes.

Parameters:
  • model_bytes (bytes) – Onnx model in bytes.

  • external_inputs_only (bool)

Returns:

List of input names of the model.

Return type:

list[str]

get_input_shapes(model, external_inputs_only=True)

This function returns the inputs shapes for the given onnx model.

Parameters:
  • model (ModelProto)

  • external_inputs_only (bool)

Return type:

dict[str, list[int]]

get_input_shapes_from_bytes(model_bytes)

This function returns the input shapes of the given onnx model in bytes.

Parameters:

model_bytes (bytes) – Onnx model in bytes.

Returns:

Dictionary of inputs names and shapes.

Return type:

dict[str, list[int]]

get_min_opset_for_precisions(precisions)

Gets the minimum required opset version for a set of Q/DQ precision types.

Parameters:

precisions (set) – Set of precision type strings (e.g., ‘float8_e4m3fn’, ‘int4’).

Returns:

Minimum required opset version for the given precisions.

Return type:

int

get_node_names(model)

This function returns all node names from the given onnx model.

Parameters:

model (ModelProto) – Loaded in-memory onnx ModelProto.

Returns:

List of node names of the model.

Return type:

list[str]

get_node_names_from_bytes(model_bytes)

This function returns all node names from the given onnx model in bytes.

Parameters:
  • model – onnx model in bytes.

  • model_bytes (bytes)

Returns:

List of node names of the model.

Return type:

list[str]

get_opset_version(model)

Returns the opset version of the given model.

Parameters:

model (ModelProto)

Return type:

int

get_output_names(model)

This function returns the output names of the given onnx model.

Parameters:

model (ModelProto) – Loaded in-memory onnx ModelProto.

Returns:

List of output names of the model.

Return type:

list[str]

get_output_names_from_bytes(model_bytes)

This function returns the output names of the given onnx model in bytes.

Parameters:

model_bytes (bytes) – Onnx model in bytes.

Returns:

List of output names of the model.

Return type:

list[str]

get_output_shapes(model)

This function returns the output shapes for the given onnx model.

Parameters:

model (ModelProto)

Return type:

dict[str, list[int]]

get_parent_nodes(node)

Returns list of input producer nodes for the given node.

Parameters:

node (Node)

Return type:

list[Node]

get_producer_nodes(model, tensor_name)

Get all producer nodes for a given tensor name.

Parameters:
  • model (ModelProto) – The ONNX model to search.

  • tensor_name (str) – Name of the tensor to find producers for.

Returns:

List of nodes that produce the tensor.

Return type:

list[onnx.NodeProto]

get_qdq_precisions(model)

Gets the Q/DQ precision types present in the model.

Parameters:

model (ModelProto) – Loaded in-memory onnx ModelProto.

Returns:

Set of Q/DQ precision types present in the model (e.g., ‘float8_e4m3fn’, ‘int8’,

’int4’, ‘float4_e2m1fn’).

Return type:

set

get_tensor_by_name(onnx_model, tensor_name)

This function returns a tensor from its name.

This function searches for a tensor in the model’s: 1. Value info (shape/type info, no data) 2. Initializers (TensorProto, contains actual data) 3. Inputs and outputs

Parameters:
  • onnx_model (ModelProto) – ONNX model.

  • tensor_name (str) – tensor name.

Returns:

tensor

Return type:

ValueInfoProto | TensorProto | None

get_variable_inputs(node)

Returns the variable inputs of the given Node.

Parameters:

node (Node)

Return type:

list[Variable]

has_attribute(node, attr_name)

Checks if the given node has the specified attribute.

Parameters:
  • node (NodeProto)

  • attr_name (str)

Return type:

bool

infer_shapes(model, **kwargs)

Infers shapes of the onnx graph, handles large models.

Parameters:

model (ModelProto)

infer_types(model, use_standalone_type_inference=False, **kwargs)

Infers types (and optionally shapes) based on the use_standalone_type_inference flag.

When use_standalone_type_inference is True, uses a standalone type inference implementation that only infers types. Otherwise, uses ONNX’s infer_shapes which infers both types and shapes.

Parameters:
  • model (ModelProto) – ONNX model to infer types/shapes for.

  • use_standalone_type_inference (bool) – If True, use standalone type inference (_infer_types_only). If False, use ONNX’s shape inference (infer_shapes).

  • **kwargs – Additional arguments passed to infer_shapes when not using standalone type inference.

Returns:

Model with inferred types (and shapes if not using standalone type inference).

Return type:

onnx.ModelProto

infer_types_verification(model)

Verify that all reachable tensors have a defined type.

This is necessary because some nodes may be removed during the inference process, leaving unreachable value_info entries.

Parameters:

model (ModelProto)

Return type:

ModelProto

name_onnx_nodes(graph)

Assigns name to the onnx nodes if not present and return the modified status.

Parameters:

graph (GraphProto)

Return type:

bool

onnx_type_str_to_enum(dtype)

Converts ONNX type in string format to onnx.TensorProto format.

Example: ‘tensor(float16)’ becomes onnx.TensorProto.FLOAT16

Parameters:

dtype (str) – ONNX type in string format.

Returns:

ONNX type in enum format.

Return type:

int

parse_shapes_spec(shapes_spec)

Parse shapes spec and returns them in a dictionary.

Example shapes spec: input0:1x3x256x256,input1:1x3x128x128

Parameters:

shapes_spec (str)

Return type:

dict[str, list[int]]

randomize_weights(onnx_path)

Assigns random values to the onnx model weights.

Parameters:

onnx_path (str)

Return type:

None

randomize_weights_onnx_bytes(onnx_bytes, seed=0)

Assigns random values to the onnx model weights.

Parameters:
  • onnx_bytes (bytes)

  • seed (int)

Return type:

bytes

read_f16_tensor_as_fp32(tensor)

Reads a float16 or bfloat16 tensor as a float32 numpy ndarray.

remove_node_training_mode(onnx_model, node_op_type)

Remove training_mode attribute and extra training outputs from nodes of a given op type.

This also removes the unused outputs from the training_mode nodes.

Parameters:
  • onnx_model (ModelProto) – The onnx model.

  • node_op_type (str) – The node type to remove training_mode attribute from.

Returns:

The onnx model with the training_mode attribute removed.

Return type:

ModelProto

remove_redundant_casts(onnx_model)

Removes both sequential casts and casts that don’t change precision.

This method optimizes the graph by removing unnecessary cast operations that either: 1. Don’t actually change the data type 2. Could be replaced by a single cast operation 3. Can be folded into a preceding Constant node

Parameters:

onnx_model (ModelProto) – The ONNX model to optimize.

Returns:

Model with redundant casts removed.

Return type:

onnx.ModelProto

remove_weights_data(onnx_bytes)

Removes raw weight data from the onnx model.

Parameters:

onnx_bytes (bytes)

Return type:

bytes

save_onnx(model, onnx_path, save_as_external_data=False)

Save an ONNX model to given path. If a model is larger than 2GB, will save with external data.

Parameters:
  • model (ModelProto)

  • onnx_path (str)

  • save_as_external_data (bool)

save_onnx_bytes_to_dir(onnx_bytes, onnx_dir, onnx_name)

Saves the onnx bytes to a directory with specified file name.

Parameters:
  • onnx_bytes (bytes)

  • onnx_dir (str)

  • onnx_name (str)

Return type:

None

update_domain(onnx_model, op_type, domain)

Updates the domain of all the nodes of the specified op_type to the specified domain.

Parameters:
  • onnx_model (ModelProto)

  • op_type (str)

  • domain (str)

Return type:

ModelProto

validate_batch_size(onnx_bytes, batch_size)

Returns True if all the model inputs has batch dimension equal to batch_size.

Parameters:
  • onnx_bytes (bytes)

  • batch_size (int)

Return type:

bool

validate_onnx(onnx_bytes)

Returns True if the onnx_bytes is valid, else False.

Parameters:

onnx_bytes (bytes)

Return type:

bool