Source code for nvalchemi.data.datapipes.dataloader

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AtomicData-native DataLoader with CUDA-stream prefetching.

The ``DataLoader`` class is designed to be a drop-in replacement
for ``torch.data.DataLoader``, specializing for ``nvalchemi``
and atomistic systems by emitting ``Batch`` data.

Additionally, the ``DataLoader`` provides two mechanisms for
performant data loading: an asynchronous prefetching mechanism,
as well as the use of CUDA streams; both of which can be used
to developer highly performance data loading and preprocessing
workflows.
"""

from __future__ import annotations

from collections.abc import Iterator

import torch
from torch.utils.data import RandomSampler, Sampler, SequentialSampler

from nvalchemi.data.batch import Batch
from nvalchemi.data.datapipes.dataset import Dataset


[docs] class DataLoader: """Batch-iterating data loader that yields :class:`~nvalchemi.data.batch.Batch`. Wraps a :class:`Dataset` and yields ``Batch`` objects built via :meth:`Batch.from_data_list`. CUDA-stream prefetching is supported for overlapping I/O with computation. Parameters ---------- dataset : Dataset AtomicData-native dataset to load from. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False Randomize sample order each epoch. drop_last : bool, default=False Drop the last incomplete batch. sampler : torch.utils.data.Sampler | None, default=None Custom sampler (overrides ``shuffle``). prefetch_factor : int, default=2 How many batches to prefetch ahead. num_streams : int, default=4 Number of CUDA streams for prefetching. use_streams : bool, default=True Enable CUDA-stream prefetching. Examples -------- >>> from nvalchemi.data.datapipes import AtomicDataZarrReader, Dataset, DataLoader >>> reader = AtomicDataZarrReader("dataset.zarr") # doctest: +SKIP >>> ds = Dataset(reader, device="cpu") # doctest: +SKIP >>> loader = DataLoader(ds, batch_size=4) # doctest: +SKIP >>> for batch in loader: # doctest: +SKIP ... print(batch.positions.shape) """ def __init__( self, dataset: Dataset, *, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, sampler: Sampler | None = None, prefetch_factor: int = 2, num_streams: int = 4, use_streams: bool = True, ) -> None: """Initialize the AtomicData-native DataLoader. Parameters ---------- dataset : Dataset AtomicData-native dataset to load from. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False Randomize sample order each epoch. drop_last : bool, default=False Drop the last incomplete batch. sampler : torch.utils.data.Sampler | None, default=None Custom sampler (overrides ``shuffle``). prefetch_factor : int, default=2 How many batches to prefetch ahead. num_streams : int, default=4 Number of CUDA streams for prefetching. use_streams : bool, default=True Enable CUDA-stream prefetching. Raises ------ ValueError If batch_size < 1. """ if batch_size < 1: raise ValueError(f"batch_size must be >= 1, got {batch_size}") # Set up attributes directly (standalone class) self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.prefetch_factor = prefetch_factor self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() # Handle sampler if sampler is not None: self.sampler = sampler elif shuffle: self.sampler = RandomSampler(dataset) else: self.sampler = SequentialSampler(dataset) # Create CUDA streams for prefetching self._streams: list[torch.cuda.Stream] = [] if self.use_streams: for _ in range(num_streams): self._streams.append(torch.cuda.Stream()) def __len__(self) -> int: """Return the number of batches. Returns ------- int Number of batches in the dataloader. """ n_samples = len(self.dataset) if self.drop_last: return n_samples // self.batch_size return (n_samples + self.batch_size - 1) // self.batch_size def __iter__(self) -> Iterator[Batch]: """Iterate over batches. Uses stream-based prefetching when enabled to overlap IO, GPU transfers, and computation. Yields ------ Batch Batched AtomicData as a disjoint graph. """ if self.prefetch_factor > 0 and self.use_streams: yield from self._iter_prefetch() else: yield from self._iter_simple() def _generate_batches(self) -> Iterator[list[int]]: """Generate batches of indices. Yields ------ list[int] List of sample indices for each batch. """ batch: list[int] = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if batch and not self.drop_last: yield batch def _iter_simple(self) -> Iterator[Batch]: """Simple synchronous iteration without prefetching. Yields ------ Batch Collated batch of AtomicData. """ for batch_indices in self._generate_batches(): samples = [self.dataset[idx] for idx in batch_indices] # Extract AtomicData from (AtomicData, metadata) tuples data_list = [atomic_data for atomic_data, _ in samples] batch = Batch.from_data_list(data_list, skip_validation=True) yield batch def _iter_prefetch(self) -> Iterator[Batch]: """Iteration with stream-based prefetching. Strategy: 1. Prefetch ``prefetch_factor`` batches worth of samples 2. As we yield batches, prefetch more to keep the pipeline full 3. Each sample in a batch uses a different stream for overlap Yields ------ Batch Collated batch of AtomicData. """ # Collect all batches upfront for prefetch planning all_batches = list(self._generate_batches()) if not all_batches: return num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) stream_idx = 0 # Start initial prefetch prefetched_up_to = 0 for batch_idx in range(num_prefetch_batches): for sample_idx in all_batches[batch_idx]: stream = self._streams[stream_idx % self.num_streams] self.dataset.prefetch(sample_idx, stream=stream) stream_idx += 1 prefetched_up_to = batch_idx + 1 # Yield batches and prefetch more for batch_idx, batch_indices in enumerate(all_batches): # Collect samples (uses prefetched if available) samples = [self.dataset[idx] for idx in batch_indices] # Extract AtomicData from (AtomicData, metadata) tuples data_list = [atomic_data for atomic_data, _ in samples] batch = Batch.from_data_list(data_list, skip_validation=True) # Prefetch next batch if available next_prefetch_idx = prefetched_up_to if next_prefetch_idx < len(all_batches): for sample_idx in all_batches[next_prefetch_idx]: stream = self._streams[stream_idx % self.num_streams] self.dataset.prefetch(sample_idx, stream=stream) stream_idx += 1 prefetched_up_to += 1 yield batch # Clean up any remaining prefetch state self.dataset.cancel_prefetch()
[docs] def set_epoch(self, epoch: int) -> None: """Set the epoch for the sampler (used in distributed training). Parameters ---------- epoch : int Current epoch number. """ if hasattr(self.sampler, "set_epoch"): self.sampler.set_epoch(epoch)