# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromtripyimportconstraints,exportfromtripy.common.exceptionimportraise_error
[docs]@export.public_api(document_under="operations/functions")@constraints.dtypes(constraints={"input":"T1",constraints.RETURN_VALUE:"T1"},variables={"T1":["float32","float16","bfloat16","float8","int4","int8","int32","int64","bool"]},)defflatten(input:"tripy.Tensor",start_dim:int=0,end_dim:int=-1)->"tripy.Tensor":""" Flattens the input tensor from start_dim to end_dim. Args: input: The input tensor to be flattened. start_dim: The first dimension to flatten (default is 0). end_dim: The last dimension to flatten (default is -1, which includes the last dimension). Returns: A flattened tensor. .. code-block:: python :linenos: :caption: Flatten All Dimensions input = tp.iota((1, 2, 1), dtype=tp.float32) output = tp.flatten(input) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().flatten()) .. code-block:: python :linenos: :caption: Flatten Starting from First Dimension input = tp.iota((2, 3, 4), dtype=tp.float32) output = tp.flatten(input, start_dim=1) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().reshape(2, -1)) .. code-block:: python :linenos: :caption: Flatten a Specific Range of Dimensions input = tp.iota((2, 3, 4, 5), dtype=tp.float32) output = tp.flatten(input, start_dim=1, end_dim=2) assert np.array_equal(cp.from_dlpack(output).get(), cp.from_dlpack(input).get().reshape(2, -1, 5)) """fromtripy.frontend.trace.ops.reshapeimportreshape# Infer the actual dimensions to flatten based on start_dim and end_dim.ifstart_dim<0:start_dim+=input.rankifend_dim<0:end_dim+=input.rank# Ensure start_dim and end_dim are within the valid range.ifnot(0<=start_dim<input.rank)ornot(start_dim<=end_dim<input.rank):raise_error(f"Invalid dimensions: start_dim={start_dim}, end_dim={end_dim}, rank={input.rank}.")shape=input.shapeflattened_dim_size=math.prod(shape[start_dim:end_dim+1])flattened_shape=shape[:start_dim]+[flattened_dim_size]+shape[end_dim+1:]returnreshape(input,flattened_shape)