Custom Operations

Plugins allow you to extend TensorRT with custom operations.

  • The quickly deployable plugin (QDP) framework is the easiest way to write plugins.

Implementing The Plugin

In this guide, we’ll implement a plugin that increments a tensor by 1.

See also

TensorRT’s guide on QDPs includes more details on implementing plugins.

We must:

  1. Register the interface for the plugin.

  2. Implement the plugin kernel.

  3. Generate PTX.

Registering The Plugin Interface

trtp.register decorates a function that defines the plugin interface:

 1import tensorrt.plugin as trtp
 2
 3# Plugin IDs are of the form: "<namespace>::<name>" and
 4# uniquely identify a plugin.
 5INCREMENT_PLUGIN_ID = "example::increment"
 6
 7
 8@trtp.register(INCREMENT_PLUGIN_ID)
 9def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
10    """
11    Defines the plugin interface - inputs, outputs and attributes.
12
13    Args:
14        inp0: Input tensor descriptor
15        block_size: Block size for the Triton kernel
16
17    Returns:
18        Output tensor descriptor with same shape/dtype as input
19    """
20    return inp0.like()

Implementing The Kernel

For this example, we use OpenAI’s Triton language to implement the kernel:

 1import triton
 2import triton.language as tl
 3
 4
 5@triton.jit
 6def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
 7    pid = tl.program_id(0)
 8    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 9    mask = offsets < num_elements
10    x = tl.load(x_ptr + offsets, mask=mask)
11    tl.store(y_ptr + offsets, x + 1, mask=mask)

Note

Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX.

Retrieving PTX

trtp.aot_impl decorates a function that retrieves PTX, launch parameters, and any extra scalar arguments:

  1from typing import Tuple, Union
  2import tensorrt.plugin as trtp
  3
  4
  5@trtp.aot_impl(INCREMENT_PLUGIN_ID)
  6def increment_aot_impl(
  7    inp0: trtp.TensorDesc,
  8    block_size: int,
  9    outputs: Tuple[trtp.TensorDesc],
 10    tactic: int,
 11) -> Tuple[
 12    Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
 13]:
 14
 15    def _drop_unused_entry_params(ptx: str, kernel_name: str) -> str:
 16        """
 17        Removes unreferenced PTX entry parameters for the given kernel.
 18
 19        NOTE: This is a temporary workaround and will not be necessary in a
 20            future version of TensorRT.
 21
 22        Why this exists:
 23        - Newer Triton versions may add extra kernel entry parameters for
 24          runtime plumbing.
 25        - For simple kernels, these extra params can be unreferenced
 26          in the PTX body.
 27        - Some plugin launch paths expect only the explicitly
 28          modeled arguments.
 29
 30        This helper keeps only parameters that are actually referenced by
 31        `ld.param ... [<param_name>]` in the PTX body.
 32        """
 33        import re
 34
 35        lines = ptx.splitlines()
 36        entry_start = next(
 37            (
 38                i
 39                for i, line in enumerate(lines)
 40                if f".entry {kernel_name}(" in line
 41            ),
 42            None,
 43        )
 44        if entry_start is None:
 45            return ptx
 46
 47        entry_end = next(
 48            (
 49                i
 50                for i in range(entry_start + 1, len(lines))
 51                if lines[i].strip() == ")"
 52            ),
 53            None,
 54        )
 55        if entry_end is None:
 56            return ptx
 57
 58        param_lines = lines[entry_start + 1 : entry_end]
 59        body = "\n".join(lines[entry_end + 1 :])
 60
 61        def param_name(line: str):
 62            match = re.search(r"\b([A-Za-z_][A-Za-z0-9_]*)\s*,?\s*$", line)
 63            return match.group(1) if match and ".param" in line else None
 64
 65        used = {
 66            name
 67            for line in param_lines
 68            if (name := param_name(line))
 69            and re.search(rf"\[{re.escape(name)}\]", body)
 70        }
 71        filtered_params = [
 72            line
 73            for line in param_lines
 74            if (name := param_name(line)) is None or name in used
 75        ]
 76        if len(filtered_params) == len(param_lines):
 77            return ptx
 78
 79        for i in range(len(filtered_params) - 1, -1, -1):
 80            if ".param" in filtered_params[i]:
 81                filtered_params[i] = filtered_params[i].rstrip().rstrip(",")
 82                break
 83
 84        return "\n".join(
 85            lines[: entry_start + 1] + filtered_params + lines[entry_end:]
 86        )
 87
 88    src = triton.compiler.ASTSource(
 89        fn=increment,
 90        signature={"x_ptr": "*fp32", "num_elements": "i32", "y_ptr": "*fp32"},
 91        constexprs={
 92            "BLOCK_SIZE": block_size,
 93        },
 94    )
 95
 96    compiled_kernel = triton.compile(src)
 97    metadata = compiled_kernel.metadata
 98
 99    # Set the grid, block dims and shared memory for the
100    # kernel (as symbolic expressions)
101    launch_params = trtp.KernelLaunchParams()
102    num_elements = inp0.shape_expr.numel()
103
104    launch_params.grid_x = trtp.cdiv(num_elements, block_size)
105    launch_params.block_x = metadata.num_warps * 32
106    launch_params.shared_mem = metadata.shared
107
108    # Define extra scalar arguments for the
109    # kernel (as symbolic expressions)
110    extra_args = trtp.SymIntExprs(1)
111    extra_args[0] = trtp.SymInt32(num_elements)
112
113    # Optional compatibility step for environments where Triton emits extra
114    # unreferenced entry parameters.
115    ptx = _drop_unused_entry_params(compiled_kernel.asm["ptx"], metadata.name)
116
117    return metadata.name, ptx, launch_params, extra_args

Using The Plugin

We can use the plugin with nvtripy.plugin():

1inp = tp.iota((2, 2))
2# Plugin attributes are passed as keyword arguments and must match
3# the attributes specified by the registration function.
4out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256)
Local Variables
>>> inp
tensor(
    [[0, 0],
     [1, 1]],
    dtype=float32, loc=gpu:0, shape=(2, 2))

>>> out
tensor(
    [[1, 1],
     [2, 2]],
    dtype=float32, loc=gpu:0, shape=(2, 2))