Tensor¶
- class nvtripy.Tensor(data: Any, dtype: dtype | None = None, device: device | None = None, name: str | None = None)[source]¶
- Bases: - object- A tensor is a multi-dimensional array that contains elements of a uniform data type. - Parameters:
- data (Any) – The data with which to initialize the tensor. For types that support the DLPack protocol, copying data is avoided if possible. 
- dtype (dtype | None) – The data type of the tensor. This parameter is only allowed when data is an empty list. 
- device (device | None) – The device on which to allocate the tensor. This parameter is only allowed when data is a Python number or list. 
- name (str | None) – The name of the tensor. If provided, this must be a unique string. 
 
 - Example- 1tensor = tp.Tensor([1.0, 2.0, 3.0]) Local Variables¶- >>> tensor tensor([1, 2, 3], dtype=float32, loc=cpu:0, shape=(3,)) - eval() Tensor[source]¶
- Immediately evaluates this tensor. By default, tensors are evaluated lazily. - Returns:
- The evaluated tensor. 
- Return type:
 - Example- 1import time 2 3start = time.perf_counter() 4tensor = tp.ones((3, 3)) 5init_time = time.perf_counter() 6tensor.eval() 7eval_time = time.perf_counter() 8 9print(f"Tensor init_time took: {(init_time - start) * 1000.0:.3f} ms") 10print(f"Tensor evaluation took: {(eval_time - init_time) * 1000.0:.3f} ms") Local Variables¶- >>> tensor tensor( [[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(3, 3)) Output¶- Tensor init_time took: 1.564 ms Tensor evaluation took: 68.178 ms 
 - __abs__() Tensor¶
- Computes the elementwise absolute value of the input tensor. - Parameters:
- self (Tensor) – [dtype=T1] The input tensor. 
- Returns:
- [dtype=T1] A new tensor of the same shape. 
- Return type:
 - Example- 1input = tp.Tensor([-1, -2]) 2output = abs(input) Local Variables¶- >>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __add__(other: Tensor | Number) Tensor¶
- Performs an elementwise sum. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b Local Variables¶- >>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __eq__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘equal’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b == a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __floordiv__(other: Tensor | Number) Tensor¶
- Performs an elementwise floor division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a // b Local Variables¶- >>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __ge__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘greater than or equal’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 1]) 3output = b >= a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __getitem__(index: Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) Tensor¶
- Returns a tensor containing a slice of this tensor. - Parameters:
- self (Tensor) – [dtype=T1] Tensor that will be sliced. 
- index (Tensor | slice | int | DimensionSize | ellipsis | None | Sequence[slice | int | DimensionSize | ellipsis | None]) – The index or slice. If this is a - Tensor, the operation is equivalent to calling- gather()along the first dimension. If this is None, a new dimension of size 1 will be inserted at that position.
 
- Returns:
- [dtype=T1] A tensor containing the slice of this tensor. 
- Return type:
 - Example: Indexing With Integers- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1] Local Variables¶- >>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([2, 3], dtype=float32, loc=gpu:0, shape=(2,)) - Example: Indexing With Slices- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[1:] Local Variables¶- >>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(2, 2)) - Example: Indexing With Ellipsis- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (1, 3, 2)) 2output = input[..., 1:] Local Variables¶- >>> input tensor( [[[0, 1], [2, 3], [4, 5]]], dtype=float32, loc=gpu:0, shape=(1, 3, 2)) >>> output tensor( [[[1], [3], [5]]], dtype=float32, loc=gpu:0, shape=(1, 3, 1)) - Example: Reversing Data With Negative Step- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[:, ::-1] Local Variables¶- >>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[1, 0], [3, 2], [5, 4]], dtype=float32, loc=gpu:0, shape=(3, 2)) - Example: Indexing With Tensors (Gather)- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2index = tp.Tensor([2, 0]) 3output = input[index] Local Variables¶- >>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> index tensor([2, 0], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor( [[4, 5], [0, 1]], dtype=float32, loc=gpu:0, shape=(2, 2)) - Example: Adding New Dimensions With None- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (3, 2)) 2output = input[None, :, None] Local Variables¶- >>> input tensor( [[0, 1], [2, 3], [4, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[[[0, 1]], [[2, 3]], [[4, 5]]]], dtype=float32, loc=gpu:0, shape=(1, 3, 1, 2)) 
 - __gt__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘greater than’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([3, 1]) 3output = b > a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 1], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __invert__() Tensor¶
- Performs an elementwise logical NOT. - Parameters:
- self (Tensor) – [dtype=T1] The input tensor. 
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - DATA TYPE CONSTRAINTS:
- T1: - bool
 
 - Example- 1a = tp.Tensor([True, False, False]) 2output = ~a Local Variables¶- >>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([False, True, True], dtype=bool, loc=gpu:0, shape=(3,)) 
 - __le__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘less than or equal’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([2, 5]) 3output = b <= a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __lt__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘less than’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 5]) 3output = b < a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 5], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __matmul__(other: Tensor) Tensor¶
- Performs matrix multiplication between two tensors. - If both tensors are 1D, a dot product is performed. The output is a scalar. 
- If either argument, but not both, is 1D, matrix-vector multiplication is performed:
- For inputs of shape \((M, N)\) and \((N,)\), the output will have shape \((M,)\). 
- For inputs of shape \((N,)\) and \((N, K)\), the output will have shape \((K,)\). 
 
 
- If both tensors are 2D, matrix-matrix multiplication is performed.
- For inputs of shape \((M, N)\) and \((N, K)\), the output will have shape \((M, K)\). 
 
- If the tensor has more than 2 dimensions, it is treated as a stack of matrices.
- If the ranks differ for tensors with 2 or more dimensions, dimensions are prepended until the ranks match. The first \(N-2\) dimensions will be broacasted if required. 
 
 - Parameters:
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - Example: Dot Product- 1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3,), dtype=tp.float32) 3 4output = a @ b Local Variables¶- >>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> output tensor(5, dtype=float32, loc=gpu:0, shape=()) - Example: Matrix-Vector Multiplication- 1a = tp.iota((3,), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b Local Variables¶- >>> a tensor([0, 1, 2], dtype=float32, loc=gpu:0, shape=(3,)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor([5, 5], dtype=float32, loc=gpu:0, shape=(2,)) - Example: Matrix-Matrix Multiplication- 1a = tp.iota((2, 3), dtype=tp.float32) 2b = tp.iota((3, 2), dtype=tp.float32) 3 4output = a @ b Local Variables¶- >>> a tensor( [[0, 0, 0], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> b tensor( [[0, 0], [1, 1], [2, 2]], dtype=float32, loc=gpu:0, shape=(3, 2)) >>> output tensor( [[0, 0], [3, 3]], dtype=float32, loc=gpu:0, shape=(2, 2)) - Example: Batched Matrix Multiplication- 1a = tp.iota((1, 2, 2, 2), dtype=tp.float32, dim=-1) 2b = tp.iota((1, 2, 2), dtype=tp.float32, dim=-2) 3 4output = a @ b Local Variables¶- >>> a tensor( [[[[0, 1], [0, 1]], [[0, 1], [0, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2)) >>> b tensor( [[[0, 0], [1, 1]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2)) >>> output tensor( [[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]], dtype=float32, loc=gpu:0, shape=(1, 2, 2, 2)) 
 - __mod__(other: Tensor | Number) Tensor¶
- Performs a modulo operation, which computes the remainder of a division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([3.0, 4.0]) 3output = a % b Local Variables¶- >>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([3, 4], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __mul__(other: Tensor | Number) Tensor¶
- Performs an elementwise multiplication. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b Local Variables¶- >>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __ne__(other: Tensor | Number) Tensor¶
- Performs an elementwise ‘not equal’ comparison. - Parameters:
- Returns:
- [dtype=T2] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 3]) 3output = b != a Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([True, False], dtype=bool, loc=gpu:0, shape=(2,)) 
 - __neg__() Tensor¶
- Computes the elementwise megative value of the elements of the input tensor. - Parameters:
- self (Tensor) – [dtype=T1] The input tensor. 
- Returns:
- [dtype=T1] A new tensor of the same shape. 
- Return type:
 - Example- 1input = tp.Tensor([-1, -2]) 2output = -input Local Variables¶- >>> input tensor([-1, -2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __or__(other: Tensor) Tensor¶
- Performs an elementwise logical OR. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - DATA TYPE CONSTRAINTS:
- T1: - bool
 
 - Example- 1a = tp.Tensor([True, False, False]) 2b = tp.Tensor([False, True, False]) 3output = a | b Local Variables¶- >>> a tensor([True, False, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> b tensor([False, True, False], dtype=bool, loc=cpu:0, shape=(3,)) >>> output tensor([True, True, False], dtype=bool, loc=gpu:0, shape=(3,)) 
 - __pow__(other: Tensor | Number) Tensor¶
- Performs an elementwise exponentiation. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a**b Local Variables¶- >>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 8], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __radd__(other: Tensor | Number) Tensor¶
- Performs an elementwise sum. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([1, 2]) 2b = tp.Tensor([2, 3]) 3output = a + b Local Variables¶- >>> a tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 5], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __rfloordiv__(other: Tensor | Number) Tensor¶
- Performs an elementwise floor division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = 2 2b = tp.Tensor([2.0, 3.0]) 3output = a // b Local Variables¶- >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 0], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __rmod__(other: Tensor | Number) Tensor¶
- Performs a modulo operation, which computes the remainder of a division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([4.0, 6.0]) 2output = 2 % a Local Variables¶- >>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __rmul__(other: Tensor | Number) Tensor¶
- Performs an elementwise multiplication. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([1.0, 2.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a * b Local Variables¶- >>> a tensor([1, 2], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 6], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __rpow__(other: Tensor | Number) Tensor¶
- Performs an elementwise exponentiation. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = 2.0 2b = tp.Tensor([2.0, 3.0]) 3output = a**b Local Variables¶- >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([4, 8], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __rsub__(other: Tensor | Number) Tensor¶
- Performs an elementwise subtraction. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = 1 2b = tp.Tensor([1, 2]) 3output = a - b Local Variables¶- >>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([0, -1], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __rtruediv__(other: Tensor | Number) Tensor¶
- Performs an elementwise division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = 6.0 2b = tp.Tensor([2.0, 3.0]) 3output = a / b Local Variables¶- >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([3, 2], dtype=float32, loc=gpu:0, shape=(2,)) 
 - __sub__(other: Tensor | Number) Tensor¶
- Performs an elementwise subtraction. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([2, 3]) 2b = tp.Tensor([1, 2]) 3output = a - b Local Variables¶- >>> a tensor([2, 3], dtype=int32, loc=cpu:0, shape=(2,)) >>> b tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 1], dtype=int32, loc=gpu:0, shape=(2,)) 
 - __truediv__(other: Tensor | Number) Tensor¶
- Performs an elementwise division. - Parameters:
- Returns:
- [dtype=T1] A new tensor with the broadcasted shape. 
- Return type:
 - Example- 1a = tp.Tensor([4.0, 6.0]) 2b = tp.Tensor([2.0, 3.0]) 3output = a / b Local Variables¶- >>> a tensor([4, 6], dtype=float32, loc=cpu:0, shape=(2,)) >>> b tensor([2, 3], dtype=float32, loc=cpu:0, shape=(2,)) >>> output tensor([2, 2], dtype=float32, loc=gpu:0, shape=(2,)) 
 - cast(dtype: dtype) Tensor¶
- Returns a tensor with the contents of the input tensor casted to the specified data type. - For casts into quantized datatypes ( - int4and- float8), this performs a per-tensor quantization into that datatype with scale 1.0; for casts from those datatypes, this performs a per-tensor dequantization with scale 1.0. Direct use of- quantize()and- dequantize()allows for finer control over these parameters.- Parameters:
- Returns:
- [dtype=T2] A tensor containing the casted values. 
- Return type:
 - DATA TYPE CONSTRAINTS:
- UNSUPPORTED DATA TYPE COMBINATIONS:
 - Example- 1input = tp.Tensor([1, 2]) 2output = tp.cast(input, tp.float32) Local Variables¶- >>> input tensor([1, 2], dtype=int32, loc=cpu:0, shape=(2,)) >>> output tensor([1, 2], dtype=float32, loc=gpu:0, shape=(2,)) - See also 
 - copy(device: device) Tensor¶
- Copies the input tensor to the specified device. - Caution - This function cannot be used in a compiled function or - nvtripy.Modulebecause it depends on evaluating its inputs, which is not allowed during compilation.- Parameters:
- Returns:
- [dtype=T1] A new tensor on the specified device. 
- Raises:
- TripyException – If the input tensor is already on the specified device, as performing copies within the same device is currently not supported. 
- Return type:
 
 - flatten(start_dim: int = 0, end_dim: int = -1) Tensor¶
- Flattens the input tensor from start_dim to end_dim. - Parameters:
- input (Tensor) – [dtype=T1] The input tensor to be flattened. 
- start_dim (int) – The first dimension to flatten (default is 0). 
- end_dim (int) – The last dimension to flatten (default is -1, which includes the last dimension). 
 
- Returns:
- [dtype=T1] A flattened tensor. 
- Return type:
 - Example: Flatten All Dimensions- 1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.flatten(input) Local Variables¶- >>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,)) - Example: Flatten Starting from First Dimension- 1input = tp.iota((2, 3, 4), dtype=tp.float32) 2output = tp.flatten(input, start_dim=1) Local Variables¶- >>> input tensor( [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 3, 4)) >>> output tensor( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 12)) - Example: Flatten a Specific Range of Dimensions- 1input = tp.iota((2, 3, 4, 5), dtype=tp.float32) 2output = tp.flatten(input, start_dim=1, end_dim=2) Local Variables¶- >>> input tensor( [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], [[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]]], dtype=float32, loc=gpu:0, shape=(2, 3, 4, 5)) >>> output tensor( [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ..., [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], ..., [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 12, 5)) 
 - permute(perm: Sequence[int]) Tensor¶
- Returns a tensor with its dimensions permuted. - Parameters:
- input (Tensor) – [dtype=T1] The input tensor. 
- perm (Sequence[int]) – The desired ordering of dimensions. It must contain all integers in \([0..N-1]\) exactly once, where \(N\) is the rank of the input tensor. 
 
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - Example- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3)) 2output = tp.permute(input, (1, 0)) Local Variables¶- >>> input tensor( [[0, 1, 2], [3, 4, 5]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 3], [1, 4], [2, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) 
 - reshape(shape: Sequence[int | DimensionSize]) Tensor¶
- Returns a new tensor with the contents of the input tensor in the specified shape. - Parameters:
- input (Tensor) – [dtype=T1] The input tensor. 
- shape (Sequence[int | DimensionSize]) – The desired compatible shape. If a shape dimension is -1, its value is inferred based on the other dimensions and the number of elements in the input. Atmost one dimension can be -1. 
 
- Returns:
- [dtype=T1] A new tensor with the specified shape. 
- Return type:
 - Example- 1input = tp.iota((2, 3), dtype=tp.float32) 2output = tp.reshape(input, (1, 6)) Local Variables¶- >>> input tensor( [[0, 0, 0], [1, 1, 1]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 0, 0, 1, 1, 1]], dtype=float32, loc=gpu:0, shape=(1, 6)) 
 - property shape: Tuple[int | DimensionSize]¶
- Represents the shape of the tensor. - Parameters:
- self – [dtype=T1] The input tensor. 
- Returns:
- A sequence containing the shape of this tensor. 
 - Example- 1input = tp.ones((8, 2)) 2shape = input.shape Local Variables¶- >>> input tensor( [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]], dtype=float32, loc=gpu:0, shape=(8, 2)) >>> shape ( 8, 2, ) 
 - squeeze(dims: Sequence[int] | int) Tensor¶
- Returns a new tensor with the specified singleton dimensions of the input tensor removed. - Parameters:
- input (Tensor) – [dtype=T1] The input tensor. 
- dims (Sequence[int] | int) – The dimension(s) to remove. These must have a length of 1. 
 
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - Example: Squeeze All Dimensions- 1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, dims=(0, 2)) Local Variables¶- >>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,)) - Example: Squeeze First Dimension- 1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, 0) Local Variables¶- >>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor( [[0], [0]], dtype=float32, loc=gpu:0, shape=(2, 1)) - Example: Squeeze First And Third Dimension- 1input = tp.iota((1, 2, 1), dtype=tp.float32) 2output = tp.squeeze(input, (0, 2)) Local Variables¶- >>> input tensor( [[[0], [0]]], dtype=float32, loc=gpu:0, shape=(1, 2, 1)) >>> output tensor([0, 0], dtype=float32, loc=gpu:0, shape=(2,)) 
 - tolist() List | Number[source]¶
- Returns the tensor as a nested list. If the tensor is a scalar, returns a python number. - Returns:
- The tensor represented as a nested list or a python number. 
- Return type:
- List | Number 
 - Example: Ranked tensor- 1tensor = tp.ones((2, 2)) 2tensor_list = tensor.tolist() Local Variables¶- >>> tensor_list [[1.0, 1.0], [1.0, 1.0]] - Example: Scalar- 1tensor = tp.Tensor(2.0) 2tensor_scalar = tensor.tolist() Local Variables¶- >>> tensor_scalar 2.0 
 - transpose(dim0: int, dim1: int) Tensor¶
- Returns a new tensor that is a transposed version of the input tensor where - dim0and- dim1are swapped.- Parameters:
- input (Tensor) – [dtype=T1] The input tensor. 
- dim0 (int) – The first dimension to be transposed. 
- dim1 (int) – The second dimension to be transposed. 
 
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - Example- 1input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3)) 2output = tp.transpose(input, 0, 1) Local Variables¶- >>> input tensor( [[0, 1, 2], [3, 4, 5]], dtype=float32, loc=gpu:0, shape=(2, 3)) >>> output tensor( [[0, 3], [1, 4], [2, 5]], dtype=float32, loc=gpu:0, shape=(3, 2)) 
 - unsqueeze(dim: int) Tensor¶
- Returns a new tensor with the contents of the input tensor with a singleton dimension inserted before the specified axis. - Parameters:
- input (Tensor) – [dtype=T1] The input tensor. 
- dim (int) – index before which to insert the singleton dimension. A negative dimension will be converted to - dim = dim + input.rank + 1.
 
- Returns:
- [dtype=T1] A new tensor. 
- Return type:
 - Example- 1input = tp.iota((2, 2), dtype=tp.float32) 2output = tp.unsqueeze(input, 1) Local Variables¶- >>> input tensor( [[0, 0], [1, 1]], dtype=float32, loc=gpu:0, shape=(2, 2)) >>> output tensor( [[[0, 0]], [[1, 1]]], dtype=float32, loc=gpu:0, shape=(2, 1, 2)) 
 
See also: