Source code for accvlab.dali_pipeline_framework.processing_steps.conditional_element_removal

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Used to enable type hints using a class type inside the implementation of that class itself.
from __future__ import annotations

from typing import Union, Tuple, Sequence, List

try:
    from typing import override
except ImportError:
    from typing_extensions import override

from ..operators_impl import numba_operators as numba_op

from ..pipeline.sample_data_group import SampleDataGroup

from .pipeline_step_base import PipelineStepBase


[docs] class ConditionalElementRemover(PipelineStepBase): '''Remove elements from arrays (e.g., per‑object data) based on a boolean mask. Arrays are stored as (multi-dimensional) tensors; for each array a dimension index indicates the element axis (the axis along which the elements to be removed/retained are enumerated). Elements with mask value ``False`` are removed along the configured dimension for each target field. See also: Multiple classes are available which evaluate conditions of some kind and store the results as boolean masks. These masks can be used in this class: - :class:`AnnotationElementConditionEval` - :class:`VisibleBboxSelector` - :class:`PointsInRangeCheck` - :class:`BoundingBoxToHeatmapConverter` ''' def __init__( self, annotation_field_name: Union[str, int], mask_field_name: Union[str, int], field_names_to_process: Sequence[Union[str, int]], field_dims_to_process: Sequence[int], fields_to_process_num_dims: Sequence[int], remove_mask_field: bool, ): ''' Args: annotation_field_name: Name of the annotation data group field to process. Each annotation field is processed independently. mask_field_name: Name of the boolean mask indicating which elements to keep (``True``) or remove (``False``). Must be a child of each annotation field. field_names_to_process: Names of fields to process. The fields must be present in each annotation field. field_dims_to_process: For each field name, the dimension index along which elements are to be removed. fields_to_process_num_dims: For each field name, the number of dimensions in the tensor. remove_mask_field: Whether to remove the mask field after applying this step. ''' assert len(field_names_to_process) == len( field_dims_to_process ), "Number of elements in `field_names_to_process` and `field_dims_to_process` do not match." assert len(field_names_to_process) == len( fields_to_process_num_dims ), "Number of elements in `field_names_to_process` and `fields_to_process_num_dims` do not match." self._annotation_field_name = annotation_field_name self._mask_field_name = mask_field_name self._field_names_to_process = field_names_to_process self._field_dims_to_process = field_dims_to_process self._fields_to_process_num_dims = fields_to_process_num_dims self._do_remove_mask_field = remove_mask_field @override def _process(self, data: SampleDataGroup) -> SampleDataGroup: # Make sure annotations have all needed fields and set output fields annotation_paths = data.find_all_occurrences(self._annotation_field_name) for ap in annotation_paths: annotations = data.get_item_in_path(ap) is_active = annotations[self._mask_field_name] for name, dim, num_dims in zip( self._field_names_to_process, self._field_dims_to_process, self._fields_to_process_num_dims ): curr_data = annotations[name] curr_data_type = annotations.get_type_of_field(name) curr_res = numba_op.remove_inactive(curr_data, is_active, dim, num_dims, curr_data_type) annotations[name] = curr_res if self._do_remove_mask_field: self._remove_mask(data) return data @override def _check_and_adjust_data_format_input_to_output(self, data_empty: SampleDataGroup) -> SampleDataGroup: annotation_paths = data_empty.find_all_occurrences(self._annotation_field_name) if len(annotation_paths) == 0: raise annotation_paths( f"No occurrences of annotations found. Annotation data group fields are expected to have the name '{self._annotation_field_name}', as specified in the constructor." ) for ap in annotation_paths: annotation = data_empty.get_item_in_path(ap) for field in self._field_names_to_process: if not field in annotation.contained_top_level_field_names: raise KeyError(f"No field to process '{field}' in annotation at path '{ap}'") if not self._mask_field_name in annotation.contained_top_level_field_names: raise KeyError(f"No mask field '{self._mask_field_name}' in annotation at path `{ap}`") if self._do_remove_mask_field: self._remove_mask(data_empty) return data_empty def _remove_mask(self, data_inout: SampleDataGroup): '''Remove the mask field from the annotation data group. Args: data_inout: Data to be processed by the step. ''' annotation_paths = data_inout.find_all_occurrences(self._annotation_field_name) for ap in annotation_paths: annotation = data_inout.get_item_in_path(ap) annotation.remove_field(self._mask_field_name)