split

tripy.split(input: Tensor, indices_or_sections: int | Sequence[int], dim: int = 0) Tensor | Sequence[Tensor][source]

Splits input along the dimension dim, producing slices of the input tensor.

If given a single int for indices_or_sections (let us call it \(n\)), it produces \(n\) slices of equal size along dimension dim as long as \(n\) divides the size of dimension dim. For example, if input is one-dimensional and the size of dimension dim is \(k\), then the result is \(\texttt{input[:} k/n \texttt{]}\), \(\texttt{input[} k/n \texttt{:} 2k/n \texttt{]}\), \(\ldots\), \(\texttt{input[} (n-1)k/n \texttt{:]}\).

If given a sequence of values for indices_or_sections, these will be treated as indices for creating slices. For example, if we call the indices \(i_0\), \(i_1\), \(\ldots\), \(i_n\) and assume input is one-dimensional, the result is equivalent to \(input[:i_0]\), \(input[i_0:i_1]\), \(input[i_1:i_2]\), \(\ldots\), \(input[i_n:]\).

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor.

  • indices_or_sections (int | Sequence[int]) – If a single integer, it gives the number of equal slices to produce. If a list of integers, it gives boundary indices for the slices.

  • dim (int) – The dimension along which the slices are done. All other dimensions are included in full.

Returns:

[dtype=T1] A list of slices per the above specification or a single tensor if only one slice is created.

Return type:

Tensor | Sequence[Tensor]

TYPE CONSTRAINTS:
Example: Simple case.
Simple case.
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, 2, dim=0)
>>> input
tensor(
    [[0.0000, 1.0000, 2.0000, 3.0000],
     [4.0000, 5.0000, 6.0000, 7.0000],
     [8.0000, 9.0000, 10.0000, 11.0000],
     [12.0000, 13.0000, 14.0000, 15.0000]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))
Example: Choosing a different dimension.
Choosing a different dimension.
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, 2, dim=1)
>>> input
tensor(
    [[0.0000, 1.0000, 2.0000, 3.0000],
     [4.0000, 5.0000, 6.0000, 7.0000],
     [8.0000, 9.0000, 10.0000, 11.0000],
     [12.0000, 13.0000, 14.0000, 15.0000]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))
Example: Multiple index arguments.
Multiple index arguments.
1input = tp.reshape(tp.arange(16, dtype=tp.float32), (4, 4))
2outputs = tp.split(input, [1, 2])
>>> input
tensor(
    [[0.0000, 1.0000, 2.0000, 3.0000],
     [4.0000, 5.0000, 6.0000, 7.0000],
     [8.0000, 9.0000, 10.0000, 11.0000],
     [12.0000, 13.0000, 14.0000, 15.0000]], 
    dtype=float32, loc=gpu:0, shape=(4, 4))