Architecture

This document explains the overall architecture of Tripy.

Overview

The main technical requirement of Tripy is twofold:

  1. We must be able to provide a Pythonic, functional style interface to the user.

  2. We must be able to provide a computation graph to the compiler.

Tripy uses a series of intermediate representations to solve this problem. Below is a diagram of how these IRs are connected to each other:

graph TD subgraph "Python" subgraph "Frontend" A[Tripy Python API] -->|Stage Out| B[Trace] end subgraph "Flat IR" B --> C[FlatIR] end subgraph "Backend" C --> D[MLIR StableHLO] end end subgraph "C++" D --> E[MLIR-TRT]; end

Trace

The Trace is meant to provide a 1:1 graph representation of the user’s Python code. It is a bipartite graph consisting of Trace Operations and TraceTensors - each Trace Operation is connected to input/output TraceTensors and each TraceTensor has knowledge of its producer Trace Operation.

FlatIR

The FlatIR is a lower level representation that provides a 1:1 mapping to MLIR operations. Like the Trace, it is a bipartite graph consisting of FlatIR Operations and FlatIRTensors.

MLIR

The final representation before we hand off to the compiler is the MLIR itself, which, in our case, consists primarily of StableHLO operations but can also include other dialects for certain operations.

A Day In The Life Of A Tripy Tensor

To understand these components better, let’s take a look at what happens when we write a simple program like:

1inp = tp.full((2, 3), value=0.5)
2out = tp.tanh(inp)
3out.eval()

Tracing

We’ll start with the first line:

1inp = tp.full((2, 3), value=0.5)

Where Do tp.full() and tp.tanh() Come From?

The tp.full() and tp.tanh() APIs are part of the frontend and like other frontend functions, map to one or more (just one in this case) Trace operations. For frontend functions that map to exactly one Trace operation, we define the function directly alongside the corresponding Trace operation. In this case, the Fill operation provides tp.full() and the UnaryElementwise operation provides tp.tanh().

We organize it this way to reduce the number of files that need to be touched when adding new ops. If an operation is composed of multiple Trace operations, the frontend function can be defined under the frontend/ops submodule instead.

What Does It Do?

Tripy uses a lazy evaluation model; that means that computation doesn’t happen immediately when you call a function like tp.full() or tp.tanh(). Instead, all we do is create a frontend Tensor object which contains a Trace operation. The Trace operation includes inputs and outputs in the form of TraceTensors and is essentially just a symbolic representation of the computation that needs to be done.

As we call other functions that use this frontend Tensor, we connect new Trace operations to its output TraceTensors. You can think of this as iteratively building up an implicit graph.

The Implicit Frontend Graph

As mentioned before, as you create new frontend Tensors, we build up an implicit graph comprised of Trace operations and TraceTensors.

After running both of these lines, our implicit graph will look something like this:

graph TD subgraph "'inp' Tensor" A[Fill] --> B(trace_tensor0) end subgraph "'out' Tensor" B --> C[UnaryElementwise] C --> D(trace_tensor1) end

Evaluation

The bulk of the real work happens once we reach the final line:

1out.eval()

As mentioned before, Tripy uses a lazy evaluation model where a tensor is only evaluated when it is used. A tensor is considered “used” when, for example:

  • We interoperate with another framework (e.g. torch.from_dlpack(out) or np.from_dlpack(out))

  • When __repr__ is called (e.g. if we print(out))

  • We explicitly call eval() as we’re doing here.

In order to actually evaluate the tensor, a few different things need to happen:

Building The Trace

We start by building up the Trace. Since each frontend Tensor contains a Trace operation that’s already connected with the Trace operations in other tensors, we just need to walk backwards from the output tensor, collecting trace operations as we go.

Here’s the textual representation for the Trace from our example:

t117 = storage(data=[2, 3], shape=(2,), dtype=int32, device=gpu:0)
t118 = storage(shape=(), dtype=float32, device=gpu:0)
t119 = fill(t117, t118, dtype=float32)
t121 = unaryelementwise(t119, kind=Kind.TANH)
outputs:
    t121: [shape=([-1, -1]), dtype=(float32), loc=(gpu:0)]

When we’ve built up the complete trace, we run rank, data type, and device inference. This is why the output tensor in the trace has its rank, dtype, and loc fields populated.

Lowering To FlatIR

Once we have the Trace, we lower it into FlatIR. FlatIR is a very thin layer which provides a 1:1 mapping with the MLIR dialects we use.

To perform the lowering, each Trace operation implements to_flat_ir(), which generates a subgraph with one or more FlatIR operations.

Here’s a snippet for how you might implement tanh (the actual implementation differs, but this is good enough for a conceptual understanding):

1def to_flat_ir(self, inputs, outputs):
2    from tripy.flat_ir.ops import TanhOp
3
4    TanhOp.build(inputs, outputs)

Wait a second - what’s happening here? The function doesn’t return anything; in fact, it doesn’t appear to be doing anything at all!

The way this works is as follows: when we call to_flat_ir() we provide input and output FlatIRTensors. to_flat_ir() is responsible for generating a subgraph of FlatIR operations that bind to these inputs and outputs. The BaseFlatIROp build function updates the producer of the output tensors, meaning that just building a FlatIR operation is enough to add it to the subgraph. Once this binding is done, we take the resulting subgraph and inline it into the FlatIR, remapping the I/O tensors to those that already exist in the FlatIR.

Here’s the textual representation for the FlatIR from our example; you’ll notice that we have more operations now than we did in the trace:

Main Function:
inputs:
t117: [rank=(1), shape=([2]), dtype=(int32), loc=(gpu:0)] = ConstantOp(data=[2, 3])
t118: [rank=(0), shape=([]), dtype=(float32), loc=(gpu:0)] = ConstantOp()
t119: [rank=(2), dtype=(float32), loc=(gpu:0)] = DynamicBroadcastOp(t118, t117, broadcast_dim=[])
t121: [rank=(2), dtype=(float32), loc=(gpu:0)] = TanhOp(t119)
outputs:
    t121: [rank=(2), dtype=(float32), loc=(gpu:0)]

Lowering To MLIR

Our final translation step is to go from FlatIR into MLIR.

Similar to Trace operations, FlatIR operations implement to_mlir() which generates MLIR operations. Unlike Trace operations, this is always a 1:1 mapping.

Here’s a snippet for how tanh() is implemented:

1def to_mlir(self, operands):
2    return [stablehlo.TanhOp(*operands)]

There’s not much more to explain here, so let’s go right to the textual representation:

module @outs_t121_17 {
  func.func @main() -> tensor<?x?xf32> {
    %c = stablehlo.constant dense<[2, 3]> : tensor<2xi32>
    %cst = stablehlo.constant dense<5.000000e-01> : tensor<f32>
    %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32>
    %1 = stablehlo.tanh %0 : tensor<?x?xf32>
    return %1 : tensor<?x?xf32>
  }
}

Compilation

Once we have the complete MLIR representation, we then compile it to an executable using the MLIR-TRT compiler.

Execution

Finally, we use the MLIR-TRT executor to launch the executable and retrieve the output tensors. The executable returns memrefs which we then wrap in Tripy frontend tripy.Tensors.