gather¶
- tripy.gather(input: Tensor, dim: int, index: Tensor) Tensor [source]¶
Gather values from the input tensor along the specified axis based on the specified indices. This behaves similarly to
numpy.take()
.- Parameters:
- Returns:
[dtype=T1] A new tensor of the same shape along every dimension except
dim
, which will have a size equal tolen(index)
.- Return type:
Example
1data = tp.iota((3, 3, 2)) 2indices = tp.Tensor([0, 2], dtype=tp.int32) 3output = tp.gather(data, 1, indices)
>>> data tensor( [[[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000]], [[1.0000, 1.0000], [1.0000, 1.0000], [1.0000, 1.0000]], [[2.0000, 2.0000], [2.0000, 2.0000], [2.0000, 2.0000]]], dtype=float32, loc=gpu:0, shape=(3, 3, 2)) >>> indices tensor([0, 2], dtype=int32, loc=gpu:0, shape=(2,)) >>> output tensor( [[[0.0000, 0.0000], [0.0000, 0.0000]], [[1.0000, 1.0000], [1.0000, 1.0000]], [[2.0000, 2.0000], [2.0000, 2.0000]]], dtype=float32, loc=gpu:0, shape=(3, 2, 2))