Tensor

class nvtripy.Tensor(data: Any, dtype: dtype | None = None, device: device | None = None, name: str | None = None)[source]

Bases: object

A tensor is a multi-dimensional array that contains elements of a uniform data type.

Parameters:
  • data (Any) – The data with which to initialize the tensor. For types that support the DLPack protocol, copying data is avoided if possible.

  • dtype (dtype | None) – The data type of the tensor. This parameter is only allowed when data is an empty list.

  • device (device | None) – The device on which to allocate the tensor. This parameter is only allowed when data is a Python number or list.

  • name (str | None) – The name of the tensor. If provided, this must be a unique string.

Example
1tensor = tp.Tensor([1.0, 2.0, 3.0])
Local Variables
>>> tensor
tensor([1, 2, 3], dtype=float32, loc=cpu:0, shape=(3,))
eval() Tensor[source]

Immediately evaluates this tensor. By default, tensors are evaluated lazily.

Returns:

The evaluated tensor.

Return type:

Tensor

Example
 1import time
 2
 3start = time.perf_counter()
 4tensor = tp.ones((3, 3))
 5init_time = time.perf_counter()
 6tensor.eval()
 7eval_time = time.perf_counter()
 8
 9print(f"Tensor init_time took: {(init_time - start)  * 1000.0:.3f} ms")
10print(f"Tensor evaluation took: {(eval_time - init_time)  * 1000.0:.3f} ms")
Local Variables
>>> tensor
tensor(
    [[1, 1, 1],
     [1, 1, 1],
     [1, 1, 1]], 
    dtype=float32, loc=gpu:0, shape=(3, 3))
Output
Tensor init_time took: 1.479 ms
Tensor evaluation took: 61.540 ms
__abs__() Tensor

Computes the elementwise absolute value of the input tensor.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor of the same shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1input = tp.Tensor([-1, -2])
2output = abs(input)
Local Variables
>>> input
tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
__add__(other: Tensor | Number) Tensor

Performs an elementwise sum.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to add to this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1, 2])
2b = tp.Tensor([2, 3])
3output = a + b
Local Variables
>>> a
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
__eq__(other: Tensor | Number) Tensor

Performs an elementwise ‘equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 5])
3output = b == a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__floordiv__(other: Tensor | Number) Tensor

Performs an elementwise floor division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to floor-divide this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([3.0, 4.0])
3output = a // b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 1], dtype=float32, loc=gpu:0, shape=(2,))
__ge__(other: Tensor | Number) Tensor

Performs an elementwise ‘greater than or equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 1])
3output = b >= a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 1], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__getitem__(index: Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) Tensor

Returns a tensor containing a slice of this tensor.

Parameters:
  • self (Tensor) – [dtype=T1] Tensor that will be sliced.

  • index (Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) – The index or slice. If this is a Tensor, the operation is equivalent to calling gather() along the first dimension. If this is None, a new dimension of size 1 will be inserted at that position.

Returns:

[dtype=T1] A tensor containing the slice of this tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example: Indexing With Integers
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[1]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor([2, 3], dtype=float32, loc=gpu:0, shape=(2,))
Example: Indexing With Slices
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[1:]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Indexing With Ellipsis
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (1, 3, 2))
2output = input[..., 1:]
Local Variables
>>> input
tensor(
    [[[0, 1],
      [2, 3],
      [4, 5]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 2))

>>> output
tensor(
    [[[1],
      [3],
      [5]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 1))
Example: Reversing Data With Negative Step
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[:, ::-1]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[1, 0],
     [3, 2],
     [5, 4]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))
Example: Indexing With Tensors (Gather)
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2index = tp.Tensor([2, 0])
3output = input[index]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> index
tensor([2, 0], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor(
    [[4, 5],
     [0, 1]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Adding New Dimensions With None
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2))
2output = input[None, :, None]
Local Variables
>>> input
tensor(
    [[0, 1],
     [2, 3],
     [4, 5]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[[[0, 1]],

      [[2, 3]],

      [[4, 5]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 1, 2))
__gt__(other: Tensor | Number) Tensor

Performs an elementwise ‘greater than’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([3, 1])
3output = b > a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 1], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__invert__() Tensor

Performs an elementwise logical NOT.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([True, False, False])
2output = ~a
Local Variables
>>> a
tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> output
tensor([False, True, True], dtype=bool, loc=gpu:0, shape=(3,))
__le__(other: Tensor | Number) Tensor

Performs an elementwise ‘less than or equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([2, 5])
3output = b <= a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__lt__(other: Tensor | Number) Tensor

Performs an elementwise ‘less than’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 5])
3output = b < a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 5], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__matmul__(other: Tensor) Tensor

Performs matrix multiplication between two tensors.

  • If both tensors are 1D, a dot product is performed. The output is a scalar.

  • If either argument, but not both, is 1D, matrix-vector multiplication is performed:
    • For inputs of shape \((M, N)\) and \((N,)\), the output will have shape \((M,)\).

    • For inputs of shape \((N,)\) and \((N, K)\), the output will have shape \((K,)\).

  • If both tensors are 2D, matrix-matrix multiplication is performed.

    For inputs of shape \((M, N)\) and \((N, K)\), the output will have shape \((M, K)\).

  • If the tensor has more than 2 dimensions, it is treated as a stack of matrices.

    If the ranks differ for tensors with 2 or more dimensions, dimensions are prepended until the ranks match. The first \(N-2\) dimensions will be broacasted if required.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor) – [dtype=T1] The tensor by which to multiply.

Returns:

[dtype=T1] A new tensor.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example: Dot Product
1a = tp.iota((3,), dtype=tp.float32)
2b = tp.iota((3,), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> b
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> output
tensor(5, dtype=float32, loc=gpu:0, shape=())
Example: Matrix-Vector Multiplication
1a = tp.iota((3,), dtype=tp.float32)
2b = tp.iota((3, 2), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,))

>>> b
tensor(
    [[0, 0],
     [1, 1],
     [2, 2]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor([5, 5], dtype=float32, loc=gpu:0, shape=(2,))
Example: Matrix-Matrix Multiplication
1a = tp.iota((2, 3), dtype=tp.float32)
2b = tp.iota((3, 2), dtype=tp.float32)
3
4output = a @ b
Local Variables
>>> a
tensor(
    [[0, 0, 0],
     [1, 1, 1]], 
    dtype=float32, loc=gpu:0, shape=(2, 3))

>>> b
tensor(
    [[0, 0],
     [1, 1],
     [2, 2]], 
    dtype=float32, loc=gpu:0, shape=(3, 2))

>>> output
tensor(
    [[0, 0],
     [3, 3]], 
    dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Batched Matrix Multiplication
1a = tp.iota((1, 2, 2, 2), dtype=tp.float32, dim=-1)
2b = tp.iota((1, 2, 2), dtype=tp.float32, dim=-2)
3
4output = a @ b
Local Variables
>>> a
tensor(
    [[[[0, 1],
       [0, 1]],

      [[0, 1],
       [0, 1]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))

>>> b
tensor(
    [[[0, 0],
      [1, 1]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2))

>>> output
tensor(
    [[[[1, 1],
       [1, 1]],

      [[1, 1],
       [1, 1]]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))
__mod__(other: Tensor | Number) Tensor

Performs a modulo operation, which computes the remainder of a division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide self. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([3.0, 4.0])
3output = a % b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,))
__mul__(other: Tensor | Number) Tensor

Performs an elementwise multiplication.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to multiply this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a * b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
__ne__(other: Tensor | Number) Tensor

Performs an elementwise ‘not equal’ comparison.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be compared to this one. It should be broadcast-compatible.

Returns:

[dtype=T2] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 3])
3output = b != a
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
__neg__() Tensor

Computes the elementwise megative value of the elements of the input tensor.

Parameters:

self (Tensor) – [dtype=T1] The input tensor.

Returns:

[dtype=T1] A new tensor of the same shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1input = tp.Tensor([-1, -2])
2output = -input
Local Variables
>>> input
tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
__or__(other: Tensor) Tensor

Performs an elementwise logical OR.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor) – [dtype=T1] The tensor to OR with this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([True, False, False])
2b = tp.Tensor([False, True, False])
3output = a | b
Local Variables
>>> a
tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> b
tensor([False, True, False], dtype=bool, loc=cpu:0, shape=(3,))

>>> output
tensor([True, True, False], dtype=bool, loc=gpu:0, shape=(3,))
__pow__(other: Tensor | Number) Tensor

Performs an elementwise exponentiation.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to exponentiate this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a**b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 8], dtype=float32, loc=gpu:0, shape=(2,))
__radd__(other: Tensor | Number) Tensor

Performs an elementwise sum.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to add to this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1, 2])
2b = tp.Tensor([2, 3])
3output = a + b
Local Variables
>>> a
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
__rfloordiv__(other: Tensor | Number) Tensor

Performs an elementwise floor division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be floor-divided by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 2
2b = tp.Tensor([2.0, 3.0])
3output = a // b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 0], dtype=float32, loc=gpu:0, shape=(2,))
__rmod__(other: Tensor | Number) Tensor

Performs a modulo operation, which computes the remainder of a division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide self. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2output = 2 % a
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
__rmul__(other: Tensor | Number) Tensor

Performs an elementwise multiplication.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to multiply this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([1.0, 2.0])
2b = tp.Tensor([2.0, 3.0])
3output = a * b
Local Variables
>>> a
tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
__rpow__(other: Tensor | Number) Tensor

Performs an elementwise exponentiation.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be exponentiated by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 2.0
2b = tp.Tensor([2.0, 3.0])
3output = a**b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([4, 8], dtype=float32, loc=gpu:0, shape=(2,))
__rsub__(other: Tensor | Number) Tensor

Performs an elementwise subtraction.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be subtracted from this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 1
2b = tp.Tensor([1, 2])
3output = a - b
Local Variables
>>> b
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([0, -1], dtype=int32, loc=gpu:0, shape=(2,))
__rtruediv__(other: Tensor | Number) Tensor

Performs an elementwise division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to be divided by this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = 6.0
2b = tp.Tensor([2.0, 3.0])
3output = a / b
Local Variables
>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([3, 2], dtype=float32, loc=gpu:0, shape=(2,))
__sub__(other: Tensor | Number) Tensor

Performs an elementwise subtraction.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor to subtract from this one. It must be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([2, 3])
2b = tp.Tensor([1, 2])
3output = a - b
Local Variables
>>> a
tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,))

>>> b
tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,))

>>> output
tensor([1, 1], dtype=int32, loc=gpu:0, shape=(2,))
__truediv__(other: Tensor | Number) Tensor

Performs an elementwise division.

Parameters:
  • self (Tensor) – [dtype=T1] Input tensor.

  • other (Tensor | Number) – [dtype=T1] The tensor by which to divide this one. It should be broadcast-compatible.

Returns:

[dtype=T1] A new tensor with the broadcasted shape.

Return type:

Tensor

DATA TYPE CONSTRAINTS:
Example
1a = tp.Tensor([4.0, 6.0])
2b = tp.Tensor([2.0, 3.0])
3output = a / b
Local Variables
>>> a
tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,))

>>> b
tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,))

>>> output
tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
property shape: Tuple[int | DimensionSize]

Represents the shape of the tensor.

Parameters:

self – [dtype=T1] The input tensor.

Returns:

A sequence containing the shape of this tensor.

DATA TYPE CONSTRAINTS:
Example
1input = tp.ones((8, 2))
2shape = input.shape
Local Variables
>>> input
tensor(
    [[1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1],
     [1, 1]], 
    dtype=float32, loc=gpu:0, shape=(8, 2))

>>> shape
(
    8,
    2,
)
tolist() List | Number[source]

Returns the tensor as a nested list. If the tensor is a scalar, returns a python number.

Returns:

The tensor represented as a nested list or a python number.

Return type:

List | Number

Example: Ranked tensor
1tensor = tp.ones((2, 2))
2tensor_list = tensor.tolist()
Local Variables
>>> tensor_list
[[1.0, 1.0], [1.0, 1.0]]
Example: Scalar
1tensor = tp.Tensor(2.0)
2tensor_scalar = tensor.tolist()
Local Variables
>>> tensor_scalar
2.0

See also: