argmax

tripy.argmax(input: Tensor, dim: int | None = None, keepdim: bool = False) Tensor[source]

Returns a new tensor containing the indices of maximum values of the input tensor along the specified dimension. If there are multiple maximum values, then the indices of the first maximum value are returned.

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor.

  • dim (int | None) – The dimension along which to reduce. If this is not provided, the index of the flattened input is returned.

  • keepdim (bool) – Whether to retain reduced dimensions in the output. If this is False, reduced dimensions will be squeezed.

Returns:

[dtype=T2] A new tensor.

Return type:

Tensor

TYPE CONSTRAINTS:
Example
Example
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
2output = tp.argmax(input, 0)
>>> input
tensor(
    [[0.0000, 1.0000, 2.0000],
     [3.0000, 4.0000, 5.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 3))
>>> output
tensor([1, 1, 1], dtype=int32, loc=gpu:0, shape=(3,))