Adding New Operators¶
You may find it helpful to read the architecture documentation before you start reading this guide.
Adding new operators to Tripy typically involves making changes in the frontend as well
as in the FlatIR
. In some cases, the frontend operator can be expressed in terms of existing
FlatIR
operators, in which case you only need to make changes in the frontend.
Let’s take a look at an example of how you might add an Iota
operator to Tripy.
So that it doesn’t clash with Tripy’s actual Iota
implementation, we’ll call it
Theta
instead.
Implementation¶
FlatIR
Operator¶
The FlatIR
operator is usually the most challenging aspect of implementing operators
in Tripy. The good news is that you might not even need to do this if the low-level operators
you need already exist in the FlatIR
. And if you do, then it’ll only get easier after this!
We’ll start by adding a new file under tripy/flat_ir/ops
called
theta.py
; see the inline comments for explanations of what’s happening:
1from dataclasses import dataclass
2
3from mlir_tensorrt.compiler import ir
4from mlir_tensorrt.compiler.dialects import stablehlo
5
6from tripy.flat_ir.ops.base import BaseFlatIROp
7
8
9# Every `FlatIR` operator is implemented as a `dataclass` so that the base
10# class can automatically implement several methods by inspecting the child
11# class fields at runtime. The `repr=False` is important because the default
12# `__repr__` method generated by `dataclass` will be extremely verbose and
13# makes interactive debugging more difficult.
14@dataclass(repr=False)
15class ThetaOp(BaseFlatIROp):
16 dim: int
17
18 # `to_mlir()` is the trickiest bit. As the name implies, the method is meant to lower the
19 # `FlatIR` operator into MLIR. To figure out which MLIR operators to use, refer to
20 # the 'MLIR Python API Guide' (linked below).
21 def to_mlir(self, operands):
22 out_type = self.outputs[0].to_mlir()
23 theta_dim = ir.IntegerAttr.get(type=ir.IntegerType.get_signless(64), value=self.dim)
24 output = stablehlo.DynamicIotaOp(result=out_type, output_shape=operands[0], iota_dimension=theta_dim)
25 return [output]
Links:
Exposing The Operator¶
One of the principles we follow when writing submodules is that other submodules should not need to reach into the internals of a submodule to retrieve something they need.
For example, a class which needs to import ThetaOp
does not need to know where exactly
within the flat_ir.ops
module the ThetaOp
lives - it should be able to just import it
from the submodule.
To make this possible, we need to import the ThetaOp
into the flat_ir.ops
submodule.
We can do so by adding the following line into
tripy/flat_ir/ops/__init__.py
:
1from tripy.flat_ir.ops.theta import ThetaOp
Trace
Operator And The Public API¶
Now that we have a FlatIR
operator, we can implement a Trace
operator that will use it
along with a public API function. Let’s create a new file under
tripy/frontend/trace/ops
called theta.py
.
Trace
Operator¶
First, we’ll implement the Trace
operator itself:
1from dataclasses import dataclass
2from typing import Tuple
3
4from tripy import utils
5from tripy.common import datatype, device
6from tripy.common.exception import raise_error
7from tripy.frontend.trace.ops.base import BaseTraceOp
8import tripy.frontend.trace.ops.utils as op_utils
9
10
11# Just like with `FlatIR` operators, all `Trace` operators are implemented as `dataclass`es.
12# As before, we want `repr=False` here.
13@dataclass(repr=False)
14class Theta(BaseTraceOp):
15 # Notice that we do *not* need to define a constructor and can rely on the default
16 # implementation provided by `dataclass`.
17 dim: int
18 dtype: datatype.dtype
19
20 # `infer_rank()` populates the rank of the output `TraceTensor`s.
21 # Here we use one of the predefined policies to set the output rank to the same as the shape (i.e. the length)
22 # of the shape operand.
23 infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input()
24
25 # *Optional* `infer_dtypes()` populates the data types of the
26 # output `TraceTensor`s. The default implementation copies the input
27 # data types if they are all the same, so you may not need to implement this.
28 def infer_dtypes(self):
29 self.outputs[0].dtype = self.dtype
30
31 # *Optional* `infer_devices()` populates the devices of the
32 # output `TraceTensor`s. The default implementation copies the input
33 # devices if they are all the same, so you may not need to implement this either.
34 def infer_devices(self):
35 self.outputs[0].device = device("gpu")
36
37 # `to_flat_ir()` translates the `Trace` operator to a subgraph of
38 # one or more `FlatIR` operators. In our case, it's just a 1:1
39 # mapping to the `ThetaOp` we created earlier.
40 def to_flat_ir(self, inputs, outputs):
41 # Note that we import the `FlatIR` operator within the function
42 # call - this is to avoid circular dependencies.
43 from tripy.flat_ir.ops import ThetaOp
44 import tripy.frontend.trace.ops.utils as op_utils
45
46 # This code may look a bit confusing; for more details, look at the
47 # 'FlatIR section in the architecture document' (linked below).
48 ThetaOp.build(inputs, outputs, dim=self.dim)
Links:
Public API¶
Next, we can define the public interface. Since our public interface maps 1:1 with the Trace
operator we just implemented and does not require weights, we’ll add it in the same file.
If our API required a composition of multiple Trace
operators, then we would instead implement
it under frontend/ops/
.
If it required weights (i.e. inputs that are expected to always be constant), then we would implement
it as a tripy.Module
under frontend/module
.
1from tripy import export
2import tripy.frontend.utils as frontend_utils
3from tripy.types import ShapeLike
4
5# We can use the `export.public_api()` decorator to automatically export this function into the
6# top-level module. This means it will be accessible as `tripy.theta`.
7#
8# This decorator also controls how the API is exposed in the documentation - the `document_under`
9# option determines where in the documentation hierarchy this API will show up.
10#
11# If we needed to provide any special autodoc options, we could use the `autodoc_options` parameter.
12@export.public_api(document_under="tensor_operations")
13
14# The `convert_to_tensors` decorator automatically converts compatible arguments,
15# like `TensorLike` or `ShapeLike`s, into tensors.
16@frontend_utils.convert_to_tensors()
17def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
18 # For any public facing interfaces, we have documentation requirements which you can read
19 # about in the 'Docs README' (linked below). The docstring we've implemented here
20 # adheres to all of these requirements. Non-compliant docstrings will, in most cases,
21 # cause test failures; however, you should still manually ensure you're writing high-quality
22 # docstrings.
23 #
24 # The examples in docstrings are run as part of our tests, so you should also add
25 # assertions to make sure things are functionally correct. In this case, we check
26 # that the `output` we create in the code example is what we expect.
27 """
28 Fills an output tensor with consecutive values starting from zero along the given dimension.
29
30 Args:
31 shape: The desired shape.
32 dim: Dimension along which to perform the theta operation.
33 This cannot exceed the rank of the specified shape.
34 dtype: The desired data type.
35
36 Returns:
37 A tensor of shape ``shape`` and data type ``dtype``.
38
39 .. code-block:: python
40 :linenos:
41 :caption: Example
42
43 output = tp.theta([3])
44
45 assert np.array_equal(cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32))
46 """
47
48 # Next we build the trace operator. The `build()` function is also responsible for constructing
49 # the output frontend Tensors. All of the arguments that follow the inputs
50 # are forwarded directly to the constructor of the `Trace` operator.
51 return Theta.build([shape], dim, dtype)
52
Links:
Exposing The Operator¶
Similarly to the FlatIR
operator, we need to import Theta
into the
frontend.trace.ops
submodule. We can do so by adding the following line into
tripy/frontend/trace/ops/__init__.py
:
1from tripy.frontend.trace.ops.theta import Theta, theta
Testing¶
Now that we’ve implemented our operator, let’s write tests for it. The structure of the
tests/
directory mirrors that of the tripy/
directory
(you can read more about that here). We need to test both the FlatIR
and Trace
operators.
Testing The Trace Operator And Public API¶
Since we implemented our Trace
operator and public API in
tripy/frontend/trace/ops
, we’ll add the test under
tests/frontend/trace/ops
.
Create a new file there called test_theta.py
:
1import tripy as tp
2from tests import helper
3from tripy.frontend.trace.ops import Theta
4
5
6class TestTheta:
7 # This ensures that the public API function creates a frontend `Tensor`
8 # and populates it with the right `Trace` operator.
9 def test_op_func(self):
10 a = tp.theta([2, 3])
11 assert isinstance(a, tp.Tensor)
12 assert isinstance(a.trace_tensor.producer, Theta)
13
14 # You should also include negative tests for anything that is expected to
15 # fail. In our case, we just have `test_invalid_dim`,
16 # which ensures that we emit an error if the `dim` parameter is outside
17 # the allowed range.
18 def test_invalid_dim(self):
19 with helper.raises(tp.TripyException, match="iota dimension cannot go beyond the output rank"):
20 tp.theta([2, 3], dim=3).eval()
Integration Tests¶
The code examples in the docstring of the public API serve as good sanity integration tests. However, you should still add separate integration tests to get better coverage.
Our docstring covers the 1D case, so let’s add an integration test to cover the multidimensional case.
Create a new file called test_theta.py
under tests/integration
:
1import numpy as np
2import cupy as cp
3
4import tripy as tp
5
6
7def test_multi_dimensional():
8 output = tp.theta([2, 3], dim=1)
9 expected = np.broadcast_to(np.arange(0, 3, dtype=np.float32), (2, 3))
10
11 assert np.array_equal(cp.from_dlpack(output).get(), expected)
12
Done!¶
If you’ve reached this point, you have successfully added a new operation to Tripy. Congratulations!