Source code for accvlab.dali_pipeline_framework.inputs.sampler_input_callable
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from nvidia.dali import types
from ..pipeline import SampleDataGroup
from .callable_base import CallableBase
from .sampler_base import SamplerBase
from .data_provider import DataProvider
try:
from typing import override
except ImportError:
from typing_extensions import override
[docs]
class SamplerInputCallable(CallableBase):
'''Input callable using a sampler to provide data according to the sampler (also see :class:`~accvlab.dali_pipeline_framework.inputs.SamplerBase`).
This callable also handles indicating the end of an epoch (by raising :class:`StopIteration`).
Information on when an epoch ends is obtained from the sampler (which in turn should indicate this
by raising :class:`StopIteration`, see documentation of :class:`SamplerBase`).
As the sampler can have an internal state (while the callable is expected to be stateless), a look-up
table is pre-generated at construction, leading to overhead and the need to know the maximum number of
iterations in advance.
Note:
To avoid the overhead of pre-generating the look-up table, it is recommended to only use this class
if a single process for data loading is not enough and prefer
:class:`~accvlab.dali_pipeline_framework.inputs.SamplerInputIterable` in general.
'''
def __init__(
self,
data_provider: DataProvider,
sampler: SamplerBase,
max_num_iterations: int,
pre_fetch_queue_length: int,
shard_id: int = 0,
num_shards: int = 1,
):
'''
Args:
data_provider: Data provider to use (following the interface defined in :class:`DataProvider`).
sampler: Sampler to use (following the interface defined in :class:`SamplerBase`).
max_num_iterations: Maximum number of iterations that will be performed.
pre_fetch_queue_length: Length of the pre-fetch queue depth of the DALI pipeline using this input callable. Used together with ``max_num_iterations`` to ensure that
the sampling look-up table is large enough.
shard_id: Shard ID (default value of 0 should be used if sharding is not used)
num_shards: Total of shards (default value of 1 should be used if sharding is not used)
'''
self._data_provider = data_provider
self._shard_id = shard_id
self._num_shards = num_shards
self._max_num_iterations = max_num_iterations
self._pre_fetch_queue_length = pre_fetch_queue_length
self._max_num_iterations_inc_queue = max_num_iterations + pre_fetch_queue_length
self._look_up_table = []
i = 0
curr_epoch_look_up_table = []
while i < self._max_num_iterations_inc_queue:
try:
batch = sampler.get_next_batch_indices()
curr_epoch_look_up_table.append(batch)
i += 1
except StopIteration:
self._look_up_table.append(curr_epoch_look_up_table)
curr_epoch_look_up_table = []
sampler.reset()
self._look_up_table.append(curr_epoch_look_up_table)
self._total_batch_size = len(self._look_up_table[0][0])
self._local_batch_size = self._total_batch_size // num_shards
assert (
self._local_batch_size * self._num_shards == self._total_batch_size
), f"Total batch size ({self._total_batch_size}) not divisible by number of shards ({self._num_shards})."
@property
@override
def used_sample_data_structure(self) -> SampleDataGroup:
'''Data format blueprint used for the individual samples'''
res = self._data_provider.sample_data_structure
res.set_apply_mapping(False)
return res
@override
def __call__(self, sample_info: types.SampleInfo) -> tuple:
epoch_idx = sample_info.epoch_idx
batch_idx = sample_info.idx_in_epoch // self._local_batch_size
idx_in_local_batch = sample_info.idx_in_batch
if epoch_idx >= len(self._look_up_table):
raise RuntimeError(
f"Maximum iteration number or pre-fetch queue length exceeded. SamplerInputCallable can only be used for the maximum number of iterations defined at construction\n"
+ f"({self._max_num_iterations} in this case) and a maximum pre-fetch queue depth defined at construction ({self._pre_fetch_queue_length} in this case). These two values\n"
+ f"define the total maximum number of batches which the callable can provide ({self._max_num_iterations_inc_queue} in this case)."
)
epoch_size = len(self._look_up_table[epoch_idx])
if batch_idx >= epoch_size:
raise StopIteration
batch_of_indices = self._look_up_table[epoch_idx][batch_idx]
idx_in_full_batch = idx_in_local_batch + self._shard_id * self._local_batch_size
index_to_use = batch_of_indices[idx_in_full_batch]
sample_data = self._data_provider.get_data(index_to_use)
return sample_data.get_data()
@property
@override
def length(self) -> Optional[int]:
'''Number of batches in one epoch.
If the underlying sampler is not epoch-based, the length is the overall number of batches
that can be generated (i.e. the maximum number of iterations defined at construction plus
the pre-fetch queue length).
'''
return len(self._look_up_table[0])