Source code for tripy.frontend.ops.flatten

# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from tripy import constraints, export
from tripy.common.exception import raise_error


[docs] @export.public_api(document_under="operations/functions") @constraints.dtypes( constraints={"input": "T1", constraints.RETURN_VALUE: "T1"}, variables={"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"]}, ) def flatten(input: "tripy.Tensor", start_dim: int = 0, end_dim: int = -1) -> "tripy.Tensor": """ Flattens the input tensor from start_dim to end_dim. Args: input: The input tensor to be flattened. start_dim: The first dimension to flatten (default is 0). end_dim: The last dimension to flatten (default is -1, which includes the last dimension). Returns: A flattened tensor. .. code-block:: python :linenos: :caption: Flatten All Dimensions input = tp.iota((1, 2, 1), dtype=tp.float32) output = tp.flatten(input) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().flatten()) .. code-block:: python :linenos: :caption: Flatten Starting from First Dimension input = tp.iota((2, 3, 4), dtype=tp.float32) output = tp.flatten(input, start_dim=1) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().reshape(2, -1)) .. code-block:: python :linenos: :caption: Flatten a Specific Range of Dimensions input = tp.iota((2, 3, 4, 5), dtype=tp.float32) output = tp.flatten(input, start_dim=1, end_dim=2) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().reshape(2, -1, 5)) """ from tripy.frontend.trace.ops.reshape import reshape # Infer the actual dimensions to flatten based on start_dim and end_dim. if start_dim < 0: start_dim += input.rank if end_dim < 0: end_dim += input.rank # Ensure start_dim and end_dim are within the valid range. if not (0 <= start_dim < input.rank) or not (start_dim <= end_dim < input.rank): raise_error(f"Invalid dimensions: start_dim={start_dim}, end_dim={end_dim}, rank={input.rank}.") shape = input.shape flattened_dim_size = math.prod(shape[start_dim : end_dim + 1]) flattened_shape = shape[:start_dim] + [flattened_dim_size] + shape[end_dim + 1 :] return reshape(input, flattened_shape)