Flatten#
API#
- class warp_nn.modules.layers.Flatten(start_dim: int = 1, end_dim: int = -1)[source]#
Bases:
ModuleFlatten a contiguous range of dimensions into a single dimension.
Given an input of shape \((*, D_{\text{start}}, \ldots, D_{\text{end}}, *)\), where \(*\) means any number or absence of dimensions, the output has shape \(\left(*, \prod_{i=\text{start}}^{\text{end}} D_i, *\right)\).
- Parameters:
start_dim – Dimension to start flattening from.
end_dim – Dimension to end flattening at.
- __call__(
- input: array,
Forward pass of the module.
- Parameters:
input – The input array, with shape
(*, start_dim, ..., end_dim, *).- Returns:
The output array, with shape
(*, prod(start_dim, ..., end_dim), *).- Raises:
IndexError – If the start dimension comes after the end dimension.