Tensor¶
- class nvtripy.Tensor(data: Any, dtype: dtype | None = None, device: device | None = None, name: str | None = None, fetch_stack_info: bool = True)[source]¶
Bases:
object
A tensor is a multi-dimensional array that contains elements of a uniform data type.
- Parameters:
data (Any) – The data with which to initialize the tensor. For types that support the DLPack protocol, copying data is avoided if possible.
dtype (dtype | None) – The data type of the tensor.
device (device | None) – The device on which to allocate the tensor. If the provided data is not on this device, it will be copied. By default, the tensor will be allocated on the same device as the data argument.
name (str | None) – The name of the tensor. If provided, this must be a unique string.
fetch_stack_info (bool) – Whether to fetch stack information for the tensor. Stack information allows Tripy to generate much higher quality error messages at the cost of a small overhead when initializing the tensor.
Example
1tensor = tp.Tensor([1.0, 2.0, 3.0], dtype=tp.float32)
Local Variables¶>>> tensor tensor([1, 2, 3], dtype=float32, loc=cpu:0, shape=(3,))
- eval() Tensor [source]¶
Immediately evaluates this tensor. By default, tensors are evaluated lazily.
Note that an evaluated tensor will always reside in device memory.
- Returns:
The evaluated tensor.
- Return type:
Example
1import time 2 3start = time.perf_counter() 4tensor = tp.ones((3, 3)) 5init_time = time.perf_counter() 6tensor.eval() 7eval_time = time.perf_counter() 8 9print(f"Tensor init_time took: {(init_time - start) * 1000.0:.3f} ms") 10print(f"Tensor evaluation took: {(eval_time - init_time) * 1000.0:.3f} ms")
Local Variables¶>>> tensor tensor( [[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(3, 3))
Output¶Tensor init_time took: 3.364 ms Tensor evaluation took: 31.854 ms
- __abs__() Tensor ¶
Computes the elementwise absolute value of the input tensor.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor of the same shape.
- Return type:
Example
1input = tp.Tensor([-1, -2], dtype=tp.int32) 2output = abs(input)
Local Variables¶>>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
- __add__(other: Tensor | Number) Tensor ¶
Performs an elementwise sum.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b
Local Variables¶>>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
- __eq__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b == a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __floordiv__(other: Tensor | Number) Tensor ¶
Performs an elementwise floor division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a // b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=float32, loc=gpu:0, shape=(2,))
- __ge__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘greater than or equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 1]) 3output = b >= a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __getitem__(index: Tensor | slice | int | DimensionSize | Sequence[slice | int | DimensionSize]) Tensor ¶
Returns a tensor containing a slice of this tensor.
- Parameters:
self (Tensor) – [dtype=T1] Tensor that will be sliced.
index (Tensor | slice | int | DimensionSize | Sequence[slice | int | DimensionSize]) – The index or slice. If this is a
Tensor
, the operation is equivalent to callinggather()
along the first dimension.
- Returns:
[dtype=T1] A tensor containing the slice of this tensor.
- Return type:
Example: Indexing With Integers
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([2, 3], dtype=float32, loc=gpu:0, shape=(2,))
Example: Indexing With Slices
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1:]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Reversing Data With Negative Step
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[:, ::-1]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[1, 0], [3, 2], [5, 4]], dtype=float32, loc=gpu:0, shape=(3, 2))
Example: Indexing With Tensors (Gather)
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2index = tp.Tensor([2, 0], dtype=tp.int32) 3output = input[index]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> index tensor([2, 0], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor( [[4, 5], [0, 1]], dtype=float32, loc=gpu:0, shape=(2, 2))
- __gt__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘greater than’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([3, 1]) 3output = b > a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __invert__() Tensor ¶
Performs an elementwise logical NOT.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor.
- Return type:
- DATA TYPE CONSTRAINTS:
T1:
bool
Example
1a = tp.Tensor([True, False, False]) 2output = ~a
Local Variables¶>>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([False, True, True], dtype=bool, loc=gpu:0, shape=(3,))
- __le__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘less than or equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b <= a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __lt__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘less than’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 5]) 3output = b < a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __matmul__(other: Tensor) Tensor ¶
Performs matrix multiplication between two tensors.
If both tensors are 1D, a dot product is performed. The output is a scalar.
- If either argument, but not both, is 1D, matrix-vector multiplication is performed:
For inputs of shape \((M, N)\) and \((N,)\), the output will have shape \((M,)\).
For inputs of shape \((N,)\) and \((N, K)\), the output will have shape \((K,)\).
- If both tensors are 2D, matrix-matrix multiplication is performed.
For inputs of shape \((M, N)\) and \((N, K)\), the output will have shape \((M, K)\).
- If the tensor has more than 2 dimensions, it is treated as a stack of matrices.
If the ranks differ for tensors with 2 or more dimensions, dimensions are prepended until the ranks match. The first \(N-2\) dimensions will be broacasted if required.
- Parameters:
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example: Dot Product
1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3,), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> output tensor(5, dtype=float32, loc=gpu:0, shape=())
Example: Matrix-Vector Multiplication
1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([5, 5], dtype=float32, loc=gpu:0, shape=(2,))
Example: Matrix-Matrix Multiplication
1a = tp.iota((2, 3), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor( [[0, 0, 0], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[0, 0], [3, 3]], dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Batched Matrix Multiplication
1a = tp.iota((1, 2, 2, 2), dtype=tp.float32, dim=-1) 2b = tp.iota((1, 2, 2), dtype=tp.float32, dim=-2) 3 4output = a @ b
Local Variables¶>>> a tensor( [[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2)) >>> b tensor( [[[0, 0], [1, 1]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2)) >>> output tensor( [[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))
- __mod__(other: Tensor | Number) Tensor ¶
Performs a modulo operation, which computes the remainder of a division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a % b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __mul__(other: Tensor | Number) Tensor ¶
Performs an elementwise multiplication.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
- __ne__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘not equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 3]) 3output = b != a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __neg__() Tensor ¶
Computes the elementwise megative value of the elements of the input tensor.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor of the same shape.
- Return type:
Example
1input = tp.Tensor([-1, -2], dtype=tp.int32) 2output = -input
Local Variables¶>>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
- __or__(other: Tensor) Tensor ¶
Performs an elementwise logical OR.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
- DATA TYPE CONSTRAINTS:
T1:
bool
Example
1a = tp.Tensor([True, False, False]) 2b = tp.Tensor([False, True, False]) 3output = a | b
Local Variables¶>>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> b tensor([False, True, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([True, True, False], dtype=bool, loc=gpu:0, shape=(3,))
- __pow__(other: Tensor | Number) Tensor ¶
Performs an elementwise exponentiation.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a**b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 8], dtype=float32, loc=gpu:0, shape=(2,))
- __radd__(other: Tensor | Number) Tensor ¶
Performs an elementwise sum.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b
Local Variables¶>>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
- __rfloordiv__(other: Tensor | Number) Tensor ¶
Performs an elementwise floor division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 2 2b = tp.Tensor([2.0, 3.0]) 3output = a // b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 0], dtype=float32, loc=gpu:0, shape=(2,))
- __rmod__(other: Tensor | Number) Tensor ¶
Performs a modulo operation, which computes the remainder of a division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2output = 2 % a
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __rmul__(other: Tensor | Number) Tensor ¶
Performs an elementwise multiplication.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
- __rpow__(other: Tensor | Number) Tensor ¶
Performs an elementwise exponentiation.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 2.0 2b = tp.Tensor([2.0, 3.0]) 3output = a**b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([4, 8], dtype=float32, loc=gpu:0, shape=(2,))
- __rsub__(other: Tensor | Number) Tensor ¶
Performs an elementwise subtraction.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 1 2b = tp.Tensor([1, 2]) 3output = a - b
Local Variables¶>>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([0, -1], dtype=int32, loc=gpu:0, shape=(2,))
- __rtruediv__(other: Tensor | Number) Tensor ¶
Performs an elementwise division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 6.0 2b = tp.Tensor([2.0, 3.0]) 3output = a / b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __sub__(other: Tensor | Number) Tensor ¶
Performs an elementwise subtraction.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 2]) 3output = a - b
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=int32, loc=gpu:0, shape=(2,))
- __truediv__(other: Tensor | Number) Tensor ¶
Performs an elementwise division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a / b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
- property shape: Tuple[int | DimensionSize]¶
Represents the shape of the tensor.
- Parameters:
self – [dtype=T1] The input tensor.
- Returns:
A sequence containing the shape of this tensor.
Example
1input = tp.ones((8, 2)) 2shape = input.shape
Local Variables¶>>> input tensor( [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]], dtype=float32, loc=gpu:0, shape=(8, 2)) >>> shape ( 8, 2, )
- tolist() List | Number [source]¶
Returns the tensor as a nested list. If the tensor is a scalar, returns a python number.
- Returns:
The tensor represented as a nested list or a python number.
- Return type:
List | Number
Example: Ranked tensor
1tensor = tp.ones((2, 2)) 2tensor_list = tensor.tolist()
Local Variables¶>>> tensor_list [[1.0, 1.0], [1.0, 1.0]]
Example: Scalar
1tensor = tp.Tensor(2.0, dtype=tp.float32) 2tensor_scalar = tensor.tolist()
Local Variables¶>>> tensor_scalar 2.0
See also: