# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AtomicData-native DataLoader with CUDA-stream prefetching.
The ``DataLoader`` class is designed to be a drop-in replacement
for ``torch.data.DataLoader``, specializing for ``nvalchemi``
and atomistic systems by emitting ``Batch`` data.
Additionally, the ``DataLoader`` provides two mechanisms for
performant data loading: an asynchronous prefetching mechanism,
as well as the use of CUDA streams; both of which can be used
to developer highly performance data loading and preprocessing
workflows.
"""
from __future__ import annotations
from collections.abc import Iterator
import torch
from torch.utils.data import RandomSampler, Sampler, SequentialSampler
from nvalchemi.data.batch import Batch
from nvalchemi.data.datapipes.dataset import Dataset
[docs]
class DataLoader:
"""Batch-iterating data loader that yields :class:`~nvalchemi.data.batch.Batch`.
Wraps a :class:`Dataset` and yields ``Batch`` objects
built via :meth:`Batch.from_data_list`. CUDA-stream prefetching is
supported for overlapping I/O with computation.
Parameters
----------
dataset : Dataset
AtomicData-native dataset to load from.
batch_size : int, default=1
Number of samples per batch.
shuffle : bool, default=False
Randomize sample order each epoch.
drop_last : bool, default=False
Drop the last incomplete batch.
sampler : torch.utils.data.Sampler | None, default=None
Custom sampler (overrides ``shuffle``).
prefetch_factor : int, default=2
How many batches to prefetch ahead.
num_streams : int, default=4
Number of CUDA streams for prefetching.
use_streams : bool, default=True
Enable CUDA-stream prefetching.
Examples
--------
>>> from nvalchemi.data.datapipes import AtomicDataZarrReader, Dataset, DataLoader
>>> reader = AtomicDataZarrReader("dataset.zarr") # doctest: +SKIP
>>> ds = Dataset(reader, device="cpu") # doctest: +SKIP
>>> loader = DataLoader(ds, batch_size=4) # doctest: +SKIP
>>> for batch in loader: # doctest: +SKIP
... print(batch.positions.shape)
"""
def __init__(
self,
dataset: Dataset,
*,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
sampler: Sampler | None = None,
prefetch_factor: int = 2,
num_streams: int = 4,
use_streams: bool = True,
) -> None:
"""Initialize the AtomicData-native DataLoader.
Parameters
----------
dataset : Dataset
AtomicData-native dataset to load from.
batch_size : int, default=1
Number of samples per batch.
shuffle : bool, default=False
Randomize sample order each epoch.
drop_last : bool, default=False
Drop the last incomplete batch.
sampler : torch.utils.data.Sampler | None, default=None
Custom sampler (overrides ``shuffle``).
prefetch_factor : int, default=2
How many batches to prefetch ahead.
num_streams : int, default=4
Number of CUDA streams for prefetching.
use_streams : bool, default=True
Enable CUDA-stream prefetching.
Raises
------
ValueError
If batch_size < 1.
"""
if batch_size < 1:
raise ValueError(f"batch_size must be >= 1, got {batch_size}")
# Set up attributes directly (standalone class)
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.prefetch_factor = prefetch_factor
self.num_streams = num_streams
self.use_streams = use_streams and torch.cuda.is_available()
# Handle sampler
if sampler is not None:
self.sampler = sampler
elif shuffle:
self.sampler = RandomSampler(dataset)
else:
self.sampler = SequentialSampler(dataset)
# Create CUDA streams for prefetching
self._streams: list[torch.cuda.Stream] = []
if self.use_streams:
for _ in range(num_streams):
self._streams.append(torch.cuda.Stream())
def __len__(self) -> int:
"""Return the number of batches.
Returns
-------
int
Number of batches in the dataloader.
"""
n_samples = len(self.dataset)
if self.drop_last:
return n_samples // self.batch_size
return (n_samples + self.batch_size - 1) // self.batch_size
def __iter__(self) -> Iterator[Batch]:
"""Iterate over batches.
Uses stream-based prefetching when enabled to overlap IO,
GPU transfers, and computation.
Yields
------
Batch
Batched AtomicData as a disjoint graph.
"""
if self.prefetch_factor > 0 and self.use_streams:
yield from self._iter_prefetch()
else:
yield from self._iter_simple()
def _generate_batches(self) -> Iterator[list[int]]:
"""Generate batches of indices.
Yields
------
list[int]
List of sample indices for each batch.
"""
batch: list[int] = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch and not self.drop_last:
yield batch
def _iter_simple(self) -> Iterator[Batch]:
"""Simple synchronous iteration without prefetching.
Yields
------
Batch
Collated batch of AtomicData.
"""
for batch_indices in self._generate_batches():
samples = [self.dataset[idx] for idx in batch_indices]
# Extract AtomicData from (AtomicData, metadata) tuples
data_list = [atomic_data for atomic_data, _ in samples]
batch = Batch.from_data_list(data_list, skip_validation=True)
yield batch
def _iter_prefetch(self) -> Iterator[Batch]:
"""Iteration with stream-based prefetching.
Strategy:
1. Prefetch ``prefetch_factor`` batches worth of samples
2. As we yield batches, prefetch more to keep the pipeline full
3. Each sample in a batch uses a different stream for overlap
Yields
------
Batch
Collated batch of AtomicData.
"""
# Collect all batches upfront for prefetch planning
all_batches = list(self._generate_batches())
if not all_batches:
return
num_prefetch_batches = min(self.prefetch_factor, len(all_batches))
stream_idx = 0
# Start initial prefetch
prefetched_up_to = 0
for batch_idx in range(num_prefetch_batches):
for sample_idx in all_batches[batch_idx]:
stream = self._streams[stream_idx % self.num_streams]
self.dataset.prefetch(sample_idx, stream=stream)
stream_idx += 1
prefetched_up_to = batch_idx + 1
# Yield batches and prefetch more
for batch_idx, batch_indices in enumerate(all_batches):
# Collect samples (uses prefetched if available)
samples = [self.dataset[idx] for idx in batch_indices]
# Extract AtomicData from (AtomicData, metadata) tuples
data_list = [atomic_data for atomic_data, _ in samples]
batch = Batch.from_data_list(data_list, skip_validation=True)
# Prefetch next batch if available
next_prefetch_idx = prefetched_up_to
if next_prefetch_idx < len(all_batches):
for sample_idx in all_batches[next_prefetch_idx]:
stream = self._streams[stream_idx % self.num_streams]
self.dataset.prefetch(sample_idx, stream=stream)
stream_idx += 1
prefetched_up_to += 1
yield batch
# Clean up any remaining prefetch state
self.dataset.cancel_prefetch()
[docs]
def set_epoch(self, epoch: int) -> None:
"""Set the epoch for the sampler (used in distributed training).
Parameters
----------
epoch : int
Current epoch number.
"""
if hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)