Custom Operations¶
Plugins allow you to extend TensorRT with custom operations.
The quickly deployable plugin (QDP) framework is the easiest way to write plugins.
Implementing The Plugin¶
In this guide, we’ll implement a plugin that increments a tensor by 1.
See also
TensorRT’s guide on QDPs includes more details on implementing plugins.
We must:
Register the interface for the plugin.
Implement the plugin kernel.
Generate PTX.
Registering The Plugin Interface¶
trtp.register
decorates a function that defines the plugin interface:
1import tensorrt.plugin as trtp
2
3# Plugin IDs are of the form: "<namespace>::<name>" and
4# uniquely identify a plugin.
5INCREMENT_PLUGIN_ID = "example::increment"
6
7
8@trtp.register(INCREMENT_PLUGIN_ID)
9def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> trtp.TensorDesc:
10 """
11 Defines the plugin interface - inputs, outputs and attributes.
12
13 Args:
14 inp0: Input tensor descriptor
15 block_size: Block size for the Triton kernel
16
17 Returns:
18 Output tensor descriptor with same shape/dtype as input
19 """
20 return inp0.like()
Implementing The Kernel¶
For this example, we use OpenAI’s Triton language to implement the kernel:
1import triton
2import triton.language as tl
3
4
5@triton.jit
6def increment(x_ptr, num_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
7 pid = tl.program_id(0)
8 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
9 mask = offsets < num_elements
10 x = tl.load(x_ptr + offsets, mask=mask)
11 tl.store(y_ptr + offsets, x + 1, mask=mask)
Note
Kernels can be written in many other ways, e.g. CUDA, CUTLASS, Numba, etc. as long as we can emit PTX.
Retrieving PTX¶
trtp.aot_impl
decorates a function that retrieves PTX, launch parameters, and any extra scalar arguments:
1from typing import Tuple, Union
2import tensorrt.plugin as trtp
3
4
5@trtp.aot_impl(INCREMENT_PLUGIN_ID)
6def increment_aot_impl(
7 inp0: trtp.TensorDesc,
8 block_size: int,
9 outputs: Tuple[trtp.TensorDesc],
10 tactic: int,
11) -> Tuple[
12 Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
13]:
14 src = triton.compiler.ASTSource(
15 fn=increment,
16 signature="*fp32,i32,*fp32",
17 constants={
18 "BLOCK_SIZE": block_size,
19 },
20 )
21
22 compiled_kernel = triton.compile(src)
23
24 # Set the grid, block dims and shared memory for the
25 # kernel (as symbolic expressions)
26 launch_params = trtp.KernelLaunchParams()
27 num_elements = inp0.shape_expr.numel()
28 launch_params.grid_x = trtp.cdiv(num_elements, block_size)
29 launch_params.block_x = compiled_kernel.metadata.num_warps * 32
30 launch_params.shared_mem = compiled_kernel.metadata.shared
31
32 # Define extra scalar arguments for the
33 # kernel (as symbolic expressions)
34 extra_args = trtp.SymIntExprs(1)
35 extra_args[0] = trtp.SymInt32(num_elements)
36
37 return (
38 compiled_kernel.metadata.name,
39 compiled_kernel.asm["ptx"],
40 launch_params,
41 extra_args,
42 )
Using The Plugin¶
We can use the plugin with nvtripy.plugin()
:
1inp = tp.iota((2, 2))
2# Plugin attributes are passed as keyword arguments and must match
3# the attributes specified by the registration function.
4out = tp.plugin(INCREMENT_PLUGIN_ID, [inp], block_size=256)
Local Variables
>>> inp
tensor(
[[0, 0],
[1, 1]],
dtype=float32, loc=gpu:0, shape=(2, 2))
>>> out
tensor(
[[1, 1],
[2, 2]],
dtype=float32, loc=gpu:0, shape=(2, 2))