Tensor¶
- class nvtripy.Tensor(data: Any, dtype: dtype | None = None, device: device | None = None, name: str | None = None)[source]¶
Bases:
object
A tensor is a multi-dimensional array that contains elements of a uniform data type.
- Parameters:
data (Any) – The data with which to initialize the tensor. For types that support the DLPack protocol, copying data is avoided if possible.
dtype (dtype | None) – The data type of the tensor. This parameter is only allowed when data is an empty list.
device (device | None) – The device on which to allocate the tensor. This parameter is only allowed when data is a Python number or list.
name (str | None) – The name of the tensor. If provided, this must be a unique string.
Example
1tensor = tp.Tensor([1.0, 2.0, 3.0])
Local Variables¶>>> tensor tensor([1, 2, 3], dtype=float32, loc=cpu:0, shape=(3,))
- eval() Tensor [source]¶
Immediately evaluates this tensor. By default, tensors are evaluated lazily.
- Returns:
The evaluated tensor.
- Return type:
Example
1import time 2 3start = time.perf_counter() 4tensor = tp.ones((3, 3)) 5init_time = time.perf_counter() 6tensor.eval() 7eval_time = time.perf_counter() 8 9print(f"Tensor init_time took: {(init_time - start) * 1000.0:.3f} ms") 10print(f"Tensor evaluation took: {(eval_time - init_time) * 1000.0:.3f} ms")
Local Variables¶>>> tensor tensor( [[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(3, 3))
Output¶Tensor init_time took: 1.564 ms Tensor evaluation took: 68.178 ms
- __abs__() Tensor ¶
Computes the elementwise absolute value of the input tensor.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor of the same shape.
- Return type:
Example
1input = tp.Tensor([-1, -2]) 2output = abs(input)
Local Variables¶>>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
- __add__(other: Tensor | Number) Tensor ¶
Performs an elementwise sum.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b
Local Variables¶>>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
- __eq__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b == a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __floordiv__(other: Tensor | Number) Tensor ¶
Performs an elementwise floor division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a // b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=float32, loc=gpu:0, shape=(2,))
- __ge__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘greater than or equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 1]) 3output = b >= a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __getitem__(index: Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) Tensor ¶
Returns a tensor containing a slice of this tensor.
- Parameters:
self (Tensor) – [dtype=T1] Tensor that will be sliced.
index (Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) – The index or slice. If this is a
Tensor
, the operation is equivalent to callinggather()
along the first dimension. If this is None, a new dimension of size 1 will be inserted at that position.
- Returns:
[dtype=T1] A tensor containing the slice of this tensor.
- Return type:
Example: Indexing With Integers
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([2, 3], dtype=float32, loc=gpu:0, shape=(2,))
Example: Indexing With Slices
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1:]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Indexing With Ellipsis
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (1, 3, 2)) 2output = input[..., 1:]
Local Variables¶>>> input tensor( [[[0, 1], [2, 3], [4, 5]]], dtype=float32, loc=gpu:0, shape=(1, 3, 2)) >>> output tensor( [[[1], [3], [5]]], dtype=float32, loc=gpu:0, shape=(1, 3, 1))
Example: Reversing Data With Negative Step
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[:, ::-1]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[1, 0], [3, 2], [5, 4]], dtype=float32, loc=gpu:0, shape=(3, 2))
Example: Indexing With Tensors (Gather)
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2index = tp.Tensor([2, 0]) 3output = input[index]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> index tensor([2, 0], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor( [[4, 5], [0, 1]], dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Adding New Dimensions With None
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[None, :, None]
Local Variables¶>>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[[[0, 1]], [[2, 3]], [[4, 5]]]], dtype=float32, loc=gpu:0, shape=(1, 3, 1, 2))
- __gt__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘greater than’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([3, 1]) 3output = b > a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __invert__() Tensor ¶
Performs an elementwise logical NOT.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor.
- Return type:
- DATA TYPE CONSTRAINTS:
T1:
bool
Example
1a = tp.Tensor([True, False, False]) 2output = ~a
Local Variables¶>>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([False, True, True], dtype=bool, loc=gpu:0, shape=(3,))
- __le__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘less than or equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b <= a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __lt__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘less than’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 5]) 3output = b < a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __matmul__(other: Tensor) Tensor ¶
Performs matrix multiplication between two tensors.
If both tensors are 1D, a dot product is performed. The output is a scalar.
- If either argument, but not both, is 1D, matrix-vector multiplication is performed:
For inputs of shape \((M, N)\) and \((N,)\), the output will have shape \((M,)\).
For inputs of shape \((N,)\) and \((N, K)\), the output will have shape \((K,)\).
- If both tensors are 2D, matrix-matrix multiplication is performed.
For inputs of shape \((M, N)\) and \((N, K)\), the output will have shape \((M, K)\).
- If the tensor has more than 2 dimensions, it is treated as a stack of matrices.
If the ranks differ for tensors with 2 or more dimensions, dimensions are prepended until the ranks match. The first \(N-2\) dimensions will be broacasted if required.
- Parameters:
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example: Dot Product
1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3,), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> output tensor(5, dtype=float32, loc=gpu:0, shape=())
Example: Matrix-Vector Multiplication
1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([5, 5], dtype=float32, loc=gpu:0, shape=(2,))
Example: Matrix-Matrix Multiplication
1a = tp.iota((2, 3), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b
Local Variables¶>>> a tensor( [[0, 0, 0], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[0, 0], [3, 3]], dtype=float32, loc=gpu:0, shape=(2, 2))
Example: Batched Matrix Multiplication
1a = tp.iota((1, 2, 2, 2), dtype=tp.float32, dim=-1) 2b = tp.iota((1, 2, 2), dtype=tp.float32, dim=-2) 3 4output = a @ b
Local Variables¶>>> a tensor( [[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2)) >>> b tensor( [[[0, 0], [1, 1]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2)) >>> output tensor( [[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2))
- __mod__(other: Tensor | Number) Tensor ¶
Performs a modulo operation, which computes the remainder of a division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a % b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __mul__(other: Tensor | Number) Tensor ¶
Performs an elementwise multiplication.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
- __ne__(other: Tensor | Number) Tensor ¶
Performs an elementwise ‘not equal’ comparison.
- Parameters:
- Returns:
[dtype=T2] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 3]) 3output = b != a
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,))
- __neg__() Tensor ¶
Computes the elementwise megative value of the elements of the input tensor.
- Parameters:
self (Tensor) – [dtype=T1] The input tensor.
- Returns:
[dtype=T1] A new tensor of the same shape.
- Return type:
Example
1input = tp.Tensor([-1, -2]) 2output = -input
Local Variables¶>>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,))
- __or__(other: Tensor) Tensor ¶
Performs an elementwise logical OR.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
- DATA TYPE CONSTRAINTS:
T1:
bool
Example
1a = tp.Tensor([True, False, False]) 2b = tp.Tensor([False, True, False]) 3output = a | b
Local Variables¶>>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> b tensor([False, True, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([True, True, False], dtype=bool, loc=gpu:0, shape=(3,))
- __pow__(other: Tensor | Number) Tensor ¶
Performs an elementwise exponentiation.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a**b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 8], dtype=float32, loc=gpu:0, shape=(2,))
- __radd__(other: Tensor | Number) Tensor ¶
Performs an elementwise sum.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b
Local Variables¶>>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,))
- __rfloordiv__(other: Tensor | Number) Tensor ¶
Performs an elementwise floor division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 2 2b = tp.Tensor([2.0, 3.0]) 3output = a // b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 0], dtype=float32, loc=gpu:0, shape=(2,))
- __rmod__(other: Tensor | Number) Tensor ¶
Performs a modulo operation, which computes the remainder of a division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2output = 2 % a
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __rmul__(other: Tensor | Number) Tensor ¶
Performs an elementwise multiplication.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b
Local Variables¶>>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,))
- __rpow__(other: Tensor | Number) Tensor ¶
Performs an elementwise exponentiation.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 2.0 2b = tp.Tensor([2.0, 3.0]) 3output = a**b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([4, 8], dtype=float32, loc=gpu:0, shape=(2,))
- __rsub__(other: Tensor | Number) Tensor ¶
Performs an elementwise subtraction.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 1 2b = tp.Tensor([1, 2]) 3output = a - b
Local Variables¶>>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([0, -1], dtype=int32, loc=gpu:0, shape=(2,))
- __rtruediv__(other: Tensor | Number) Tensor ¶
Performs an elementwise division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = 6.0 2b = tp.Tensor([2.0, 3.0]) 3output = a / b
Local Variables¶>>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 2], dtype=float32, loc=gpu:0, shape=(2,))
- __sub__(other: Tensor | Number) Tensor ¶
Performs an elementwise subtraction.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 2]) 3output = a - b
Local Variables¶>>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=int32, loc=gpu:0, shape=(2,))
- __truediv__(other: Tensor | Number) Tensor ¶
Performs an elementwise division.
- Parameters:
- Returns:
[dtype=T1] A new tensor with the broadcasted shape.
- Return type:
Example
1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a / b
Local Variables¶>>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,))
- cast(dtype: dtype) Tensor ¶
Returns a tensor with the contents of the input tensor casted to the specified data type.
For casts into quantized datatypes (
int4
andfloat8
), this performs a per-tensor quantization into that datatype with scale 1.0; for casts from those datatypes, this performs a per-tensor dequantization with scale 1.0. Direct use ofquantize()
anddequantize()
allows for finer control over these parameters.- Parameters:
- Returns:
[dtype=T2] A tensor containing the casted values.
- Return type:
- DATA TYPE CONSTRAINTS:
- UNSUPPORTED DATA TYPE COMBINATIONS:
Example
1input = tp.Tensor([1, 2]) 2output = tp.cast(input, tp.float32)
Local Variables¶>>> input tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,))
See also
- copy(device: device) Tensor ¶
Copies the input tensor to the specified device.
Caution
This function cannot be used in a compiled function or
nvtripy.Module
because it depends on evaluating its inputs, which is not allowed during compilation.- Parameters:
- Returns:
[dtype=T1] A new tensor on the specified device.
- Raises:
TripyException – If the input tensor is already on the specified device, as performing copies within the same device is currently not supported.
- Return type:
- flatten(start_dim: int = 0, end_dim: int = -1) Tensor ¶
Flattens the input tensor from start_dim to end_dim.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor to be flattened.
start_dim (int) – The first dimension to flatten (default is 0).
end_dim (int) – The last dimension to flatten (default is -1, which includes the last dimension).
- Returns:
[dtype=T1] A flattened tensor.
- Return type:
Example: Flatten All Dimensions
1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.flatten(input)
Local Variables¶>>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,))
Example: Flatten Starting from First Dimension
1input = tp.iota((2, 3, 4), dtype=tp.float32) 2output = tp.flatten(input, start_dim=1)
Local Variables¶>>> input tensor( [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 3, 4)) >>> output tensor( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 12))
Example: Flatten a Specific Range of Dimensions
1input = tp.iota((2, 3, 4, 5), dtype=tp.float32) 2output = tp.flatten(input, start_dim=1, end_dim=2)
Local Variables¶>>> input tensor( [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], [[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]]], dtype=float32, loc=gpu:0, shape=(2, 3, 4, 5)) >>> output tensor( [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ..., [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], ..., [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 12, 5))
- permute(perm: Sequence[int]) Tensor ¶
Returns a tensor with its dimensions permuted.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
perm (Sequence[int]) – The desired ordering of dimensions. It must contain all integers in \([0..N-1]\) exactly once, where \(N\) is the rank of the input tensor.
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3)) 2output = tp.permute(input, (1, 0))
Local Variables¶>>> input tensor( [[0, 1, 2], [3, 4, 5]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 3], [1, 4], [2, 5]], dtype=float32, loc=gpu:0, shape=(3, 2))
- reshape(shape: Sequence[int | DimensionSize]) Tensor ¶
Returns a new tensor with the contents of the input tensor in the specified shape.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
shape (Sequence[int | DimensionSize]) – The desired compatible shape. If a shape dimension is -1, its value is inferred based on the other dimensions and the number of elements in the input. Atmost one dimension can be -1.
- Returns:
[dtype=T1] A new tensor with the specified shape.
- Return type:
Example
1input = tp.iota((2, 3), dtype=tp.float32) 2output = tp.reshape(input, (1, 6))
Local Variables¶>>> input tensor( [[0, 0, 0], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 0, 0, 1, 1, 1]], dtype=float32, loc=gpu:0, shape=(1, 6))
- property shape: Tuple[int | DimensionSize]¶
Represents the shape of the tensor.
- Parameters:
self – [dtype=T1] The input tensor.
- Returns:
A sequence containing the shape of this tensor.
Example
1input = tp.ones((8, 2)) 2shape = input.shape
Local Variables¶>>> input tensor( [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]], dtype=float32, loc=gpu:0, shape=(8, 2)) >>> shape ( 8, 2, )
- squeeze(dims: Sequence[int] | int) Tensor ¶
Returns a new tensor with the specified singleton dimensions of the input tensor removed.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
dims (Sequence[int] | int) – The dimension(s) to remove. These must have a length of 1.
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example: Squeeze All Dimensions
1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, dims=(0, 2))
Local Variables¶>>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,))
Example: Squeeze First Dimension
1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, 0)
Local Variables¶>>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor( [[0], [0]], dtype=float32, loc=gpu:0, shape=(2, 1))
Example: Squeeze First And Third Dimension
1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, (0, 2))
Local Variables¶>>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,))
- tolist() List | Number [source]¶
Returns the tensor as a nested list. If the tensor is a scalar, returns a python number.
- Returns:
The tensor represented as a nested list or a python number.
- Return type:
List | Number
Example: Ranked tensor
1tensor = tp.ones((2, 2)) 2tensor_list = tensor.tolist()
Local Variables¶>>> tensor_list [[1.0, 1.0], [1.0, 1.0]]
Example: Scalar
1tensor = tp.Tensor(2.0) 2tensor_scalar = tensor.tolist()
Local Variables¶>>> tensor_scalar 2.0
- transpose(dim0: int, dim1: int) Tensor ¶
Returns a new tensor that is a transposed version of the input tensor where
dim0
anddim1
are swapped.- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
dim0 (int) – The first dimension to be transposed.
dim1 (int) – The second dimension to be transposed.
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example
1input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3)) 2output = tp.transpose(input, 0, 1)
Local Variables¶>>> input tensor( [[0, 1, 2], [3, 4, 5]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 3], [1, 4], [2, 5]], dtype=float32, loc=gpu:0, shape=(3, 2))
- unsqueeze(dim: int) Tensor ¶
Returns a new tensor with the contents of the input tensor with a singleton dimension inserted before the specified axis.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
dim (int) – index before which to insert the singleton dimension. A negative dimension will be converted to
dim = dim + input.rank + 1
.
- Returns:
[dtype=T1] A new tensor.
- Return type:
Example
1input = tp.iota((2, 2), dtype=tp.float32) 2output = tp.unsqueeze(input, 1)
Local Variables¶>>> input tensor( [[0, 0], [1, 1]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> output tensor( [[[0, 0]], [[1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 1, 2))
See also: