Source code for nvalchemi.data.datapipes.backends.zarr

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Zarr backend for AtomicData (de)serialization.

This module provides the concrete implementation of ``AtomicData``
(de)serialization using high performance ``zarr`` array I/O.

The ``AtomicDataZarrWriter`` class is designed to allow for efficient,
amortized data writes with the ability to directly save/append ``Batch``
objects to disk.

The ``AtomicDataZarrReader`` provides a concrete ``Reader`` implementation
that reads in arrays from disk, and maps them to ``torch.Tensor``s that
are intended to composed with :class:`nvalchemi.data.datapipes.Dataset`.

To understand usage, users should refer to ``examples/data/datapipes/read_zarr_store.py``.
"""

from __future__ import annotations

import re
from collections.abc import Mapping
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias

import numpy as np
import torch
import zarr
import zarr.abc.codec
from plum import dispatch, overload
from pydantic import BaseModel, ConfigDict, Field, model_validator
from zarr.abc.store import Store
from zarr.storage import StorePath

# These need to be available at runtime for plum dispatch
from nvalchemi.data.atomic_data import AtomicData
from nvalchemi.data.batch import Batch
from nvalchemi.data.datapipes.backends.base import Reader

# Type alias for zarr store-like objects
StoreLike: TypeAlias = Store | StorePath | Path | str | dict[str, Any]

# TODO: make classes inherit from PNM when stable


[docs] class ZarrArrayConfig(BaseModel): """Configuration for Zarr array compression, chunking, and sharding. Parameters ---------- compressors : tuple[zarr.abc.codec.Codec, ...] | None Compressor codec(s) to apply. E.g. ``(zarr.codecs.ZstdCodec(level=3),)``. filters : tuple[zarr.abc.codec.Codec, ...] | None Array-to-array filter codec(s). E.g. ``(zarr.codecs.TransposeCodec(order=(1, 0)),)``. serializer : zarr.abc.codec.Codec | None Bytes serializer codec. E.g. ``zarr.codecs.BytesCodec(endian="little")``. chunk_size : int | None Chunk length along dimension 0. Other dimensions use their full extent. ``None`` uses Zarr defaults. shard_size : int | None Shard length along dimension 0. When set, multiple chunks are stored in a single storage object. Must be a multiple of ``chunk_size`` when both are specified. ``None`` disables sharding. write_empty_chunks : bool Whether to write chunks that are entirely fill-valued. Default ``True``. """ compressors: Annotated[ tuple[zarr.abc.codec.Codec, ...] | None, Field(description="Compressor codec(s) to apply."), ] = None filters: Annotated[ tuple[zarr.abc.codec.Codec, ...] | None, Field(description="Array-to-array filter codec(s)."), ] = None serializer: Annotated[ zarr.abc.codec.Codec | None, Field(description="Bytes serializer codec."), ] = None chunk_size: Annotated[ int | None, Field( description="Chunk length along dimension 0. Other dims use full extent." ), ] = None shard_size: Annotated[ int | None, Field( description=( "Shard length along dimension 0. " "When set, multiple chunks are stored in a single storage object. " "Must be a multiple of chunk_size when both are specified." ), ), ] = None write_empty_chunks: Annotated[ bool, Field(description="Whether to write chunks that are entirely fill-valued."), ] = True model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="after") def _validate_shard_chunk_alignment(self) -> ZarrArrayConfig: """Validate that shard_size is a multiple of chunk_size.""" if self.shard_size is not None and self.chunk_size is not None: if self.shard_size % self.chunk_size != 0: msg = ( f"shard_size ({self.shard_size}) must be a multiple of " f"chunk_size ({self.chunk_size})" ) raise ValueError(msg) return self
[docs] class ZarrWriteConfig(BaseModel): """Top-level write configuration for ``AtomicDataZarrWriter``. Provides per-group defaults and optional per-field overrides. Parameters ---------- meta : ZarrArrayConfig Config for metadata arrays (pointers, masks). Usually no compression. core : ZarrArrayConfig Config for core data arrays (positions, energy, etc.). custom : ZarrArrayConfig Config for user-added custom arrays. field_overrides : dict[str, ZarrArrayConfig] Per-field overrides. Keys are field names (e.g. ``"positions"``). Takes precedence over group-level config. Examples -------- >>> from zarr.codecs import ZstdCodec, BloscCodec >>> config = ZarrWriteConfig( ... core=ZarrArrayConfig(compressors=(ZstdCodec(level=3),), chunk_size=1024), ... field_overrides={ ... "positions": ZarrArrayConfig(compressors=(BloscCodec(cname="lz4"),)) ... }, ... ) """ meta: Annotated[ ZarrArrayConfig, Field( default_factory=ZarrArrayConfig, description="Config for metadata arrays (pointers, masks).", ), ] core: Annotated[ ZarrArrayConfig, Field( default_factory=ZarrArrayConfig, description="Config for core data arrays (positions, energy, etc.).", ), ] custom: Annotated[ ZarrArrayConfig, Field( default_factory=ZarrArrayConfig, description="Config for user-added custom arrays.", ), ] field_overrides: Annotated[ dict[str, ZarrArrayConfig], Field( default_factory=dict, description="Per-field overrides. Takes precedence over group-level config.", ), ] model_config = ConfigDict(arbitrary_types_allowed=True)
def _get_field_level(key: str) -> str: """Return 'atom', 'edge', or 'system' for a core field key. Parameters ---------- key : str Field name. Returns ------- str One of 'atom', 'edge', or 'system'. """ match key: case k if k in AtomicData._default_node_keys: return "atom" case k if k in AtomicData._default_edge_keys: return "edge" case k if k in AtomicData._default_system_keys: return "system" case _: # Default to atom level for unknown keys return "atom" # NOTE: the generic *index*/*face* regex fallback returning -1 is local to # the Zarr backend. No current AtomicData edge field reaches it, and the Zarr # read paths (_slice_edge_array) reject cat_dim != 0 with a RuntimeError. def _get_cat_dim(key: str) -> int: """Return concatenation dimension for a field. Parameters ---------- key : str Field name. Returns ------- int Concatenation dimension. """ if key == "neighbor_list": return 0 if bool(re.search("(index|face)", key)): return -1 return 0 def _slice_edge_array(arr: Any, key: str, edge_start: int, edge_end: int) -> Any: """Slice an edge-level array on dim 0, rejecting non-zero cat dims. Parameters ---------- arr : Any Numpy array or zarr array to slice. key : str Field name (used for error messages and cat_dim lookup). edge_start : int Start index along the edge dimension. edge_end : int End index along the edge dimension. Returns ------- Any Sliced array ``arr[edge_start:edge_end]``. Raises ------ RuntimeError If ``_get_cat_dim(key)`` returns anything other than 0. """ cat_dim = _get_cat_dim(key) if cat_dim != 0: raise RuntimeError( f"Unexpected cat_dim={cat_dim} for edge field '{key}'. " "All edge fields should use (E, ...) layout with cat_dim=0." ) return arr[edge_start:edge_end]
[docs] class AtomicDataZarrWriter: """Writer for serializing AtomicData into Zarr stores. Writes AtomicData objects into a structured Zarr store with CSR-style pointer arrays for variable-size graph data. Supports single writes, batch writes, appending, custom fields, soft-delete, and defragmentation. The Zarr store layout is: dataset.zarr/ ├── meta/ # Pointer arrays + masks │ ├── atoms_ptr # int64 [N+1] — cumulative node counts │ ├── edges_ptr # int64 [N+1] — cumulative edge counts │ ├── samples_mask # bool [N] — False = deleted sample │ ├── atoms_mask # bool [V_total] — False = deleted atom │ └── edges_mask # bool [E_total] — False = deleted edge ├── core/ # AtomicData fields (auto-populated) │ ├── atomic_numbers # int64 [V_total] │ ├── positions # float32 [V_total, 3] │ └── ... ├── custom/ # User-defined arrays (optional) │ └── <user_key> # any dtype, any shape └── .zattrs # root metadata Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. config : ZarrWriteConfig | Mapping[str, Any] | None Compression/chunking configuration. Can be a ``ZarrWriteConfig`` instance or a dict that will be converted to one. Default is ``None`` (use Zarr defaults). Attributes ---------- _store : StoreLike The zarr store used for I/O. _config : ZarrWriteConfig The write configuration for compression and chunking. """ def __init__( self, store: StoreLike, config: ZarrWriteConfig | Mapping[str, Any] | None = None, ) -> None: """Initialize the writer with a target store. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. config : ZarrWriteConfig | Mapping[str, Any] | None Compression/chunking configuration. Can be a ``ZarrWriteConfig`` instance or a dict that will be converted to one. Default is ``None`` (use Zarr defaults). """ self._store: StoreLike = store if isinstance(config, Mapping): config = ZarrWriteConfig.model_validate(config) if config is None: config = ZarrWriteConfig() self._config = config def _open(self, mode: Literal["r", "r+", "w", "w-", "a"] = "r") -> zarr.Group: """Open the zarr store with the given mode. Parameters ---------- mode : Literal["r", "r+", "w", "w-", "a"] Zarr access mode ('r', 'r+', 'w', 'w-', 'a'). Returns ------- zarr.Group The opened zarr group. """ return zarr.open(self._store, mode=mode) # type: ignore[return-value] def _store_exists(self) -> bool: """Check whether the store already contains data. For filesystem paths, checks if the path exists. For abstract stores (MemoryStore, FsspecStore, etc.), attempts to open in read mode and check for content. Returns ------- bool True if the store exists and contains data. """ if isinstance(self._store, (str, Path)): return Path(self._store).exists() # For abstract stores (MemoryStore, FsspecStore, etc.), # try opening read-only and check for content try: root = zarr.open(self._store, mode="r") # type: ignore[call-overload] # If we can list any members, the store has data return len(list(root.group_keys())) > 0 or len(list(root.array_keys())) > 0 except Exception: return False def _resolve_array_kwargs( self, key: str, group: str, data: np.ndarray, *, cat_dim: int = 0 ) -> dict[str, Any]: """Resolve compression/chunking kwargs for a ``create_array`` call. Parameters ---------- key : str Array name (e.g. ``"positions"``, ``"atoms_ptr"``). group : str Group name: ``"meta"``, ``"core"``, or ``"custom"``. data : np.ndarray The data to be written (used to determine chunk shape). cat_dim : int, optional The concatenation axis (variable-length dimension) for chunking. Defaults to 0. For ``neighbor_list`` (stored as ``[E, 2]``), use 0. Returns ------- dict[str, Any] Keyword arguments to pass to ``zarr.Group.create_array``. """ base_cfg: ZarrArrayConfig = getattr(self._config, group) cfg = self._config.field_overrides.get(key, base_cfg) kwargs: dict[str, Any] = {} if cfg.compressors is not None: kwargs["compressors"] = cfg.compressors if cfg.filters is not None: kwargs["filters"] = cfg.filters if cfg.serializer is not None: kwargs["serializer"] = cfg.serializer if cfg.chunk_size is not None: shape = list(data.shape) shape[cat_dim] = cfg.chunk_size kwargs["chunks"] = tuple(shape) if cfg.shard_size is not None: shape = list(data.shape) shape[cat_dim] = cfg.shard_size kwargs["shards"] = tuple(shape) if not cfg.write_empty_chunks: kwargs["config"] = {"write_empty_chunks": False} return kwargs @overload def write(self, data: AtomicData) -> None: # noqa: F811 """Write a single AtomicData.""" self.write([data]) @overload def write(self, data: list[AtomicData]) -> None: # noqa: F811 """Write a list of AtomicData to a new Zarr store.""" self.write(Batch.from_data_list(data, device="cpu")) @overload def write(self, data: Batch) -> None: # noqa: F811 """Write a Batch to a new Zarr store. This is the efficient bulk-write path. Since a Batch already has all tensors concatenated (node/edge level) or stacked (system level), each field is written to zarr in a single I/O operation with no per-sample iteration. Parameters ---------- batch : Batch Batched atomic data to write. Raises ------ FileExistsError If store already exists. ValueError If batch is empty. """ if self._store_exists(): raise FileExistsError(f"Zarr store already exists at {self._store}") num_samples = data.num_graphs if num_samples is None or num_samples == 0: raise ValueError("No data provided to write.") root = self._open(mode="w") meta_group = root.create_group("meta") core_group = root.create_group("core") root.create_group("custom") # Build pointer arrays directly from batch metadata — no iteration nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long) # Handle case where num_edges_list is empty (no edges in data) if data.num_edges_list: edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long) else: # No edges: create zeros for each sample edges_tensor = torch.zeros(num_samples, dtype=torch.long) atoms_ptr = torch.cat( [torch.zeros(1, dtype=torch.long), torch.cumsum(nodes_tensor, dim=0)] ) edges_ptr = torch.cat( [torch.zeros(1, dtype=torch.long), torch.cumsum(edges_tensor, dim=0)] ) total_atoms = int(atoms_ptr[-1].item()) total_edges = int(edges_ptr[-1].item()) # Write meta arrays atoms_ptr_np = self._to_numpy(atoms_ptr) edges_ptr_np = self._to_numpy(edges_ptr) samples_mask_np = np.ones(num_samples, dtype=bool) atoms_mask_np = np.ones(total_atoms, dtype=bool) edges_mask_np = np.ones(total_edges, dtype=bool) meta_group.create_array( "atoms_ptr", data=atoms_ptr_np, **self._resolve_array_kwargs("atoms_ptr", "meta", atoms_ptr_np), ) meta_group.create_array( "edges_ptr", data=edges_ptr_np, **self._resolve_array_kwargs("edges_ptr", "meta", edges_ptr_np), ) meta_group.create_array( "samples_mask", data=samples_mask_np, **self._resolve_array_kwargs("samples_mask", "meta", samples_mask_np), ) meta_group.create_array( "atoms_mask", data=atoms_mask_np, **self._resolve_array_kwargs("atoms_mask", "meta", atoms_mask_np), ) meta_group.create_array( "edges_mask", data=edges_mask_np, **self._resolve_array_kwargs("edges_mask", "meta", edges_mask_np), ) # Build field level mapping fields_metadata: dict[str, dict[str, str]] = {"core": {}, "custom": {}} # Collect all field keys from the batch's level categorization. # batch.keys is {"node": set, "edge": set, "system": set} — these are the # field names present in the batch, already categorized by level. excluded = {"batch_idx", "batch_ptr", "device", "dtype", "info", "num_nodes"} all_field_keys: set[str] = set() level_map: dict[str, str] = {} # key -> "atom"/"edge"/"system" for level_name, key_set in (data.keys or {}).items(): for k in key_set: all_field_keys.add(k) # batch.keys uses "node"/"edge"/"system"; zarr format uses "atom"/"edge"/"system" level_map[k] = "atom" if level_name == "node" else level_name # Get all tensor attributes from the batch (Pydantic's to_dict works) batch_dict = data.to_dict() # Write each field: one zarr I/O per field, no per-sample loop for key in all_field_keys: val = batch_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue level = level_map.get(key, _get_field_level(key)) fields_metadata["core"][key] = level # System-level: Batch stacks (1, 3, 3) -> (N, 1, 3, 3); squeeze dim 1 so zarr has (N, 3, 3) if level == "system" and val.dim() > 2: while val.dim() > 2 and val.shape[1] == 1: val = val.squeeze(1) np_val = self._to_numpy(val) cat_dim = _get_cat_dim(key) if cat_dim < 0: cat_dim += np_val.ndim core_group.create_array( key, data=np_val, **self._resolve_array_kwargs(key, "core", np_val, cat_dim=cat_dim), ) root.attrs["num_samples"] = num_samples root.attrs["fields"] = fields_metadata
[docs] @dispatch def write(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811 """Write atomic data to a new Zarr store. Creates the Zarr store with core/, meta/, custom/ groups. Builds atoms_ptr, edges_ptr, and initializes all masks to True. If data is a Batch, calls to_data_list() first. If data is a single AtomicData, wraps in a list. Parameters ---------- data : AtomicData | list[AtomicData] | Batch Data to write. Raises ------ FileExistsError If store already exists. """ pass
@overload def append(self, data: AtomicData) -> None: # noqa: F811 """Append a single AtomicData to an existing Zarr store. While this dispatch is available for convenience, we recommend users to try and amortize I/O operations by packing multiple data to write, instead of one at a time. This can be achieved by passing either a ``Batch`` object, or a list of ``AtomicData`` which will automatically form a batch. Parameters ---------- data : AtomicData Single atomic data to append. Raises ------ FileNotFoundError If store does not exist. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] # Read existing pointer tails old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:]) old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:]) old_num_samples = int(root.attrs["num_samples"]) last_atom_ptr = int(old_atoms_ptr[-1].item()) last_edge_ptr = int(old_edges_ptr[-1].item()) # Get counts directly from the single AtomicData data_dict = data.to_dict() num_atoms = int(data.num_nodes) # Determine num_edges from neighbor_list if present neighbor_list = data_dict.get("neighbor_list") if neighbor_list is not None and isinstance(neighbor_list, torch.Tensor): num_edges = neighbor_list.shape[0] else: num_edges = 0 # Extend pointer arrays with single new entries new_atom_ptr = torch.tensor([last_atom_ptr + num_atoms], dtype=torch.long) new_edge_ptr = torch.tensor([last_edge_ptr + num_edges], dtype=torch.long) self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atom_ptr)) self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edge_ptr)) # Extend masks (single sample, its atoms, its edges) self._extend_array( meta_group["samples_mask"], self._to_numpy(torch.ones(1, dtype=torch.bool)), ) self._extend_array( meta_group["atoms_mask"], self._to_numpy(torch.ones(num_atoms, dtype=torch.bool)), ) self._extend_array( meta_group["edges_mask"], self._to_numpy(torch.ones(num_edges, dtype=torch.bool)), ) # Extend each existing core field excluded = {"batch_idx", "batch_ptr", "device", "dtype", "info", "num_nodes"} for key in core_group.keys(): val = data_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue # System-level fields need unsqueeze(0) to add the sample dimension level = _get_field_level(key) if level == "system": val = val.unsqueeze(0) if val.dim() == 0 else val cat_dim = _get_cat_dim(key) self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim) root.attrs["num_samples"] = old_num_samples + 1 @overload def append(self, data: list[AtomicData]) -> None: # noqa: F811 """Append a list of AtomicData to an existing Zarr store.""" device = data[0].device self.append(Batch.from_data_list(data, device)) @overload def append(self, data: Batch) -> None: # noqa: F811 """Append a Batch to an existing Zarr store. This is the efficient bulk-append path. Since a Batch already has all tensors concatenated (node/edge level) or stacked (system level), each field is extended in a single I/O operation with no per-sample iteration. Parameters ---------- data : Batch Batched atomic data to append. Raises ------ FileNotFoundError If store does not exist. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") num_samples = data.num_graphs if num_samples is None or num_samples == 0: return root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] # Read existing pointer tails old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:]) old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:]) old_num_samples = int(root.attrs["num_samples"]) last_atom_ptr = int(old_atoms_ptr[-1].item()) last_edge_ptr = int(old_edges_ptr[-1].item()) # Compute new pointer entries from batch metadata nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long) # Handle case where num_edges_list is empty (no edges in data) if data.num_edges_list: edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long) else: # No edges: create zeros for each sample edges_tensor = torch.zeros(num_samples, dtype=torch.long) new_atoms_ptr = last_atom_ptr + torch.cumsum(nodes_tensor, dim=0) new_edges_ptr = last_edge_ptr + torch.cumsum(edges_tensor, dim=0) new_total_atoms = int(new_atoms_ptr[-1].item()) new_total_edges = int(new_edges_ptr[-1].item()) # Extend pointer arrays self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atoms_ptr)) self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edges_ptr)) # Extend masks self._extend_array( meta_group["samples_mask"], self._to_numpy(torch.ones(num_samples, dtype=torch.bool)), ) self._extend_array( meta_group["atoms_mask"], self._to_numpy( torch.ones(new_total_atoms - last_atom_ptr, dtype=torch.bool) ), ) self._extend_array( meta_group["edges_mask"], self._to_numpy( torch.ones(new_total_edges - last_edge_ptr, dtype=torch.bool) ), ) # Get all tensor attributes from the batch (Pydantic's to_dict works) batch_dict = data.to_dict() # Extend each field — single I/O per field excluded = {"batch_idx", "batch_ptr", "device", "dtype", "info", "num_nodes"} for key in core_group.keys(): val = batch_dict.get(key) if val is None or not isinstance(val, torch.Tensor): continue if key in excluded: continue level = _get_field_level(key) if level == "system" and val.dim() > 2: while val.dim() > 2 and val.shape[1] == 1: val = val.squeeze(1) cat_dim = _get_cat_dim(key) self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim) root.attrs["num_samples"] = old_num_samples + num_samples
[docs] @dispatch def append(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811 """Append data to an existing Zarr store. Extends all arrays along concatenation axis. Extends pointer arrays and masks. Updates num_samples in .zattrs. Parameters ---------- data : AtomicData | list[AtomicData] | Batch Data to append. Raises ------ FileNotFoundError If store does not exist. """ pass
[docs] def add_custom( self, key: str, data: torch.Tensor, level: Literal["atom", "edge", "system"] ) -> None: """Add a custom array to the custom/ group. Parameters ---------- key : str Name for the custom array. data : torch.Tensor Tensor data. First dimension must match: - num_samples for "system" level - total atoms for "atom" level - total edges for "edge" level level : str One of "atom", "edge", "system". Raises ------ ValueError If level is invalid or data shape doesn't match. FileNotFoundError If store does not exist. """ if level not in ("atom", "edge", "system"): raise ValueError( f"Invalid level '{level}'. Must be 'atom', 'edge', or 'system'." ) if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r+") meta_group = root["meta"] custom_group = root["custom"] # Validate shape num_samples = int(root.attrs["num_samples"]) atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] total_atoms = int(atoms_ptr[-1]) total_edges = int(edges_ptr[-1]) expected_size = { "system": num_samples, "atom": total_atoms, "edge": total_edges, }[level] if data.shape[0] != expected_size: raise ValueError( f"Data shape[0]={data.shape[0]} does not match expected " f"size={expected_size} for level='{level}'." ) # Write to custom group (convert to numpy at zarr boundary) np_data = self._to_numpy(data) custom_group.create_array( key, data=np_data, **self._resolve_array_kwargs(key, "custom", np_data) ) # Update fields metadata fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) if "custom" not in fields_metadata: fields_metadata["custom"] = {} fields_metadata["custom"][key] = level root.attrs["fields"] = fields_metadata
[docs] def delete(self, indices: list[int] | torch.Tensor) -> None: """Soft-delete samples by index. Sets masks to False and zeros out data slices in core/ and custom/. Pointer arrays are NOT modified. Parameters ---------- indices : list[int] | torch.Tensor Sample indices to delete. """ if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") # Convert to torch tensor for consistent handling if isinstance(indices, list): indices_tensor = torch.as_tensor(indices, dtype=torch.long) else: indices_tensor = indices.to(torch.long) if len(indices_tensor) == 0: return root = self._open(mode="r+") meta_group = root["meta"] core_group = root["core"] atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] samples_mask = meta_group["samples_mask"][:] atoms_mask = meta_group["atoms_mask"][:] edges_mask = meta_group["edges_mask"][:] fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) for idx in indices_tensor: idx = int(idx) # Mark sample as deleted samples_mask[idx] = False # Get slice ranges atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1]) edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1]) # Zero out atoms_mask and edges_mask atoms_mask[atom_start:atom_end] = False edges_mask[edge_start:edge_end] = False # Zero out core fields for key in core_group.keys(): level = fields_metadata.get("core", {}).get(key, _get_field_level(key)) arr = core_group[key] if level == "atom": self._zero_slice(arr, atom_start, atom_end, axis=0) elif level == "edge": cat_dim = _get_cat_dim(key) self._zero_slice(arr, edge_start, edge_end, axis=cat_dim) elif level == "system": self._zero_slice(arr, idx, idx + 1, axis=0) # Zero out custom fields if "custom" in root: custom_group = root["custom"] for key in custom_group.keys(): level = fields_metadata.get("custom", {}).get(key, "system") arr = custom_group[key] if level == "atom": self._zero_slice(arr, atom_start, atom_end, axis=0) elif level == "edge": self._zero_slice(arr, edge_start, edge_end, axis=0) elif level == "system": self._zero_slice(arr, idx, idx + 1, axis=0) # Write back masks meta_group["samples_mask"][:] = samples_mask meta_group["atoms_mask"][:] = atoms_mask meta_group["edges_mask"][:] = edges_mask
[docs] def defragment( self, config: ZarrWriteConfig | Mapping[str, Any] | None = None ) -> None: """Rewrite store excluding deleted samples. Rebuilds all arrays, pointer arrays, and resets all masks to True. Parameters ---------- config : ZarrWriteConfig | Mapping[str, Any] | None Optional new write configuration for the rebuilt arrays. When provided, also updates the writer's stored config for future operations. When ``None``, reuses the existing writer config. """ if config is not None: if isinstance(config, Mapping): config = ZarrWriteConfig.model_validate(config) self._config = config if not self._store_exists(): raise FileNotFoundError(f"Zarr store does not exist at {self._store}") root = self._open(mode="r") meta_group = root["meta"] core_group = root["core"] atoms_ptr = meta_group["atoms_ptr"][:] edges_ptr = meta_group["edges_ptr"][:] samples_mask = meta_group["samples_mask"][:] fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}})) # Find active sample indices active_indices = np.where(samples_mask)[0] if len(active_indices) == 0: # All samples deleted, just reset to empty # Clear store by opening with mode="w" (overwrite) new_root = self._open(mode="w") new_root.create_group("meta") new_root.create_group("core") new_root.create_group("custom") new_root.attrs["num_samples"] = 0 new_root.attrs["fields"] = {"core": {}, "custom": {}} return # Collect active data for each field new_core_data: dict[str, list[np.ndarray]] = { key: [] for key in core_group.keys() } new_custom_data: dict[str, list[np.ndarray]] = {} if "custom" in root: custom_group = root["custom"] new_custom_data = {key: [] for key in custom_group.keys()} new_num_nodes: list[int] = [] new_num_edges: list[int] = [] # Pre-read all arrays once to avoid re-reading per sample core_arrays = {key: core_group[key][:] for key in core_group.keys()} custom_arrays: dict[str, np.ndarray] = {} if "custom" in root: custom_arrays = { key: root["custom"][key][:] for key in root["custom"].keys() } for idx in active_indices: idx = int(idx) atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1]) edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1]) new_num_nodes.append(atom_end - atom_start) new_num_edges.append(edge_end - edge_start) for key in core_group.keys(): level = fields_metadata.get("core", {}).get(key, _get_field_level(key)) arr = core_arrays[key] if level == "atom": new_core_data[key].append(arr[atom_start:atom_end]) elif level == "edge": new_core_data[key].append( _slice_edge_array(arr, key, edge_start, edge_end) ) elif level == "system": # System level: index by sample new_core_data[key].append(arr[idx : idx + 1]) if custom_arrays: for key in custom_arrays: level = fields_metadata.get("custom", {}).get(key, "system") arr = custom_arrays[key] if level == "atom": new_custom_data[key].append(arr[atom_start:atom_end]) elif level == "edge": new_custom_data[key].append( _slice_edge_array(arr, key, edge_start, edge_end) ) elif level == "system": new_custom_data[key].append(arr[idx : idx + 1]) # Clear store and create new structure (mode="w" clears existing data) new_root = self._open(mode="w") new_meta = new_root.create_group("meta") new_core = new_root.create_group("core") new_custom = new_root.create_group("custom") # Build new pointer arrays new_atoms_ptr = np.array([0] + list(np.cumsum(new_num_nodes)), dtype=np.int64) new_edges_ptr = np.array([0] + list(np.cumsum(new_num_edges)), dtype=np.int64) new_total_atoms = int(new_atoms_ptr[-1]) new_total_edges = int(new_edges_ptr[-1]) new_num_samples = len(active_indices) new_samples_mask = np.ones(new_num_samples, dtype=np.bool_) new_atoms_mask = np.ones(new_total_atoms, dtype=np.bool_) new_edges_mask = np.ones(new_total_edges, dtype=np.bool_) new_meta.create_array( "atoms_ptr", data=new_atoms_ptr, **self._resolve_array_kwargs("atoms_ptr", "meta", new_atoms_ptr), ) new_meta.create_array( "edges_ptr", data=new_edges_ptr, **self._resolve_array_kwargs("edges_ptr", "meta", new_edges_ptr), ) new_meta.create_array( "samples_mask", data=new_samples_mask, **self._resolve_array_kwargs("samples_mask", "meta", new_samples_mask), ) new_meta.create_array( "atoms_mask", data=new_atoms_mask, **self._resolve_array_kwargs("atoms_mask", "meta", new_atoms_mask), ) new_meta.create_array( "edges_mask", data=new_edges_mask, **self._resolve_array_kwargs("edges_mask", "meta", new_edges_mask), ) # Concatenate and write core arrays for key, arrays in new_core_data.items(): if arrays: cat_dim = _get_cat_dim(key) concatenated = np.concatenate(arrays, axis=cat_dim) resolved_cat_dim = ( cat_dim if cat_dim >= 0 else cat_dim + concatenated.ndim ) new_core.create_array( key, data=concatenated, **self._resolve_array_kwargs( key, "core", concatenated, cat_dim=resolved_cat_dim ), ) # Concatenate and write custom arrays for key, arrays in new_custom_data.items(): if arrays: concatenated = np.concatenate(arrays, axis=0) new_custom.create_array( key, data=concatenated, **self._resolve_array_kwargs(key, "custom", concatenated), ) # Update metadata new_root.attrs["num_samples"] = new_num_samples new_root.attrs["fields"] = fields_metadata
@staticmethod def _to_numpy(tensor: torch.Tensor) -> np.ndarray: """Convert a torch tensor to numpy for zarr I/O. Parameters ---------- tensor : torch.Tensor Tensor to convert. Returns ------- np.ndarray Numpy array for zarr storage. """ return tensor.detach().cpu().numpy() @staticmethod def _extend_array(arr: zarr.Array, data: np.ndarray, axis: int = 0) -> None: """Extend a zarr array along an axis. Parameters ---------- arr : zarr.Array Zarr array to extend. data : np.ndarray Data to append. axis : int Axis along which to extend. """ old_shape = arr.shape new_len = data.shape[axis] # Build new shape new_shape = list(old_shape) new_shape[axis] = old_shape[axis] + new_len # Resize array arr.resize(tuple(new_shape)) # Write new data slices: list[slice | int] = [slice(None)] * len(old_shape) slices[axis] = slice(old_shape[axis], new_shape[axis]) arr[tuple(slices)] = data @staticmethod def _zero_slice(arr: zarr.Array, start: int, end: int, axis: int = 0) -> None: """Zero out a slice of a zarr array. Parameters ---------- arr : zarr.Array Zarr array to modify. start : int Start index. end : int End index. axis : int Axis along which to slice. """ if start >= end: return slices: list[slice | int] = [slice(None)] * len(arr.shape) slices[axis] = slice(start, end) # Create zeros with correct shape shape = list(arr.shape) shape[axis] = end - start zeros = np.zeros(shape, dtype=arr.dtype) arr[tuple(slices)] = zeros
[docs] class AtomicDataZarrReader(Reader): """Reader for loading AtomicData from Zarr stores. This reader provides random-access loading of AtomicData samples from Zarr stores created by :class:`AtomicDataZarrWriter`. It supports soft-deleted samples via the samples_mask and provides efficient random access using pointer arrays. The Zarr store layout expected is: dataset.zarr/ ├── meta/ # Pointer arrays + masks │ ├── atoms_ptr # int64 [N+1] — cumulative node counts │ ├── edges_ptr # int64 [N+1] — cumulative edge counts │ └── samples_mask # bool [N] — False = deleted sample ├── core/ # AtomicData fields │ ├── atomic_numbers # int64 [V_total] │ ├── positions # float32 [V_total, 3] │ └── ... └── custom/ # User-defined arrays (optional) Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. pin_memory : bool, default=False If True, place tensors in pinned (page-locked) memory for faster async CPU→GPU transfers. include_index_in_metadata : bool, default=True If True, include sample index in the metadata dict. Attributes ---------- _store : StoreLike The underlying zarr store reference. Examples -------- >>> from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader # doctest: +SKIP >>> reader = AtomicDataZarrReader(store="dataset.zarr") # doctest: +SKIP >>> data_dict, metadata = reader[0] # returns dict and metadata # doctest: +SKIP >>> atomic_data = AtomicDataZarrReader.to_atomic_data(data_dict) # doctest: +SKIP """ def __init__( self, store: StoreLike, *, pin_memory: bool = False, include_index_in_metadata: bool = True, ) -> None: """Initialize the reader with a Zarr store. Parameters ---------- store : StoreLike Any zarr-compatible store: filesystem path (str or Path), or a zarr Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath, or a dict for in-memory buffer storage. pin_memory : bool, default=False If True, place tensors in pinned (page-locked) memory. include_index_in_metadata : bool, default=True If True, include sample index in the metadata dict. Raises ------ FileNotFoundError If the Zarr store does not exist (for filesystem paths). ValueError If the store is missing required groups (meta, core). """ super().__init__( pin_memory=pin_memory, include_index_in_metadata=include_index_in_metadata, ) self._store: StoreLike = store # For filesystem paths, provide a friendly existence check if isinstance(store, (str, Path)) and not Path(store).exists(): raise FileNotFoundError(f"Zarr store does not exist at {store}") # Open the Zarr store in read mode self._root = zarr.open(self._store, mode="r") # Validate store structure if "meta" not in self._root: raise ValueError(f"Zarr store at {self._store} is missing 'meta' group") if "core" not in self._root: raise ValueError(f"Zarr store at {self._store} is missing 'core' group") # Load cached state from the store self.refresh()
[docs] def refresh(self) -> None: """Reload cached pointer arrays, masks, and metadata from the store. Call this method after external modifications to the Zarr store (e.g., appending or deleting samples via :class:`AtomicDataZarrWriter`) to ensure the reader reflects the current state of the data. Raises ------ RuntimeError If the reader has been closed. """ if self._root is None: raise RuntimeError("Cannot refresh a closed reader.") # Re-open the store to pick up structural changes self._root = zarr.open(self._store, mode="r") # Cache pointer arrays as torch tensors self._atoms_ptr = torch.from_numpy(self._root["meta"]["atoms_ptr"][:]).to( torch.long ) self._edges_ptr = torch.from_numpy(self._root["meta"]["edges_ptr"][:]).to( torch.long ) # Cache samples mask self._samples_mask = torch.from_numpy(self._root["meta"]["samples_mask"][:]).to( torch.bool ) # Build logical->physical index mapping (indices where mask is True) self._active_indices = torch.where(self._samples_mask)[0] # Cache fields metadata self._fields_metadata: dict[str, dict[str, str]] = dict( self._root.attrs.get("fields", {"core": {}, "custom": {}}) )
def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load raw data for a single sample. Parameters ---------- index : int Logical sample index (0 to len-1), accounting for deleted samples. Returns ------- dict[str, torch.Tensor] Dictionary mapping field names to CPU tensors. Raises ------ IndexError If index is out of range. """ # Map logical index to physical index physical_idx = int(self._active_indices[index].item()) # Get slice ranges from pointer arrays atom_start = int(self._atoms_ptr[physical_idx].item()) atom_end = int(self._atoms_ptr[physical_idx + 1].item()) edge_start = int(self._edges_ptr[physical_idx].item()) edge_end = int(self._edges_ptr[physical_idx + 1].item()) data: dict[str, torch.Tensor] = {} # Load core fields core_group = self._root["core"] for key in core_group.array_keys(): level = self._fields_metadata.get("core", {}).get( key, _get_field_level(key) ) arr = core_group[key] if level == "atom": data[key] = torch.from_numpy(arr[atom_start:atom_end]) elif level == "edge": tensor = torch.from_numpy( _slice_edge_array(arr, key, edge_start, edge_end) ) # neighbor_list needs to be converted from global to local indices # by subtracting the atom offset for this sample if key == "neighbor_list": tensor = tensor - atom_start data[key] = tensor else: # system level # Keep batch dim for system-level fields data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) # Load custom fields if present if "custom" in self._root: custom_group = self._root["custom"] for key in custom_group.array_keys(): level = self._fields_metadata.get("custom", {}).get(key, "system") arr = custom_group[key] if level == "atom": data[key] = torch.from_numpy(arr[atom_start:atom_end]) elif level == "edge": data[key] = torch.from_numpy( _slice_edge_array(arr, key, edge_start, edge_end) ) else: # system level data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) return data def __len__(self) -> int: """Return the number of active (non-deleted) samples. Returns ------- int Number of samples available for reading. """ return len(self._active_indices) def _get_sample_metadata(self, index: int) -> dict[str, str]: """Return metadata for a sample. Parameters ---------- index : int Logical sample index. Returns ------- dict[str, str] Dictionary containing source file information. """ physical_idx = int(self._active_indices[index].item()) return { "source_file": str(self._store), "physical_index": str(physical_idx), }
[docs] def close(self) -> None: """Release the Zarr store reference and clean up resources.""" self._root = None super().close()