Adding New Operators

You may find it helpful to read the architecture documentation before you start reading this guide.

Adding new operators to Tripy typically involves making changes in the frontend as well as in the FlatIR. In some cases, the frontend operator can be expressed in terms of existing FlatIR operators, in which case you only need to make changes in the frontend.

Let’s take a look at an example of how you might add an Iota operator to Tripy. So that it doesn’t clash with Tripy’s actual Iota implementation, we’ll call it Theta instead.

Implementation

FlatIR Operator

The FlatIR operator is usually the most challenging aspect of implementing operators in Tripy. The good news is that you might not even need to do this if the low-level operators you need already exist in the FlatIR. And if you do, then it’ll only get easier after this!

We’ll start by adding a new file under tripy/flat_ir/ops called theta.py; see the inline comments for explanations of what’s happening:

 1from dataclasses import dataclass
 2
 3from mlir_tensorrt.compiler import ir
 4from mlir_tensorrt.compiler.dialects import stablehlo
 5
 6from tripy.flat_ir.ops.base import BaseFlatIROp
 7
 8
 9# Every `FlatIR` operator is implemented as a `dataclass` so that the base
10# class can automatically implement several methods by inspecting the child
11# class fields at runtime. The `repr=False` is important because the default
12# `__repr__` method generated by `dataclass` will be extremely verbose and
13# makes interactive debugging more difficult.
14@dataclass(repr=False)
15class ThetaOp(BaseFlatIROp):
16    dim: int
17
18    # `to_mlir()` is the trickiest bit. As the name implies, the method is meant to lower the
19    # `FlatIR` operator into MLIR. To figure out which MLIR operators to use, refer to
20    # the 'MLIR Python API Guide' (linked below).
21    def to_mlir(self, operands):
22        out_type = self.outputs[0].to_mlir()
23        theta_dim = ir.IntegerAttr.get(type=ir.IntegerType.get_signless(64), value=self.dim)
24        output = stablehlo.DynamicIotaOp(result=out_type, output_shape=operands[0], iota_dimension=theta_dim)
25        return [output]

Links:

Exposing The Operator

One of the principles we follow when writing submodules is that other submodules should not need to reach into the internals of a submodule to retrieve something they need.

For example, a class which needs to import ThetaOp does not need to know where exactly within the flat_ir.ops module the ThetaOp lives - it should be able to just import it from the submodule.

To make this possible, we need to import the ThetaOp into the flat_ir.ops submodule. We can do so by adding the following line into tripy/flat_ir/ops/__init__.py:

1from tripy.flat_ir.ops.theta import ThetaOp

Trace Operator And The Public API

Now that we have a FlatIR operator, we can implement a Trace operator that will use it along with a public API function. Let’s create a new file under tripy/frontend/trace/ops called theta.py.

Trace Operator

First, we’ll implement the Trace operator itself:

 1from dataclasses import dataclass
 2from typing import Tuple
 3
 4from tripy import utils
 5from tripy.common import datatype, device
 6from tripy.common.exception import raise_error
 7from tripy.frontend.trace.ops.base import BaseTraceOp
 8import tripy.frontend.trace.ops.utils as op_utils
 9
10
11# Just like with `FlatIR` operators, all `Trace` operators are implemented as `dataclass`es.
12# As before, we want `repr=False` here.
13@dataclass(repr=False)
14class Theta(BaseTraceOp):
15    # Notice that we do *not* need to define a constructor and can rely on the default
16    # implementation provided by `dataclass`.
17    dim: int
18    dtype: datatype.dtype
19
20    # `infer_rank()` populates the rank of the output `TraceTensor`s.
21    # Here we use one of the predefined policies to set the output rank to the same as the shape (i.e. the length)
22    # of the shape operand.
23    infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input()
24
25    # *Optional* `infer_dtypes()` populates the data types of the
26    # output `TraceTensor`s. The default implementation copies the input
27    # data types if they are all the same, so you may not need to implement this.
28    def infer_dtypes(self):
29        self.outputs[0].dtype = self.dtype
30
31    # *Optional* `infer_devices()` populates the devices of the
32    # output `TraceTensor`s. The default implementation copies the input
33    # devices if they are all the same, so you may not need to implement this either.
34    def infer_devices(self):
35        self.outputs[0].device = device("gpu")
36
37    # `to_flat_ir()` translates the `Trace` operator to a subgraph of
38    # one or more `FlatIR` operators. In our case, it's just a 1:1
39    # mapping to the `ThetaOp` we created earlier.
40    def to_flat_ir(self, inputs, outputs):
41        # Note that we import the `FlatIR` operator within the function
42        # call - this is to avoid circular dependencies.
43        from tripy.flat_ir.ops import ThetaOp
44        import tripy.frontend.trace.ops.utils as op_utils
45
46        # This code may look a bit confusing; for more details, look at the
47        # 'FlatIR section in the architecture document' (linked below).
48        ThetaOp.build(inputs, outputs, dim=self.dim)

Links:

Public API

Next, we can define the public interface. Since our public interface maps 1:1 with the Trace operator we just implemented and does not require weights, we’ll add it in the same file.

If our API required a composition of multiple Trace operators, then we would instead implement it under frontend/ops/.

If it required weights (i.e. inputs that are expected to always be constant), then we would implement it as a tripy.Module under frontend/module.

 1from tripy import export
 2import tripy.frontend.utils as frontend_utils
 3from tripy.types import ShapeLike
 4
 5# We can use the `export.public_api()` decorator to automatically export this function into the
 6# top-level module. This means it will be accessible as `tripy.theta`.
 7#
 8# This decorator also controls how the API is exposed in the documentation - the `document_under`
 9# option determines where in the documentation hierarchy this API will show up.
10#
11# If we needed to provide any special autodoc options, we could use the `autodoc_options` parameter.
12@export.public_api(document_under="tensor_operations")
13
14# The `convert_to_tensors` decorator automatically converts compatible arguments,
15# like `TensorLike` or `ShapeLike`s, into tensors.
16@frontend_utils.convert_to_tensors()
17def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
18    # For any public facing interfaces, we have documentation requirements which you can read
19    # about in the 'Docs README' (linked below). The docstring we've implemented here
20    # adheres to all of these requirements. Non-compliant docstrings will, in most cases,
21    # cause test failures; however, you should still manually ensure you're writing high-quality
22    # docstrings.
23    #
24    # The examples in docstrings are run as part of our tests, so you should also add
25    # assertions to make sure things are functionally correct. In this case, we check
26    # that the `output` we create in the code example is what we expect.
27    """
28    Fills an output tensor with consecutive values starting from zero along the given dimension.
29
30    Args:
31        shape: The desired shape.
32        dim: Dimension along which to perform the theta operation.
33            This cannot exceed the rank of the specified shape.
34        dtype: The desired data type.
35
36    Returns:
37        A tensor of shape ``shape`` and data type ``dtype``.
38
39    .. code-block:: python
40        :linenos:
41        :caption: Example
42
43        output = tp.theta([3])
44
45        assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32))
46    """
47
48    # Next we build the trace operator. The `build()` function is also responsible for constructing
49    # the output frontend Tensors. All of the arguments that follow the inputs
50    # are forwarded directly to the constructor of the `Trace` operator.
51    return Theta.build([shape], dim, dtype)
52

Links:

Exposing The Operator

Similarly to the FlatIR operator, we need to import Theta into the frontend.trace.ops submodule. We can do so by adding the following line into tripy/frontend/trace/ops/__init__.py:

1from tripy.frontend.trace.ops.theta import Theta, theta

Testing

Now that we’ve implemented our operator, let’s write tests for it. The structure of the tests/ directory mirrors that of the tripy/ directory (you can read more about that here). We need to test both the FlatIR and Trace operators.

Testing The Trace Operator And Public API

Since we implemented our Trace operator and public API in tripy/frontend/trace/ops, we’ll add the test under tests/frontend/trace/ops. Create a new file there called test_theta.py:

 1import tripy as tp
 2from tests import helper
 3from tripy.frontend.trace.ops import Theta
 4
 5
 6class TestTheta:
 7    # This ensures that the public API function creates a frontend `Tensor`
 8    # and populates it with the right `Trace` operator.
 9    def test_op_func(self):
10        a = tp.theta([2, 3])
11        assert isinstance(a, tp.Tensor)
12        assert isinstance(a.trace_tensor.producer, Theta)
13
14    # You should also include negative tests for anything that is expected to
15    # fail. In our case, we just have `test_invalid_dim`,
16    # which ensures that we emit an error if the `dim` parameter is outside
17    # the allowed range.
18    def test_invalid_dim(self):
19        with helper.raises(tp.TripyException, match="iota dimension cannot go beyond the output rank"):
20            tp.theta([2, 3], dim=3).eval()

Integration Tests

The code examples in the docstring of the public API serve as good sanity integration tests. However, you should still add separate integration tests to get better coverage.

Our docstring covers the 1D case, so let’s add an integration test to cover the multidimensional case. Create a new file called test_theta.py under tests/integration:

 1import numpy as np
 2import cupy as cp
 3
 4import tripy as tp
 5
 6
 7def test_multi_dimensional():
 8    output = tp.theta([2, 3], dim=1)
 9    expected = np.broadcast_to(np.arange(0, 3, dtype=np.float32), (2, 3))
10
11    assert np.array_equal(cp.from_dlpack(output).get(), expected)
12

Done!

If you’ve reached this point, you have successfully added a new operation to Tripy. Congratulations!