split

nvtripy.split(input: Tensor, num_split_or_sizes: int | Sequence[int], dim: int = 0) Tuple[Tensor][source]

Splits a tensor along the specified dimension.

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor.

  • num_split_or_sizes (int | Sequence[int]) –

    If this is an int, the input is split into this many equal sized chunks. If the dimension cannot be divided evenly, the last chunk will be smaller.

    If this is a Sequence[int], the input will be split into len(num_split_or_sizes) chunks where the \(i^{th}\) chunk has a size of num_split_or_sizes[i]. The size of the chunk will be clamped if the input is too small.

  • dim (int) – The dimension along which the slices are done. All other dimensions are included in full.

Returns:

[dtype=T1] A tuple of slices of the input tensor.

Return type:

Tuple[Tensor]

DATA TYPE CONSTRAINTS:
Example: Splitting Into 2 Chunks
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, 2)
Local Variables
>>> input
tensor(
    [[0, 1, 2, 3],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))

>>> outputs
(
    tensor(
        [[0, 1, 2, 3],
         [4, 5, 6, 7]], 
        dtype=float32, loc=gpu:0, shape=(2, 4)),
    tensor(
        [[8, 9, 10, 11],
         [12, 13, 14, 15]], 
        dtype=float32, loc=gpu:0, shape=(2, 4)),
)
Example: Splitting Along A Different Dimension
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, 2, dim=1)
Local Variables
>>> input
tensor(
    [[0, 1, 2, 3],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))

>>> outputs
(
    tensor(
        [[0, 1],
         [4, 5],
         [8, 9],
         [12, 13]], 
        dtype=float32, loc=gpu:0, shape=(4, 2)),
    tensor(
        [[2, 3],
         [6, 7],
         [10, 11],
         [14, 15]], 
        dtype=float32, loc=gpu:0, shape=(4, 2)),
)
Example: Splitting With Custom Chunk Sizes
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, [1, 1, 2])
Local Variables
>>> input
tensor(
    [[0, 1, 2, 3],
     [4, 5, 6, 7],
     [8, 9, 10, 11],
     [12, 13, 14, 15]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))

>>> outputs
(
    tensor(
        [[0, 1, 2, 3]], 
        dtype=float32, loc=gpu:0, shape=(1, 4)),
    tensor(
        [[4, 5, 6, 7]], 
        dtype=float32, loc=gpu:0, shape=(1, 4)),
    tensor(
        [[8, 9, 10, 11],
         [12, 13, 14, 15]], 
        dtype=float32, loc=gpu:0, shape=(2, 4)),
)