## SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromdataclassesimportdataclassfromtripyimportexport,utilsfromtripy.commonimportdatatypefromtripy.frontend.module.moduleimportModulefromtripy.frontend.module.parameterimportParameter,DefaultParameter
[docs]@export.public_api(document_under="operations/modules")@dataclass@utils.constant_fields(["dtype"])classEmbedding(Module):""" A lookup table for embedding vectors of a fixed size. Embedding vectors can be retrieved by their indices. """dtype:datatype.dtyper"""The data type used to perform the operation"""weight:Parameterr"""The embedding lookup table of shape :math:`[\text{num_embeddings}, \text{embedding_dim}]`."""def__init__(self,num_embeddings:int,embedding_dim:int,dtype:datatype.dtype=datatype.float32)->None:r""" Args: num_embeddings: Number of embedding vectors in the lookup table. embedding_dim: Size of each embedding vector in the lookup table. dtype: The data type to use for the weight parameter. .. code-block:: python :linenos: :caption: Example embedding = tp.Embedding(num_embeddings=4, embedding_dim=6) input = tp.Tensor([0, 2], dtype=tp.int32) output = embedding(input) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(embedding.weight).get()[[0,2], :]) """super().__init__()self.dtype=dtypeself.weight=DefaultParameter((num_embeddings,embedding_dim),dtype)
[docs]def__call__(self,x:"tripy.Tensor")->"tripy.Tensor":r""" Args: x: A tensor of shape :math:`[N]` containing the indices of the desired embedding vectors. Returns: A tensor of shape :math:`[N, \text{embedding_dim}]` containing the embedding vectors. """fromtripy.frontend.trace.ops.gatherimportgatherreturngather(self.weight,0,x)