Custom Operations¶
Plugins allow you to extend TensorRT with custom operations.
The quickly deployable plugin (QDP) framework is the easiest way to write plugins.
Implementing The Plugin¶
In this guide, we’ll implement a plugin that increments a tensor by 1.
See also
TensorRT’s guide on QDPs includes more details on implementing plugins.
We must:
Register the interface for the plugin.
Implement the plugin kernel.
Generate PTX.
Registering The Plugin Interface¶
trtp.register
decorates a function that defines the plugin interface:
1import tensorrt.plugin as trtp
2
3# Plugin IDs are of the form: "<namespace>::<name>" and
4# uniquely identify a plugin.
5INCREMENT_PLUGIN_ID = "example::increment"
6
7
8@trtp.register(INCREMENT_PLUGIN_ID)
9def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
10 """
11 Defines the plugin interface - inputs, outputs and attributes.
12
13 Args:
14 inp0: Input tensor descriptor
15 block_size: Block size for the Triton kernel
16
17 Returns:
18 Output tensor descriptor with same shape/dtype as input
19 """
20 return inp0.like()
Implementing The Kernel¶
For this example, we use OpenAI’s Triton language to implement the kernel:
1import triton
2import triton.language as tl
3
4
5@triton.jit
6def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
7 pid = tl.program_id(0)
8 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
9 mask = offsets < num_elements
10 x = tl.load(x_ptr + offsets, mask=mask)
11 tl.store(y_ptr + offsets, x + 1, mask=mask)
Note
Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX.
Retrieving PTX¶
trtp.aot_impl
decorates a function that retrieves PTX, launch parameters, and any extra scalar arguments:
1from typing import Tuple, Union
2import tensorrt.plugin as trtp
3
4
5@trtp.aot_impl(INCREMENT_PLUGIN_ID)
6def increment_aot_impl(
7 inp0: trtp.TensorDesc,
8 block_size: int,
9 outputs: Tuple[trtp.TensorDesc],
10 tactic: int,
11) -> Tuple[
12 Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
13]:
14
15 def _drop_unused_entry_params(ptx: str, kernel_name: str) -> str:
16 """
17 Removes unreferenced PTX entry parameters for the given kernel.
18
19 NOTE: This is a temporary workaround and will not be necessary in a
20 future version of TensorRT.
21
22 Why this exists:
23 - Newer Triton versions may add extra kernel entry parameters for
24 runtime plumbing.
25 - For simple kernels, these extra params can be unreferenced
26 in the PTX body.
27 - Some plugin launch paths expect only the explicitly
28 modeled arguments.
29
30 This helper keeps only parameters that are actually referenced by
31 `ld.param ... [<param_name>]` in the PTX body.
32 """
33 import re
34
35 lines = ptx.splitlines()
36 entry_start = next(
37 (
38 i
39 for i, line in enumerate(lines)
40 if f".entry {kernel_name}(" in line
41 ),
42 None,
43 )
44 if entry_start is None:
45 return ptx
46
47 entry_end = next(
48 (
49 i
50 for i in range(entry_start + 1, len(lines))
51 if lines[i].strip() == ")"
52 ),
53 None,
54 )
55 if entry_end is None:
56 return ptx
57
58 param_lines = lines[entry_start + 1 : entry_end]
59 body = "\n".join(lines[entry_end + 1 :])
60
61 def param_name(line: str):
62 match = re.search(r"\b([A-Za-z_][A-Za-z0-9_]*)\s*,?\s*$", line)
63 return match.group(1) if match and ".param" in line else None
64
65 used = {
66 name
67 for line in param_lines
68 if (name := param_name(line))
69 and re.search(rf"\[{re.escape(name)}\]", body)
70 }
71 filtered_params = [
72 line
73 for line in param_lines
74 if (name := param_name(line)) is None or name in used
75 ]
76 if len(filtered_params) == len(param_lines):
77 return ptx
78
79 for i in range(len(filtered_params) - 1, -1, -1):
80 if ".param" in filtered_params[i]:
81 filtered_params[i] = filtered_params[i].rstrip().rstrip(",")
82 break
83
84 return "\n".join(
85 lines[: entry_start + 1] + filtered_params + lines[entry_end:]
86 )
87
88 src = triton.compiler.ASTSource(
89 fn=increment,
90 signature={"x_ptr": "*fp32", "num_elements": "i32", "y_ptr": "*fp32"},
91 constexprs={
92 "BLOCK_SIZE": block_size,
93 },
94 )
95
96 compiled_kernel = triton.compile(src)
97 metadata = compiled_kernel.metadata
98
99 # Set the grid, block dims and shared memory for the
100 # kernel (as symbolic expressions)
101 launch_params = trtp.KernelLaunchParams()
102 num_elements = inp0.shape_expr.numel()
103
104 launch_params.grid_x = trtp.cdiv(num_elements, block_size)
105 launch_params.block_x = metadata.num_warps * 32
106 launch_params.shared_mem = metadata.shared
107
108 # Define extra scalar arguments for the
109 # kernel (as symbolic expressions)
110 extra_args = trtp.SymIntExprs(1)
111 extra_args[0] = trtp.SymInt32(num_elements)
112
113 # Optional compatibility step for environments where Triton emits extra
114 # unreferenced entry parameters.
115 ptx = _drop_unused_entry_params(compiled_kernel.asm["ptx"], metadata.name)
116
117 return metadata.name, ptx, launch_params, extra_args
Using The Plugin¶
We can use the plugin with nvtripy.plugin():
1inp = tp.iota((2, 2))
2# Plugin attributes are passed as keyword arguments and must match
3# the attributes specified by the registration function.
4out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256)
Local Variables
>>> inp
tensor(
[[0, 0],
[1, 1]],
dtype=float32, loc=gpu:0, shape=(2, 2))
>>> out
tensor(
[[1, 1],
[2, 2]],
dtype=float32, loc=gpu:0, shape=(2, 2))