utils

Utility functions for AutoCast.

This module provides common utility functions used across the AutoCast package. It includes functions for graph traversal, tensor type inference, model validation, and mapping setup between nodes, initializers, and value info. These utilities support the core functionality of model precision conversion.

Functions

clear_types_and_shapes_recursive

Recursively clear type/shape information for a graph and all its subgraphs.

get_op_types_not_supported_in_low_precision

Get a list of ops not supported in low precision for the opset_version = max(model.opset, min_opset).

get_unique_consumer_node

Get a single consumer node and raise exception if there are multiple consumers.

setup_mappings

Setup and return mappings for model components.

walk_subgraphs_recursive

Recursively walk through a graph and all its subgraphs, applying a callback.

clear_types_and_shapes_recursive(graph, clear_shapes=True, is_subgraph=False)

Recursively clear type/shape information for a graph and all its subgraphs.

Resets intermediate (value_info) and output tensor types to UNDEFINED and, when clear_shapes is True, replaces concrete dims with a symbolic "unk" so a subsequent modelopt.onnx.utils.infer_types() re-derives them from the operator graph. For subgraphs, input types/shapes are cleared too so they propagate from the parent graph. This does not change tensor rank, so it cannot repair a stale rank (see _reconcile_stale_output_shapes).

Parameters:
  • graph (GraphProto) – The ONNX graph to clear types and shapes for.

  • clear_shapes (bool) – If True, also clear shapes (False keeps shapes for type-only inference).

  • is_subgraph (bool) – Whether this is a subgraph (True) or the main graph (False).

Return type:

None

get_op_types_not_supported_in_low_precision(model, min_opset, low_precision_type='float16')

Get a list of ops not supported in low precision for the opset_version = max(model.opset, min_opset).

An op is considered to be supported if at least one of the inputs may be in low precision. Ops where only some of the inputs may be in low precision are considered supported by this function and may need special handling. See PrecisionConverter::_should_skip_low_precision_input_conversion.

Parameters:
  • model (ModelProto) – ONNX model.

  • min_opset (int) – Minimum opset version.

  • low_precision_type (str) – Target precision to reduce to (‘float16’ or ‘bfloat16’).

Returns:

List of ops not supported in low precision for the current opset version.

Return type:

ops_without_support

get_unique_consumer_node(model, tensor_name)

Get a single consumer node and raise exception if there are multiple consumers.

Parameters:
  • model (ModelProto) – The ONNX model to search.

  • tensor_name (str) – Name of the tensor to find consumer for.

Returns:

The single consumer node.

Return type:

onnx.NodeProto

Raises:

Exception – If there is not exactly one consumer node.

setup_mappings(model)

Setup and return mappings for model components.

Parameters:

model (ModelProto) – ONNX model to create mappings for.

Returns:

  • value_info_map: Mapping of names to value infos.

  • initializer_map: Mapping of names to initializers.

  • node_to_init_map: Mapping of node names to their initializer inputs.

Return type:

Tuple containing

walk_subgraphs_recursive(graph, callback, parent_node=None, is_subgraph=False)

Recursively walk through a graph and all its subgraphs, applying a callback.

This utility function traverses an ONNX graph and all nested subgraphs by examining graph attributes in nodes. It works with standard control flow operators (Scan, If, Loop) as well as custom operators that define subgraphs using ONNX graph attributes.

Parameters:
  • graph (GraphProto) – The graph to walk.

  • callback (Callable) – Function to call for each graph. Signature: callback(graph, parent_node, is_subgraph).

  • parent_node (NodeProto) – The parent node containing this subgraph (None for main graph).

  • is_subgraph (bool) – Whether this is a subgraph (True) or the main graph (False).

Return type:

None

Note

Works with any node that has attributes of type AttributeProto.GRAPH or AttributeProto.GRAPHS, including custom operators.