Source code for accvlab.dali_pipeline_framework.inputs.sampler_input_iterable

# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Sequence, Optional

from nvidia.dali import types

from ..pipeline import SampleDataGroup

from .iterable_base import IterableBase
from .sampler_base import SamplerBase
from .data_provider import DataProvider

try:
    from typing import override
except ImportError:
    from typing_extensions import override


[docs] class SamplerInputIterable(IterableBase): '''Input iterable using a sampler to provide batches according to the sampler (also see :class:`SamplerBase`). The iterable can be used with a parallel external source. However, in this case, the data reading is performed in one worker process due to serial nature of an iterable. This means that while the data reading is asynchronous to the main thread, it is not further parallelized. This iterable also handles indicating the end of an epoch (by raising :class:`StopIteration`). Information on when an epoch ends is obtained from the sampler (which in turn should indicate this by raising :class:`StopIteration`, see documentation of :class:`SamplerBase`). After the end of the epoch, the iterable needs to be reset (by obtaining a new iterator) before data for the next epoch can be obtained. Note: If further parallelization is desired (i.e. more than one worker thread), :class:`SamplerInputCallable` can be used instead of this class (at the cost of pre-computing look-up tables in advance, see the corresponding note in the documentation of :class:`SamplerInputCallable`). ''' def __init__( self, data_provider: DataProvider, sampler: SamplerBase, shard_id: int = 0, num_shards: int = 1, ): ''' Args: data_provider: Data provider to use (following the interface defined in :class:`DataProvider`). sampler: Sampler to use (following the interface defined in :class:`SamplerBase`). shard_id: Shard ID (default value of 0 should be used if sharding is not used). num_shards: Total of shards (default value of 1 should be used if sharding is not used). ''' self._data_provider = data_provider self._shard_id = shard_id self._num_shards = num_shards self._sampler = sampler self._local_batch_size = None self._total_batch_size = None self._epoch = 0 self._before_first_iter_called = True self._sharding_set_up = False @property @override def used_sample_data_structure(self) -> SampleDataGroup: '''Data format blueprint used for the individual samples''' res = self._data_provider.sample_data_structure res.set_apply_mapping(False) return res @override def __iter__(self) -> 'SamplerInputIterable': if self._before_first_iter_called: self._before_first_iter_called = False else: if self._sampler.is_epoch_based: self._sampler.reset() return self @override def __next__(self) -> tuple: batch_indices = self._sampler.get_next_batch_indices() if not self._sharding_set_up: self._total_batch_size = len(batch_indices) self._local_batch_size = self._total_batch_size // self._num_shards assert ( self._local_batch_size * self._num_shards == self._total_batch_size ), "Total batch size is not divisible by the number of used shards." self._sharding_set_up = True min_index_in_total_batch = self._shard_id * self._local_batch_size indices_to_use = batch_indices[ min_index_in_total_batch : min_index_in_total_batch + self._local_batch_size ] sample_data = [self._data_provider.get_data(idx) for idx in indices_to_use] batch_res_data = [sd.get_data() for sd in sample_data] res = self._combine_res_to_batch(batch_res_data) return res @property @override def length(self) -> Optional[int]: '''Number of batches in one epoch. If the underlying sampler is not epoch-based, ``None`` is returned. ''' return self._sampler.length @staticmethod def _combine_res_to_batch(per_sample_res: Sequence[Sequence[Any]]) -> Sequence[Sequence[Any]]: '''Combine per-sample results into a batch format. Transposes the data structure from ``per_sample_res[sample_idx][field_idx]`` to ``output[field_idx][sample_idx]`` to match the expected batch format for DALI. Here, ``fields`` are the individual data fields of the sample data structure. Args: per_sample_res: List of sample results, where each sample is a list of field values. Returns: List of field results, where each field is a list of sample values. ''' num_samples = len(per_sample_res) num_fields = len(per_sample_res[0]) res = [None] * num_fields for f in range(num_fields): res[f] = [per_sample_res[s][f] for s in range(num_samples)] return res