Adding New Operators¶
You may find it helpful to read the architecture documentation before you start reading this guide.
Adding new operators to Tripy typically involves making changes in the frontend as well
as in the FlatIR
. In some cases, the frontend operator can be expressed in terms of existing
FlatIR
operators, in which case you only need to make changes in the frontend.
Let’s take a look at an example of how you might add an Iota
operator to Tripy.
So that it doesn’t clash with Tripy’s actual Iota
implementation, we’ll call it
Theta
instead.
Implementation¶
FlatIR
Operator¶
The FlatIR
operator is usually the most challenging aspect of implementing operators
in Tripy. The good news is that you might not even need to do this if the low-level operators
you need already exist in the FlatIR
. And if you do, then it’ll only get easier after this!
We’ll start by adding a new file under nvtripy/flat_ir/ops
called
theta.py
; see the inline comments for explanations of what’s happening:
1from dataclasses import dataclass
2
3from mlir_tensorrt.compiler import ir
4from mlir_tensorrt.compiler.dialects import stablehlo
5
6from nvtripy.flat_ir.ops.base import BaseFlatIROp
7
8
9# Every `FlatIR` operator is implemented as a `dataclass` so that the base
10# class can automatically implement several methods by inspecting the child
11# class fields at runtime. The `repr=False` is important because the default
12# `__repr__` method generated by `dataclass` will be extremely verbose and
13# makes interactive debugging more difficult.
14@dataclass(repr=False)
15class ThetaOp(BaseFlatIROp):
16 dim: int
17
18 # `to_mlir()` is the trickiest bit. As the name implies, the method is
19 # meant to lower the `FlatIR` operator into MLIR. To figure out which
20 # MLIR operators to use, refer to the 'MLIR Python API Guide'
21 # (linked below).
22 def to_mlir(self, operands):
23 out_type = self.outputs[0].to_mlir()
24 theta_dim = ir.IntegerAttr.get(
25 type=ir.IntegerType.get_signless(64), value=self.dim
26 )
27 output = stablehlo.DynamicIotaOp(
28 result=out_type, output_shape=operands[0], iota_dimension=theta_dim
29 )
30 return [output]
Links:
Exposing The Operator¶
One of the principles we follow when writing submodules is that other submodules should not need to reach into the internals of a submodule to retrieve something they need.
For example, a class which needs to import ThetaOp
does not need to know where exactly
within the flat_ir.ops
module the ThetaOp
lives - it should be able to just import it
from the submodule.
To make this possible, we need to import the ThetaOp
into the flat_ir.ops
submodule.
We can do so by adding the following line into
nvtripy/flat_ir/ops/__init__.py
:
1from nvtripy.flat_ir.ops.theta import ThetaOp
Trace
Operator And The Public API¶
Now that we have a FlatIR
operator, we can implement a Trace
operator that will use it
along with a public API function. Let’s create a new file under
nvtripy/frontend/trace/ops
called theta.py
.
Trace
Operator¶
First, we’ll implement the Trace
operator itself:
1from dataclasses import dataclass
2from typing import Tuple
3
4from nvtripy import utils
5from nvtripy.common import datatype, device
6from nvtripy.common.exception import raise_error
7from nvtripy.frontend.trace.ops.base import BaseTraceOp
8import nvtripy.frontend.trace.ops.utils as op_utils
9
10
11# Just like with `FlatIR` operators, all `Trace` operators are implemented
12# as `dataclass`es. As before, we want `repr=False` here.
13@dataclass(repr=False)
14class Theta(BaseTraceOp):
15 # Notice that we do *not* need to define a constructor and can rely on
16 # the default implementation provided by `dataclass`.
17 dim: int
18 dtype: datatype.dtype
19
20 # `infer_rank()` populates the rank of the output `TraceTensor`s.
21 # Here we use one of the predefined policies to set the output rank
22 # to the same as the shape (i.e. the length) of the shape operand.
23 infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input()
24
25 # *Optional* `infer_dtypes()` populates the data types of the
26 # output `TraceTensor`s. The default implementation copies the input
27 # data types if they are all the same, so you may not need to implement
28 # this.
29 def infer_dtypes(self):
30 self.outputs[0].dtype = self.dtype
31
32 # *Optional* `infer_devices()` populates the devices of the
33 # output `TraceTensor`s. The default implementation copies the input
34 # devices if they are all the same, so you may not need to implement
35 # this either.
36 def infer_devices(self):
37 self.outputs[0].device = device("gpu")
38
39 # `to_flat_ir()` translates the `Trace` operator to a subgraph of
40 # one or more `FlatIR` operators. In our case, it's just a 1:1
41 # mapping to the `ThetaOp` we created earlier.
42 def to_flat_ir(self, inputs, outputs):
43 # Note that we import the `FlatIR` operator within the function
44 # call - this is to avoid circular dependencies.
45 from nvtripy.flat_ir.ops import ThetaOp
46 import nvtripy.frontend.trace.ops.utils as op_utils
47
48 # This code may look a bit confusing; for more details, look at the
49 # 'FlatIR section in the architecture document' (linked below).
50 ThetaOp.build(inputs, outputs, dim=self.dim)
Links:
Public API¶
Next, we can define the public interface. Since our public interface maps 1:1 with the Trace
operator we just implemented and does not require weights, we’ll add it in the same file.
If our API required a composition of multiple Trace
operators, then we would instead implement
it under frontend/ops/
.
If it required weights (i.e. inputs that are expected to always be constant), then we would implement
it as a nvtripy.Module
under frontend/module
.
1from nvtripy import export
2from nvtripy.utils import wrappers
3from nvtripy.types import ShapeLike
4
5
6# We can use the `export.public_api()` decorator to automatically export this
7# function into the top-level module. This means it will be accessible as
8# `nvtripy.theta`.
9#
10# This decorator also controls how the API is exposed in the documentation -
11# the `document_under` option determines where in the documentation hierarchy
12# this API will show up.
13#
14# If we needed to provide any special autodoc options, we could use the
15# `autodoc_options` parameter.
16@export.public_api(document_under="tensor_operations")
17
18# We can use the `wrappers.interface` decorator to specify constraints on
19# inputs and perform transformations on them, like automatically converting
20# compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors.
21# We will aim to include most constraints and transformations in this decorator
22# so as to avoid layering too many decorators.
23@wrappers.interface(convert_to_tensors=True)
24def theta(
25 shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32
26) -> "nvtripy.Tensor":
27 # For any public facing interfaces, we have documentation requirements which
28 # you can read about in the 'Docs README' (linked below). The docstring
29 # we've implemented here adheres to all of these requirements. Non-compliant
30 # docstrings will, in most cases, cause test failures; however, you should
31 # still manually ensure you're writing high-quality docstrings.
32 #
33 # The examples in docstrings are run as part of our tests, so you should
34 # also add assertions to make sure things are functionally correct. In this
35 # case, we check that the `output` we create in the code example is what we
36 # expect.
37 """
38 Fills an output tensor with consecutive values starting from zero
39 along the given dimension.
40
41 Args:
42 shape: The desired shape.
43 dim: Dimension along which to perform the theta operation.
44 This cannot exceed the rank of the specified shape.
45 dtype: The desired data type.
46
47 Returns:
48 A tensor of shape ``shape`` and data type ``dtype``.
49
50 .. code-block:: python
51 :linenos:
52
53 output = tp.theta([3])
54
55 assert np.array_equal(
56 cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)
57 )
58 """
59
60 # Next we build the trace operator. The `build()` function is also
61 # responsible for constructing the output frontend Tensors. All of the
62 # arguments that follow the inputs are forwarded directly to the
63 # constructor of the `Trace` operator.
64 return Theta.build([shape], dim, dtype)
Links:
Exposing The Operator¶
Similarly to the FlatIR
operator, we need to import Theta
into the
frontend.trace.ops
submodule. We can do so by adding the following line into
nvtripy/frontend/trace/ops/__init__.py
:
1from nvtripy.frontend.trace.ops.theta import Theta, theta
Testing¶
Now that we’ve implemented our operator, let’s write tests for it. The structure of the
tests/
directory mirrors that of the nvtripy/
directory
(you can read more about that here). We need to test both the FlatIR
and Trace
operators.
Testing The Trace Operator And Public API¶
Since we implemented our Trace
operator and public API in
nvtripy/frontend/trace/ops
, we’ll add the test under
tests/frontend/trace/ops
.
Create a new file there called test_theta.py
:
1import nvtripy as tp
2from tests import helper
3from nvtripy.frontend.trace.ops import Theta
4
5
6class TestTheta:
7 # This ensures that the public API function creates a frontend `Tensor`
8 # and populates it with the right `Trace` operator.
9 def test_op_func(self):
10 a = tp.theta([2, 3])
11 assert isinstance(a, tp.Tensor)
12 assert isinstance(a.trace_tensor.producer, Theta)
13
14 # You should also include negative tests for anything that is expected to
15 # fail. In our case, we just have `test_invalid_dim`,
16 # which ensures that we emit an error if the `dim` parameter is outside
17 # the allowed range.
18 def test_invalid_dim(self):
19 with helper.raises(
20 tp.TripyException,
21 match="iota dimension cannot go beyond the output rank",
22 ):
23 tp.theta([2, 3], dim=3).eval()
Integration Tests¶
The code examples in the docstring of the public API serve as good sanity integration tests. However, you should still add separate integration tests to get better coverage.
Our docstring covers the 1D case, so let’s add an integration test to cover the multidimensional case.
Create a new file called test_theta.py
under tests/integration
:
1import numpy as np
2import cupy as cp
3
4import nvtripy as tp
5
6
7def test_multi_dimensional():
8 output = tp.theta([2, 3], dim=1)
9 expected = tp.Tensor([[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], dtype=tp.float32)
10
11 assert tp.equal(output, expected)
Done!¶
If you’ve reached this point, you have successfully added a new operation to Tripy. Congratulations!