argmax¶
- nvtripy.argmax(input: Tensor, dim: int | None = None, keepdim: bool = False) Tensor[source]¶
Returns a new tensor containing the indices of maximum values of the input tensor along the specified dimension.
If there are multiple maximum values, then the indices of the first maximum value are returned.
- Parameters:
- Returns:
A new tensor.
- Return type:
- INPUT REQUIREMENTS:
input.dtypeis one of [float32,float16,bfloat16,int32,int64]- OUTPUT GUARANTEES:
return[0].dtype==int32
Example
1input = tp.Tensor([[1.0, 0.0, 3.0], [0.5, 2.0, 1.5]]) 2output = tp.argmax(input, 0)
Local Variables¶>>> input tensor( [[1, 0, 3], [0.5, 2, 1.5]], dtype=float32, loc=cpu:0, shape=(2, 3)) >>> output tensor([0, 1, 0], dtype=int32, loc=gpu:0, shape=(3,))