Source code for nvalchemi.dynamics.sinks

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data sink abstractions for storing and retrieving batched atomic data.

This module provides storage backends for Batch data used in dynamics
simulations. Implementations include GPU buffers, CPU memory, and
disk-backed Zarr storage.
"""

from __future__ import annotations

import shutil
from abc import ABC, abstractmethod
from pathlib import Path

import torch
from torch import distributed as dist

from nvalchemi.data import AtomicData, Batch
from nvalchemi.data.datapipes.backends.zarr import (
    AtomicDataZarrReader,
    AtomicDataZarrWriter,
    StoreLike,
)


[docs] class DataSink(ABC): """ Abstract base class for local storage of Batch data. DataSink provides a unified interface for storing and retrieving batched atomic data. Implementations can target different storage backends such as GPU memory, CPU memory, or disk. Attributes ---------- capacity : int Maximum number of samples that can be stored. Methods ------- write(batch) Store a batch of data. read() Retrieve all stored data as a Batch. drain() Read all stored data and clear the sink. zero() Clear all stored data. __len__() Return the number of samples currently stored. Examples -------- >>> sink = HostMemory(capacity=100) >>> sink.write(batch) >>> len(sink) 2 >>> retrieved = sink.read() """
[docs] @abstractmethod def write(self, batch: Batch, mask: torch.Tensor | None = None) -> None: """ Store a batch of atomic data. Parameters ---------- batch : Batch The batch of atomic data to store. mask : torch.Tensor | None, optional Boolean tensor of shape ``(batch.num_graphs,)`` indicating which samples to copy (``True`` = copy). If ``None``, all samples are copied. Default is ``None``. Raises ------ RuntimeError If the buffer is full and cannot accept more data. """ ...
[docs] @abstractmethod def read(self) -> Batch: """ Retrieve all stored data as a single Batch. Returns ------- Batch A batch containing all stored atomic data. Raises ------ RuntimeError If no data has been stored (buffer is empty). """ ...
[docs] @abstractmethod def zero(self) -> None: """ Clear all stored data and reset the buffer. After calling this method, `len(self)` returns 0. """ ...
[docs] @abstractmethod def __len__(self) -> int: """ Return the number of samples currently stored. Returns ------- int Number of atomic data samples in the buffer. """ ...
@property @abstractmethod def capacity(self) -> int: """ Return the maximum storage capacity. Returns ------- int Maximum number of samples that can be stored. """ ... @property def is_full(self) -> bool: """ Check if the buffer has reached capacity. Returns ------- bool True if the buffer is at or over capacity, False otherwise. """ return len(self) >= self.capacity
[docs] def drain(self) -> Batch: """ Read all stored samples and clear the sink. This is equivalent to calling :meth:`read` followed by :meth:`zero`, but subclasses may override for a more efficient atomic operation. Returns ------- Batch All samples that were stored in the sink. Raises ------ RuntimeError If the sink is empty. """ batch = self.read() self.zero() return batch
@property def local_rank(self) -> int: """Return the local rank of this data sink.""" rank = 0 if dist.is_initialized(): rank = dist.get_node_local_rank() return rank @property def global_rank(self) -> int: """Return the global rank of this data sink.""" rank = 0 if dist.is_initialized(): rank = dist.get_global_rank() return rank
[docs] class GPUBuffer(DataSink): """GPU-resident buffer for storing batched atomic data. This buffer lazily pre-allocates a :class:`Batch` with fixed maximum sizes for atoms and edges on the first :meth:`write` call. The incoming batch serves as a template for attribute keys and dtypes, ensuring all fields are preserved (not just positions and atomic_numbers). Parameters ---------- capacity : int Maximum number of samples (graphs) to store. max_atoms : int Maximum number of atoms per sample. max_edges : int Maximum number of edges per sample. device : torch.device | str, optional CUDA device to store data on. Default is "cuda". Attributes ---------- capacity : int Maximum storage capacity. device : torch.device Target CUDA device for stored tensors. Examples -------- >>> buffer = GPUBuffer(capacity=100, max_atoms=50, max_edges=200, device="cuda:0") >>> buffer.write(batch) >>> len(buffer) 2 >>> retrieved = buffer.read() """
[docs] def __init__( self, capacity: int, max_atoms: int, max_edges: int, device: torch.device | str = "cuda", ) -> None: """Initialize the GPU buffer. Parameters ---------- capacity : int Maximum number of samples (graphs) to store. max_atoms : int Maximum number of atoms per sample. max_edges : int Maximum number of edges per sample. device : torch.device | str, optional CUDA device to store data on. Default is "cuda". Raises ------ RuntimeError If CUDA is not available or a non-CUDA device is specified. """ if not torch.cuda.is_available(): raise RuntimeError( "GPUBuffer requires available CUDA devices:" f" found CUDA available: {torch.cuda.is_available()}" f" with device count={torch.cuda.device_count()}" ) if isinstance(device, str) and "cuda" not in device: raise RuntimeError(f"GPUBuffer requires a CUDA device, got: '{device}'") if isinstance(device, torch.device) and "cuda" not in device.type: raise RuntimeError( f"GPUBuffer requires a CUDA device, got: '{device.type}'" ) self._capacity = capacity self._max_atoms = max_atoms self._max_edges = max_edges self._device = torch.device(device) if isinstance(device, str) else device self._buffer: Batch | None = None # _copied_mask is allocated fresh on each write (per-write output mask) self._copied_mask: torch.Tensor | None = None # Pre-allocated dest_mask tracks which buffer slots are occupied (capacity-sized) self._dest_mask: torch.Tensor | None = None
def _ensure_buffer(self, template: Batch) -> None: """Create the internal Batch buffer on first use. Allocates a pre-sized buffer using :meth:`Batch.empty` with capacity derived from constructor parameters. The template batch provides attribute keys and dtypes. Parameters ---------- template : Batch A concrete batch to derive attribute keys and dtypes from. """ if self._buffer is not None: return self._buffer = Batch.empty( num_systems=self._capacity, num_nodes=self._capacity * self._max_atoms, num_edges=self._capacity * self._max_edges, template=template, device=self._device, ) # Pre-allocate dest_mask for system-level occupancy tracking self._dest_mask = torch.zeros( self._capacity, dtype=torch.bool, device=self._device ) # Trigger lazy init of _batch_ptr for all groups so zero() can preserve it for group in self._buffer._storage.groups.values(): if hasattr(group, "_lazy_init_batch_ptr"): group._lazy_init_batch_ptr() # Extend _batch_ptr to capacity + 2 (Batch.empty uses capacity + 1, # but compute_put_per_system_fit_mask requires capacity + 2) self._restore_batch_ptr_capacity() def _restore_batch_ptr_capacity(self) -> None: """Restore ``_batch_ptr`` to full pre-allocated capacity after put. :meth:`SegmentedLevelStorage.put` trims ``_batch_ptr`` to the number of active segments via ``.clone()``, which destroys the headroom needed for subsequent appends. This method re-extends each segmented group's ``_batch_ptr`` to ``capacity + 2`` while preserving the meaningful prefix written by the Warp kernels. The ``+2`` accounts for the formula used by :meth:`SegmentedLevelStorage.compute_put_per_system_fit_mask`: it requires ``num_dest_segments + n_seg + 2`` entries to accommodate segment boundaries plus safety margin. """ if self._buffer is None: return # Need capacity + 2 to satisfy compute_put_per_system_fit_mask formula batch_ptr_cap = self._capacity + 2 for group in self._buffer._storage.groups.values(): if not hasattr(group, "segment_lengths"): continue # skip UniformLevelStorage bp = group._batch_ptr if bp is not None and bp.shape[0] < batch_ptr_cap: new_bp = torch.zeros(batch_ptr_cap, dtype=bp.dtype, device=bp.device) new_bp[: bp.shape[0]] = bp group._batch_ptr = new_bp def write(self, batch: Batch, mask: torch.Tensor | None = None) -> None: """Store atomic data into the buffer. When *mask* is provided, only samples where ``mask[i]`` is ``True`` are copied into the buffer. When *mask* is ``None``, all samples in *batch* are copied. Uses :meth:`Batch.put` for efficient in-place copying without tensor allocation. This method will set values for ``_copied_mask`` and ``_dest_mask``. Parameters ---------- batch : Batch The source batch of atomic data. mask : torch.Tensor | None, optional Boolean tensor of shape ``(batch.num_graphs,)`` indicating which samples to copy (``True`` = copy). If ``None``, all samples are copied. Raises ------ RuntimeError If adding the selected samples would exceed capacity, or if a system exceeds the configured max_atoms or max_edges limits. ValueError If mask length does not match batch.num_graphs. """ num_total = batch.num_graphs or 0 if num_total == 0: return # Build mask if not provided if mask is None: mask = torch.ones(num_total, dtype=torch.bool, device=batch.device) else: mask = mask.to(device=batch.device, dtype=torch.bool) if mask.shape[0] != num_total: raise ValueError( f"mask length {mask.shape[0]} != num_graphs {num_total}" ) # Ensure buffer is allocated with full capacity (lazy init on first write) self._ensure_buffer(template=batch) # Count how many graphs we're trying to write num_to_write = int(mask.sum().item()) if num_to_write == 0: return # Validate graph capacity current_count = len(self) if current_count + num_to_write > self._capacity: raise RuntimeError( f"Buffer is full. Cannot add {num_to_write} samples to buffer " f"with {current_count}/{self._capacity} samples." ) # Validate atom capacity for masked graphs nodes_per_graph = batch.num_nodes_per_graph max_atoms_in_batch = int(nodes_per_graph[mask].max().item()) if max_atoms_in_batch > self._max_atoms: raise RuntimeError( f"Atom capacity exceeded: a system has {max_atoms_in_batch} atoms " f"but buffer max_atoms={self._max_atoms}" ) # Validate edge capacity for masked graphs (only if edges exist in batch) if self._max_edges > 0 and batch.num_edges > 0: edges_per_graph = batch.num_edges_per_graph max_edges_in_batch = int(edges_per_graph[mask].max().item()) if max_edges_in_batch > self._max_edges: raise RuntimeError( f"Edge capacity exceeded: a system has {max_edges_in_batch} edges " f"but buffer max_edges={self._max_edges}" ) # Allocate fresh per-write output mask indicating which src graphs were placed self._copied_mask = torch.zeros( num_total, dtype=torch.bool, device=self._device ) self._buffer.put( batch, mask, copied_mask=self._copied_mask, dest_mask=self._dest_mask, ) # Restore _batch_ptr capacity after put() trims it self._restore_batch_ptr_capacity() def read(self) -> Batch: """Retrieve stored (non-padding) data as a single Batch. The pre-allocated buffer may have more capacity than stored samples. This method extracts only the filled graphs, excluding zero-padded slots. Returns ------- Batch A batch containing the stored atomic data (no padding). Raises ------ RuntimeError If the buffer is empty. """ if self._buffer is None or len(self) == 0: raise RuntimeError("Cannot read from empty buffer.") if len(self) == self._capacity: # Buffer is full — return clone of entire buffer return self._buffer.clone() # Cast int32 batch_ptr → int64 for index_select compatibility. # Warp stores batch_ptr as int32; PyTorch index_select expects int64. # Save originals and restore afterwards to keep Warp kernel compatibility. saved_ptrs: dict[str, torch.Tensor] = {} for name, group in self._buffer._storage.groups.items(): bp = getattr(group, "_batch_ptr", None) if bp is not None and bp.dtype != torch.int64: saved_ptrs[name] = bp group._batch_ptr = bp.to(torch.int64) try: indices = torch.arange(len(self), dtype=torch.long, device=self._device) result = self._buffer.index_select(indices) finally: # Restore int32 batch_ptr for subsequent Warp put() calls for name, original in saved_ptrs.items(): self._buffer._storage.groups[name]._batch_ptr = original return result def zero(self) -> None: """Clear all stored data and reset the buffer. Zeros all subtensors within the pre-allocated buffer while preserving the data structure and allocated memory. This avoids re-allocation on the next :meth:`write` and keeps the buffer shape intact for ``isend``/``irecv`` symmetry. """ # Reset occupancy masks if self._dest_mask is not None: self._dest_mask.zero_() if self._buffer is None: return # Delegate buffer reset to Batch.zero() which properly handles # both UniformLevelStorage and SegmentedLevelStorage bookkeeping. self._buffer.zero() # Clear _num_segments / _num_elements_kept caches (not handled by Batch.zero) for group in self._buffer._storage.groups.values(): if hasattr(group, "_num_segments"): object.__delattr__(group, "_num_segments") object.__delattr__(group, "_num_elements_kept") def __len__(self) -> int: """Return the number of samples currently stored. Returns ------- int Number of atomic data samples in the buffer. """ if self._buffer is None: return 0 return self._buffer.num_graphs @property def capacity(self) -> int: """Return the maximum storage capacity. Returns ------- int Maximum number of samples that can be stored. """ return self._capacity @property def device(self) -> torch.device: """Return the storage device. Returns ------- torch.device Device where data is stored. """ return self._device
[docs] class HostMemory(DataSink): """ CPU-resident buffer for storing batched atomic data. This buffer ensures all data is stored on CPU memory, regardless of the input batch's device. It is useful for staging data before disk I/O or for CPU-side processing. Parameters ---------- capacity : int Maximum number of samples to store. Attributes ---------- capacity : int Maximum storage capacity. Examples -------- >>> host_buffer = HostMemory(capacity=1000) >>> host_buffer.write(gpu_batch) # Data moved to CPU >>> cpu_batch = host_buffer.read() """
[docs] def __init__(self, capacity: int) -> None: """ Initialize the host memory buffer. Parameters ---------- capacity : int Maximum number of samples to store. """ self._capacity = capacity self._data_list: list[AtomicData] = [] self._device = torch.device("cpu")
def write(self, batch: Batch, mask: torch.Tensor | None = None) -> None: """ Store a batch of atomic data on CPU. Decomposes the batch into individual AtomicData objects, moves them to CPU, and appends to internal storage. Parameters ---------- batch : Batch The batch of atomic data to store. mask : torch.Tensor | None, optional Boolean tensor of shape ``(batch.num_graphs,)`` indicating which samples to write (``True`` = write). If ``None``, all samples are written. Default is ``None``. Raises ------ RuntimeError If adding the selected samples would exceed capacity. ValueError If mask length does not match batch.num_graphs. """ num_total = batch.num_graphs or 0 if num_total == 0: return # Apply mask to select samples if mask is not None: mask = mask.to(device=batch.device, dtype=torch.bool) if mask.shape[0] != num_total: raise ValueError( f"mask length {mask.shape[0]} != num_graphs {num_total}" ) num_selected = int(mask.sum().item()) if num_selected == 0: return if num_selected < num_total: indices = torch.nonzero(mask, as_tuple=True)[0] _ = batch.ptr # trigger lazy init for SegmentedLevelStorage batch = batch.index_select(indices) data_list = batch.to_data_list() if len(self._data_list) + len(data_list) > self._capacity: raise RuntimeError( f"Buffer is full. Cannot add {len(data_list)} samples " f"to buffer with {len(self._data_list)}/{self._capacity} samples." ) # Move data to CPU before storing for data in data_list: self._data_list.append(data.to(self._device)) def read(self) -> Batch: """ Retrieve all stored data as a CPU-resident Batch. Returns ------- Batch A batch containing all stored atomic data on CPU. Raises ------ RuntimeError If the buffer is empty. """ if len(self._data_list) == 0: raise RuntimeError("Cannot read from empty buffer.") return Batch.from_data_list(self._data_list, device=self._device) def zero(self) -> None: """Clear all stored data and reset the buffer.""" self._data_list.clear() def __len__(self) -> int: """ Return the number of samples currently stored. Returns ------- int Number of atomic data samples in the buffer. """ return len(self._data_list) @property def capacity(self) -> int: """ Return the maximum storage capacity. Returns ------- int Maximum number of samples that can be stored. """ return self._capacity
[docs] class ZarrData(DataSink): """ Zarr-backed storage for batched atomic data. This sink persists atomic data using the Zarr format, supporting both local filesystem and remote/in-memory stores via ``StoreLike``. Delegates serialization to :class:`AtomicDataZarrWriter` for efficient, amortized I/O with CSR-style pointer arrays. Supports any zarr-compatible store: filesystem paths (str or Path), zarr Store instances (LocalStore, MemoryStore, FsspecStore for remote storage like S3/GCS), StorePath, or dict for in-memory buffers. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), zarr Store instance, StorePath, or dict for in-memory buffer storage. capacity : int, optional Maximum number of samples to store. Default is 1,000,000. Attributes ---------- capacity : int Maximum storage capacity. store : StoreLike The backing zarr store. Examples -------- >>> zarr_sink = ZarrData("/path/to/store", capacity=100000) >>> zarr_sink.write(batch) >>> loaded_batch = zarr_sink.read() Using an in-memory store: >>> zarr_sink = ZarrData({}, capacity=1000) # dict acts as memory store """
[docs] def __init__(self, store: StoreLike, capacity: int = 1_000_000) -> None: """ Initialize the Zarr data sink. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), zarr Store instance, StorePath, or dict for in-memory buffer storage. capacity : int, optional Maximum number of samples to store. Default is 1,000,000. """ self._store: StoreLike = store self._capacity = capacity self._count = 0 self._written_once = False # Lazily create writer — don't create store until first write self._writer: AtomicDataZarrWriter | None = None
def _get_writer(self) -> AtomicDataZarrWriter: """Get or create the AtomicDataZarrWriter instance. Returns ------- AtomicDataZarrWriter The writer instance for this sink. """ if self._writer is None: self._writer = AtomicDataZarrWriter(self._store) return self._writer def write(self, batch: Batch, mask: torch.Tensor | None = None) -> None: """ Store a batch of atomic data to Zarr. Uses :class:`AtomicDataZarrWriter` for efficient bulk writes. The first write uses ``write()`` (creates store), subsequent writes use ``append()`` (extends existing store). Parameters ---------- batch : Batch The batch of atomic data to store. mask : torch.Tensor | None, optional Boolean tensor of shape ``(batch.num_graphs,)`` indicating which samples to write (``True`` = write). If ``None``, all samples are written. Default is ``None``. Raises ------ RuntimeError If adding the selected samples would exceed capacity. ValueError If mask length does not match batch.num_graphs. """ num_total = batch.num_graphs or 0 if num_total == 0: return # Nothing to write # Apply mask to select samples if mask is not None: mask = mask.to(device=batch.device, dtype=torch.bool) if mask.shape[0] != num_total: raise ValueError( f"mask length {mask.shape[0]} != num_graphs {num_total}" ) num_selected = int(mask.sum().item()) if num_selected == 0: return if num_selected < num_total: indices = torch.nonzero(mask, as_tuple=True)[0] _ = batch.ptr # trigger lazy init for SegmentedLevelStorage batch = batch.index_select(indices) num_graphs = num_selected else: num_graphs = num_total if self._count + num_graphs > self._capacity: raise RuntimeError( f"Store is full. Cannot add {num_graphs} samples " f"to store with {self._count}/{self._capacity} samples." ) writer = self._get_writer() if not self._written_once: writer.write(batch) self._written_once = True else: writer.append(batch) self._count += num_graphs def read(self) -> Batch: """ Load all stored data from Zarr as a Batch. Delegates to :class:`AtomicDataZarrReader` for efficient reading of samples from the CSR-style layout created by :class:`AtomicDataZarrWriter`. Returns ------- Batch A batch containing all stored atomic data. Raises ------ RuntimeError If the store is empty. """ if self._count == 0: raise RuntimeError("Cannot read from empty store.") with AtomicDataZarrReader(self._store) as reader: # TODO: optimize this by adding index_select/slicing to amortize overhead data_list = [AtomicData(**reader[i][0]) for i in range(len(reader))] return Batch.from_data_list(data_list) def zero(self) -> None: """Clear all stored data and reset the store.""" # Handle different store types for cleanup if isinstance(self._store, (str, Path)): # Filesystem path — delete directory if it exists store_path = Path(self._store) if store_path.exists(): shutil.rmtree(store_path) elif isinstance(self._store, dict): # In-memory dict store — clear all contents self._store.clear() # For other Store types (LocalStore, MemoryStore, FsspecStore, etc.), # the writer will handle overwriting when opened in write mode. # Reset state self._writer = AtomicDataZarrWriter(self._store) self._count = 0 self._written_once = False def __len__(self) -> int: """ Return the number of samples currently stored. Returns ------- int Number of atomic data samples in the store. """ return self._count @property def capacity(self) -> int: """ Return the maximum storage capacity. Returns ------- int Maximum number of samples that can be stored. """ return self._capacity