trtorch

Functions

trtorch. set_device ( gpu_id )
trtorch. compile ( module: torch.jit._script.ScriptModule , inputs=[] , device=None , disable_tf32=False , sparse_weights=False , enabled_precisions={} , refit=False , debug=False , strict_types=False , capability=<EngineCapability.default: 0> , num_min_timing_iters=2 , num_avg_timing_iters=1 , workspace_size=0 , max_batch_size=0 , calibrator=None , truncate_long_and_double=False , require_full_compilation=False , min_block_size=3 , torch_executed_ops=[] , torch_executed_modules=[] ) → torch.jit._script.ScriptModule

Compile a TorchScript module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler and will convert methods to JIT Graphs which call equivalent TensorRT engines

Converts specifically the forward method of a TorchScript Module

Parameters

module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

Keyword Arguments
  • inputs ( List [ Union ( trtorch.Input , torch.Tensor ) ] ) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

    input=[
        trtorch.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        trtorch.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let trtorch infer settings
    ]
    
  • device ( Union ( trtorch.Device , torch.device , dict ) ) –

    Target device for TensorRT engines to run on

    device=trtorch.Device("dla:1", allow_gpu_fallback=True)
    
  • disable_tf32 ( bool ) – Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas

  • sparse_weights ( bool ) – Enable sparsity for convolution and fully connected layers.

  • enabled_precision ( Set ( Union ( torch.dtype , trtorch.dtype ) ) ) – The set of datatypes that TensorRT can use when selecting kernels

  • refit ( bool ) – Enable refitting

  • debug ( bool ) – Enable debuggable engine

  • strict_types ( bool ) – Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set

  • capability ( trtorch.EngineCapability ) – Restrict kernel selection to safe gpu kernels or safe dla kernels

  • num_min_timing_iters ( int ) – Number of minimization timing iterations used to select kernels

  • num_avg_timing_iters ( int ) – Number of averaging timing iterations used to select kernels

  • workspace_size ( int ) – Maximum size of workspace given to TensorRT

  • max_batch_size ( int ) – Maximum batch size (must be >= 1 to be set, 0 means not set)

  • truncate_long_and_double ( bool ) – Truncate weights provided in int64 or double (float64) to int32 and float32

  • calibrator ( Union ( trtorch._C.IInt8Calibrator , tensorrt.IInt8Calibrator ) ) – Calibrator object which will provide data to the PTQ system for INT8 Calibration

  • require_full_compilation ( bool ) – Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch

  • min_block_size ( int ) – The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT

  • torch_executed_ops ( List [ str ] ) – List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but require_full_compilation is True

  • torch_executed_modules ( List [ str ] ) – List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but require_full_compilation is True

Returns

Compiled TorchScript Module, when run it will execute via TensorRT

Return type

torch.jit.ScriptModule

trtorch. convert_method_to_trt_engine ( module: torch.jit._script.ScriptModule , method_name: str , inputs=[] , device=None , disable_tf32=False , sparse_weights=False , enabled_precisions={} , refit=False , debug=False , strict_types=False , capability=<EngineCapability.default: 0> , num_min_timing_iters=2 , num_avg_timing_iters=1 , workspace_size=0 , max_batch_size=0 , truncate_long_and_double=False , calibrator=None ) → str

Convert a TorchScript module method to a serialized TensorRT engine

Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings

Parameters
  • module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

  • method_name ( str ) – Name of method to convert

Keyword Arguments
  • inputs ( List [ Union ( trtorch.Input , torch.Tensor ) ] ) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

    input=[
        trtorch.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        trtorch.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let trtorch infer settings
    ]
    
  • device ( Union ( trtorch.Device , torch.device , dict ) ) –

    Target device for TensorRT engines to run on

    device=trtorch.Device("dla:1", allow_gpu_fallback=True)
    
  • disable_tf32 ( bool ) – Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas

  • sparse_weights ( bool ) – Enable sparsity for convolution and fully connected layers.

  • enabled_precision ( Set ( Union ( torch.dtype , trtorch.dtype ) ) ) – The set of datatypes that TensorRT can use when selecting kernels

  • refit ( bool ) – Enable refitting

  • debug ( bool ) – Enable debuggable engine

  • strict_types ( bool ) – Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set

  • capability ( trtorch.EngineCapability ) – Restrict kernel selection to safe gpu kernels or safe dla kernels

  • num_min_timing_iters ( int ) – Number of minimization timing iterations used to select kernels

  • num_avg_timing_iters ( int ) – Number of averaging timing iterations used to select kernels

  • workspace_size ( int ) – Maximum size of workspace given to TensorRT

  • max_batch_size ( int ) – Maximum batch size (must be >= 1 to be set, 0 means not set)

  • truncate_long_and_double ( bool ) – Truncate weights provided in int64 or double (float64) to int32 and float32

  • calibrator ( Union ( trtorch._C.IInt8Calibrator , tensorrt.IInt8Calibrator ) ) – Calibrator object which will provide data to the PTQ system for INT8 Calibration

Returns

Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs

Return type

bytes

trtorch. check_method_op_support ( module : torch.jit._script.ScriptModule , method_name : str ) → bool

Checks to see if a method is fully supported by TRTorch

Checks if a method of a TorchScript module can be compiled by TRTorch, if not, a list of operators that are not supported are printed out and the function returns false, else true.

Parameters
  • module ( torch.jit.ScriptModule ) – Source module, a result of tracing or scripting a PyTorch torch.nn.Module

  • method_name ( str ) – Name of method to check

Returns

True if supported Method

Return type

bool

trtorch. embed_engine_in_new_module ( serialized_engine : bytes , device = None ) → torch.jit._script.ScriptModule

Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module

Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module. Registers the forward method to execute the TensorRT engine with the function signature:

forward(Tensor[]) -> Tensor[]

Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules

Parameters

serialized_engine ( bytes ) – Serialized TensorRT engine from either TRTorch or TensorRT APIs

Keyword Arguments

device ( Union ( trtorch.Device , torch.device , dict ) ) – Target device to run engine on. Must be compatible with engine provided. Default: Current active device

Returns

New TorchScript module with engine embedded

Return type

torch.jit.ScriptModule

trtorch. get_build_info ( ) → str

Returns a string containing the build information of TRTorch distribution

Returns

String containing the build information for TRTorch distribution

Return type

str

trtorch. dump_build_info ( )

Prints build information about the TRTorch distribution to stdout

trtorch. TensorRTCompileSpec ( inputs=[] , device=None , disable_tf32=False , sparse_weights=False , enabled_precisions={} , refit=False , debug=False , strict_types=False , capability=<EngineCapability.default: 0> , num_min_timing_iters=2 , num_avg_timing_iters=1 , workspace_size=0 , max_batch_size=0 , truncate_long_and_double=False , calibrator=None ) → <torch._C.ScriptClass object at 0x7f2412d57fb0>

Utility to create a formated spec dictionary for using the PyTorch TensorRT backend

Keyword Arguments
  • inputs ( List [ Union ( trtorch.Input , torch.Tensor ) ] ) –

    Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type.

    input=[
        trtorch.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
        trtorch.Input(
            min_shape=(1, 224, 224, 3),
            opt_shape=(1, 512, 512, 3),
            max_shape=(1, 1024, 1024, 3),
            dtype=torch.int32
            format=torch.channel_last
        ), # Dynamic input shape for input #2
        torch.randn((1, 3, 224, 244)) # Use an example tensor and let trtorch infer settings
    ]
    
  • device ( Union ( trtorch.Device , torch.device , dict ) ) –

    Target device for TensorRT engines to run on

    device=trtorch.Device("dla:1", allow_gpu_fallback=True)
    
  • disable_tf32 ( bool ) – Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas

  • sparse_weights ( bool ) – Enable sparsity for convolution and fully connected layers.

  • enabled_precision ( Set ( Union ( torch.dtype , trtorch.dtype ) ) ) – The set of datatypes that TensorRT can use when selecting kernels

  • refit ( bool ) – Enable refitting

  • debug ( bool ) – Enable debuggable engine

  • strict_types ( bool ) – Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set

  • capability ( trtorch.EngineCapability ) – Restrict kernel selection to safe gpu kernels or safe dla kernels

  • num_min_timing_iters ( int ) – Number of minimization timing iterations used to select kernels

  • num_avg_timing_iters ( int ) – Number of averaging timing iterations used to select kernels

  • workspace_size ( int ) – Maximum size of workspace given to TensorRT

  • max_batch_size ( int ) – Maximum batch size (must be >= 1 to be set, 0 means not set)

  • truncate_long_and_double ( bool ) – Truncate weights provided in int64 or double (float64) to int32 and float32

  • calibrator – Calibrator object which will provide data to the PTQ system for INT8 Calibration

Classes

class trtorch. Input ( * args , ** kwargs )

Defines an input to a module in terms of expected shape, data type and tensor format.

__init__ ( * args , ** kwargs )

__init__ Method for trtorch.Input

Input accepts one of a few construction patterns

Parameters

shape ( Tuple or List , optional ) – Static shape of input tensor

Keyword Arguments
  • shape ( Tuple or List , optional ) – Static shape of input tensor

  • min_shape ( Tuple or List , optional ) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • opt_shape ( Tuple or List , optional ) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • max_shape ( Tuple or List , optional ) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC

  • dtype ( torch.dtype or trtorch.dtype ) – Expected data type for input tensor (default: trtorch.dtype.float32)

  • format ( torch.memory_format or trtorch.TensorFormat ) – The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)

Examples

  • Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)

  • Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW)

  • Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW

dtype = <dtype.unknown: 5>

trtorch.dtype.float32)

Type

The expected data type of the input tensor (default

format = <TensorFormat.contiguous: 0>

trtorch.TensorFormat.NCHW)

Type

The expected format of the input tensor (default

shape = None

Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form { "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }

Type

(Tuple or Dict)

shape_mode = None

Is input statically or dynamically shaped

Type

(trtorch.Input._ShapeMode)

class trtorch. Device ( * args , ** kwargs )

Defines a device that can be used to specify target devices for engines

__init__ ( * args , ** kwargs )

__init__ Method for trtorch.Device

Device accepts one of a few construction patterns

Parameters

spec ( str ) – String with device spec e.g. “dla:0” for dla, core_id 0

Keyword Arguments
  • gpu_id ( int ) – ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided

  • dla_core ( int ) – ID of target DLA core. If specified, no positional arguments should be provided.

  • allow_gpu_fallback ( bool ) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)

Examples

  • Device(“gpu:1”)

  • Device(“cuda:1”)

  • Device(“dla:0”, allow_gpu_fallback=True)

  • Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)

  • Device(dla_core=0, allow_gpu_fallback=True)

  • Device(gpu_id=1)

allow_gpu_fallback = False

(bool) Whether falling back to GPU if DLA cannot support an op should be allowed

device_type = None

Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.

Type

( trtorch.DeviceType )

dla_core = -1

(int) Core ID for target DLA core

gpu_id = -1

(int) Device ID for target GPU

Enums

class trtorch. dtype

Enum to specifiy operating precision for engine execution

Members:

float : 32 bit floating point number

float32 : 32 bit floating point number

half : 16 bit floating point number

float16 : 16 bit floating point number

int8 : 8 bit integer number

int32 : 32 bit integer number

bool : Boolean value

unknown : Unknown data type

class trtorch. DeviceType

Enum to specify device kinds to build TensorRT engines for

Members:

GPU : Specify using GPU to execute TensorRT Engine

DLA : Specify using DLA to execute TensorRT Engine (Jetson Only)

class trtorch. EngineCapability

Enum to specify engine capability settings (selections of kernels to meet safety requirements)

Members:

safe_gpu : Use safety GPU kernels only

safe_dla : Use safety DLA kernels only

default : Use default behavior

class trtorch. TensorFormat

Enum to specifiy the memory layout of tensors

Members:

contiguous : Contiguous memory layout (NCHW / Linear)

channel_last : Channel last memory layout (NHWC)

Submodules