Source code for nvalchemi.data.datapipes.backends.zarr

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Zarr backend for AtomicData (de)serialization.

This module provides the concrete implementation of ``AtomicData``
(de)serialization using high performance ``zarr`` array I/O.

The ``AtomicDataZarrWriter`` class is designed to allow for efficient,
amortized data writes with the ability to directly save/append ``Batch``
objects to disk.

The ``AtomicDataZarrReader`` provides a concrete ``Reader`` implementation
that reads in arrays from disk, and maps them to ``torch.Tensor``s that
are intended to composed with :class:`nvalchemi.data.datapipes.Dataset`.

To understand usage, users should refer to ``examples/data/datapipes/read_zarr_store.py``.
"""

from __future__ import annotations

import re
from pathlib import Path
from typing import Any, Literal, TypeAlias

import numpy as np
import torch
import zarr
from plum import dispatch, overload
from zarr.abc.store import Store
from zarr.storage import StorePath

# These need to be available at runtime for plum dispatch
from nvalchemi.data.atomic_data import AtomicData
from nvalchemi.data.batch import Batch
from nvalchemi.data.datapipes.backends.base import Reader

# Type alias for zarr store-like objects
StoreLike: TypeAlias = Store | StorePath | Path | str | dict[str, Any]

# TODO: make classes inherit from PNM when stable


def _get_field_level(key: str) -> str:
    """Return 'atom', 'edge', or 'system' for a core field key.

    Parameters
    ----------
    key : str
        Field name.

    Returns
    -------
    str
        One of 'atom', 'edge', or 'system'.
    """
    match key:
        case k if k in AtomicData.__node_keys__:
            return "atom"
        case k if k in AtomicData.__edge_keys__:
            return "edge"
        case k if k in AtomicData.__system_keys__:
            return "system"
        case _:
            # Default to atom level for unknown keys
            return "atom"


def _get_cat_dim(key: str) -> int:
    """Return concatenation dimension for a field.

    Returns -1 for keys containing 'index' or 'face', 0 otherwise.
    (Mirrors DataMixin.__cat_dim__)

    Parameters
    ----------
    key : str
        Field name.

    Returns
    -------
    int
        Concatenation dimension.
    """
    if bool(re.search("(index|face)", key)):
        return -1
    return 0


[docs] class AtomicDataZarrWriter: """Writer for serializing AtomicData into Zarr stores. Writes AtomicData objects into a structured Zarr store with CSR-style pointer arrays for variable-size graph data. Supports single writes, batch writes, appending, custom fields, soft-delete, and defragmentation. The Zarr store layout is: dataset.zarr/ ├── meta/ # Pointer arrays + masks │ ├── atoms_ptr # int64 [N+1] — cumulative node counts │ ├── edges_ptr # int64 [N+1] — cumulative edge counts │ ├── samples_mask # bool [N] — False = deleted sample │ ├── atoms_mask # bool [V_total] — False = deleted atom │ └── edges_mask # bool [E_total] — False = deleted edge ├── core/ # AtomicData fields (auto-populated) │ ├── atomic_numbers # int64 [V_total] │ ├── positions # float32 [V_total, 3] │ └── ... ├── custom/ # User-defined arrays (optional) │ └── <user_key> # any dtype, any shape └── .zattrs # root metadata Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. Attributes ---------- _store : StoreLike The zarr store used for I/O. """ def __init__(self, store: StoreLike) -> None: """Initialize the writer with a target store. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. """ self._store: StoreLike = store def _open(self, mode: Literal["r", "r+", "w", "w-", "a"] = "r") -> zarr.Group: """Open the zarr store with the given mode. Parameters ---------- mode : Literal["r", "r+", "w", "w-", "a"] Zarr access mode ('r', 'r+', 'w', 'w-', 'a'). Returns ------- zarr.Group The opened zarr group. """ return zarr.open(self._store, mode=mode) # type: ignore[return-value] def _store_exists(self) -> bool: """Check whether the store already contains data. For filesystem paths, checks if the path exists. For abstract stores (MemoryStore, FsspecStore, etc.), attempts to open in read mode and check for content. Returns ------- bool True if the store exists and contains data. """ if isinstance(self._store, (str, Path)): return Path(self._store).exists() # For abstract stores (MemoryStore, FsspecStore, etc.), # try opening read-only and check for content try: root = zarr.open(self._store, mode="r") # type: ignore[call-overload] # If we can list any members, the store has data return len(list(root.group_keys())) > 0 or len(list(root.array_keys())) > 0 except Exception: return False @overload def write(self, data: AtomicData) -> None: # noqa: F811 """Write a single AtomicData.""" self.write([data]) @overload def write(self, data: list[AtomicData]) -> None: # noqa: F811 """Write a list of AtomicData to a new Zarr store.""" self.write(Batch.from_data_list(data, device="cpu")) @overload def write(self, data: Batch) -> None: # noqa: F811 """Write a Batch to a new Zarr store. This is the efficient bulk-write path. Since a Batch already has all tensors concatenated (node/edge level) or stacked (system level), each field is written to zarr in a single I/O operation with no per-sample iteration. Parameters ---------- batch : Batch Batched atomic data to write. Raises ------ FileExistsError If store already exists. ValueError If batch is empty. """ if self._store_exists(): raise FileExistsError(f"Zarr store already exists at {self._store}") num_samples = data.num_graphs if num_samples is None or num_samples == 0: raise ValueError("No data provided to write.") root = self._open(mode="w") meta_group = root.create_group("meta") core_group = root.create_group("core") root.create_group("custom") # Build pointer arrays directly from batch metadata — no iteration nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long) # Handle case where num_edges_list is empty (no edges in data) if data.num_edges_list: edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long) else: # No edges: create zeros for each sample edges_tensor = torch.zeros(num_samples, dtype=torch.long) atoms_ptr = torch.cat( [torch.zeros(1, dtype=torch.long), torch.cumsum(nodes_tensor, dim=0)] ) edges_ptr = torch.cat( [torch.zeros(1, dtype=torch.long), torch.cumsum(edges_tensor, dim=0)] ) total_atoms = int(atoms_ptr[-1].item()) total_edges = int(edges_ptr[-1].item()) # Write meta arrays meta_group.create_array("atoms_ptr", data=self._to_numpy(atoms_ptr)) meta_group.create_array("edges_ptr", data=self._to_numpy(edges_ptr)) meta_group.create_array( "samples_mask", data=self._to_numpy(torch.ones(num_samples, dtype=torch.bool)), ) meta_group.create_array( "atoms_mask", data=self._to_numpy(torch.ones(total_atoms, dtype=torch.bool)) ) meta_group.create_array( "edges_mask", data=self._to_numpy(torch.ones(total_edges, dtype=torch.bool)) ) # Build field level mapping fields_metadata: dict[str, dict[str, str]] = {"core": {}, "custom": {}} # Collect all field keys from the batch's level categorization. # batch.keys is {"node": set, "edge": set, "system": set} — these are the # field names present in the batch, already categorized by level. excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"} all_field_keys: set[str] = set() level_map: dict[str, str] = {} # key -> "atom"/"edge"/"system" for level_name, key_set in (data.keys or {}).items(): for k in key_set: all_field_keys.add(k) # batch.keys uses "node"/"edge"/"system"; zarr format uses "atom"/"edge"/"system" level_map[k] = "atom" if level_name == "node" else level_name # Get all tensor attributes from the batch (Pydantic's to_dict works) batch_dict = data.to_dict() # Write each field: one zarr I/O per field, no per-sample loop for key in all_field_keys: val = batch_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue level = level_map.get(key, _get_field_level(key)) fields_metadata["core"][key] = level # The tensor is already in concatenated/stacked form from the Batch. # Batch stores edge_index as (num_edges, 2); zarr format is (2, num_edges). if key == "edge_index": val = val.transpose(0, 1) # System-level: Batch stacks (1, 3, 3) -> (N, 1, 3, 3); squeeze dim 1 so zarr has (N, 3, 3) if level == "system" and val.dim() > 2: while val.dim() > 2 and val.shape[1] == 1: val = val.squeeze(1) core_group.create_array(key, data=self._to_numpy(val)) root.attrs["num_samples"] = num_samples root.attrs["fields"] = fields_metadata
[docs] @dispatch def write(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811 """Write atomic data to a new Zarr store. Creates the Zarr store with core/, meta/, custom/ groups. Builds atoms_ptr, edges_ptr, and initializes all masks to True. If data is a Batch, calls to_data_list() first. If data is a single AtomicData, wraps in a list. Parameters ---------- data : AtomicData | list[AtomicData] | Batch Data to write. Raises ------ FileExistsError If store already exists. """ pass
@overload def append(self, data: AtomicData) -> None: # noqa: F811 """Append a single AtomicData to an existing Zarr store. While this dispatch is available for convenience, we recommend users to try and amortize I/O operations by packing multiple data to write, instead of one at a time. This can be achieved by passing either a ``Batch`` object, or a list of ``AtomicData`` which will automatically form a batch. Parameters ---------- data : AtomicData Single atomic data to append. Raises ------ FileNotFoundError If store does not exist. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] # Read existing pointer tails old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:]) old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:]) old_num_samples = int(root.attrs["num_samples"]) last_atom_ptr = int(old_atoms_ptr[-1].item()) last_edge_ptr = int(old_edges_ptr[-1].item()) # Get counts directly from the single AtomicData data_dict = data.to_dict() num_atoms = int(data.num_nodes) # Determine num_edges from edge_index if present edge_index = data_dict.get("edge_index") if edge_index is not None and isinstance(edge_index, torch.Tensor): num_edges = edge_index.shape[-1] else: num_edges = 0 # Extend pointer arrays with single new entries new_atom_ptr = torch.tensor([last_atom_ptr + num_atoms], dtype=torch.long) new_edge_ptr = torch.tensor([last_edge_ptr + num_edges], dtype=torch.long) self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atom_ptr)) self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edge_ptr)) # Extend masks (single sample, its atoms, its edges) self._extend_array( meta_group["samples_mask"], self._to_numpy(torch.ones(1, dtype=torch.bool)), ) self._extend_array( meta_group["atoms_mask"], self._to_numpy(torch.ones(num_atoms, dtype=torch.bool)), ) self._extend_array( meta_group["edges_mask"], self._to_numpy(torch.ones(num_edges, dtype=torch.bool)), ) # Extend each existing core field excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"} for key in core_group.keys(): val = data_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue # System-level fields need unsqueeze(0) to add the sample dimension level = _get_field_level(key) if level == "system": val = val.unsqueeze(0) if val.dim() == 0 else val cat_dim = _get_cat_dim(key) self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim) root.attrs["num_samples"] = old_num_samples + 1 @overload def append(self, data: list[AtomicData]) -> None: # noqa: F811 """Append a list of AtomicData to an existing Zarr store.""" device = data[0].device self.append(Batch.from_data_list(data, device)) @overload def append(self, data: Batch) -> None: # noqa: F811 """Append a Batch to an existing Zarr store. This is the efficient bulk-append path. Since a Batch already has all tensors concatenated (node/edge level) or stacked (system level), each field is extended in a single I/O operation with no per-sample iteration. Parameters ---------- data : Batch Batched atomic data to append. Raises ------ FileNotFoundError If store does not exist. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") num_samples = data.num_graphs if num_samples is None or num_samples == 0: return root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] # Read existing pointer tails old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:]) old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:]) old_num_samples = int(root.attrs["num_samples"]) last_atom_ptr = int(old_atoms_ptr[-1].item()) last_edge_ptr = int(old_edges_ptr[-1].item()) # Compute new pointer entries from batch metadata nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long) # Handle case where num_edges_list is empty (no edges in data) if data.num_edges_list: edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long) else: # No edges: create zeros for each sample edges_tensor = torch.zeros(num_samples, dtype=torch.long) new_atoms_ptr = last_atom_ptr + torch.cumsum(nodes_tensor, dim=0) new_edges_ptr = last_edge_ptr + torch.cumsum(edges_tensor, dim=0) new_total_atoms = int(new_atoms_ptr[-1].item()) new_total_edges = int(new_edges_ptr[-1].item()) # Extend pointer arrays self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atoms_ptr)) self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edges_ptr)) # Extend masks self._extend_array( meta_group["samples_mask"], self._to_numpy(torch.ones(num_samples, dtype=torch.bool)), ) self._extend_array( meta_group["atoms_mask"], self._to_numpy( torch.ones(new_total_atoms - last_atom_ptr, dtype=torch.bool) ), ) self._extend_array( meta_group["edges_mask"], self._to_numpy( torch.ones(new_total_edges - last_edge_ptr, dtype=torch.bool) ), ) # Get all tensor attributes from the batch (Pydantic's to_dict works) batch_dict = data.to_dict() # Extend each field — single I/O per field excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"} for key in core_group.keys(): val = batch_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue # Batch stores edge_index as (num_edges, 2); zarr format is (2, num_edges) if key == "edge_index": val = val.transpose(0, 1) level = _get_field_level(key) if level == "system" and val.dim() > 2: while val.dim() > 2 and val.shape[1] == 1: val = val.squeeze(1) cat_dim = _get_cat_dim(key) self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim) root.attrs["num_samples"] = old_num_samples + num_samples
[docs] @dispatch def append(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811 """Append data to an existing Zarr store. Extends all arrays along concatenation axis. Extends pointer arrays and masks. Updates num_samples in .zattrs. Parameters ---------- data : AtomicData | list[AtomicData] | Batch Data to append. Raises ------ FileNotFoundError If store does not exist. """ pass
[docs] def add_custom( self, key: str, data: torch.Tensor, level: Literal["atom", "edge", "system"] ) -> None: """Add a custom array to the custom/ group. Parameters ---------- key : str Name for the custom array. data : torch.Tensor Tensor data. First dimension must match: - num_samples for "system" level - total atoms for "atom" level - total edges for "edge" level level : str One of "atom", "edge", "system". Raises ------ ValueError If level is invalid or data shape doesn't match. FileNotFoundError If store does not exist. """ if level not in ("atom", "edge", "system"): raise ValueError( f"Invalid level '{level}'. Must be 'atom', 'edge', or 'system'." ) if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r+") meta_group = root["meta"] custom_group = root["custom"] # Validate shape num_samples = int(root.attrs["num_samples"]) atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] total_atoms = int(atoms_ptr[-1]) total_edges = int(edges_ptr[-1]) expected_size = { "system": num_samples, "atom": total_atoms, "edge": total_edges, }[level] if data.shape[0] != expected_size: raise ValueError( f"Data shape[0]={data.shape[0]} does not match expected " f"size={expected_size} for level='{level}'." ) # Write to custom group (convert to numpy at zarr boundary) custom_group.create_array(key, data=self._to_numpy(data)) # Update fields metadata fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) if "custom" not in fields_metadata: fields_metadata["custom"] = {} fields_metadata["custom"][key] = level root.attrs["fields"] = fields_metadata
[docs] def delete(self, indices: list[int] | torch.Tensor) -> None: """Soft-delete samples by index. Sets masks to False and zeros out data slices in core/ and custom/. Pointer arrays are NOT modified. Parameters ---------- indices : list[int] | torch.Tensor Sample indices to delete. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") # Convert to torch tensor for consistent handling if isinstance(indices, list): indices_tensor = torch.as_tensor(indices, dtype=torch.long) else: indices_tensor = indices.to(torch.long) if len(indices_tensor) == 0: return root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] samples_mask = meta_group["samples_mask"][:] atoms_mask = meta_group["atoms_mask"][:] edges_mask = meta_group["edges_mask"][:] fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) for idx in indices_tensor: idx = int(idx) # Mark sample as deleted samples_mask[idx] = False # Get slice ranges atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1]) edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1]) # Zero out atoms_mask and edges_mask atoms_mask[atom_start:atom_end] = False edges_mask[edge_start:edge_end] = False # Zero out core fields for key in core_group.keys(): level = fields_metadata.get("core", {}).get(key, _get_field_level(key)) arr = core_group[key] if level == "atom": self._zero_slice(arr, atom_start, atom_end, axis=0) elif level == "edge": cat_dim = _get_cat_dim(key) self._zero_slice(arr, edge_start, edge_end, axis=cat_dim) elif level == "system": self._zero_slice(arr, idx, idx + 1, axis=0) # Zero out custom fields if "custom" in root: custom_group = root["custom"] for key in custom_group.keys(): level = fields_metadata.get("custom", {}).get(key, "system") arr = custom_group[key] if level == "atom": self._zero_slice(arr, atom_start, atom_end, axis=0) elif level == "edge": self._zero_slice(arr, edge_start, edge_end, axis=0) elif level == "system": self._zero_slice(arr, idx, idx + 1, axis=0) # Write back masks meta_group["samples_mask"][:] = samples_mask meta_group["atoms_mask"][:] = atoms_mask meta_group["edges_mask"][:] = edges_mask
[docs] def defragment(self) -> None: """Rewrite store excluding deleted samples. Rebuilds all arrays, pointer arrays, and resets all masks to True. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r") meta_group = root["meta"] core_group = root["core"] atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] samples_mask = meta_group["samples_mask"][:] fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) # Find active sample indices active_indices = np.where(samples_mask)[0] if len(active_indices) == 0: # All samples deleted, just reset to empty # Clear store by opening with mode="w" (overwrite) new_root = self._open(mode="w") new_root.create_group("meta") new_root.create_group("core") new_root.create_group("custom") new_root.attrs["num_samples"] = 0 new_root.attrs["fields"] = {"core": {}, "custom": {}} return # Collect active data for each field new_core_data: dict[str, list[np.ndarray]] = { key: [] for key in core_group.keys() } new_custom_data: dict[str, list[np.ndarray]] = {} if "custom" in root: custom_group = root["custom"] new_custom_data = {key: [] for key in custom_group.keys()} new_num_nodes: list[int] = [] new_num_edges: list[int] = [] # Pre-read all arrays once to avoid re-reading per sample core_arrays = {key: core_group[key][:] for key in core_group.keys()} custom_arrays: dict[str, np.ndarray] = {} if "custom" in root: custom_arrays = { key: root["custom"][key][:] for key in root["custom"].keys() } for idx in active_indices: idx = int(idx) atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1]) edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1]) new_num_nodes.append(atom_end - atom_start) new_num_edges.append(edge_end - edge_start) for key in core_group.keys(): level = fields_metadata.get("core", {}).get(key, _get_field_level(key)) arr = core_arrays[key] if level == "atom": new_core_data[key].append(arr[atom_start:atom_end]) elif level == "edge": cat_dim = _get_cat_dim(key) if cat_dim == -1: # edge_index: shape [2, E] new_core_data[key].append(arr[:, edge_start:edge_end]) else: new_core_data[key].append(arr[edge_start:edge_end]) elif level == "system": # System level: index by sample new_core_data[key].append(arr[idx : idx + 1]) if custom_arrays: for key in custom_arrays: level = fields_metadata.get("custom", {}).get(key, "system") arr = custom_arrays[key] if level == "atom": new_custom_data[key].append(arr[atom_start:atom_end]) elif level == "edge": new_custom_data[key].append(arr[edge_start:edge_end]) elif level == "system": new_custom_data[key].append(arr[idx : idx + 1]) # Clear store and create new structure (mode="w" clears existing data) new_root = self._open(mode="w") new_meta = new_root.create_group("meta") new_core = new_root.create_group("core") new_custom = new_root.create_group("custom") # Build new pointer arrays new_atoms_ptr = np.array([0] + list(np.cumsum(new_num_nodes)), dtype=np.int64) new_edges_ptr = np.array([0] + list(np.cumsum(new_num_edges)), dtype=np.int64) new_total_atoms = int(new_atoms_ptr[-1]) new_total_edges = int(new_edges_ptr[-1]) new_num_samples = len(active_indices) new_meta.create_array("atoms_ptr", data=new_atoms_ptr) new_meta.create_array("edges_ptr", data=new_edges_ptr) new_meta.create_array( "samples_mask", data=np.ones(new_num_samples, dtype=np.bool_) ) new_meta.create_array( "atoms_mask", data=np.ones(new_total_atoms, dtype=np.bool_) ) new_meta.create_array( "edges_mask", data=np.ones(new_total_edges, dtype=np.bool_) ) # Concatenate and write core arrays for key, arrays in new_core_data.items(): if arrays: cat_dim = _get_cat_dim(key) concatenated = np.concatenate(arrays, axis=cat_dim) new_core.create_array(key, data=concatenated) # Concatenate and write custom arrays for key, arrays in new_custom_data.items(): if arrays: concatenated = np.concatenate(arrays, axis=0) new_custom.create_array(key, data=concatenated) # Update metadata new_root.attrs["num_samples"] = new_num_samples new_root.attrs["fields"] = fields_metadata
@staticmethod def _to_numpy(tensor: torch.Tensor) -> np.ndarray: """Convert a torch tensor to numpy for zarr I/O. Parameters ---------- tensor : torch.Tensor Tensor to convert. Returns ------- np.ndarray Numpy array for zarr storage. """ return tensor.detach().cpu().numpy() @staticmethod def _extend_array(arr: zarr.Array, data: np.ndarray, axis: int = 0) -> None: """Extend a zarr array along an axis. Parameters ---------- arr : zarr.Array Zarr array to extend. data : np.ndarray Data to append. axis : int Axis along which to extend. """ old_shape = arr.shape new_len = data.shape[axis] # Build new shape new_shape = list(old_shape) new_shape[axis] = old_shape[axis] + new_len # Resize array arr.resize(tuple(new_shape)) # Write new data slices: list[slice | int] = [slice(None)] * len(old_shape) slices[axis] = slice(old_shape[axis], new_shape[axis]) arr[tuple(slices)] = data @staticmethod def _zero_slice(arr: zarr.Array, start: int, end: int, axis: int = 0) -> None: """Zero out a slice of a zarr array. Parameters ---------- arr : zarr.Array Zarr array to modify. start : int Start index. end : int End index. axis : int Axis along which to slice. """ if start >= end: return slices: list[slice | int] = [slice(None)] * len(arr.shape) slices[axis] = slice(start, end) # Create zeros with correct shape shape = list(arr.shape) shape[axis] = end - start zeros = np.zeros(shape, dtype=arr.dtype) arr[tuple(slices)] = zeros
[docs] class AtomicDataZarrReader(Reader): """Reader for loading AtomicData from Zarr stores. This reader provides random-access loading of AtomicData samples from Zarr stores created by :class:`AtomicDataZarrWriter`. It supports soft-deleted samples via the samples_mask and provides efficient random access using pointer arrays. The Zarr store layout expected is: dataset.zarr/ ├── meta/ # Pointer arrays + masks │ ├── atoms_ptr # int64 [N+1] — cumulative node counts │ ├── edges_ptr # int64 [N+1] — cumulative edge counts │ └── samples_mask # bool [N] — False = deleted sample ├── core/ # AtomicData fields │ ├── atomic_numbers # int64 [V_total] │ ├── positions # float32 [V_total, 3] │ └── ... └── custom/ # User-defined arrays (optional) Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. pin_memory : bool, default=False If True, place tensors in pinned (page-locked) memory for faster async CPU→GPU transfers. include_index_in_metadata : bool, default=True If True, include sample index in the metadata dict. Attributes ---------- _store : StoreLike The underlying zarr store reference. Examples -------- >>> from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader # doctest: +SKIP >>> reader = AtomicDataZarrReader(store="dataset.zarr") # doctest: +SKIP >>> data_dict, metadata = reader[0] # returns dict and metadata # doctest: +SKIP >>> atomic_data = AtomicDataZarrReader.to_atomic_data(data_dict) # doctest: +SKIP """ def __init__( self, store: StoreLike, *, pin_memory: bool = False, include_index_in_metadata: bool = True, ) -> None: """Initialize the reader with a Zarr store. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. pin_memory : bool, default=False If True, place tensors in pinned (page-locked) memory. include_index_in_metadata : bool, default=True If True, include sample index in the metadata dict. Raises ------ FileNotFoundError If the Zarr store does not exist (for filesystem paths). ValueError If the store is missing required groups (meta, core). """ super().__init__( pin_memory=pin_memory, include_index_in_metadata=include_index_in_metadata, ) self._store: StoreLike = store # For filesystem paths, provide a friendly existence check if isinstance(store, (str, Path)) and not Path(store).exists(): raise FileNotFoundError(f"Zarr store does not exist at {store}") # Open the Zarr store in read mode self._root = zarr.open(self._store, mode="r") # Validate store structure if "meta" not in self._root: raise ValueError(f"Zarr store at {self._store} is missing 'meta' group") if "core" not in self._root: raise ValueError(f"Zarr store at {self._store} is missing 'core' group") # Load cached state from the store self.refresh()
[docs] def refresh(self) -> None: """Reload cached pointer arrays, masks, and metadata from the store. Call this method after external modifications to the Zarr store (e.g., appending or deleting samples via :class:`AtomicDataZarrWriter`) to ensure the reader reflects the current state of the data. Raises ------ RuntimeError If the reader has been closed. """ if self._root is None: raise RuntimeError("Cannot refresh a closed reader.") # Re-open the store to pick up structural changes self._root = zarr.open(self._store, mode="r") # Cache pointer arrays as torch tensors self._atoms_ptr = torch.from_numpy(self._root["meta"]["atoms_ptr"][:]).to( torch.long ) self._edges_ptr = torch.from_numpy(self._root["meta"]["edges_ptr"][:]).to( torch.long ) # Cache samples mask self._samples_mask = torch.from_numpy(self._root["meta"]["samples_mask"][:]).to( torch.bool ) # Build logical->physical index mapping (indices where mask is True) self._active_indices = torch.where(self._samples_mask)[0] # Cache fields metadata self._fields_metadata: dict[str, dict[str, str]] = dict( self._root.attrs.get("fields", {"core": {}, "custom": {}}) )
def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load raw data for a single sample. Parameters ---------- index : int Logical sample index (0 to len-1), accounting for deleted samples. Returns ------- dict[str, torch.Tensor] Dictionary mapping field names to CPU tensors. Raises ------ IndexError If index is out of range. """ # Map logical index to physical index physical_idx = int(self._active_indices[index].item()) # Get slice ranges from pointer arrays atom_start = int(self._atoms_ptr[physical_idx].item()) atom_end = int(self._atoms_ptr[physical_idx + 1].item()) edge_start = int(self._edges_ptr[physical_idx].item()) edge_end = int(self._edges_ptr[physical_idx + 1].item()) data: dict[str, torch.Tensor] = {} # Load core fields core_group = self._root["core"] for key in core_group.array_keys(): level = self._fields_metadata.get("core", {}).get( key, _get_field_level(key) ) arr = core_group[key] if level == "atom": data[key] = torch.from_numpy(arr[atom_start:atom_end]) elif level == "edge": cat_dim = _get_cat_dim(key) if cat_dim == -1: # Shape is [..., E], slice on last dim tensor = torch.from_numpy(arr[:, edge_start:edge_end]) else: # Shape is [E, ...], slice on first dim tensor = torch.from_numpy(arr[edge_start:edge_end]) # edge_index needs to be converted from global to local indices # by subtracting the atom offset for this sample if key == "edge_index": tensor = tensor - atom_start data[key] = tensor else: # system level # Keep batch dim for system-level fields data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) # Load custom fields if present if "custom" in self._root: custom_group = self._root["custom"] for key in custom_group.array_keys(): level = self._fields_metadata.get("custom", {}).get(key, "system") arr = custom_group[key] if level == "atom": data[key] = torch.from_numpy(arr[atom_start:atom_end]) elif level == "edge": cat_dim = _get_cat_dim(key) if cat_dim == -1: data[key] = torch.from_numpy(arr[:, edge_start:edge_end]) else: data[key] = torch.from_numpy(arr[edge_start:edge_end]) else: # system level data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) return data def __len__(self) -> int: """Return the number of active (non-deleted) samples. Returns ------- int Number of samples available for reading. """ return len(self._active_indices) def _get_sample_metadata(self, index: int) -> dict[str, str]: """Return metadata for a sample. Parameters ---------- index : int Logical sample index. Returns ------- dict[str, str] Dictionary containing source file information. """ physical_idx = int(self._active_indices[index].item()) return { "source_file": str(self._store), "physical_index": str(physical_idx), }
[docs] def close(self) -> None: """Release the Zarr store reference and clean up resources.""" self._root = None super().close()