Tensor

class nvtripy.Tensor(data: Any, dtype: dtype | None = None, device: device | None = None, name: str | None = None, fetch_stack_info: bool = True)[source]

Bases: object

A tensor is a multi-dimensional array that contains elements of a uniform data type.

Parameters:
  • data (Any) – The data with which to initialize the tensor. For types that support the DLPack protocol, copying data is avoided if possible.

  • dtype (dtype | None) – The data type of the tensor.

  • device (device | None) – The device on which to allocate the tensor. If the provided data is not on this device, it will be copied. By default, the tensor will be allocated on the same device as the data argument.

  • name (str | None) – The name of the tensor. If provided, this must be a unique string.

  • fetch_stack_info (bool) – Whether to fetch stack information for the tensor. Stack information allows Tripy to generate much higher quality error messages at the cost of a small overhead when initializing the tensor.

Example
1tensor = tp.Tensor([1.0, 2.0, 3.0], dtype=tp.float32)
Local Variables
>>> tensor
tensor([1, 2, 3], dtype=float32, loc=cpu:0, shape=(3,))
eval() Tensor[source]

Immediately evaluates this tensor. By default, tensors are evaluated lazily.

Note that an evaluated tensor will always reside in device memory.

Returns:

The evaluated tensor.

Return type:

Tensor

Example
 1import time
 2
 3start = time.perf_counter()
 4tensor = tp.ones((3, 3))
 5init_time = time.perf_counter()
 6tensor.eval()
 7eval_time = time.perf_counter()
 8
 9print(f"Tensor init_time took: {(init_time - start)  * 1000.0:.3f} ms")
10print(f"Tensor evaluation took: {(eval_time - init_time)  * 1000.0:.3f} ms")
Local Variables
>>> tensor
tensor(
    [[1, 1, 1],
     [1, 1, 1],
     [1, 1, 1]], 
    dtype=float32, loc=gpu:0, shape=(3, 3))
Output
Tensor init_time took: 3.364 ms
Tensor evaluation took: 31.854 ms
__abs__() Tensor

Computes the elementwise absolute value of the input tensor.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor of the same shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1input = tp.Tensor([-1, -2], dtype=tp.int32)
2output = abs(input)
Local Variables
>>> input
tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
__add__(other: Tensor | Number) Tensor

Performs an elementwise sum.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to add to this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1, 2])
2b = tp.Tensor([2, 3])
3output = a + b
Local Variables
>>> a
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
__eq__(other: Tensor | Number) Tensor

Performs an elementwise ‘equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 5])
3output = b == a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__floordiv__(other: Tensor | Number) Tensor

Performs an elementwise floor division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to floor-divide this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([3.0, 4.0])
3output = a // b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 1], dtype=float32, loc=gpu:0, shape=(2,))
__ge__(other: Tensor | Number) Tensor

Performs an elementwise ‘greater than or equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 1])
3output = b >= a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 1], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__getitem__(index: Tensor | slice | int | DimensionSize | Sequence[slice | int | DimensionSize]) Tensor

Returns a tensor containing a slice of this tensor.

Parameters:
  • self (Tensor) – [dtype=T1] Tensor that will be sliced.

  • index (Tensor | slice | int | DimensionSize | Sequence[slice | int | DimensionSize]) – The index or slice. If this is a Tensor, the operation is equivalent to calling gather() along the first dimension.

Returns:

[dtype=T1] A tensor containing the slice of this tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example: Indexing With Integers
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[1]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor([2, 3], dtype=float32, loc=gpu:0, shape=(2,))
Example: Indexing With Slices
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[1:]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Reversing Data With Negative Step
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[:, ::-1]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[1, 0],
     [3, 2],
     [5, 4]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))
Example: Indexing With Tensors (Gather)
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2index = tp.Tensor([2, 0], dtype=tp.int32)
3output = input[index]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> index
tensor([2, 0], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor(
    [[4, 5],
     [0, 1]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
__gt__(other: Tensor | Number) Tensor

Performs an elementwise ‘greater than’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([3, 1])
3output = b > a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 1], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__invert__() Tensor

Performs an elementwise logical NOT.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([True, False, False])
2output = ~a
Local Variables
>>> a
tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> output
tensor([False, True, True], dtype=bool, loc=gpu:0, shape=(3,))
__le__(other: Tensor | Number) Tensor

Performs an elementwise ‘less than or equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 5])
3output = b <= a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__lt__(other: Tensor | Number) Tensor

Performs an elementwise ‘less than’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 5])
3output = b < a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__matmul__(other: Tensor) Tensor

Performs matrix multiplication between two tensors.

  • If both tensors are 1D, a dot product is performed. The output is a scalar.

  • If either argument, but not both, is 1D, matrix-vector multiplication is performed:
    • For inputs of shape \((M, N)\) and \((N,)\), the output will have shape \((M,)\).

    • For inputs of shape \((N,)\) and \((N, K)\), the output will have shape \((K,)\).

  • If both tensors are 2D, matrix-matrix multiplication is performed.

    For inputs of shape \((M, N)\) and \((N, K)\), the output will have shape \((M, K)\).

  • If the tensor has more than 2 dimensions, it is treated as a stack of matrices.

    If the ranks differ for tensors with 2 or more dimensions, dimensions are prepended until the ranks match. The first \(N-2\) dimensions will be broacasted if required.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor) – [dtype=T1] The tensor by which to multiply.

Returns:

[dtype=T1] A new tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example: Dot Product
1a = tp.iota((3,), dtype=tp.float32)
2b = tp.iota((3,), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> b
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> output
tensor(5, dtype=float32, loc=gpu:0, shape=())
Example: Matrix-Vector Multiplication
1a = tp.iota((3,), dtype=tp.float32)
2b = tp.iota((3, 2), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> b
tensor(
    [[0, 0],
     [1, 1],
     [2, 2]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor([5, 5], dtype=float32, loc=gpu:0, shape=(2,))
Example: Matrix-Matrix Multiplication
1a = tp.iota((2, 3), dtype=tp.float32)
2b = tp.iota((3, 2), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor(
    [[0, 0, 0],
     [1, 1, 1]], 
    dtype=float32, loc=gpu:0, shape=(2, 3))

>>> b
tensor(
    [[0, 0],
     [1, 1],
     [2, 2]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[0, 0],
     [3, 3]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Batched Matrix Multiplication
1a = tp.iota((1, 2, 2, 2), dtype=tp.float32, dim=-1)
2b = tp.iota((1, 2, 2), dtype=tp.float32, dim=-2)
3
4output = a @ b
Local Variables
>>> a
tensor(
    [[[[0, 1],
       [0, 1]],

      [[0, 1],
       [0, 1]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))

>>> b
tensor(
    [[[0, 0],
      [1, 1]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2))

>>> output
tensor(
    [[[[1, 1],
       [1, 1]],

      [[1, 1],
       [1, 1]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))
__mod__(other: Tensor | Number) Tensor

Performs a modulo operation, which computes the remainder of a division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide self. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([3.0, 4.0])
3output = a % b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,))
__mul__(other: Tensor | Number) Tensor

Performs an elementwise multiplication.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to multiply this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a * b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
__ne__(other: Tensor | Number) Tensor

Performs an elementwise ‘not equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 3])
3output = b != a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__neg__() Tensor

Computes the elementwise megative value of the elements of the input tensor.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor of the same shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1input = tp.Tensor([-1, -2], dtype=tp.int32)
2output = -input
Local Variables
>>> input
tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
__or__(other: Tensor) Tensor

Performs an elementwise logical OR.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor) – [dtype=T1] The tensor to OR with this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([True, False, False])
2b = tp.Tensor([False, True, False])
3output = a | b
Local Variables
>>> a
tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> b
tensor([False, True, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> output
tensor([True, True, False], dtype=bool, loc=gpu:0, shape=(3,))
__pow__(other: Tensor | Number) Tensor

Performs an elementwise exponentiation.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to exponentiate this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a**b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 8], dtype=float32, loc=gpu:0, shape=(2,))
__radd__(other: Tensor | Number) Tensor

Performs an elementwise sum.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to add to this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1, 2])
2b = tp.Tensor([2, 3])
3output = a + b
Local Variables
>>> a
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
__rfloordiv__(other: Tensor | Number) Tensor

Performs an elementwise floor division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be floor-divided by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 2
2b = tp.Tensor([2.0, 3.0])
3output = a // b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 0], dtype=float32, loc=gpu:0, shape=(2,))
__rmod__(other: Tensor | Number) Tensor

Performs a modulo operation, which computes the remainder of a division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide self. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2output = 2 % a
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
__rmul__(other: Tensor | Number) Tensor

Performs an elementwise multiplication.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to multiply this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a * b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
__rpow__(other: Tensor | Number) Tensor

Performs an elementwise exponentiation.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be exponentiated by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 2.0
2b = tp.Tensor([2.0, 3.0])
3output = a**b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([4, 8], dtype=float32, loc=gpu:0, shape=(2,))
__rsub__(other: Tensor | Number) Tensor

Performs an elementwise subtraction.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be subtracted from this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 1
2b = tp.Tensor([1, 2])
3output = a - b
Local Variables
>>> b
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([0, -1], dtype=int32, loc=gpu:0, shape=(2,))
__rtruediv__(other: Tensor | Number) Tensor

Performs an elementwise division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be divided by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 6.0
2b = tp.Tensor([2.0, 3.0])
3output = a / b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 2], dtype=float32, loc=gpu:0, shape=(2,))
__sub__(other: Tensor | Number) Tensor

Performs an elementwise subtraction.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to subtract from this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 2])
3output = a - b
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 1], dtype=int32, loc=gpu:0, shape=(2,))
__truediv__(other: Tensor | Number) Tensor

Performs an elementwise division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([2.0, 3.0])
3output = a / b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
property shape: Tuple[int | DimensionSize]

Represents the shape of the tensor.

Parameters:

self – [dtype=T1] The input tensor.

Returns:

A sequence containing the shape of this tensor.

DATA TYPE CONSTRAINTS:
Example
1input = tp.ones((8, 2))
2shape = input.shape
Local Variables
>>> input
tensor(
    [[1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1]], 
    dtype=float32, loc=gpu:0, shape=(8, 2))

>>> shape
(
    8,
    2,
)
tolist() List | Number[source]

Returns the tensor as a nested list. If the tensor is a scalar, returns a python number.

Returns:

The tensor represented as a nested list or a python number.

Return type:

List | Number

Example: Ranked tensor
1tensor = tp.ones((2, 2))
2tensor_list = tensor.tolist()
Local Variables
>>> tensor_list
[[1.0, 1.0], [1.0, 1.0]]
Example: Scalar
1tensor = tp.Tensor(2.0, dtype=tp.float32)
2tensor_scalar = tensor.tolist()
Local Variables
>>> tensor_scalar
2.0

See also: