Source code for accvlab.dali_pipeline_framework.processing_steps.unneeded_fields_remover

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, List, Tuple

try:
    from typing import override
except ImportError:
    from typing_extensions import override

from ..pipeline.sample_data_group import SampleDataGroup

from .pipeline_step_base import PipelineStepBase


[docs] class UnneededFieldRemover(PipelineStepBase): '''Processing step for removing unneeded fields from the data. This step does not add any processing steps to the DALI graph, i.e. it is fully performed on DALI graph construction time and does not have any overhead at runtime. This means that this step can be used inside the pipeline multiple times to ensure a clean data structure without any performance penalty (apart from the overhead at graph construction time). Note: For pipelines which use data which is not needed in the final output (e.g. intermediate results, image size on the CPU, etc.), it is advisable to perform this step at least once, directly before outputting the data, in order to avoid unneeded copies & clutter in the final output. ''' def __init__(self, unneeded_field_names: Union[Tuple[Union[str, int], ...], List[Union[str, int]]]): ''' Args: unneeded_field_names: Names of the fields to be removed. All fields with those names are removed. ''' assert isinstance(unneeded_field_names, tuple) or isinstance( unneeded_field_names, list ), f"`unneeded_field_names` has to be a tuple or a list; got: {type(unneeded_field_names)}" self._unneeded_field_names = unneeded_field_names @override def _process(self, data: SampleDataGroup) -> SampleDataGroup: # The actual processing (i.e. `__call__()`) is exactly the same as what `_check_and_adjust_data_format_input_to_output()` needs to do. Therefore, implement it # in this method and call it from both of the mentioned methods. for ufn in self._unneeded_field_names: data.remove_all_occurrences(ufn) return data @override def _check_and_adjust_data_format_input_to_output(self, data_empty: SampleDataGroup) -> SampleDataGroup: return self._process(data_empty)