where¶
- nvtripy.where(condition: Tensor, input: Tensor, other: Tensor) Tensor [source]¶
Returns a new tensor of elements selected from either
input
orother
, depending oncondition
.- Parameters:
condition (Tensor) – [dtype=T2] The condition tensor. Where this is
True
, elements are selected frominput
. Otherwise, elements are selected fromother
.input (Tensor) – [dtype=T1] Tensor of values selected at indices where condition is
True
.other (Tensor) – [dtype=T1] Tensor values selected at indices where condition is
False
.
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
- Constraints:
All three parameters must be broadcast-compatible with each other.
- DATA TYPE CONSTRAINTS:
Example
1condition = tp.Tensor([[True, False], [True, True]]) 2input = tp.ones([2, 2], dtype=tp.float32) 3other = tp.zeros([2, 2], dtype=tp.float32) 4output = tp.where(condition, input, other)
Local Variables¶>>> condition tensor( [[True, False], [True, True]], dtype=bool, loc=cpu:0, shape=(2, 2)) >>> input tensor( [[1, 1], [1, 1]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> other tensor( [[0, 0], [0, 0]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> output tensor( [[1, 0], [1, 1]], dtype=float32, loc=gpu:0, shape=(2, 2))