Source code for tilus.ir.layout.register_layout

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
import itertools
from dataclasses import dataclass
from functools import cached_property
from typing import Sequence

import tabulate
from hidet import boolean
from hidet.ir.expr import Expr, logical_and
from hidet.utils import prod

from tilus.extensions.hidet.ir.utils.index_transform import index_deserialize, index_serialize
from tilus.ir.mfunction import MultiFunction, multi_function
from tilus.ir.node import IRNode

Int = int | Expr


[docs] @dataclass(frozen=True, eq=False) class RegisterLayout(IRNode): """Layout for register tensor. Attributes ---------- shape: tuple[int, ...] The shape of the layout, which is the shape of the register tensor. mode_shape: tuple[int, ...] The size of each mode. spatial_modes: tuple[int, ...] The spatial modes. local_modes: tuple[int, ...] The local modes. """ shape: tuple[int, ...] mode_shape: tuple[int, ...] spatial_modes: tuple[int, ...] local_modes: tuple[int, ...] def __mul__(self, other): if not isinstance(other, RegisterLayout): raise TypeError(f"Cannot multiply {type(self)} with {type(other)}") from tilus.ir.layout.register_layout_ops import compose return compose(self, other) def __truediv__(self, other): if not isinstance(other, RegisterLayout): raise TypeError(f"Cannot divide {type(self)} with {type(other)}") from tilus.ir.layout.register_layout_ops import divide return divide(self, other) def __eq__(self, other): if not isinstance(other, RegisterLayout): return False return ( self.shape == other.shape and self.mode_shape == other.mode_shape and self.spatial_modes == other.spatial_modes and self.local_modes == other.local_modes ) def __hash__(self): return id(self) def with_shape(self, shape: Sequence[int]) -> RegisterLayout: validate_layout(shape, self.mode_shape, self.spatial_modes, self.local_modes) return dataclasses.replace(self, shape=tuple(shape)) @cached_property def grouped_modes(self): from .utils import get_mode_groups return get_mode_groups(self.shape, self.mode_shape) @cached_property def spatial_shape(self) -> list[int]: return [self.mode_shape[i] if i >= 0 else -i for i in self.spatial_modes] @cached_property def local_shape(self) -> list[int]: return [self.mode_shape[i] for i in self.local_modes] @cached_property def local_size(self) -> int: return prod(self.local_shape) @cached_property def spatial_size(self) -> int: return prod(self.spatial_shape) @cached_property def size(self) -> int: return prod(self.shape) def spatial_mfunction(self) -> MultiFunction: """ Get the multi-function that maps the global indices to the spatial indices (serialized). """ return multi_function( shape=self.shape, mode_shape=self.mode_shape, modes=self.spatial_modes, ) def get_spatial(self, global_indices: Sequence[Int]) -> list[Expr]: mode_indices: list[Int] = [] for index, modes in zip(global_indices, self.grouped_modes): shape = [self.mode_shape[mode] for mode in modes] mode_indices.extend(index_deserialize(index, shape)) replicate_dims = [] replicate_sizes = [] spatial_indices: list[Int] = [] for i, mode in enumerate(self.spatial_modes): if mode < 0: replicate_dims.append(i) replicate_sizes.append(-mode) spatial_indices.append(0) else: spatial_indices.append(mode_indices[mode]) results: list[Expr] = [] for items in itertools.product(*[range(s) for s in replicate_sizes]): for dim, value in zip(replicate_dims, items): spatial_indices[dim] = value results.append(index_serialize(spatial_indices, self.spatial_shape)) return results def get_local(self, global_indices: Sequence[Int]) -> Expr: if len(global_indices) != len(self.shape): raise ValueError( "Global indices must match the shape of the layout, got {} vs {}".format( len(global_indices), len(self.shape) ) ) mode_indices: list[Int] = [] for index, modes in zip(global_indices, self.grouped_modes): shape = [self.mode_shape[mode] for mode in modes] mode_indices.extend(index_deserialize(index, shape)) local_indices: list[Int] = [mode_indices[i] for i in self.local_modes] return index_serialize(local_indices, self.local_shape) def get_global(self, *, spatial_index: Int, local_index: Int) -> list[Expr]: spatial_indices = index_deserialize(spatial_index, self.spatial_shape) local_indices = index_deserialize(local_index, self.local_shape) mode_indices: list[Int] = [0 for _ in range(len(self.mode_shape))] for i, index in enumerate(spatial_indices): if self.spatial_modes[i] >= 0: mode_indices[self.spatial_modes[i]] = index for i, index in enumerate(local_indices): mode_indices[self.local_modes[i]] = index global_indices: list[Expr] = [] grouped_mode_indices = [[mode_indices[i] for i in group_modes] for group_modes in self.grouped_modes] for mode_indices, modes in zip(grouped_mode_indices, self.grouped_modes): shape = [self.mode_shape[i] for i in modes] global_indices.append(index_serialize(mode_indices, shape)) return global_indices # operations def local(self, *shape: int) -> RegisterLayout: from tilus.ir.layout.register_layout_ops import compose, local return compose(self, local(*shape)) def spatial(self, *shape: int) -> RegisterLayout: from tilus.ir.layout.register_layout_ops import compose, spatial return compose(self, spatial(*shape)) def column_spatial(self, *shape: int) -> RegisterLayout: from tilus.ir.layout.register_layout_ops import column_spatial, compose return compose(self, column_spatial(*shape)) def column_local(self, *shape: int) -> RegisterLayout: from tilus.ir.layout.register_layout_ops import column_local, compose return compose(self, column_local(*shape)) def reduce_to(self, shape: Sequence[int]) -> RegisterLayout: """ Reduce the layout to the given shape by removing the modes that are not in the shape. Parameters ---------- shape: Sequence[int] The shape to reduce to. Returns ------- ret: RegisterLayout The reduced layout. """ from tilus.ir.layout.register_layout_ops import reduce_to return reduce_to(self, shape)
def validate_layout( shape: Sequence[int], mode_shape: Sequence[int], spatial_modes: Sequence[int], local_modes: Sequence[int], ) -> None: """ Validate the layout parameters. Parameters ---------- shape: Sequence[int] The shape of the layout. mode_shape: Sequence[int] The shape of the modes. spatial_modes: Sequence[int] The spatial modes of the layout. local_modes: Sequence[int] The local modes of the layout. """ assert all(s >= 1 for s in shape), "Shape must only be positive integers" # validate modes remaining_shape = list(shape) for mode in reversed(mode_shape): if mode == 1: continue while remaining_shape and remaining_shape[-1] == 1: remaining_shape.pop() if len(remaining_shape) == 0 or remaining_shape[-1] % mode != 0: raise ValueError(f"Mode {mode} does not divide the remaining shape {remaining_shape}") remaining_shape[-1] //= mode while remaining_shape and remaining_shape[-1] == 1: remaining_shape.pop() if remaining_shape: raise ValueError("Modes {} and shape {} do not match".format(mode_shape, shape)) # the thread dims and local dims must be # 1. the indices of modes # 2. the thread_dims can contain negative values representing replicated threads (containing the same value) used_dims = [] for dim in spatial_modes: if dim < 0: continue if not (0 <= dim < len(mode_shape)): raise ValueError(f"Thread dim {dim} is out of range for modes {mode_shape}") used_dims.append(dim) for dim in local_modes: if not (0 <= dim < len(mode_shape)): raise ValueError(f"Local dim {dim} is out of range for modes {mode_shape}") used_dims.append(dim) if len(used_dims) != len(set(used_dims)): raise ValueError("Thread dims and local dims must be unique") def visualize_layout(layout: RegisterLayout) -> str: """ Visualize the layout in a human-readable format. Parameters ---------- layout: RegisterLayout The layout to be converted. Returns ------- ret: str The string representation of the layout that is human-readable. """ head = str(layout) # normalize the shape into 3-dimension shape = [s for s in layout.shape if s > 1] # prune 1s while len(shape) < 3: shape.insert(0, 1) while len(shape) > 3: shape = [prod(shape[0:2])] + shape[2:] layout = layout.with_shape(shape) tables: list[str] = [] for batch in range(shape[0]): table: list[list[str]] = [] for i in range(shape[1]): row = [] for j in range(shape[2]): local_index = layout.get_local(global_indices=[batch, i, j]) thread_indices = layout.get_spatial(global_indices=[batch, i, j]) thread_indices.sort() if len(thread_indices) == 1: row.append(f"{thread_indices[0]}: {local_index}") else: row.append(f"{thread_indices}: {local_index}") table.append(row) tables.append(tabulate.tabulate(table, tablefmt="simple_grid")) return head + "\n" + "\n".join(tables) def _canonicalize_singleton_modes(layout: RegisterLayout) -> RegisterLayout: singleton_modes = [mode for mode, size in enumerate(layout.mode_shape) if size == 1] if not singleton_modes: return layout mode_map = {} i = 0 for mode, size in enumerate(layout.mode_shape): if size == 1: mode_map[mode] = -1 else: mode_map[mode] = i i += 1 mode_shape = [size for size in layout.mode_shape if size > 1] spatial_modes = [ mode_map[mode] if mode >= 0 else mode for mode in layout.spatial_modes if mode < 0 or mode_map[mode] != -1 ] local_modes = [mode_map[mode] for mode in layout.local_modes if mode_map[mode] != -1] return RegisterLayout( shape=layout.shape, mode_shape=tuple(mode_shape), spatial_modes=tuple(spatial_modes), local_modes=tuple(local_modes), ) def _canonicalize_contiguous_modes(layout: RegisterLayout) -> RegisterLayout: # get the map to mode kind mode_kind: dict[int, str] = {} mode_index: dict[int, int] = {} for i, mode in enumerate(layout.spatial_modes): if mode < 0: continue mode_kind[mode] = "spatial" mode_index[mode] = i for i, mode in enumerate(layout.local_modes): mode_kind[mode] = "local" mode_index[mode] = i # determine the modes that should be merged merge_modes: list[list[int]] = [] for modes in layout.grouped_modes: i = 0 while i < len(modes): j = i while ( j + 1 < len(modes) and mode_kind[modes[j]] == mode_kind[modes[j + 1]] and mode_index[modes[j]] + 1 == mode_index[modes[j + 1]] ): j += 1 merge_modes.append(modes[i : j + 1]) i = j + 1 if all(len(modes) == 1 for modes in merge_modes): # no merge needed return layout # get the map from the original mode to the new mode-group, which corresponds to the mode in the canonical layout mode_map: dict[int, int] = {} for i, modes in enumerate(merge_modes): for j, mode in enumerate(modes): if j == 0: mode_map[mode] = i else: mode_map[mode] = -1 # this mode is merged with the first mode in the group, mark it as -1 mode_shape = [prod(layout.mode_shape[i] for i in modes) for modes in merge_modes] spatial_modes = [ mode_map[mode] if mode >= 0 else mode for mode in layout.spatial_modes if mode < 0 or mode_map[mode] != -1 ] local_modes = [mode_map[mode] for mode in layout.local_modes if mode_map[mode] != -1] return RegisterLayout( shape=layout.shape, mode_shape=tuple(mode_shape), spatial_modes=tuple(spatial_modes), local_modes=tuple(local_modes), )
[docs] def register_layout( shape: Sequence[int], mode_shape: Sequence[int], spatial_modes: Sequence[int], local_modes: Sequence[int], ) -> RegisterLayout: """ Create a register layout with the given shape, mode shape, spatial modes, and local modes. Parameters ---------- shape: Sequence[int] The shape of the layout. mode_shape: Sequence[int] The shape of the modes. spatial_modes: Sequence[int] The spatial modes of the layout. local_modes: Sequence[int] The local modes of the layout. Returns ------- ret: RegisterLayout The created register layout. """ validate_layout( shape=shape, mode_shape=mode_shape, spatial_modes=spatial_modes, local_modes=local_modes, ) layout = RegisterLayout( shape=tuple(shape), mode_shape=tuple(mode_shape), spatial_modes=tuple(spatial_modes), local_modes=tuple(local_modes), ) return canonicalize_layout(layout)
def canonicalize_layout(layout: RegisterLayout) -> RegisterLayout: """ Canonicalize the layout by 1. merging the modes that are contiguous in the three places: shape, spatial_modes, and local_modes 2. removing the singletons in the modes Any layout with the same mapping will be canonicalized to the same layout. Parameters ---------- layout: RegisterLayout The layout to be canonicalized Returns ------- ret: RegisterLayout The canonicalized layout """ return _canonicalize_contiguous_modes(_canonicalize_singleton_modes(layout)) def locate_at(layout: RegisterLayout, global_indices: Sequence[Int], spatial_index: Int) -> Expr: """ Check if the global indices are located at the given spatial index. Parameters ---------- layout: RegisterLayout The layout to be checked. global_indices: Sequence[Int] The global indices to be checked. spatial_index: Int The spatial index to be checked. Returns ------- ret: Expr Expression with value True if the global indices are located at the given spatial index, False otherwise. """ if len(global_indices) != len(layout.shape): raise ValueError( "Global indices must match the shape of the layout, got {} vs {}".format( len(global_indices), len(layout.shape) ) ) mode_indices: list[Int] = [] for index, modes in zip(global_indices, layout.grouped_modes): shape = [layout.mode_shape[mode] for mode in modes] mode_indices.extend(index_deserialize(index, shape)) condition = boolean.true spatial_indices: list[Expr] = index_deserialize(spatial_index, layout.spatial_shape) for i, mode in enumerate(layout.spatial_modes): if mode < 0: continue condition = logical_and(condition, mode_indices[mode] == spatial_indices[i]) return condition