## SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromnvtripyimportexportfromnvtripy.commonimportdatatypeasdtfromnvtripy.frontendimportwrappersfromnvtripy.frontend.constraintsimportGetInput,GetReturn,OneOffromnvtripy.frontend.opsimportutilsasop_utilsfromnvtripy.frontend.ops._registryimportregister_tensor_methodfromnvtripy.frontend.ops.dequantizeimportdequantizefromnvtripy.frontend.ops.quantizeimportquantizefromnvtripy.trace.ops.castimportCast
[docs]@register_tensor_method("cast")@export.public_api(document_under="operations/functions")@wrappers.interface(input_requirements=(((GetInput("input").dtype!=dt.float8)|~OneOf(GetInput("dtype"),[dt.int4,dt.int8]))&((GetInput("input").dtype!=dt.int8)|(GetInput("dtype")!=dt.float8))&((GetInput("input").dtype!=dt.int4)|~OneOf(GetInput("dtype"),[dt.float8,dt.int8,dt.int64]))),output_guarantees=GetReturn(0).dtype==GetInput("dtype"),)defcast(input:"nvtripy.Tensor",dtype:"nvtripy.dtype")->"nvtripy.Tensor":r""" Returns a tensor with the contents of the input tensor casted to the specified data type. For casts into quantized datatypes (:class:`int4` and :class:`float8`), this performs a per-tensor quantization into that datatype with scale 1.0; for casts `from` those datatypes, this performs a per-tensor dequantization with scale 1.0. Direct use of :func:`quantize` and :func:`dequantize` allows for finer control over these parameters. Args: input: The input tensor. dtype: The desired data type. Returns: A tensor containing the casted values. .. code-block:: python :linenos: input = tp.Tensor([1, 2]) output = tp.cast(input, tp.float32) assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 2], dtype=np.float32)) .. seealso:: :func:`quantize`, :func:`dequantize` """ifinput.dtype==dtype:returninput# Note: we check for int8 below because MLIR-TRT can handle it in ordinary conversions# even though it is a quantized dtype# If given a quantized input, dequantize before converting. If bool is the target dtype,# we do still need to quantize int8s because it compiles into an MLIR-TRT *comparison* opifop_utils.is_quantized_dtype(input.dtype)and(input.dtype!=dt.int8ordtype==dt.bool):dequant_dtype=dt.float32input=dequantize(input,1.0,dequant_dtype)ifdtype==dequant_dtype:returninputifop_utils.is_quantized_dtype(dtype)anddtype!=dt.int8:ifinput.dtype!=dt.float32:input=op_utils.create_op(Cast,[input],dt.float32)returnquantize(input,1.0,dtype)returnop_utils.create_op(Cast,[input],dtype)