Source code for nvalchemi.data.datapipes.backends.base

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base class for datapipe readers."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Any

import torch

logger = logging.getLogger(__name__)


[docs] class Reader(ABC): """Abstract base class for data readers. Readers are intentionally simple and transactional: - Load data from a source (file, database, etc.) - Return ``(dict[str, torch.Tensor], metadata_dict)`` tuples with CPU tensors - No threading, no prefetching, no device transfers Subclasses must implement :meth:`_load_sample` and :meth:`__len__`. Parameters ---------- pin_memory : bool, default=False If True, pin loaded tensors to page-locked memory for faster async CPU→GPU transfers. include_index_in_metadata : bool, default=True If True, automatically add ``"index"`` to each sample's metadata dict. Examples -------- >>> class MyReader(Reader): ... def _load_sample(self, index: int) -> dict[str, torch.Tensor]: ... return {"x": torch.randn(3)} ... def __len__(self) -> int: ... return 10 >>> reader = MyReader() # doctest: +SKIP >>> data, meta = reader[0] # doctest: +SKIP """ def __init__( self, *, pin_memory: bool = False, include_index_in_metadata: bool = True, ) -> None: """Initialize the Reader base class. Parameters ---------- pin_memory : bool, default=False If True, pin loaded tensors to page-locked memory for faster async CPU→GPU transfers. include_index_in_metadata : bool, default=True If True, automatically add ``"index"`` to each sample's metadata dict. """ self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata @abstractmethod def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load raw tensor data for a single sample. Parameters ---------- index : int Sample index (0-based). Returns ------- dict[str, torch.Tensor] Mapping of field names to CPU tensors. """ raise NotImplementedError @abstractmethod def __len__(self) -> int: """Return the total number of available samples. Returns ------- int Number of samples. """ raise NotImplementedError def _get_field_names(self) -> list[str]: """Return field names by inspecting the first sample. Returns ------- list[str] Field names from the first sample, or empty if reader is empty. """ if len(self) == 0: return [] data = self._load_sample(0) return list(data.keys()) def _get_sample_metadata(self, index: int) -> dict[str, Any]: """Return additional metadata for a sample. Override in subclasses to provide per-sample metadata such as source file paths or physical indices. Parameters ---------- index : int Sample index. Returns ------- dict[str, Any] Metadata dictionary. Empty by default. """ return {} @property def field_names(self) -> list[str]: """Field names available in each sample. Returns ------- list[str] Field names. """ return self._get_field_names() def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: """Load a sample and its metadata by index. Handles negative indexing, bounds checking, optional pin-memory, and automatic index injection into metadata. Parameters ---------- index : int Sample index. Negative values are supported. Returns ------- tuple[dict[str, torch.Tensor], dict[str, Any]] ``(data_dict, metadata)`` pair with CPU tensors. Raises ------ IndexError If *index* is out of range. """ if index < 0: index = len(self) + index if index < 0 or index >= len(self): raise IndexError( f"Index {index} out of range for reader with {len(self)} samples" ) data_dict = self._load_sample(index) metadata = self._get_sample_metadata(index) if self.include_index_in_metadata: metadata["index"] = index if self.pin_memory: data_dict = {k: v.pin_memory() for k, v in data_dict.items()} return data_dict, metadata def __iter__(self) -> Iterator[tuple[dict[str, torch.Tensor], dict[str, Any]]]: """Iterate over all samples sequentially. Yields ------ tuple[dict[str, torch.Tensor], dict[str, Any]] ``(data_dict, metadata)`` for each sample. Raises ------ RuntimeError If any sample fails to load. """ for i in range(len(self)): try: yield self[i] except Exception as e: error_msg = f"Sample {i} failed with exception: {type(e).__name__}: {e}" logger.error(error_msg) raise RuntimeError(error_msg) from e
[docs] def close(self) -> None: """Release resources held by the reader. Override in subclasses to close file handles, connections, etc. """ pass
def __enter__(self) -> Reader: """Enter context manager. Returns ------- Reader This reader instance. """ return self def __exit__( self, exc_type: type | None, exc_val: BaseException | None, exc_tb: Any ) -> None: """Exit context manager, calling :meth:`close`. Parameters ---------- exc_type : type | None Exception type, if any. exc_val : BaseException | None Exception value, if any. exc_tb : Any Exception traceback, if any. """ self.close() def __repr__(self) -> str: """Return a string representation of the reader. Returns ------- str Human-readable summary. """ return ( f"{self.__class__.__name__}(len={len(self)}, pin_memory={self.pin_memory})" )