precisionconverter
Precision conversion module for ONNX models.
This module provides functionality for converting ONNX models between different floating point precisions, specifically handling conversions between FP32 and lower precisions like FP16 or BF16. It handles the insertion of cast operations, conversion of initializers, and ensures model validity through type checking and cleanup of redundant operations.
Classes
Precision conversion module for ONNX models. |
|
PrecisionTypes(onnx_type, numpy_type, str_short, str_full) |
- class PrecisionConverter
Bases:
object
Precision conversion module for ONNX models.
This module provides functionality for converting ONNX models between different floating point precisions, specifically handling conversions between FP32 and lower precisions like FP16 or BF16. It handles the insertion of cast operations, conversion of initializers, and ensures model validity.
- Public Methods:
convert: Convert specified nodes to FP16/BF16 precision while keeping others in FP32.
- __init__(model, value_info_map, initializer_map, node_to_init_map, keep_io_types=False, low_precision_type='fp16', init_conversion_max_bytes=inf)
Initialize PrecisionConverter.
- Parameters:
model (ModelProto) – ONNX model to convert.
value_info_map (dict[str, ValueInfoProto]) – Map of tensor names to value info.
initializer_map (dict[str, TensorProto]) – Map of tensor names to initializers.
node_to_init_map (dict[str, list[str]]) – Map of node names to lists of initializer names.
keep_io_types (bool) – Keep the input and output types of the model, otherwise they will be converted.
low_precision_type (str) – Precision to convert to.
init_conversion_max_bytes (int) – Maximum size in bytes for initializer conversion. Larger initializers will be cast at runtime.
- Return type:
None
- convert(high_precision_nodes, low_precision_nodes)
Convert model to mixed precision.
- Parameters:
high_precision_nodes (list[str]) – List of node names to keep in high precision.
low_precision_nodes (list[str]) – List of node names to convert to low precision.
- Returns:
The converted mixed precision model.
- Return type:
onnx.ModelProto
- class PrecisionTypes
Bases:
tuple
PrecisionTypes(onnx_type, numpy_type, str_short, str_full)
- static __new__(_cls, onnx_type, numpy_type, str_short, str_full)
Create new instance of PrecisionTypes(onnx_type, numpy_type, str_short, str_full)
- numpy_type
Alias for field number 1
- onnx_type
Alias for field number 0
- str_full
Alias for field number 3
- str_short
Alias for field number 2