Source code for accvlab.dali_pipeline_framework.pipeline.dali_structured_output_iterator

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Type, Union, Any, Callable, Optional

from collections.abc import Iterator as ABCIterator
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchDALIGenericIterator

from torch.utils.data import DataLoader

from .sample_data_group import SampleDataGroup

try:
    from typing import override
except ImportError:
    from typing_extensions import override


[docs] class DALIStructuredOutputIterator(object): '''Structured access to DALI pipeline output (as a nested dict or :class:`SampleDataGroup`). Designed as a drop-in replacement for a :class:`torch.utils.data.DataLoader`. Optionally applies a user-defined lightweight post-processing function (e.g., conversions to types not supported by DALI). '''
[docs] class SimpleIterator(ABCIterator): '''Iterator, which can e.g. be used as a drop-in replacement for a PyTorch DataLoader iterator. Note that a single iterator should be used at any point in time. if multiple iterators are used, they share the state, i.e. getting a new iterator will reset all other iterators and calling next for one iterator will advance all iterators by one element. ''' def __init__(self, obj: DALIStructuredOutputIterator): self._obj = obj # Make sure the iteration starts anew self.reset()
[docs] @override def __next__(self): '''Get the next element.''' return self._obj._next()
[docs] @override def __iter__(self): '''Get the iterator (i.e. return ``self`` without re-starting the iteration).''' return self
[docs] def reset(self): '''Reset the iterator. Will call :meth:`DALIStructuredOutputIterator.reset` for the parent object. ''' self._obj.reset()
[docs] def __len__(self): '''Get the number of elements (from parent object).''' return len(self._obj)
def __init__( self, num_batches_in_epoch: int, pipeline: Pipeline, sample_data_structure_blueprint: SampleDataGroup, contained_dataset: Optional[Any] = None, dali_generic_iterator_class: Union[ Type[PyTorchDALIGenericIterator], Any ] = PyTorchDALIGenericIterator, convert_sample_data_group_to_dict: bool = True, post_process_func: Optional[ Callable[[Union[SampleDataGroup, dict]], Union[SampleDataGroup, dict]] ] = None, ): ''' Args: num_batches_in_epoch: Number of batches in an epoch. Note that this value is only used to output if ``len(obj)`` is called. It is not used internally and is added here to ensure drop-in compatibility with :class:`torch.utils.data.DataLoader`. pipeline: DALI pipeline object. sample_data_structure_blueprint: Blueprint for the output data structure. contained_dataset: Dataset object which will be exposed via :attr:`dataset` (mirrors PyTorch ``DataLoader`` behavior). Can be a PyTorch ``Dataset`` or any other compatible object. Note that this object is not used internally. Also see :attr:`dataset`. dali_generic_iterator_class: Class for the internal DALI generic iterator. Follows the :class:`PyTorchDALIGenericIterator` interface but may emit tensors for other frameworks. Defaults to :class:`PyTorchDALIGenericIterator`. convert_sample_data_group_to_dict: If ``True``, convert output :class:`SampleDataGroup` to a nested :class:`dict`. Ensures drop-in compatibility with ``DataLoader`` when no post-processing function is provided. post_process_func: Optional post-processing function for the output. This can be e.g. used to convert data to types not supported by DALI. or perform other light-weight steps. The input is a :class:`SampleDataGroup` object if ``convert_sample_data_group_to_dict == False`` and a :class:`dict` otherwise. Note that this function is executed when the data is accessed in the thread accessing the data (typically the thread performing the training). Therefore, this function should be kept lightweight to avoid performance penalties. ''' flattened_names = sample_data_structure_blueprint.field_names_flat self._num_batches_in_epoch = num_batches_in_epoch self._iterator = dali_generic_iterator_class(pipeline, list(flattened_names)) self._blueprint = sample_data_structure_blueprint.get_empty_like_self() self._contained_dataset = contained_dataset self._convert_sample_data_group_to_dict = convert_sample_data_group_to_dict self._post_process_func = post_process_func
[docs] def __iter__(self) -> DALIStructuredOutputIterator.SimpleIterator: '''Get an iterator. Note that a single iterator should be used at any point in time. if multiple iterators are used, they share the state, i.e. getting a new iterator will reset all other iterators and calling next for one iterator will advance all iterators by one element. ''' res = self.SimpleIterator(self) return res
def _next(self) -> Union[SampleDataGroup, dict]: data = self._iterator.__next__() data_structured = self._blueprint.get_empty_like_self() data_structured.set_data_from_dali_generic_iterator_output(data, 0) if self._convert_sample_data_group_to_dict: data_structured = data_structured.to_dictionary() if self._post_process_func is not None: data_structured = self._post_process_func(data_structured) return data_structured
[docs] def reset(self): '''Reset the current iteration progress (start over from the beginning). Note that this will reset iterators of the object as well. ''' self._iterator.reset()
@property def sample_data_structure_blueprint(self) -> SampleDataGroup: '''Get the output data structure blueprint. The blueprint is a :class:`SampleDataGroup` representing the same nested data format as the output, without the actual data. See :class:`SampleDataGroup` for details. ''' return self._blueprint.get_empty_like_self() @property def internal_iterator(self) -> Union[PyTorchDALIGenericIterator, Any]: '''Get the actual DALI iterator used to access the output data internally. Note that by default, this is a :class:`nvidia.dali.plugin.pytorch.DALIGenericIterator`. However, this can be changed in the constructor and in this case, the returned object will be of the type specified in the constructor. ''' return self._iterator @property def dataset(self) -> Any: '''Get the dataset object. This is the dataset object set in the constructor (if any). If not set, this will return the object for which it is called. This property is used for compatibility with :class:`torch.utils.data.DataLoader`. ''' if self._contained_dataset is None: return self else: return self._contained_dataset
[docs] def __len__(self): '''Number of available batches. Important: This value is set manually in the constructor, and only used to output it here. This is done as it is a part of the :class:`torch.utils.data.DataLoader` interface. The value my be not the actual number of batches in the epoch, e.g. for non-epoch based pipelines. ''' return self._num_batches_in_epoch
[docs] @classmethod def CreateAsDataLoaderObject(cls, *args, **kwargs): from .dali_structured_output_iterator_data_loader_wrapper import get_masked_as_type used_type = get_masked_as_type(cls, DataLoader) obj = used_type(*args, **kwargs) return obj