where¶
- nvtripy.where(condition: Tensor, input: Tensor, other: Tensor) Tensor [source]¶
Returns a new tensor of elements selected from either
input
orother
, depending oncondition
.- Parameters:
condition (Tensor) – [dtype=T2] The condition tensor. Where this is
True
, elements are selected frominput
. Otherwise, elements are selected fromother
.input (Tensor) – [dtype=T1] Tensor of values selected at indices where condition is
True
.other (Tensor) – [dtype=T1] Tensor values selected at indices where condition is
False
.
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
- Constraints:
All three parameters must be broadcast-compatible with each other.
- TYPE CONSTRAINTS:
Example
1condition = tp.Tensor([[True, False], [True, True]]) 2input = tp.ones([2, 2], dtype=tp.float32) 3other = tp.zeros([2, 2], dtype=tp.float32) 4output = tp.where(condition, input, other)
>>> condition tensor( [[True, False], [True, True]], dtype=bool, loc=gpu:0, shape=(2, 2)) >>> input tensor( [[1.0000, 1.0000], [1.0000, 1.0000]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> other tensor( [[0.0000, 0.0000], [0.0000, 0.0000]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> output tensor( [[1.0000, 0.0000], [1.0000, 1.0000]], dtype=float32, loc=gpu:0, shape=(2, 2))