Flatten#

API#

class warp_nn.modules.layers.Flatten(start_dim: int = 1, end_dim: int = -1)[source]#

Bases: Module

Flatten a contiguous range of dimensions into a single dimension.

Given an input of shape \((*, D_{\text{start}}, \ldots, D_{\text{end}}, *)\), where \(*\) means any number or absence of dimensions, the output has shape \(\left(*, \prod_{i=\text{start}}^{\text{end}} D_i, *\right)\).

Parameters:
  • start_dim – Dimension to start flattening from.

  • end_dim – Dimension to end flattening at.

__call__(
input: array,
) array[source]#

Forward pass of the module.

Parameters:

input – The input array, with shape (*, start_dim, ..., end_dim, *).

Returns:

The output array, with shape (*, prod(start_dim, ..., end_dim), *).

Raises:

IndexError – If the start dimension comes after the end dimension.