Source code for tilus.lang.script
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, TypeAlias, TypeVar
from tilus.hidet.ir.expr import Expr
from tilus.lang.instructions import InstructionInterface
from tilus.lang.modules.cuda import cuda
if TYPE_CHECKING:
from tilus.lang.instantiated_script import InstantiatedScript, JitInstance
Int: TypeAlias = int | Expr
[docs]
class Attributes:
#: The grid dimensions of the kernel launch. Set as a list of up to 3
#: integers, e.g., ``self.attrs.blocks = [grid_x, grid_y]``.
blocks: Optional[Sequence[Int] | Int] = None
#: The cluster dimensions. Defaults to ``(1, 1, 1)``.
cluster_blocks: Sequence[Int] | Int = (1, 1, 1)
#: The number of warps per thread block. Must be a compile-time constant.
warps: Optional[int] = None
[docs]
class Script(InstructionInterface):
"""A script is a user-defined kernel function that can be compiled and executed on the GPU."""
# the compiled program will print the instruction output of the specified block
debug_block: Optional[tuple[int, int, int]] = None
# specify the schedule used for debugging. it will override any autotune space
debug_schedule: Optional[dict[str, Any]] = None
def __new__(cls, *args, **kwargs) -> InstantiatedScript: # type: ignore[no-untyped-def, misc]
from tilus.lang.instantiated_script import InstantiatedScriptCache
instantiated_script: InstantiatedScript = InstantiatedScriptCache.get(
script_cls=cls,
script_args=args,
script_kwargs=kwargs,
)
return instantiated_script
[docs]
def __init__(self) -> None:
super().__init__()
# attributes
self._attrs: Attributes = Attributes()
# modules
self.cuda = cuda
[docs]
def __call__(self, *args, **kwargs):
raise RuntimeError("This method should never be called.")
def jit_instance_for(self, *args: object, **kwargs: object) -> JitInstance:
"""
Instantiate the script program with the specified arguments and keyword arguments.
Parameters
----------
args:
The positional arguments to the __call__ method.
kwargs:
The keyword arguments to the __call__ method.
Returns
-------
ret: JitInstance
The JIT instance for the script with given arguments.
"""
raise RuntimeError("This method should never be called. See InstantiatedScript.jit_instance instead.")
# the following properties should only be access in the __call__ function
@property
def attrs(self) -> Attributes:
"""Kernel attributes like number of blocks and warps.
See :py:class:`Attributes <tilus.lang.Attributes>` for more details.
"""
return self._attrs
# the following functions should only be called in the __call__ function to construct the script program
T = TypeVar("T")
[docs]
def autotune(arg_names: str, arg_values: Sequence[Any]) -> Callable[[T], T]:
"""Annotate an autotune subspace for a tilus script.
Parameters
----------
arg_names: str
The names of the arguments for autotuning, separated by commas.
arg_values: Sequence[Any]
The sequence of the choices for the autotune parameters. Each choice can be a single value or a sequence of
values that match the names in `arg_names`.
Returns
-------
ret: Callable[[Type[Script]], Type[Script]]
The decorator that can be applied to a tilus script class for the marking of autotune parameters.
"""
def decorator(script_cls: T) -> T:
if not hasattr(script_cls, "_autotune_space"):
setattr(script_cls, "_autotune_space", {})
space = getattr(script_cls, "_autotune_space")
names = [name.strip() for name in arg_names.split(",")]
# check names and arg_values
# 1. can not define the same name more than once
if any(name in space for name in names):
common_names = set(names) & set(space.keys())
raise RuntimeError("Duplicated specification for parameters: {}".format(common_names))
# 2. the arg_values should match the names during unpacking
if not isinstance(arg_values, Sequence):
raise TypeError("The arg_values values must be a sequence")
for arg_value in arg_values:
if len(names) > 1:
if not isinstance(arg_value, Sequence) or len(arg_value) != len(names):
raise TypeError(
"Can not unpack the arg_values for arg_names\n"
f" arg_names: {arg_names}\n"
f" arg_value: {arg_value}"
)
space[arg_names] = arg_values
setattr(script_cls, "_autotune_space", space)
# return functools.wraps(wrapped=script_cls, assigned=arg_names)(script_cls)
return script_cls
return decorator