Custom Operations

Plugins allow you to extend TensorRT with custom operations.

  • The quickly deployable plugin (QDP) framework is the easiest way to write plugins.

Implementing The Plugin

In this guide, we’ll implement a plugin that increments a tensor by 1.

See also

TensorRT’s guide on QDPs includes more details on implementing plugins.

We must:

  1. Register the interface for the plugin.

  2. Implement the plugin kernel.

  3. Generate PTX.

Registering The Plugin Interface

trtp.register decorates a function that defines the plugin interface:

 1import tensorrt.plugin as trtp
 2
 3# Plugin IDs are of the form: "<namespace>::<name>" and
 4# uniquely identify a plugin.
 5INCREMENT_PLUGIN_ID = "example::increment"
 6
 7
 8@trtp.register(INCREMENT_PLUGIN_ID)
 9def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
10    """
11    Defines the plugin interface - inputs, outputs and attributes.
12
13    Args:
14        inp0: Input tensor descriptor
15        block_size: Block size for the Triton kernel
16
17    Returns:
18        Output tensor descriptor with same shape/dtype as input
19    """
20    return inp0.like()

Implementing The Kernel

For this example, we use OpenAI’s Triton language to implement the kernel:

 1import triton
 2import triton.language as tl
 3
 4
 5@triton.jit
 6def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
 7    pid = tl.program_id(0)
 8    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 9    mask = offsets < num_elements
10    x = tl.load(x_ptr + offsets, mask=mask)
11    tl.store(y_ptr + offsets, x + 1, mask=mask)

Note

Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX.

Retrieving PTX

trtp.aot_impl decorates a function that retrieves PTX, launch parameters, and any extra scalar arguments:

 1from typing import Tuple, Union
 2import tensorrt.plugin as trtp
 3
 4
 5@trtp.aot_impl(INCREMENT_PLUGIN_ID)
 6def increment_aot_impl(
 7    inp0: trtp.TensorDesc,
 8    block_size: int,
 9    outputs: Tuple[trtp.TensorDesc],
10    tactic: int,
11) -> Tuple[
12    Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
13]:
14    src = triton.compiler.ASTSource(
15        fn=increment,
16        signature="*fp32,i32,*fp32",
17        constants={
18            "BLOCK_SIZE": block_size,
19        },
20    )
21
22    compiled_kernel = triton.compile(src)
23
24    # Set the grid, block dims and shared memory for the
25    # kernel (as symbolic expressions)
26    launch_params = trtp.KernelLaunchParams()
27    num_elements = inp0.shape_expr.numel()
28    launch_params.grid_x = trtp.cdiv(num_elements, block_size)
29    launch_params.block_x = compiled_kernel.metadata.num_warps * 32
30    launch_params.shared_mem = compiled_kernel.metadata.shared
31
32    # Define extra scalar arguments for the
33    # kernel (as symbolic expressions)
34    extra_args = trtp.SymIntExprs(1)
35    extra_args[0] = trtp.SymInt32(num_elements)
36
37    return (
38        compiled_kernel.metadata.name,
39        compiled_kernel.asm["ptx"],
40        launch_params,
41        extra_args,
42    )

Using The Plugin

We can use the plugin with nvtripy.plugin():

1inp = tp.iota((2, 2))
2# Plugin attributes are passed as keyword arguments and must match
3# the attributes specified by the registration function.
4out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256)
Local Variables
>>> inp
tensor(
    [[0, 0],
     [1, 1]],
    dtype=float32, loc=gpu:0, shape=(2, 2))

>>> out
tensor(
    [[1, 1],
     [2, 2]],
    dtype=float32, loc=gpu:0, shape=(2, 2))