## SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importmathfromtypingimportAny,Sequence,Optionalfromtripyimportexportfromtripy.frontend.tensorimportTensorfromtripy.utilsimportResult
[docs]@export.public_api(document_under="modules",autodoc_options=[":no-members:",":no-special-members:"])classParameter(Tensor):""" A Parameter is a special kind of :class:`tripy.Tensor` that is treated by the compiler as a constant, enabling additional optimization opportunities. """def__init__(self,tensor:Any)->None:""" Args: tensor: The tensor value for this parameter. If provided as an external data format (e.g., a Numpy array), it will be converted into a Tripy Tensor. .. code-block:: python :linenos: :caption: Example parameter = tp.Parameter(tp.Tensor([1.0, 1.0], dtype=tp.float32)) assert isinstance(parameter, tp.Parameter) assert isinstance(parameter, tp.Tensor) """t=tensor# for convenience, this will convert other dlpack-supporting representations tooifnotisinstance(t,Tensor):t=Tensor(t)self.__dict__=t.__dict__def_is_compatible_helper(self,original_shape,other_shape,original_dtype,other_dtype)->Result:iflist(original_shape)!=list(other_shape):returnResult.err(["New parameter shape: ",other_shape," is not compatible with current shape: ",original_shape])iforiginal_dtype!=other_dtype:returnResult.err(["New parameter dtype: ",other_dtype," is not compatible with current dtype: ",original_dtype])returnResult.ok()def_is_compatible(self,other:"Parameter")->Result:# Determines whether another parameter has the same shape and# data type as this one.returnself._is_compatible_helper(self.shape,other.shape,self.dtype,other.dtype)
classDefaultParameter(Parameter):""" Behaves exactly like a parameter except does not cause data to be materialized for shape/dtype checks. Useful for initializing module parameters. """def__init__(self,shape:Optional[Sequence[int]],dtype:"tripy.dtype")->None:fromtripy.frontend.ops.tensor_initializersimportarangefromtripy.frontend.trace.ops.reshapeimportreshape_is_shape_known=TrueifshapeisNone:_is_shape_known=Falseshape=[]super().__init__(reshape(arange(math.prod(shape),dtype),shape))self._shape=shapeself._dtype=dtypeself._is_shape_known=_is_shape_knowndef_is_compatible(self,other:"Parameter")->Result:shape=self._shapeifnotself._is_shape_known:shape=other.shapereturnself._is_compatible_helper(shape,other.shape,self._dtype,other.dtype)