Using the Compiler¶
Modules and functions can be compiled ahead of time for better runtime performance.
Note that the compiler imposes some requirements on the functions/modules it can compile.
See tripy.compile()
for details.
In this guide, we’ll work with the GEGLU module defined below:
1class GEGLU(tp.Module):
2 def __init__(self, dim_in, dim_out):
3 self.proj = tp.Linear(dim_in, dim_out * 2)
4 self.dim_out = dim_out
5
6 def __call__(self, x):
7 proj = self.proj(x)
8 x, gate = tp.split(proj, 2, proj.rank - 1)
9 return x * tp.gelu(gate)
We can run this in eager mode like usual:
1layer = GEGLU(2, 8)
2
3inp = tp.ones((1, 2))
4out = layer(inp)
>>> inp
tensor(
[[1.0000, 1.0000]],
dtype=float32, loc=gpu:0, shape=(1, 2))
>>> out
tensor(
[[41.0000, 276.0000, 561.0000, 896.0000, 1281.0000, 1716.0000, 2201.0000, 2736.0000]],
dtype=float32, loc=gpu:0, shape=(1, 8))
Compiling¶
Let’s optimize the module using tripy.compile()
.
When we compile in Tripy, we need to provide shape and data type information about each runtime input
using the tripy.InputInfo
API. Other parameters to the function will be considered compile-time
constants and will be folded into the compiled function.
GEGLU
only has one input, for which we’ll create an InputInfo
like so:
1inp_info = tp.InputInfo(shape=(1, 2), dtype=tp.float32)
Then we’ll compile, which will give us a tripy.Executable
that we can run:
1fast_geglu = tp.compile(layer, args=[inp_info])
2
3out = fast_geglu(inp)
>>> out
tensor(
[[41.0000, 276.0000, 561.0000, 896.0000, 1281.0000, 1716.0000, 2201.0000, 2736.0000]],
dtype=float32, loc=gpu:0, shape=(1, 8))
Dynamic Shapes¶
When we compiled above, we used a static shape of (1, 2)
for the input.
Tripy also supports specifying a range of possible values for each dimension like so:
1inp_info = tp.InputInfo(shape=([1, 8, 16], 2), dtype=tp.float32)
>>> inp_info
InputInfo(min=(1, 2), opt=(8, 2), max=(16, 2), dtype=float32)
The shape we used above indicates that the 0th dimension should support a range of values
from 1
to 16
, optimizing for a value of 8
. For the 1st dimension, we continue using
a fixed value of 2
.
Let’s compile again with our updated InputInfo
and try changing the input shape:
1fast_geglu = tp.compile(layer, args=[inp_info])
2
3# We'll run with the input we created above, which is of shape (1, 2)
4out0 = fast_geglu(inp)
5
6# Now let's try an input of shape (2, 2):
7inp1 = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
8out1 = fast_geglu(inp1)
>>> out0
tensor(
[[41.0000, 276.0000, 561.0000, 896.0000, 1281.0000, 1716.0000, 2201.0000, 2736.0000]],
dtype=float32, loc=gpu:0, shape=(1, 8))
>>> inp1
tensor(
[[1.0000, 2.0000],
[2.0000, 3.0000]],
dtype=float32, loc=gpu:0, shape=(2, 2))
>>> out1
tensor(
[[116.0000, 585.0000, 1152.0000, 1817.0000, 2580.0000, 3441.0000, 4400.0000, 5457.0000],
[273.0000, 1428.0000, 2825.0000, 4464.0000, 6345.0000, 8468.0000, 10833.0000, 13440.0000]],
dtype=float32, loc=gpu:0, shape=(2, 8))
If we try using a shape outside of the valid range, the executable will throw a nice error:
1inp = tp.ones((32, 2), dtype=tp.float32)
2print(fast_geglu(inp))
Output:
Exception occurred:
--> <string>:3 in <module>()
Unexpected tensor shape.
For tensor: `x`, expected a shape within the bounds: min=(1, 2), max=(16, 2), but got: [32, 2].
Dimension 0 has a shape of 32, which is not within the expected bounds of [1, 16].
Note: The provided argument was:
--> /tripy/tripy/frontend/ops/tensor_initializers.py:60 in ones()
|
60 | return full(shape, 1, dtype)
| ^^^^^^^^^^^^^^^^^^^^^
--> <string>:2 in <module>()
Saving The Executable¶
You can serialize and save executables like so:
1import os
2
3# Assuming `out_dir` is the directory where you'd like to save the executable:
4executable_file_path = os.path.join(out_dir, "executable.json")
5fast_geglu.save(executable_file_path)
Then you can load and run it again:
1loaded_fast_geglu = tp.Executable.load(executable_file_path)
2
3out = loaded_fast_geglu(inp)
>>> out
tensor(
[[41.0000, 276.0000, 561.0000, 896.0000, 1281.0000, 1716.0000, 2201.0000, 2736.0000]],
dtype=float32, loc=gpu:0, shape=(1, 8))