flatten

tripy.flatten(input: Tensor, start_dim: int = 0, end_dim: int = -1) Tensor[source]

Flattens the input tensor from start_dim to end_dim.

Parameters:
  • input (Tensor) – [dtype=T1] The input tensor to be flattened.

  • start_dim (int) – The first dimension to flatten (default is 0).

  • end_dim (int) – The last dimension to flatten (default is -1, which includes the last dimension).

Returns:

[dtype=T1] A flattened tensor.

Return type:

Tensor

TYPE CONSTRAINTS:
Example: Flatten All Dimensions
Flatten All Dimensions
1input = tp.iota((1, 2, 1), dtype=tp.float32)
2output = tp.flatten(input)
>>> input
tensor(
    [[[0.0000],
      [0.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 2, 1))
>>> output
tensor([0.0000, 0.0000], dtype=float32, loc=gpu:0, shape=(2,))
Example: Flatten Starting from First Dimension
Flatten Starting from First Dimension
1input = tp.iota((2, 3, 4), dtype=tp.float32)
2output = tp.flatten(input, start_dim=1)
>>> input
tensor(
    [[[0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000]],

     [[1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000]]], 
    dtype=float32, loc=gpu:0, shape=(2, 3, 4))
>>> output
tensor(
    [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
      0.0000, 0.0000],
     [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
      1.0000, 1.0000]], 
    dtype=float32, loc=gpu:0, shape=(2, 12))
Example: Flatten a Specific Range of Dimensions
Flatten a Specific Range of Dimensions
1input = tp.iota((2, 3, 4, 5), dtype=tp.float32)
2output = tp.flatten(input, start_dim=1, end_dim=2)
>>> input
tensor(
    [[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

      [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

      [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],


     [[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],

      [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],

      [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]], 
    dtype=float32, loc=gpu:0, shape=(2, 3, 4, 5))
>>> output
tensor(
    [[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
      [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

     [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
      [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]], 
    dtype=float32, loc=gpu:0, shape=(2, 12, 5))