Source code for tilus.lang.instructions.clc

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tilus.hidet.ir.expr import Expr, Var
from tilus.ir.tensor import RegisterTensor, SharedTensor
from tilus.lang.constructs.structs import Dim3

from .root import InstructionGroup


[docs] class ClusterLaunchControlInstructionGroup(InstructionGroup): """Cluster Launch Control (CLC) instructions for dynamic work scheduling on Blackwell GPUs. CLC enables a running kernel to **cancel clusters that have not yet started**, effectively implementing dynamic grid scheduling and work-stealing patterns. A scheduler CTA can request cancellation of a pending cluster, and if successful, take over that cluster's work. The workflow is: 1. ``try_cancel()`` — asynchronously request cancellation. An opaque 16-byte response is written to shared memory, tracked by an mbarrier. 2. ``mbarrier.wait()`` — wait for the response to arrive. 3. ``query_response()`` — decode the response to check if cancellation succeeded and retrieve the canceled cluster's CTA coordinates. If cancellation succeeds, the scheduler can use the returned CTA ID to compute the work tile that the canceled cluster would have processed, and execute that work itself. If it fails (the cluster already started), the scheduler can retry or proceed with other work. With ``multicast=True``, the response is broadcast to all CTAs in the requesting cluster, so all CTAs can independently query the result. """
[docs] def try_cancel(self, response: SharedTensor, mbarrier: Expr | RegisterTensor, multicast: Expr | bool) -> None: """Request cancellation of a cluster that has not yet been launched. This instruction asynchronously requests the cancellation of a cluster that has not started running yet. It writes an opaque 16-byte response to shared memory indicating whether the operation succeeded or failed. The completion of the asynchronous operation is tracked using the provided mbarrier. On success, the response contains the CTA ID of the first CTA of the canceled cluster. No other successful response from other `try_cancel` operations from the same grid will contain that ID. The response can be decoded using the `query_response` method to determine if the cancellation was successful and to retrieve the CTA ID of the first CTA in the canceled cluster. Important: If the executing CTA has already observed the completion of a `try_cancel` instruction as failed, then issuing a subsequent `try_cancel` instruction results in undefined behavior. Parameters ---------- response: SharedTensor A naturally aligned 16-byte wide shared memory tensor where the request's response will be written. Must be in .shared::cta state space. mbarrier: Expr | RegisterTensor The mbarrier object used to track completion of the asynchronous operation. This instruction automatically performs an mbarrier arrive operation combined with an expect-tx operation on the mbarrier, setting the transaction count to 16 bytes. When the asynchronous write to `response` completes, a complete-tx operation with completeCount equal to 16 bytes will be performed on this mbarrier, decrementing the tx-count by 16 bytes and potentially allowing the mbarrier to transition to the next phase once both tx-count and pending arrivals reach zero. multicast: Expr | bool If True, the response is asynchronously written using weak async-proxy writes to the corresponding local shared memory address of each CTA in the requesting cluster. In multicast mode, for each CTA in the cluster, an mbarrier arrive operation combined with an expect-tx operation (16 bytes) is performed on that CTA's mbarrier. The completion of the writes to each CTA is signaled via a complete-tx operation to the mbarrier object on that CTA's shared memory. When using multicast, at least 32 threads are required in the current thread group, and the behavior is undefined if any CTA in the cluster has exited. If False, a single mbarrier arrive with expect-tx operation is performed on the local mbarrier, and the response is written only to the local shared memory of the calling CTA. Notes ----- - **Thread group**: Can be executed by any sized thread group (at least 32 threads if ``multicast=True``). - **Hardware**: Requires compute capability 10.0+ (sm_100). - **PTX**: ``clusterlaunchcontrol.try_cancel`` - This instruction performs an mbarrier arrive operation combined with an expect-tx operation (16 bytes) before issuing the cancellation request. The mbarrier's tx-count is increased by 16 bytes when the instruction is issued, and decreased by 16 bytes when the response write completes asynchronously. """ self._builder.cluster_launch_control_try_cancel(response, mbarrier, multicast)
[docs] def query_response(self, response: SharedTensor) -> tuple[Var, Dim3]: """Query the response from a cluster launch control try_cancel operation. This instruction decodes the opaque 16-byte response written by the `try_cancel` instruction. It extracts two pieces of information: 1. Whether the cancellation request succeeded (is_canceled) 2. If successful, the CTA ID (x, y, z coordinates) of the first CTA in the canceled cluster The response should be loaded from shared memory after the mbarrier used in `try_cancel` has signaled completion of the asynchronous operation. If the cancellation request failed, the CTA ID coordinates in the returned Dim3 are undefined and should not be used. Parameters ---------- response: SharedTensor A 16-byte wide shared memory tensor containing the opaque response from a `try_cancel` operation. This should be the same tensor that was passed to `try_cancel`. Returns ------- is_canceled: Var A variable (predicate/boolean) that is True if the cluster was successfully canceled, False otherwise. first_cta_id: Dim3 If the cancellation succeeded, this contains the (x, y, z) coordinates of the first CTA in the canceled cluster. If the cancellation failed, these values are undefined. Notes ----- - **Thread group**: Can be executed by any sized thread group. - **Hardware**: Requires compute capability 10.0+ (sm_100). - **PTX**: ``clusterlaunchcontrol.query_cancel.is_canceled`` and ``clusterlaunchcontrol.query_cancel.get_first_ctaid`` - The behavior is undefined if called before the ``try_cancel`` operation has completed (i.e., before the associated mbarrier has signaled completion). """ ret = self._builder.cluster_launch_control_query_response(response) items = [] for i in range(4): # (is_canceled, first_cta_x, first_cta_y, first_cta_z) items.append( self._builder.tensor_item_value( self._builder.slice_register(ret, offsets=[i], slice_dims=[], slice_shape=[]) ) ) return (items[0], Dim3(items[1], items[2], items[3]))