# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Sequence
from hidet.ir.dtypes import int32
from hidet.ir.expr import Expr, Var, as_expr
from tilus.extensions.hidet.ir.expr import index_vars
from tilus.extensions.hidet.ir.utils.index_transform import index_multiply
from tilus.ir.node import IRNode
from tilus.utils import prod
[docs]
@dataclass(frozen=True, eq=False)
class GlobalLayout(IRNode):
"""The layout for global tensor.
Attributes
----------
shape: tuple[Expr, ...]
The shape of the global tensor. Each dimension can be an expression of grid-invariant expression, or a constant
integer.
size: Expr
The storage size of the global tensor, in number of elements. If the layout is a `compact` layout, size
should be equal to the product of the shape dimensions. Otherwise, it can be either larger (in case of padding)
or smaller (in case of sharing data for different elements) than the product of the shape dimensions.
axes: tuple[Var, ...]
The axes of the global tensor. Each axis is a variable that represents the index of the corresponding dimension.
It should have the same length as the shape.
offset: Expr
An expression that computes the offset of the global tensor based on the axes. Only the axes and grid-invariant
variables can be used in the expression.
"""
shape: tuple[Expr, ...]
size: Expr
axes: tuple[Var, ...]
offset: Expr
[docs]
def __call__(self, *indices: Expr) -> Expr:
"""Compute the offset on given indices.
This method computes the offset of an element in the global tensor with the given indices.
Parameters
----------
indices: Sequence[Expr]
The indices of the global tensor. The length of the indices should match the number of axes in the layout.
Returns
-------
ret: Expr
The computed offset of the global tensor element at the given indices.
"""
assert len(indices) == len(self.axes)
from hidet.ir.tools import rewrite
return rewrite(self.offset, rewrite_map={axis: index for axis, index in zip(self.axes, indices)})
[docs]
@staticmethod
def create(shape: Sequence[Expr | int], size: Expr, f_offset: Callable[[Sequence[Var]], Expr]) -> GlobalLayout:
"""Create a global layout with custom mapping.
This method creates a global layout with the given shape, size, and the mapping function for the offset.
Parameters
----------
shape: Sequence[Expr | int]
The shape of the global tensor. Each dimension can be an expression of grid-invariant expression, or a
constant integer.
size: Expr
The storage size of the global tensor, in number of elements. If the layout is a `compact` layout, size
should be equal to the product of the shape dimensions. Otherwise, it can be either larger (in case of
padding) or smaller (in case of sharing data for different elements) than the product of the shape dimensions.
f_offset: Callable[[Sequence[Var]], Expr]
A function that takes a sequence of axes (variables) and returns an expression that computes the offset
of the global tensor based on those axes. The axes are variables that represent the index of the corresponding
dimension. The function should not use any variables that are not grid-invariant.
Returns
-------
ret: GlobalLayout
A new instance of GlobalLayout with the specified shape, size, axes, and offset function.
"""
expr_shape = tuple(as_expr(s) for s in shape)
axes: list[Var] = index_vars(num_vars=len(shape))
return GlobalLayout(shape=expr_shape, size=size, axes=tuple(axes), offset=f_offset(axes))
def _generic_repeat(shape: Sequence[Expr | int], ranks: Sequence[int]) -> GlobalLayout:
assert len(shape) == len(ranks)
assert len(ranks) == len(set(ranks)) and all(0 <= d < len(shape) for d in ranks)
strides: list[Expr] = [prod([s for j, s in enumerate(shape) if ranks[j] > ranks[i]]) for i in range(len(shape))]
def f_offset(axes: Sequence[Var]) -> Expr:
return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero)
return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)
def _global_compose(lhs: GlobalLayout, rhs: GlobalLayout) -> GlobalLayout:
assert len(lhs.shape) == len(rhs.shape)
ndims = len(lhs.shape)
def f_offset(axes: Sequence[Var]) -> Expr:
lhs_indices = [axes[i] // rhs.shape[i] for i in range(ndims)]
rhs_indices = [axes[i] // rhs.shape[i] for i in range(ndims)]
lhs_offset = lhs(*lhs_indices)
rhs_offset = rhs(*rhs_indices)
return lhs_offset * rhs.size + rhs_offset
shape = index_multiply(lhs.shape, rhs.shape)
size = lhs.size * rhs.size
return GlobalLayout.create(shape=shape, size=size, f_offset=f_offset)
[docs]
def global_row_major(*shape: Expr | int) -> GlobalLayout:
"""Create a global layout with row-major order.
Parameters
----------
shape: Sequence[Expr | int]
The shape of the global tensor. Each dimension can be an expression of grid-invariant expression, or a
constant integer.
Returns
-------
ret: GlobalLayout
A global layout with the specified shape in row-major order.
"""
return _generic_repeat(shape=shape, ranks=list(range(len(shape))))
[docs]
def global_column_major(*shape: Expr | int) -> GlobalLayout:
"""Create a global layout with column-major order.
Parameters
----------
shape: Sequence[Expr | int]
The shape of the global tensor. Each dimension can be an expression of grid-invariant expression, or a
constant integer.
Returns
-------
ret: GlobalLayout
A global layout with the specified shape in column-major order.
"""
return _generic_repeat(shape=shape, ranks=list(reversed(range(len(shape)))))
[docs]
def global_compose(lhs: GlobalLayout, rhs: GlobalLayout, *others: GlobalLayout) -> GlobalLayout:
"""Compose multiple global layouts.
This function composes two or more global layouts into a single global layout.
Please refer to our research paper `Tilus <https://arxiv.org/pdf/2504.12984>`_, Section 4.2 for more details on layout composition.
Parameters
----------
lhs: GlobalLayout
The left-hand side global layout.
rhs: GlobalLayout
The right-hand side global layout.
others: Sequence[GlobalLayout]
The additional global layouts to be composed with the first two. It's optional and can be empty.
Returns
-------
ret: GlobalLayout
The composed global layout that combines the effects of all input layouts.
"""
if len(others) == 0:
return _global_compose(lhs, rhs)
else:
return global_compose(_global_compose(lhs, rhs), *others)
[docs]
def global_strides(shape: Sequence[Expr | int], strides: Sequence[Expr | int]) -> GlobalLayout:
"""Create a global layout with specified strides.
This function creates a global layout with the given shape and strides. Given the axes and strides, we map the
axes to ``sum(axes[i] * strides[i])`` to compute the offset of the global tensor.
Parameters
----------
shape: Sequence[Expr | int]
The shape of the global tensor. Each dimension can be an expression of grid-invariant expression, or a
constant integer.
strides: Sequence[Expr | int]
The strides of the global tensor. Each stride corresponds to the step size in each dimension when traversing
the global tensor. It should have the same length as the shape. Each stride can be an expression of grid-invariant
expression, or a constant integer.
Returns
-------
ret: GlobalLayout
A global layout with the specified shape and strides. The offset is computed as the sum of the product of each
axis and its corresponding stride.
"""
assert len(shape) == len(strides)
def f_offset(axes: Sequence[Var]) -> Expr:
return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero)
return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)