Embedding

class tripy.Embedding(num_embeddings: int, embedding_dim: int, dtype: dtype = float32)[source]

Bases: Module

A lookup table for embedding vectors of a fixed size. Embedding vectors can be retrieved by their indices.

Parameters:
  • num_embeddings (int) – Number of embedding vectors in the lookup table.

  • embedding_dim (int) – Size of each embedding vector in the lookup table.

  • dtype (dtype) – The data type to use for the weight parameter.

Example
Example
1embedding = tp.Embedding(num_embeddings=4, embedding_dim=6)
2
3input = tp.Tensor([0, 2], dtype=tp.int32)
4output = embedding(input)
>>> embedding.state_dict()
{
    weight: tensor(
        [[0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000],
         [6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000],
         [12.0000, 13.0000, 14.0000, 15.0000, 16.0000, 17.0000],
         [18.0000, 19.0000, 20.0000, 21.0000, 22.0000, 23.0000]], 
        dtype=float32, loc=gpu:0, shape=(4, 6)),
}
>>> input
tensor([0, 2], dtype=int32, loc=gpu:0, shape=(2,))
>>> output
tensor(
    [[0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000],
     [12.0000, 13.0000, 14.0000, 15.0000, 16.0000, 17.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 6))
dtype: dtype

The data type used to perform the operation

weight: Parameter

The embedding lookup table of shape \([\text{num_embeddings}, \text{embedding_dim}]\).

__call__(x: Tensor) Tensor[source]
Parameters:

x (Tensor) – A tensor of shape \([N]\) containing the indices of the desired embedding vectors.

Returns:

A tensor of shape \([N, \text{embedding_dim}]\) containing the embedding vectors.

Return type:

Tensor