Adding New Operators

You may find it helpful to read the architecture documentation before you start reading this guide.

Adding new operators to Tripy typically involves making changes in the frontend as well as in the FlatIR. In some cases, the frontend operator can be expressed in terms of existing FlatIR operators, in which case you only need to make changes in the frontend.

Let’s take a look at an example of how you might add an Iota operator to Tripy. So that it doesn’t clash with Tripy’s actual Iota implementation, we’ll call it Theta instead.

Implementation

FlatIR Operator

The FlatIR operator is usually the most challenging aspect of implementing operators in Tripy. The good news is that you might not even need to do this if the low-level operators you need already exist in the FlatIR. And if you do, then it’ll only get easier after this!

We’ll start by adding a new file under nvtripy/flat_ir/ops called theta.py; see the inline comments for explanations of what’s happening:

 1from dataclasses import dataclass
 2
 3from mlir_tensorrt.compiler import ir
 4from mlir_tensorrt.compiler.dialects import stablehlo
 5
 6from nvtripy.flat_ir.ops.base import BaseFlatIROp
 7
 8
 9# Every `FlatIR` operator is implemented as a `dataclass` so that the base
10# class can automatically implement several methods by inspecting the child
11# class fields at runtime. The `repr=False` is important because the default
12# `__repr__` method generated by `dataclass` will be extremely verbose and
13# makes interactive debugging more difficult.
14@dataclass(repr=False)
15class ThetaOp(BaseFlatIROp):
16    dim: int
17
18    # `to_mlir()` is the trickiest bit. As the name implies, the method is
19    # meant to lower the `FlatIR` operator into MLIR. To figure out which
20    # MLIR operators to use, refer to the 'MLIR Python API Guide'
21    # (linked below).
22    def to_mlir(self, operands):
23        out_type = self.outputs[0].to_mlir()
24        theta_dim = ir.IntegerAttr.get(
25            type=ir.IntegerType.get_signless(64), value=self.dim
26        )
27        output = stablehlo.DynamicIotaOp(
28            result=out_type, output_shape=operands[0], iota_dimension=theta_dim
29        )
30        return [output]

Links:

Exposing The Operator

One of the principles we follow when writing submodules is that other submodules should not need to reach into the internals of a submodule to retrieve something they need.

For example, a class which needs to import ThetaOp does not need to know where exactly within the flat_ir.ops module the ThetaOp lives - it should be able to just import it from the submodule.

To make this possible, we need to import the ThetaOp into the flat_ir.ops submodule. We can do so by adding the following line into nvtripy/flat_ir/ops/__init__.py:

1from nvtripy.flat_ir.ops.theta import ThetaOp

Trace Operator And The Public API

Now that we have a FlatIR operator, we can implement a Trace operator that will use it along with a public API function. Let’s create a new file under nvtripy/frontend/trace/ops called theta.py.

Trace Operator

First, we’ll implement the Trace operator itself:

 1from dataclasses import dataclass
 2from typing import Tuple
 3
 4from nvtripy import utils
 5from nvtripy.common import datatype, device
 6from nvtripy.common.exception import raise_error
 7from nvtripy.frontend.trace.ops.base import BaseTraceOp
 8import nvtripy.frontend.trace.ops.utils as op_utils
 9
10
11# Just like with `FlatIR` operators, all `Trace` operators are implemented
12# as `dataclass`es. As before, we want `repr=False` here.
13@dataclass(repr=False)
14class Theta(BaseTraceOp):
15    # Notice that we do *not* need to define a constructor and can rely on
16    # the default implementation provided by `dataclass`.
17    dim: int
18    dtype: datatype.dtype
19
20    # `infer_rank()` populates the rank of the output `TraceTensor`s.
21    # Here we use one of the predefined policies to set the output rank
22    # to the same as the shape (i.e. the length) of the shape operand.
23    infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input()
24
25    # *Optional* `infer_dtypes()` populates the data types of the
26    # output `TraceTensor`s. The default implementation copies the input
27    # data types if they are all the same, so you may not need to implement
28    # this.
29    def infer_dtypes(self):
30        self.outputs[0].dtype = self.dtype
31
32    # *Optional* `infer_devices()` populates the devices of the
33    # output `TraceTensor`s. The default implementation copies the input
34    # devices if they are all the same, so you may not need to implement
35    # this either.
36    def infer_devices(self):
37        self.outputs[0].device = device("gpu")
38
39    # `to_flat_ir()` translates the `Trace` operator to a subgraph of
40    # one or more `FlatIR` operators. In our case, it's just a 1:1
41    # mapping to the `ThetaOp` we created earlier.
42    def to_flat_ir(self, inputs, outputs):
43        # Note that we import the `FlatIR` operator within the function
44        # call - this is to avoid circular dependencies.
45        from nvtripy.flat_ir.ops import ThetaOp
46        import nvtripy.frontend.trace.ops.utils as op_utils
47
48        # This code may look a bit confusing; for more details, look at the
49        # 'FlatIR section in the architecture document' (linked below).
50        ThetaOp.build(inputs, outputs, dim=self.dim)

Links:

Public API

Next, we can define the public interface. Since our public interface maps 1:1 with the Trace operator we just implemented and does not require weights, we’ll add it in the same file.

If our API required a composition of multiple Trace operators, then we would instead implement it under frontend/ops/.

If it required weights (i.e. inputs that are expected to always be constant), then we would implement it as a nvtripy.Module under frontend/module.

 1from nvtripy import export
 2from nvtripy.utils import wrappers
 3from nvtripy.types import ShapeLike
 4
 5
 6# We can use the `export.public_api()` decorator to automatically export this
 7# function into the top-level module. This means it will be accessible as
 8# `nvtripy.theta`.
 9#
10# This decorator also controls how the API is exposed in the documentation -
11# the `document_under` option determines where in the documentation hierarchy
12# this API will show up.
13#
14# If we needed to provide any special autodoc options, we could use the
15# `autodoc_options` parameter.
16@export.public_api(document_under="tensor_operations")
17
18# We can use the `wrappers.interface` decorator to specify constraints on
19# inputs and perform transformations on them, like automatically converting
20# compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors.
21# We will aim to include most constraints and transformations in this decorator
22# so as to avoid layering too many decorators.
23@wrappers.interface(convert_to_tensors=True)
24def theta(
25    shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32
26) -> "nvtripy.Tensor":
27    # For any public facing interfaces, we have documentation requirements which
28    # you can read about in the 'Docs README' (linked below). The docstring
29    # we've implemented here adheres to all of these requirements. Non-compliant
30    # docstrings will, in most cases, cause test failures; however, you should
31    # still manually ensure you're writing high-quality docstrings.
32    #
33    # The examples in docstrings are run as part of our tests, so you should
34    # also add assertions to make sure things are functionally correct. In this
35    # case, we check that the `output` we create in the code example is what we
36    # expect.
37    """
38    Fills an output tensor with consecutive values starting from zero
39    along the given dimension.
40
41    Args:
42        shape: The desired shape.
43        dim: Dimension along which to perform the theta operation.
44            This cannot exceed the rank of the specified shape.
45        dtype: The desired data type.
46
47    Returns:
48        A tensor of shape ``shape`` and data type ``dtype``.
49
50    .. code-block:: python
51        :linenos:
52
53        output = tp.theta([3])
54
55        assert np.array_equal(
56            cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32)
57        )
58    """
59
60    # Next we build the trace operator. The `build()` function is also
61    # responsible for constructing the output frontend Tensors. All of the
62    # arguments that follow the inputs are forwarded directly to the
63    # constructor of the `Trace` operator.
64    return Theta.build([shape], dim, dtype)

Links:

Exposing The Operator

Similarly to the FlatIR operator, we need to import Theta into the frontend.trace.ops submodule. We can do so by adding the following line into nvtripy/frontend/trace/ops/__init__.py:

1from nvtripy.frontend.trace.ops.theta import Theta, theta

Testing

Now that we’ve implemented our operator, let’s write tests for it. The structure of the tests/ directory mirrors that of the nvtripy/ directory (you can read more about that here). We need to test both the FlatIR and Trace operators.

Testing The Trace Operator And Public API

Since we implemented our Trace operator and public API in nvtripy/frontend/trace/ops, we’ll add the test under tests/frontend/trace/ops. Create a new file there called test_theta.py:

 1import nvtripy as tp
 2from tests import helper
 3from nvtripy.frontend.trace.ops import Theta
 4
 5
 6class TestTheta:
 7    # This ensures that the public API function creates a frontend `Tensor`
 8    # and populates it with the right `Trace` operator.
 9    def test_op_func(self):
10        a = tp.theta([2, 3])
11        assert isinstance(a, tp.Tensor)
12        assert isinstance(a.trace_tensor.producer, Theta)
13
14    # You should also include negative tests for anything that is expected to
15    # fail. In our case, we just have `test_invalid_dim`,
16    # which ensures that we emit an error if the `dim` parameter is outside
17    # the allowed range.
18    def test_invalid_dim(self):
19        with helper.raises(
20            tp.TripyException,
21            match="iota dimension cannot go beyond the output rank",
22        ):
23            tp.theta([2, 3], dim=3).eval()

Integration Tests

The code examples in the docstring of the public API serve as good sanity integration tests. However, you should still add separate integration tests to get better coverage.

Our docstring covers the 1D case, so let’s add an integration test to cover the multidimensional case. Create a new file called test_theta.py under tests/integration:

 1import numpy as np
 2import cupy as cp
 3
 4import nvtripy as tp
 5
 6
 7def test_multi_dimensional():
 8    output = tp.theta([2, 3], dim=1)
 9    expected = tp.Tensor([[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], dtype=tp.float32)
10
11    assert tp.equal(output, expected)

Done!

If you’ve reached this point, you have successfully added a new operation to Tripy. Congratulations!