split¶
- nvtripy.split(input: Tensor, num_split_or_sizes: int | Sequence[int], dim: int = 0) Tuple[Tensor] [source]¶
Splits a tensor along the specified dimension.
- Parameters:
input (Tensor) – [dtype=T1] The input tensor.
num_split_or_sizes (int | Sequence[int]) –
If this is an
int
, the input is split into this many equal sized chunks. If the dimension cannot be divided evenly, the last chunk will be smaller.If this is a
Sequence[int]
, the input will be split intolen(num_split_or_sizes)
chunks where the \(i^{th}\) chunk has a size ofnum_split_or_sizes[i]
. The size of the chunk will be clamped if the input is too small.dim (int) – The dimension along which the slices are done. All other dimensions are included in full.
- Returns:
[dtype=T1] A tuple of slices of the input tensor.
- Return type:
Tuple[Tensor]
Example: Splitting Into 2 Chunks
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4)) 2outputs = tp.split(input, 2)
Local Variables¶>>> input tensor( [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=float32, loc=gpu:0, shape=(4, 4)) >>> outputs ( tensor( [[0, 1, 2, 3], [4, 5, 6, 7]], dtype=float32, loc=gpu:0, shape=(2, 4)), tensor( [[8, 9, 10, 11], [12, 13, 14, 15]], dtype=float32, loc=gpu:0, shape=(2, 4)), )
Example: Splitting Along A Different Dimension
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4)) 2outputs = tp.split(input, 2, dim=1)
Local Variables¶>>> input tensor( [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=float32, loc=gpu:0, shape=(4, 4)) >>> outputs ( tensor( [[0, 1], [4, 5], [8, 9], [12, 13]], dtype=float32, loc=gpu:0, shape=(4, 2)), tensor( [[2, 3], [6, 7], [10, 11], [14, 15]], dtype=float32, loc=gpu:0, shape=(4, 2)), )
Example: Splitting With Custom Chunk Sizes
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4)) 2outputs = tp.split(input, [1, 1, 2])
Local Variables¶>>> input tensor( [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=float32, loc=gpu:0, shape=(4, 4)) >>> outputs ( tensor( [[0, 1, 2, 3]], dtype=float32, loc=gpu:0, shape=(1, 4)), tensor( [[4, 5, 6, 7]], dtype=float32, loc=gpu:0, shape=(1, 4)), tensor( [[8, 9, 10, 11], [12, 13, 14, 15]], dtype=float32, loc=gpu:0, shape=(2, 4)), )