where

tripy.where(condition: Tensor, input: Tensor, other: Tensor) Tensor[source]

Returns a new tensor of elements selected from either input or other, depending on condition.

Parameters:
  • condition (Tensor) – [dtype=T2] The condition tensor. Where this is True, elements are selected from input. Otherwise, elements are selected from other.

  • input (Tensor) – [dtype=T1] Tensor of values selected at indices where condition is True.

  • other (Tensor) – [dtype=T1] Tensor values selected at indices where condition is False.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

Constraints:

All three parameters must be broadcast-compatible with each other.

TYPE CONSTRAINTS:
Example
Example
1condition = tp.Tensor([[True, False], [True, True]])
2input = tp.ones([2, 2], dtype=tp.float32)
3other = tp.zeros([2, 2], dtype=tp.float32)
4output = tp.where(condition, input, other)
>>> condition
tensor(
    [[True, False],
     [True, True]], 
    dtype=bool, loc=gpu:0, shape=(2, 2))
>>> input
tensor(
    [[1.0000, 1.0000],
     [1.0000, 1.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
>>> other
tensor(
    [[0.0000, 0.0000],
     [0.0000, 0.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
>>> output
tensor(
    [[1.0000, 0.0000],
     [1.0000, 1.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))