gather

tripy.gather(input: Tensor, dim: int, index: Tensor) Tensor[source]

Gather values from the input tensor along the specified axis based on the specified indices. This behaves similarly to numpy.take().

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor

  • dim (int) – Axis along which data is gathered.

  • index (Tensor) – [dtype=T2] The indices of elements to gather.

Returns:

[dtype=T1] A new tensor of the same shape along every dimension except dim, which will have a size equal to len(index).

Return type:

Tensor

TYPE CONSTRAINTS:
Example
Example
1data = tp.iota((3, 3, 2))
2indices = tp.Tensor([0, 2], dtype=tp.int32)
3output = tp.gather(data, 1, indices)
>>> data
tensor(
    [[[0.0000, 0.0000],
      [0.0000, 0.0000],
      [0.0000, 0.0000]],

     [[1.0000, 1.0000],
      [1.0000, 1.0000],
      [1.0000, 1.0000]],

     [[2.0000, 2.0000],
      [2.0000, 2.0000],
      [2.0000, 2.0000]]], 
    dtype=float32, loc=gpu:0, shape=(3, 3, 2))
>>> indices
tensor([0, 2], dtype=int32, loc=gpu:0, shape=(2,))
>>> output
tensor(
    [[[0.0000, 0.0000],
      [0.0000, 0.0000]],

     [[1.0000, 1.0000],
      [1.0000, 1.0000]],

     [[2.0000, 2.0000],
      [2.0000, 2.0000]]], 
    dtype=float32, loc=gpu:0, shape=(3, 2, 2))