# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Zarr backend for AtomicData (de)serialization.
This module provides the concrete implementation of ``AtomicData``
(de)serialization using high performance ``zarr`` array I/O.
The ``AtomicDataZarrWriter`` class is designed to allow for efficient,
amortized data writes with the ability to directly save/append ``Batch``
objects to disk.
The ``AtomicDataZarrReader`` provides a concrete ``Reader`` implementation
that reads in arrays from disk, and maps them to ``torch.Tensor``s that
are intended to composed with :class:`nvalchemi.data.datapipes.Dataset`.
To understand usage, users should refer to ``examples/data/datapipes/read_zarr_store.py``.
"""
from __future__ import annotations
import re
from pathlib import Path
from typing import Any, Literal, TypeAlias
import numpy as np
import torch
import zarr
from plum import dispatch, overload
from zarr.abc.store import Store
from zarr.storage import StorePath
# These need to be available at runtime for plum dispatch
from nvalchemi.data.atomic_data import AtomicData
from nvalchemi.data.batch import Batch
from nvalchemi.data.datapipes.backends.base import Reader
# Type alias for zarr store-like objects
StoreLike: TypeAlias = Store | StorePath | Path | str | dict[str, Any]
# TODO: make classes inherit from PNM when stable
def _get_field_level(key: str) -> str:
"""Return 'atom', 'edge', or 'system' for a core field key.
Parameters
----------
key : str
Field name.
Returns
-------
str
One of 'atom', 'edge', or 'system'.
"""
match key:
case k if k in AtomicData.__node_keys__:
return "atom"
case k if k in AtomicData.__edge_keys__:
return "edge"
case k if k in AtomicData.__system_keys__:
return "system"
case _:
# Default to atom level for unknown keys
return "atom"
def _get_cat_dim(key: str) -> int:
"""Return concatenation dimension for a field.
Returns -1 for keys containing 'index' or 'face', 0 otherwise.
(Mirrors DataMixin.__cat_dim__)
Parameters
----------
key : str
Field name.
Returns
-------
int
Concatenation dimension.
"""
if bool(re.search("(index|face)", key)):
return -1
return 0
[docs]
class AtomicDataZarrWriter:
"""Writer for serializing AtomicData into Zarr stores.
Writes AtomicData objects into a structured Zarr store with CSR-style
pointer arrays for variable-size graph data. Supports single writes,
batch writes, appending, custom fields, soft-delete, and defragmentation.
The Zarr store layout is:
dataset.zarr/
├── meta/ # Pointer arrays + masks
│ ├── atoms_ptr # int64 [N+1] — cumulative node counts
│ ├── edges_ptr # int64 [N+1] — cumulative edge counts
│ ├── samples_mask # bool [N] — False = deleted sample
│ ├── atoms_mask # bool [V_total] — False = deleted atom
│ └── edges_mask # bool [E_total] — False = deleted edge
│
├── core/ # AtomicData fields (auto-populated)
│ ├── atomic_numbers # int64 [V_total]
│ ├── positions # float32 [V_total, 3]
│ └── ...
│
├── custom/ # User-defined arrays (optional)
│ └── <user_key> # any dtype, any shape
│
└── .zattrs # root metadata
Parameters
----------
store : StoreLike
Any zarr-compatible store: filesystem path (str or Path), or a zarr
Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath,
or a dict for in-memory buffer storage.
Attributes
----------
_store : StoreLike
The zarr store used for I/O.
"""
def __init__(self, store: StoreLike) -> None:
"""Initialize the writer with a target store.
Parameters
----------
store : StoreLike
Any zarr-compatible store: filesystem path (str or Path), or a zarr
Store instance (LocalStore, MemoryStore, FsspecStore, etc.),
StorePath, or a dict for in-memory buffer storage.
"""
self._store: StoreLike = store
def _open(self, mode: Literal["r", "r+", "w", "w-", "a"] = "r") -> zarr.Group:
"""Open the zarr store with the given mode.
Parameters
----------
mode : Literal["r", "r+", "w", "w-", "a"]
Zarr access mode ('r', 'r+', 'w', 'w-', 'a').
Returns
-------
zarr.Group
The opened zarr group.
"""
return zarr.open(self._store, mode=mode) # type: ignore[return-value]
def _store_exists(self) -> bool:
"""Check whether the store already contains data.
For filesystem paths, checks if the path exists. For abstract stores
(MemoryStore, FsspecStore, etc.), attempts to open in read mode and
check for content.
Returns
-------
bool
True if the store exists and contains data.
"""
if isinstance(self._store, (str, Path)):
return Path(self._store).exists()
# For abstract stores (MemoryStore, FsspecStore, etc.),
# try opening read-only and check for content
try:
root = zarr.open(self._store, mode="r") # type: ignore[call-overload]
# If we can list any members, the store has data
return len(list(root.group_keys())) > 0 or len(list(root.array_keys())) > 0
except Exception:
return False
@overload
def write(self, data: AtomicData) -> None: # noqa: F811
"""Write a single AtomicData."""
self.write([data])
@overload
def write(self, data: list[AtomicData]) -> None: # noqa: F811
"""Write a list of AtomicData to a new Zarr store."""
self.write(Batch.from_data_list(data, device="cpu"))
@overload
def write(self, data: Batch) -> None: # noqa: F811
"""Write a Batch to a new Zarr store.
This is the efficient bulk-write path. Since a Batch already has
all tensors concatenated (node/edge level) or stacked (system level),
each field is written to zarr in a single I/O operation with no
per-sample iteration.
Parameters
----------
batch : Batch
Batched atomic data to write.
Raises
------
FileExistsError
If store already exists.
ValueError
If batch is empty.
"""
if self._store_exists():
raise FileExistsError(f"Zarr store already exists at {self._store}")
num_samples = data.num_graphs
if num_samples is None or num_samples == 0:
raise ValueError("No data provided to write.")
root = self._open(mode="w")
meta_group = root.create_group("meta")
core_group = root.create_group("core")
root.create_group("custom")
# Build pointer arrays directly from batch metadata — no iteration
nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long)
# Handle case where num_edges_list is empty (no edges in data)
if data.num_edges_list:
edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long)
else:
# No edges: create zeros for each sample
edges_tensor = torch.zeros(num_samples, dtype=torch.long)
atoms_ptr = torch.cat(
[torch.zeros(1, dtype=torch.long), torch.cumsum(nodes_tensor, dim=0)]
)
edges_ptr = torch.cat(
[torch.zeros(1, dtype=torch.long), torch.cumsum(edges_tensor, dim=0)]
)
total_atoms = int(atoms_ptr[-1].item())
total_edges = int(edges_ptr[-1].item())
# Write meta arrays
meta_group.create_array("atoms_ptr", data=self._to_numpy(atoms_ptr))
meta_group.create_array("edges_ptr", data=self._to_numpy(edges_ptr))
meta_group.create_array(
"samples_mask",
data=self._to_numpy(torch.ones(num_samples, dtype=torch.bool)),
)
meta_group.create_array(
"atoms_mask", data=self._to_numpy(torch.ones(total_atoms, dtype=torch.bool))
)
meta_group.create_array(
"edges_mask", data=self._to_numpy(torch.ones(total_edges, dtype=torch.bool))
)
# Build field level mapping
fields_metadata: dict[str, dict[str, str]] = {"core": {}, "custom": {}}
# Collect all field keys from the batch's level categorization.
# batch.keys is {"node": set, "edge": set, "system": set} — these are the
# field names present in the batch, already categorized by level.
excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"}
all_field_keys: set[str] = set()
level_map: dict[str, str] = {} # key -> "atom"/"edge"/"system"
for level_name, key_set in (data.keys or {}).items():
for k in key_set:
all_field_keys.add(k)
# batch.keys uses "node"/"edge"/"system"; zarr format uses "atom"/"edge"/"system"
level_map[k] = "atom" if level_name == "node" else level_name
# Get all tensor attributes from the batch (Pydantic's to_dict works)
batch_dict = data.to_dict()
# Write each field: one zarr I/O per field, no per-sample loop
for key in all_field_keys:
val = batch_dict.get(key)
if val is None or not isinstance(val, torch.Tensor):
continue
if key in excluded:
continue
level = level_map.get(key, _get_field_level(key))
fields_metadata["core"][key] = level
# The tensor is already in concatenated/stacked form from the Batch.
# Batch stores edge_index as (num_edges, 2); zarr format is (2, num_edges).
if key == "edge_index":
val = val.transpose(0, 1)
# System-level: Batch stacks (1, 3, 3) -> (N, 1, 3, 3); squeeze dim 1 so zarr has (N, 3, 3)
if level == "system" and val.dim() > 2:
while val.dim() > 2 and val.shape[1] == 1:
val = val.squeeze(1)
core_group.create_array(key, data=self._to_numpy(val))
root.attrs["num_samples"] = num_samples
root.attrs["fields"] = fields_metadata
[docs]
@dispatch
def write(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811
"""Write atomic data to a new Zarr store.
Creates the Zarr store with core/, meta/, custom/ groups.
Builds atoms_ptr, edges_ptr, and initializes all masks to True.
If data is a Batch, calls to_data_list() first.
If data is a single AtomicData, wraps in a list.
Parameters
----------
data : AtomicData | list[AtomicData] | Batch
Data to write.
Raises
------
FileExistsError
If store already exists.
"""
pass
@overload
def append(self, data: AtomicData) -> None: # noqa: F811
"""Append a single AtomicData to an existing Zarr store.
While this dispatch is available for convenience, we recommend
users to try and amortize I/O operations by packing multiple
data to write, instead of one at a time. This can be achieved
by passing either a ``Batch`` object, or a list of ``AtomicData``
which will automatically form a batch.
Parameters
----------
data : AtomicData
Single atomic data to append.
Raises
------
FileNotFoundError
If store does not exist.
"""
if not self._store_exists():
raise FileNotFoundError(f"Zarr store does not exist at {self._store}")
root = self._open(mode="r+")
meta_group = root["meta"]
core_group = root["core"]
# Read existing pointer tails
old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:])
old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:])
old_num_samples = int(root.attrs["num_samples"])
last_atom_ptr = int(old_atoms_ptr[-1].item())
last_edge_ptr = int(old_edges_ptr[-1].item())
# Get counts directly from the single AtomicData
data_dict = data.to_dict()
num_atoms = int(data.num_nodes)
# Determine num_edges from edge_index if present
edge_index = data_dict.get("edge_index")
if edge_index is not None and isinstance(edge_index, torch.Tensor):
num_edges = edge_index.shape[-1]
else:
num_edges = 0
# Extend pointer arrays with single new entries
new_atom_ptr = torch.tensor([last_atom_ptr + num_atoms], dtype=torch.long)
new_edge_ptr = torch.tensor([last_edge_ptr + num_edges], dtype=torch.long)
self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atom_ptr))
self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edge_ptr))
# Extend masks (single sample, its atoms, its edges)
self._extend_array(
meta_group["samples_mask"],
self._to_numpy(torch.ones(1, dtype=torch.bool)),
)
self._extend_array(
meta_group["atoms_mask"],
self._to_numpy(torch.ones(num_atoms, dtype=torch.bool)),
)
self._extend_array(
meta_group["edges_mask"],
self._to_numpy(torch.ones(num_edges, dtype=torch.bool)),
)
# Extend each existing core field
excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"}
for key in core_group.keys():
val = data_dict.get(key)
if val is None or not isinstance(val, torch.Tensor):
continue
if key in excluded:
continue
# System-level fields need unsqueeze(0) to add the sample dimension
level = _get_field_level(key)
if level == "system":
val = val.unsqueeze(0) if val.dim() == 0 else val
cat_dim = _get_cat_dim(key)
self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim)
root.attrs["num_samples"] = old_num_samples + 1
@overload
def append(self, data: list[AtomicData]) -> None: # noqa: F811
"""Append a list of AtomicData to an existing Zarr store."""
device = data[0].device
self.append(Batch.from_data_list(data, device))
@overload
def append(self, data: Batch) -> None: # noqa: F811
"""Append a Batch to an existing Zarr store.
This is the efficient bulk-append path. Since a Batch already has
all tensors concatenated (node/edge level) or stacked (system level),
each field is extended in a single I/O operation with no per-sample
iteration.
Parameters
----------
data : Batch
Batched atomic data to append.
Raises
------
FileNotFoundError
If store does not exist.
"""
if not self._store_exists():
raise FileNotFoundError(f"Zarr store does not exist at {self._store}")
num_samples = data.num_graphs
if num_samples is None or num_samples == 0:
return
root = self._open(mode="r+")
meta_group = root["meta"]
core_group = root["core"]
# Read existing pointer tails
old_atoms_ptr = torch.from_numpy(meta_group["atoms_ptr"][:])
old_edges_ptr = torch.from_numpy(meta_group["edges_ptr"][:])
old_num_samples = int(root.attrs["num_samples"])
last_atom_ptr = int(old_atoms_ptr[-1].item())
last_edge_ptr = int(old_edges_ptr[-1].item())
# Compute new pointer entries from batch metadata
nodes_tensor = torch.tensor(data.num_nodes_list, dtype=torch.long)
# Handle case where num_edges_list is empty (no edges in data)
if data.num_edges_list:
edges_tensor = torch.tensor(data.num_edges_list, dtype=torch.long)
else:
# No edges: create zeros for each sample
edges_tensor = torch.zeros(num_samples, dtype=torch.long)
new_atoms_ptr = last_atom_ptr + torch.cumsum(nodes_tensor, dim=0)
new_edges_ptr = last_edge_ptr + torch.cumsum(edges_tensor, dim=0)
new_total_atoms = int(new_atoms_ptr[-1].item())
new_total_edges = int(new_edges_ptr[-1].item())
# Extend pointer arrays
self._extend_array(meta_group["atoms_ptr"], self._to_numpy(new_atoms_ptr))
self._extend_array(meta_group["edges_ptr"], self._to_numpy(new_edges_ptr))
# Extend masks
self._extend_array(
meta_group["samples_mask"],
self._to_numpy(torch.ones(num_samples, dtype=torch.bool)),
)
self._extend_array(
meta_group["atoms_mask"],
self._to_numpy(
torch.ones(new_total_atoms - last_atom_ptr, dtype=torch.bool)
),
)
self._extend_array(
meta_group["edges_mask"],
self._to_numpy(
torch.ones(new_total_edges - last_edge_ptr, dtype=torch.bool)
),
)
# Get all tensor attributes from the batch (Pydantic's to_dict works)
batch_dict = data.to_dict()
# Extend each field — single I/O per field
excluded = {"batch", "ptr", "device", "dtype", "info", "num_nodes"}
for key in core_group.keys():
val = batch_dict.get(key)
if val is None or not isinstance(val, torch.Tensor):
continue
if key in excluded:
continue
# Batch stores edge_index as (num_edges, 2); zarr format is (2, num_edges)
if key == "edge_index":
val = val.transpose(0, 1)
level = _get_field_level(key)
if level == "system" and val.dim() > 2:
while val.dim() > 2 and val.shape[1] == 1:
val = val.squeeze(1)
cat_dim = _get_cat_dim(key)
self._extend_array(core_group[key], self._to_numpy(val), axis=cat_dim)
root.attrs["num_samples"] = old_num_samples + num_samples
[docs]
@dispatch
def append(self, data: AtomicData | list[AtomicData] | Batch) -> None: # noqa: F811
"""Append data to an existing Zarr store.
Extends all arrays along concatenation axis.
Extends pointer arrays and masks.
Updates num_samples in .zattrs.
Parameters
----------
data : AtomicData | list[AtomicData] | Batch
Data to append.
Raises
------
FileNotFoundError
If store does not exist.
"""
pass
[docs]
def add_custom(
self, key: str, data: torch.Tensor, level: Literal["atom", "edge", "system"]
) -> None:
"""Add a custom array to the custom/ group.
Parameters
----------
key : str
Name for the custom array.
data : torch.Tensor
Tensor data. First dimension must match:
- num_samples for "system" level
- total atoms for "atom" level
- total edges for "edge" level
level : str
One of "atom", "edge", "system".
Raises
------
ValueError
If level is invalid or data shape doesn't match.
FileNotFoundError
If store does not exist.
"""
if level not in ("atom", "edge", "system"):
raise ValueError(
f"Invalid level '{level}'. Must be 'atom', 'edge', or 'system'."
)
if not self._store_exists():
raise FileNotFoundError(f"Zarr store does not exist at {self._store}")
root = self._open(mode="r+")
meta_group = root["meta"]
custom_group = root["custom"]
# Validate shape
num_samples = int(root.attrs["num_samples"])
atoms_ptr = meta_group["atoms_ptr"][:]
edges_ptr = meta_group["edges_ptr"][:]
total_atoms = int(atoms_ptr[-1])
total_edges = int(edges_ptr[-1])
expected_size = {
"system": num_samples,
"atom": total_atoms,
"edge": total_edges,
}[level]
if data.shape[0] != expected_size:
raise ValueError(
f"Data shape[0]={data.shape[0]} does not match expected "
f"size={expected_size} for level='{level}'."
)
# Write to custom group (convert to numpy at zarr boundary)
custom_group.create_array(key, data=self._to_numpy(data))
# Update fields metadata
fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}}))
if "custom" not in fields_metadata:
fields_metadata["custom"] = {}
fields_metadata["custom"][key] = level
root.attrs["fields"] = fields_metadata
[docs]
def delete(self, indices: list[int] | torch.Tensor) -> None:
"""Soft-delete samples by index.
Sets masks to False and zeros out data slices in core/ and custom/.
Pointer arrays are NOT modified.
Parameters
----------
indices : list[int] | torch.Tensor
Sample indices to delete.
"""
if not self._store_exists():
raise FileNotFoundError(f"Zarr store does not exist at {self._store}")
# Convert to torch tensor for consistent handling
if isinstance(indices, list):
indices_tensor = torch.as_tensor(indices, dtype=torch.long)
else:
indices_tensor = indices.to(torch.long)
if len(indices_tensor) == 0:
return
root = self._open(mode="r+")
meta_group = root["meta"]
core_group = root["core"]
atoms_ptr = meta_group["atoms_ptr"][:]
edges_ptr = meta_group["edges_ptr"][:]
samples_mask = meta_group["samples_mask"][:]
atoms_mask = meta_group["atoms_mask"][:]
edges_mask = meta_group["edges_mask"][:]
fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}}))
for idx in indices_tensor:
idx = int(idx)
# Mark sample as deleted
samples_mask[idx] = False
# Get slice ranges
atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1])
edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1])
# Zero out atoms_mask and edges_mask
atoms_mask[atom_start:atom_end] = False
edges_mask[edge_start:edge_end] = False
# Zero out core fields
for key in core_group.keys():
level = fields_metadata.get("core", {}).get(key, _get_field_level(key))
arr = core_group[key]
if level == "atom":
self._zero_slice(arr, atom_start, atom_end, axis=0)
elif level == "edge":
cat_dim = _get_cat_dim(key)
self._zero_slice(arr, edge_start, edge_end, axis=cat_dim)
elif level == "system":
self._zero_slice(arr, idx, idx + 1, axis=0)
# Zero out custom fields
if "custom" in root:
custom_group = root["custom"]
for key in custom_group.keys():
level = fields_metadata.get("custom", {}).get(key, "system")
arr = custom_group[key]
if level == "atom":
self._zero_slice(arr, atom_start, atom_end, axis=0)
elif level == "edge":
self._zero_slice(arr, edge_start, edge_end, axis=0)
elif level == "system":
self._zero_slice(arr, idx, idx + 1, axis=0)
# Write back masks
meta_group["samples_mask"][:] = samples_mask
meta_group["atoms_mask"][:] = atoms_mask
meta_group["edges_mask"][:] = edges_mask
[docs]
def defragment(self) -> None:
"""Rewrite store excluding deleted samples.
Rebuilds all arrays, pointer arrays, and resets all masks to True.
"""
if not self._store_exists():
raise FileNotFoundError(f"Zarr store does not exist at {self._store}")
root = self._open(mode="r")
meta_group = root["meta"]
core_group = root["core"]
atoms_ptr = meta_group["atoms_ptr"][:]
edges_ptr = meta_group["edges_ptr"][:]
samples_mask = meta_group["samples_mask"][:]
fields_metadata = dict(root.attrs.get("fields", {"core": {}, "custom": {}}))
# Find active sample indices
active_indices = np.where(samples_mask)[0]
if len(active_indices) == 0:
# All samples deleted, just reset to empty
# Clear store by opening with mode="w" (overwrite)
new_root = self._open(mode="w")
new_root.create_group("meta")
new_root.create_group("core")
new_root.create_group("custom")
new_root.attrs["num_samples"] = 0
new_root.attrs["fields"] = {"core": {}, "custom": {}}
return
# Collect active data for each field
new_core_data: dict[str, list[np.ndarray]] = {
key: [] for key in core_group.keys()
}
new_custom_data: dict[str, list[np.ndarray]] = {}
if "custom" in root:
custom_group = root["custom"]
new_custom_data = {key: [] for key in custom_group.keys()}
new_num_nodes: list[int] = []
new_num_edges: list[int] = []
# Pre-read all arrays once to avoid re-reading per sample
core_arrays = {key: core_group[key][:] for key in core_group.keys()}
custom_arrays: dict[str, np.ndarray] = {}
if "custom" in root:
custom_arrays = {
key: root["custom"][key][:] for key in root["custom"].keys()
}
for idx in active_indices:
idx = int(idx)
atom_start, atom_end = int(atoms_ptr[idx]), int(atoms_ptr[idx + 1])
edge_start, edge_end = int(edges_ptr[idx]), int(edges_ptr[idx + 1])
new_num_nodes.append(atom_end - atom_start)
new_num_edges.append(edge_end - edge_start)
for key in core_group.keys():
level = fields_metadata.get("core", {}).get(key, _get_field_level(key))
arr = core_arrays[key]
if level == "atom":
new_core_data[key].append(arr[atom_start:atom_end])
elif level == "edge":
cat_dim = _get_cat_dim(key)
if cat_dim == -1:
# edge_index: shape [2, E]
new_core_data[key].append(arr[:, edge_start:edge_end])
else:
new_core_data[key].append(arr[edge_start:edge_end])
elif level == "system":
# System level: index by sample
new_core_data[key].append(arr[idx : idx + 1])
if custom_arrays:
for key in custom_arrays:
level = fields_metadata.get("custom", {}).get(key, "system")
arr = custom_arrays[key]
if level == "atom":
new_custom_data[key].append(arr[atom_start:atom_end])
elif level == "edge":
new_custom_data[key].append(arr[edge_start:edge_end])
elif level == "system":
new_custom_data[key].append(arr[idx : idx + 1])
# Clear store and create new structure (mode="w" clears existing data)
new_root = self._open(mode="w")
new_meta = new_root.create_group("meta")
new_core = new_root.create_group("core")
new_custom = new_root.create_group("custom")
# Build new pointer arrays
new_atoms_ptr = np.array([0] + list(np.cumsum(new_num_nodes)), dtype=np.int64)
new_edges_ptr = np.array([0] + list(np.cumsum(new_num_edges)), dtype=np.int64)
new_total_atoms = int(new_atoms_ptr[-1])
new_total_edges = int(new_edges_ptr[-1])
new_num_samples = len(active_indices)
new_meta.create_array("atoms_ptr", data=new_atoms_ptr)
new_meta.create_array("edges_ptr", data=new_edges_ptr)
new_meta.create_array(
"samples_mask", data=np.ones(new_num_samples, dtype=np.bool_)
)
new_meta.create_array(
"atoms_mask", data=np.ones(new_total_atoms, dtype=np.bool_)
)
new_meta.create_array(
"edges_mask", data=np.ones(new_total_edges, dtype=np.bool_)
)
# Concatenate and write core arrays
for key, arrays in new_core_data.items():
if arrays:
cat_dim = _get_cat_dim(key)
concatenated = np.concatenate(arrays, axis=cat_dim)
new_core.create_array(key, data=concatenated)
# Concatenate and write custom arrays
for key, arrays in new_custom_data.items():
if arrays:
concatenated = np.concatenate(arrays, axis=0)
new_custom.create_array(key, data=concatenated)
# Update metadata
new_root.attrs["num_samples"] = new_num_samples
new_root.attrs["fields"] = fields_metadata
@staticmethod
def _to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Convert a torch tensor to numpy for zarr I/O.
Parameters
----------
tensor : torch.Tensor
Tensor to convert.
Returns
-------
np.ndarray
Numpy array for zarr storage.
"""
return tensor.detach().cpu().numpy()
@staticmethod
def _extend_array(arr: zarr.Array, data: np.ndarray, axis: int = 0) -> None:
"""Extend a zarr array along an axis.
Parameters
----------
arr : zarr.Array
Zarr array to extend.
data : np.ndarray
Data to append.
axis : int
Axis along which to extend.
"""
old_shape = arr.shape
new_len = data.shape[axis]
# Build new shape
new_shape = list(old_shape)
new_shape[axis] = old_shape[axis] + new_len
# Resize array
arr.resize(tuple(new_shape))
# Write new data
slices: list[slice | int] = [slice(None)] * len(old_shape)
slices[axis] = slice(old_shape[axis], new_shape[axis])
arr[tuple(slices)] = data
@staticmethod
def _zero_slice(arr: zarr.Array, start: int, end: int, axis: int = 0) -> None:
"""Zero out a slice of a zarr array.
Parameters
----------
arr : zarr.Array
Zarr array to modify.
start : int
Start index.
end : int
End index.
axis : int
Axis along which to slice.
"""
if start >= end:
return
slices: list[slice | int] = [slice(None)] * len(arr.shape)
slices[axis] = slice(start, end)
# Create zeros with correct shape
shape = list(arr.shape)
shape[axis] = end - start
zeros = np.zeros(shape, dtype=arr.dtype)
arr[tuple(slices)] = zeros
[docs]
class AtomicDataZarrReader(Reader):
"""Reader for loading AtomicData from Zarr stores.
This reader provides random-access loading of
AtomicData samples from Zarr stores created by :class:`AtomicDataZarrWriter`.
It supports soft-deleted samples via the samples_mask and provides
efficient random access using pointer arrays.
The Zarr store layout expected is:
dataset.zarr/
├── meta/ # Pointer arrays + masks
│ ├── atoms_ptr # int64 [N+1] — cumulative node counts
│ ├── edges_ptr # int64 [N+1] — cumulative edge counts
│ └── samples_mask # bool [N] — False = deleted sample
│
├── core/ # AtomicData fields
│ ├── atomic_numbers # int64 [V_total]
│ ├── positions # float32 [V_total, 3]
│ └── ...
│
└── custom/ # User-defined arrays (optional)
Parameters
----------
store : StoreLike
Any zarr-compatible store: filesystem path (str or Path), or a zarr
Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath,
or a dict for in-memory buffer storage.
pin_memory : bool, default=False
If True, place tensors in pinned (page-locked) memory for faster
async CPU→GPU transfers.
include_index_in_metadata : bool, default=True
If True, include sample index in the metadata dict.
Attributes
----------
_store : StoreLike
The underlying zarr store reference.
Examples
--------
>>> from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader # doctest: +SKIP
>>> reader = AtomicDataZarrReader(store="dataset.zarr") # doctest: +SKIP
>>> data_dict, metadata = reader[0] # returns dict and metadata # doctest: +SKIP
>>> atomic_data = AtomicDataZarrReader.to_atomic_data(data_dict) # doctest: +SKIP
"""
def __init__(
self,
store: StoreLike,
*,
pin_memory: bool = False,
include_index_in_metadata: bool = True,
) -> None:
"""Initialize the reader with a Zarr store.
Parameters
----------
store : StoreLike
Any zarr-compatible store: filesystem path (str or Path), or a zarr
Store instance (LocalStore, MemoryStore, FsspecStore, etc.), StorePath,
or a dict for in-memory buffer storage.
pin_memory : bool, default=False
If True, place tensors in pinned (page-locked) memory.
include_index_in_metadata : bool, default=True
If True, include sample index in the metadata dict.
Raises
------
FileNotFoundError
If the Zarr store does not exist (for filesystem paths).
ValueError
If the store is missing required groups (meta, core).
"""
super().__init__(
pin_memory=pin_memory,
include_index_in_metadata=include_index_in_metadata,
)
self._store: StoreLike = store
# For filesystem paths, provide a friendly existence check
if isinstance(store, (str, Path)) and not Path(store).exists():
raise FileNotFoundError(f"Zarr store does not exist at {store}")
# Open the Zarr store in read mode
self._root = zarr.open(self._store, mode="r")
# Validate store structure
if "meta" not in self._root:
raise ValueError(f"Zarr store at {self._store} is missing 'meta' group")
if "core" not in self._root:
raise ValueError(f"Zarr store at {self._store} is missing 'core' group")
# Load cached state from the store
self.refresh()
[docs]
def refresh(self) -> None:
"""Reload cached pointer arrays, masks, and metadata from the store.
Call this method after external modifications to the Zarr store
(e.g., appending or deleting samples via :class:`AtomicDataZarrWriter`)
to ensure the reader reflects the current state of the data.
Raises
------
RuntimeError
If the reader has been closed.
"""
if self._root is None:
raise RuntimeError("Cannot refresh a closed reader.")
# Re-open the store to pick up structural changes
self._root = zarr.open(self._store, mode="r")
# Cache pointer arrays as torch tensors
self._atoms_ptr = torch.from_numpy(self._root["meta"]["atoms_ptr"][:]).to(
torch.long
)
self._edges_ptr = torch.from_numpy(self._root["meta"]["edges_ptr"][:]).to(
torch.long
)
# Cache samples mask
self._samples_mask = torch.from_numpy(self._root["meta"]["samples_mask"][:]).to(
torch.bool
)
# Build logical->physical index mapping (indices where mask is True)
self._active_indices = torch.where(self._samples_mask)[0]
# Cache fields metadata
self._fields_metadata: dict[str, dict[str, str]] = dict(
self._root.attrs.get("fields", {"core": {}, "custom": {}})
)
def _load_sample(self, index: int) -> dict[str, torch.Tensor]:
"""Load raw data for a single sample.
Parameters
----------
index : int
Logical sample index (0 to len-1), accounting for deleted samples.
Returns
-------
dict[str, torch.Tensor]
Dictionary mapping field names to CPU tensors.
Raises
------
IndexError
If index is out of range.
"""
# Map logical index to physical index
physical_idx = int(self._active_indices[index].item())
# Get slice ranges from pointer arrays
atom_start = int(self._atoms_ptr[physical_idx].item())
atom_end = int(self._atoms_ptr[physical_idx + 1].item())
edge_start = int(self._edges_ptr[physical_idx].item())
edge_end = int(self._edges_ptr[physical_idx + 1].item())
data: dict[str, torch.Tensor] = {}
# Load core fields
core_group = self._root["core"]
for key in core_group.array_keys():
level = self._fields_metadata.get("core", {}).get(
key, _get_field_level(key)
)
arr = core_group[key]
if level == "atom":
data[key] = torch.from_numpy(arr[atom_start:atom_end])
elif level == "edge":
cat_dim = _get_cat_dim(key)
if cat_dim == -1:
# Shape is [..., E], slice on last dim
tensor = torch.from_numpy(arr[:, edge_start:edge_end])
else:
# Shape is [E, ...], slice on first dim
tensor = torch.from_numpy(arr[edge_start:edge_end])
# edge_index needs to be converted from global to local indices
# by subtracting the atom offset for this sample
if key == "edge_index":
tensor = tensor - atom_start
data[key] = tensor
else: # system level
# Keep batch dim for system-level fields
data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1])
# Load custom fields if present
if "custom" in self._root:
custom_group = self._root["custom"]
for key in custom_group.array_keys():
level = self._fields_metadata.get("custom", {}).get(key, "system")
arr = custom_group[key]
if level == "atom":
data[key] = torch.from_numpy(arr[atom_start:atom_end])
elif level == "edge":
cat_dim = _get_cat_dim(key)
if cat_dim == -1:
data[key] = torch.from_numpy(arr[:, edge_start:edge_end])
else:
data[key] = torch.from_numpy(arr[edge_start:edge_end])
else: # system level
data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1])
return data
def __len__(self) -> int:
"""Return the number of active (non-deleted) samples.
Returns
-------
int
Number of samples available for reading.
"""
return len(self._active_indices)
def _get_sample_metadata(self, index: int) -> dict[str, str]:
"""Return metadata for a sample.
Parameters
----------
index : int
Logical sample index.
Returns
-------
dict[str, str]
Dictionary containing source file information.
"""
physical_idx = int(self._active_indices[index].item())
return {
"source_file": str(self._store),
"physical_index": str(physical_idx),
}
[docs]
def close(self) -> None:
"""Release the Zarr store reference and clean up resources."""
self._root = None
super().close()