Source code for nvalchemi.data.batch

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Graph-aware Pydantic batch backed by :class:`MultiLevelStorage`.

This module provides a :class:`Batch` class that combines the Pydantic-model
interface of ``nvalchemi.data.batch.Batch`` with the performant tensor storage
of :class:`~nvalchemi.data.level_storage.MultiLevelStorage`.

Performance advantages over the Pydantic-based ``nvalchemi.data.batch.Batch``:

* **index_select** operates directly on concatenated tensors via segment
  selection -- no per-graph object reconstruction and re-batching.
* **to / clone** move / copy tensors in a single pass -- no
  ``model_dump`` / ``map_structure`` / ``model_validate`` round-trip.
* **batch / ptr** are lazily derived from ``segment_lengths`` -- never
  eagerly built or manually maintained.
* **No slices / cumsum bookkeeping** -- edge-index offsets are recovered
  from ``atoms.batch_ptr`` at unbatching time.
"""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterator, Sequence
from typing import Any

import numpy as np
import torch
from tensordict import TensorDict
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup, Work

from nvalchemi.data.atomic_data import AtomicData
from nvalchemi.data.data import DataMixin
from nvalchemi.data.level_storage import (
    LevelSchema,
    MultiLevelStorage,
    SegmentedLevelStorage,
    UniformLevelStorage,
)

_INDEX_KEYS = frozenset({"edge_index"})
_EXCLUDED_KEYS = frozenset({"batch", "ptr", "device", "dtype", "info"})


_OWN_ATTRS = frozenset({"device", "keys", "_storage", "_data_class"})


[docs] class Batch(DataMixin): """Graph-aware batch built on :class:`MultiLevelStorage`. Internally stores three attribute groups via an :class:`MultiLevelStorage`: * ``"atoms"`` (:class:`SegmentedLevelStorage`) -- node-level tensors * ``"edges"`` (:class:`SegmentedLevelStorage`) -- edge-level tensors * ``"system"`` (:class:`UniformLevelStorage`) -- graph-level tensors ``batch``, ``ptr``, ``num_nodes_list``, and ``num_edges_list`` are derived lazily from the segmented groups. Attributes ---------- device : torch.device Device of the underlying storage. keys : dict[str, set[str]] | None Level categorisation: ``{"node": ..., "edge": ..., "system": ...}``. """ def __init__( self, *, device: torch.device | str, storage: MultiLevelStorage | None = None, keys: dict[str, set[str]] | None = None, ) -> None: object.__setattr__( self, "_storage", storage if storage is not None else MultiLevelStorage() ) object.__setattr__(self, "_data_class", AtomicData) object.__setattr__( self, "device", torch.device(device) if isinstance(device, str) else device, ) object.__setattr__(self, "keys", keys) def __setattr__(self, name: str, value: Any) -> None: if name in _OWN_ATTRS: object.__setattr__(self, name, value) elif isinstance(value, torch.Tensor): self._storage[name] = value else: object.__setattr__(self, name, value) @classmethod def _construct( cls, *, device: torch.device | str, keys: dict[str, set[str]] | None, storage: MultiLevelStorage, data_class: type = AtomicData, ) -> Batch: """Fast constructor that bypasses __init__.""" batch = cls.__new__(cls) object.__setattr__(batch, "_storage", storage) object.__setattr__(batch, "_data_class", data_class) object.__setattr__( batch, "device", torch.device(device) if isinstance(device, str) else device, ) object.__setattr__(batch, "keys", keys) return batch # ------------------------------------------------------------------ # Properties derived from storage # ------------------------------------------------------------------ @property def num_graphs(self) -> int: """Number of graphs in the batch.""" return len(self._storage) @property def batch_size(self) -> int: """Alias for :attr:`num_graphs`.""" return self.num_graphs @property def num_nodes(self) -> int: """Total number of nodes across all graphs.""" atoms = self._atoms_group return atoms.num_elements() if atoms is not None else 0 @property def num_edges(self) -> int: """Total number of edges across all graphs.""" edges = self._edges_group return edges.num_elements() if edges is not None else 0 @property def system_capacity(self) -> int: """Maximum number of systems (graphs) this buffer can hold (e.g. from :meth:`empty`).""" system = self._system_group if system is None: return 0 return system._data.shape[0] @property def batch(self) -> Tensor: """Per-node graph assignment tensor (lazily computed).""" atoms = self._atoms_group if atoms is None: return torch.tensor([], dtype=torch.long, device=self.device) return atoms.batch_idx.long() @property def ptr(self) -> Tensor: """Cumulative node count per graph (lazily computed).""" atoms = self._atoms_group if atoms is None: return torch.zeros(1, dtype=torch.int32, device=self.device) return atoms.batch_ptr @property def edge_ptr(self) -> Tensor: """Per-atom CSR pointer into the edge list (N+1,), int32. Returns a tensor where ``edge_ptr[i] : edge_ptr[i+1]`` is the slice of edge rows in ``edge_index`` that belong to atom ``i`` (i.e. where atom ``i`` is the sender). Valid only after a COO-format :class:`~nvalchemi.dynamics.hooks.NeighborListHook` has populated the edges group. An all-zeros pointer of length ``num_nodes + 1`` is returned when the edges group is absent or empty. """ edges = self._edges_group if edges is None or edges.num_elements() == 0: N = self.num_nodes return torch.zeros(N + 1, dtype=torch.int32, device=self.device) # edge_index is stored as (E, 2); column 0 holds the sender (source) indices. ei = edges["edge_index"] # (E, 2) N = self.num_nodes src = ei[:, 0].long() # (E,) counts = torch.zeros(N, dtype=torch.int32, device=self.device) counts.scatter_add_( 0, src, torch.ones(src.shape[0], dtype=torch.int32, device=self.device) ) ptr = torch.zeros(N + 1, dtype=torch.int32, device=self.device) ptr[1:] = counts.cumsum(0) return ptr @property def num_nodes_list(self) -> list[int]: """Per-graph node counts as a Python list.""" atoms = self._atoms_group if atoms is None: return [] return atoms.segment_lengths[: len(atoms)].tolist() @property def num_edges_list(self) -> list[int]: """Per-graph edge counts as a Python list.""" edges = self._edges_group if edges is None: return [] return edges.segment_lengths[: len(edges)].tolist() @property def num_nodes_per_graph(self) -> Tensor: """Per-graph node counts as a tensor.""" atoms = self._atoms_group if atoms is None: return torch.tensor([], dtype=torch.long, device=self.device) return atoms.segment_lengths[: len(atoms)].long() @property def num_edges_per_graph(self) -> Tensor: """Per-graph edge counts as a tensor.""" edges = self._edges_group if edges is None: return torch.tensor([], dtype=torch.long, device=self.device) return edges.segment_lengths[: len(edges)].long() @property def max_num_nodes(self) -> int: """Maximum node count in any graph.""" nodes = self.num_nodes_list return max(nodes) if nodes else 0 # ------------------------------------------------------------------ # Internal group accessors # ------------------------------------------------------------------ @property def _atoms_group(self) -> SegmentedLevelStorage | None: g = self._storage.groups.get("atoms") return g if isinstance(g, SegmentedLevelStorage) else None @property def _edges_group(self) -> SegmentedLevelStorage | None: g = self._storage.groups.get("edges") return g if isinstance(g, SegmentedLevelStorage) else None @property def _system_group(self) -> UniformLevelStorage | None: return self._storage.groups.get("system") # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------
[docs] @classmethod def from_data_list( cls, data_list: list[AtomicData], device: torch.device | str | None = None, skip_validation: bool = False, attr_map: LevelSchema | None = None, exclude_keys: list[str] | None = None, ) -> Batch: """Construct a batch from a list of :class:`AtomicData` objects. Parameters ---------- data_list : list[AtomicData] Individual graphs to batch. device : torch.device | str, optional Target device. Inferred from *data_list* if ``None``. skip_validation : bool If ``True``, skip shape validation for speed. attr_map : LevelSchema, optional Attribute registry. Defaults to ``LevelSchema()``. exclude_keys : list[str], optional Keys to exclude from batching. Returns ------- Batch """ if not data_list: raise ValueError("Cannot create batch from empty data list") if device is None: device = data_list[0].device device = torch.device(device) if isinstance(device, str) else device if attr_map is None: attr_map = LevelSchema() data_cls = data_list[0].__class__ node_keys = data_cls.__node_keys__ edge_keys = data_cls.__edge_keys__ system_keys = data_cls.__system_keys__ excluded = _EXCLUDED_KEYS | set(exclude_keys or []) actual_keys = set(data_list[0].model_dump(exclude_none=True).keys()) - excluded node_tensors: dict[str, list[Tensor]] = defaultdict(list) edge_tensors: dict[str, list[Tensor]] = defaultdict(list) system_tensors: dict[str, list[Tensor]] = defaultdict(list) node_counts: list[int] = [] edge_counts: list[int] = [] node_offset = 0 for data in data_list: n_nodes = data.num_nodes n_edges = data.num_edges node_counts.append(n_nodes) edge_counts.append(n_edges) for key in actual_keys: value = getattr(data, key, None) if not isinstance(value, Tensor): continue value = value.to(device) if key in node_keys: node_tensors[key].append(value) elif key in edge_keys: if key in _INDEX_KEYS: value = value + node_offset edge_tensors[key].append(value) elif key in system_keys: system_tensors[key].append(value) node_offset += n_nodes atoms_data = {k: torch.cat(v, dim=0) for k, v in node_tensors.items()} edges_data: dict[str, Tensor] = {} for k, v in edge_tensors.items(): cat_dim = -1 if k in _INDEX_KEYS else 0 edges_data[k] = torch.cat(v, dim=cat_dim) if k in _INDEX_KEYS: edges_data[k] = edges_data[k].transpose(0, 1) system_data = {k: torch.cat(v, dim=0) for k, v in system_tensors.items()} validate = not skip_validation groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} if atoms_data: groups["atoms"] = SegmentedLevelStorage( data=atoms_data, device=device, segment_lengths=node_counts, validate=validate, attr_map=attr_map, ) if edges_data: groups["edges"] = SegmentedLevelStorage( data=edges_data, device=device, segment_lengths=edge_counts, validate=validate, attr_map=attr_map, ) if system_data: groups["system"] = UniformLevelStorage( data=system_data, device=device, validate=validate, attr_map=attr_map, ) storage = MultiLevelStorage(groups=groups, attr_map=attr_map, validate=validate) tracked_keys = { "node": set(node_tensors.keys()), "edge": set(edge_tensors.keys()), "system": set(system_tensors.keys()), } batch = cls._construct( device=device, keys=tracked_keys, storage=storage, data_class=data_cls, ) return batch._make_contiguous()
[docs] @classmethod def empty( cls, *, num_systems: int, num_nodes: int, num_edges: int, template: AtomicData | Batch | None = None, device: torch.device | str = "cpu", attr_map: LevelSchema | None = None, ) -> Batch: """Construct an empty batch with pre-allocated capacity (zero graphs, fixed storage). Storage tensors are allocated with the given capacities; no graphs are stored initially (``num_graphs == 0``). Use :meth:`put` to copy graphs into the buffer; pass ``dest_mask`` of shape ``(num_systems,)`` with ``False`` for empty slots. Parameters ---------- num_systems : int Maximum number of systems (graphs) the buffer can hold. num_nodes : int Total node (atom) capacity across all graphs. num_edges : int Total edge capacity across all graphs. template : AtomicData or Batch, optional Template for attribute keys and per-key shapes/dtypes. If ``None``, a minimal :class:`AtomicData` with ``positions``, ``atomic_numbers``, and ``energies`` is used. device : torch.device or str, optional Device for allocated tensors. attr_map : LevelSchema, optional Attribute registry; used when template is provided. Returns ------- Batch Batch with ``num_graphs == 0`` and capacity for the given sizes. """ if num_systems < 0 or num_nodes < 0 or num_edges < 0: raise ValueError( "num_systems, num_nodes, and num_edges must be non-negative" ) device = torch.device(device) if isinstance(device, str) else device if attr_map is None: attr_map = LevelSchema() if template is None: template = AtomicData( positions=torch.zeros(1, 3), atomic_numbers=torch.zeros(1, dtype=torch.long), energies=torch.tensor([[0.0]]), ) if isinstance(template, AtomicData): ref = cls.from_data_list([template], device=device, attr_map=attr_map) else: ref = template groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} for name, group in ref._storage.groups.items(): keys = list(group.keys()) if not keys: continue if name == "system": data = { k: torch.zeros( (num_systems,) + group[k].shape[1:], device=device, dtype=group[k].dtype, ) for k in keys } storage = UniformLevelStorage( data=data, device=device, validate=False, attr_map=attr_map ) object.__setattr__(storage, "_num_kept", 0) groups[name] = storage elif name == "atoms": data = { k: torch.zeros( (num_nodes,) + group[k].shape[1:], device=device, dtype=group[k].dtype, ) for k in keys } groups[name] = SegmentedLevelStorage( data=data, segment_lengths=torch.tensor([], device=device, dtype=torch.int32), device=device, batch_ptr_capacity=max(num_systems + 2, 2), validate=False, attr_map=attr_map, ) else: data = { k: torch.zeros( (num_edges,) + group[k].shape[1:], device=device, dtype=group[k].dtype, ) for k in keys } groups[name] = SegmentedLevelStorage( data=data, segment_lengths=torch.tensor([], device=device, dtype=torch.int32), device=device, batch_ptr_capacity=max(num_systems + 2, 2), validate=False, attr_map=attr_map, ) storage = MultiLevelStorage(groups=groups, attr_map=attr_map, validate=False) return cls._construct( device=device, keys=ref.keys, storage=storage, data_class=ref._data_class, )
[docs] def zero(self) -> None: """Reset this batch to an empty-but-allocated state. Zeros all leaf data tensors while preserving the allocated storage capacity. After calling ``zero()``, ``num_graphs`` returns 0 but ``system_capacity`` remains unchanged. This method is used to reset pre-allocated communication buffers (created via :meth:`empty`) between pipeline steps without reallocating memory. Notes ----- Modeled after :meth:`GPUBuffer.zero` in ``nvalchemi.dynamics.sinks``. Resets bookkeeping for both :class:`UniformLevelStorage` (``_num_kept``) and :class:`SegmentedLevelStorage` (``segment_lengths``, ``_batch_ptr``). Examples -------- >>> batch = Batch.empty(num_systems=10, num_nodes=100, num_edges=200) >>> batch.zero() >>> batch.num_graphs 0 >>> batch.system_capacity 10 """ for group in self._storage.groups.values(): group._data.apply_(lambda x: x.zero_()) if hasattr(group, "_num_kept"): object.__setattr__(group, "_num_kept", 0) if hasattr(group, "segment_lengths"): group.segment_lengths = torch.empty( 0, dtype=group.segment_lengths.dtype, device=group.segment_lengths.device, ) if group._batch_ptr is not None: batch_ptr_capacity = group._batch_ptr.shape[0] group._batch_ptr = torch.zeros( batch_ptr_capacity, dtype=torch.int32, device=group.device, ) if hasattr(group, "_batch_idx"): group._batch_idx = None group._batch_ptr_np = None
# ------------------------------------------------------------------ # Per-graph reconstruction # ------------------------------------------------------------------
[docs] def get_data(self, idx: int) -> AtomicData: """Reconstruct the :class:`AtomicData` object at position *idx*. Edge-index offsets applied during batching are automatically undone. Parameters ---------- idx : int Graph index (supports negative indexing). Returns ------- AtomicData """ if idx < 0: idx = self.num_graphs + idx data: dict[str, Any] = {} atoms = self._atoms_group if atoms is not None: atoms._lazy_init_batch_ptr() node_start = atoms._batch_ptr[idx].item() node_end = atoms._batch_ptr[idx + 1].item() for key, tensor in atoms.items(): data[key] = tensor[node_start:node_end] edges = self._edges_group if edges is not None and edges.num_elements() > 0: edges._lazy_init_batch_ptr() edge_start = edges._batch_ptr[idx].item() edge_end = edges._batch_ptr[idx + 1].item() node_offset = atoms._batch_ptr[idx] if atoms is not None else 0 for key, tensor in edges.items(): if key in _INDEX_KEYS: data[key] = ( tensor[edge_start:edge_end].transpose(0, 1) - node_offset ) else: data[key] = tensor[edge_start:edge_end] system = self._system_group if system is not None: for key, tensor in system.items(): data[key] = tensor[idx].unsqueeze(0) return self._data_class(**data)
[docs] def to_data_list(self) -> list[AtomicData]: """Reconstruct all individual :class:`AtomicData` objects. Returns ------- list[AtomicData] """ return [self.get_data(i) for i in range(self.num_graphs)]
# ------------------------------------------------------------------ # Selection / indexing # ------------------------------------------------------------------
[docs] def index_select( self, idx: int | slice | Tensor | list[int] | np.ndarray | Sequence[int], ) -> Batch: """Select a subset of graphs by index. Operates directly on concatenated tensors via segment selection -- no per-graph :class:`AtomicData` reconstruction. Parameters ---------- idx : int, slice, Tensor, list[int], np.ndarray, or Sequence[int] Graph-level index specification. Returns ------- Batch """ idx_list = self._normalize_index(idx) idx_tensor = torch.tensor(idx_list, dtype=torch.int32, device=self.device) new_groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} atoms = self._atoms_group offset_diff: Tensor | None = None if atoms is not None: new_atoms = atoms.select(idx_tensor) old_offsets = atoms.batch_ptr[idx_tensor] new_atoms._lazy_init_batch_ptr() new_offsets = new_atoms._batch_ptr[:-1] offset_diff = old_offsets - new_offsets new_groups["atoms"] = new_atoms edges = self._edges_group if edges is not None: new_edges = edges.select(idx_tensor) if "edge_index" in new_edges and offset_diff is not None: new_edges._lazy_init_batch_ptr() ei = new_edges["edge_index"] edge_batch_idx = new_edges.batch_idx correction = offset_diff[edge_batch_idx.long()] new_edges._data["edge_index"] = ei - correction.unsqueeze(1) new_groups["edges"] = new_edges system = self._system_group if system is not None: new_groups["system"] = system.select(idx_tensor) new_storage = MultiLevelStorage( groups=new_groups, attr_map=self._storage.attr_map, validate=False, ) return Batch._construct( device=self.device, keys={k: v.copy() for k, v in self.keys.items()} if self.keys else None, storage=new_storage, data_class=self._data_class, )
[docs] def put( self, src_batch: Batch, mask: Tensor, *, copied_mask: Tensor | None = None, dest_mask: Tensor | None = None, ) -> None: """Put graphs where mask[i] is True from src_batch into this batch (buffer). Computes per-level fit masks (system/atoms/edges), takes their logical_and as the copy mask, then puts with that mask so all levels only copy systems that fit in every level. Uses Warp buffer kernels; only float32 attributes copied. If copied_mask is provided, it is updated with the copy mask for :meth:`defrag`. Parameters ---------- src_batch : Batch Source batch; must have same groups (atoms/edges/system). mask : Tensor (num_graphs,) bool, True = consider copying this graph. copied_mask : Tensor, optional (num_graphs,) bool; if provided, modified in place with the actual copy mask (fit in all levels). If None, stored on *src_batch*. dest_mask : Tensor, optional For uniform (system) level: (len(self),) bool, True = slot occupied. If None, system level treats all slots as empty. """ device = self.device n = src_batch.num_graphs if mask.shape[0] != n: raise ValueError(f"mask shape {mask.shape[0]} != num_graphs {n}") mask = mask.to(device=device, dtype=torch.bool) if copied_mask is not None: if copied_mask.shape[0] != n: raise ValueError(f"copied_mask shape {copied_mask.shape[0]} != {n}") copy_mask = copied_mask.to(device=device, dtype=torch.bool) else: copy_mask = torch.zeros(n, device=device, dtype=torch.bool) object.__setattr__(src_batch, "_copied_mask", copy_mask) fit_mask = torch.ones(n, device=device, dtype=torch.bool) system = self._system_group src_system = src_batch._system_group if system is not None and src_system is not None: level_fit = torch.empty(n, device=device, dtype=torch.bool) system.compute_put_per_system_fit_mask( src_system, mask, dest_mask, level_fit ) fit_mask.logical_and_(level_fit) atoms = self._atoms_group src_atoms = src_batch._atoms_group if atoms is not None and src_atoms is not None: level_fit = torch.empty(n, device=device, dtype=torch.bool) atoms.compute_put_per_system_fit_mask(src_atoms, mask, None, level_fit) fit_mask.logical_and_(level_fit) edges = self._edges_group src_edges = src_batch._edges_group if edges is not None and src_edges is not None: level_fit = torch.empty(n, device=device, dtype=torch.bool) edges.compute_put_per_system_fit_mask(src_edges, mask, None, level_fit) fit_mask.logical_and_(level_fit) copy_mask.copy_(fit_mask) if system is not None and src_system is not None: system.put( src_system, copy_mask, copied_mask=copy_mask, dest_mask=dest_mask ) if atoms is not None and src_atoms is not None: atoms.put(src_atoms, copy_mask, copied_mask=copy_mask) if edges is not None and src_edges is not None: edges.put(src_edges, copy_mask, copied_mask=copy_mask)
[docs] def defrag( self, copied_mask: Tensor | None = None, ) -> Batch: """Defrag this batch in-place by removing graphs that were put. Drops graphs where copied_mask[i] is True (e.g. from a prior :meth:`put`). Uses Warp buffer kernels; one host sync per group to trim. Only float32 attributes are compacted. Parameters ---------- copied_mask : Tensor, optional (num_graphs,) bool; if None, uses stored value from last :meth:`put`. Returns ------- Self For method chaining. """ if copied_mask is None: copied_mask = getattr(self, "_copied_mask", None) if copied_mask is None: raise ValueError("defrag requires copied_mask or a prior put") system = self._system_group if system is not None: system.defrag(copied_mask=copied_mask) atoms = self._atoms_group if atoms is not None: atoms.defrag(copied_mask=copied_mask) edges = self._edges_group if edges is not None: edges.defrag(copied_mask=copied_mask) if hasattr(self, "_copied_mask"): object.__delattr__(self, "_copied_mask") return self
[docs] def trim( self, copied_mask: Tensor | None = None, ) -> Batch | None: """Remove marked graphs and return a new :class:`Batch` with tight storage. Unlike :meth:`defrag`, which compacts data to the front of pre-allocated buffers while preserving their capacity (ideal for fixed-size GPU buffers that will be reused with :meth:`put`), ``trim`` produces a brand-new :class:`Batch` whose underlying storage tensors are sized to exactly fit the remaining graphs — no padding, no unused trailing slots. Use :meth:`defrag` when you need to keep the buffer alive for further :meth:`put` / :meth:`defrag` cycles (e.g. communication buffers). Use ``trim`` when the batch will be consumed directly by a model or integrator and must have self-consistent tensor shapes across all storage groups. Parameters ---------- copied_mask : Tensor, optional ``(num_graphs,)`` boolean tensor where ``True`` marks graphs to remove. If *None*, uses the ``_copied_mask`` stored by the most recent :meth:`put`. Returns ------- Batch or None A new :class:`Batch` containing only the kept graphs with all tensors sized to exactly fit, or *None* if every graph was removed. Raises ------ ValueError If no *copied_mask* is provided and no prior :meth:`put` has stored one. See Also -------- defrag : In-place compaction that preserves buffer capacity. """ if copied_mask is None: copied_mask = getattr(self, "_copied_mask", None) if copied_mask is None: raise ValueError("trim requires copied_mask or a prior put") keep_mask = ~copied_mask if not keep_mask.any(): return None keep_indices = torch.where(keep_mask)[0] return self.index_select(keep_indices)
def _normalize_index( self, idx: int | slice | Tensor | list[int] | np.ndarray | Sequence[int], ) -> list[int]: """Convert various index types to a flat list of integer indices.""" match idx: case int(): result = [idx] case slice(): result = list(range(self.num_graphs)[idx]) case Tensor(): if idx.dtype == torch.bool: result = idx.flatten().nonzero(as_tuple=False).flatten().tolist() elif idx.dtype.is_floating_point: raise IndexError( f"Tensor index must be integer or bool, got {idx.dtype}" ) else: result = idx.flatten().tolist() case np.ndarray(): if idx.dtype == np.bool_: result = idx.flatten().nonzero()[0].flatten().tolist() else: result = idx.flatten().tolist() case list(): result = idx case _ if isinstance(idx, Sequence) and not isinstance(idx, str): result = list(idx) case _: raise IndexError(f"Unsupported index type: {type(idx).__name__}") if not result: raise IndexError("Index is empty") return [self.num_graphs + i if i < 0 else i for i in result] def __getitem__(self, key: str | int | slice | Tensor | list) -> Any: """Access an attribute by name, or select graphs by index. Parameters ---------- key : str or index Attribute name (returns tensor) or graph index (returns :class:`AtomicData` for int, :class:`Batch` for slice/tensor). """ match key: case str(): return self._get_attr(key) case int(): return self.get_data(key) case _: return self.index_select(key) def __setitem__(self, key: str, value: Any) -> None: """Set an attribute, routing to the correct group.""" self._storage[key] = value def __contains__(self, key: str) -> bool: return key in self._storage def __len__(self) -> int: return self.num_graphs def __iter__(self) -> Iterator[tuple[str, Any]]: yield from self._storage.items() def __repr__(self) -> str: return ( f"Batch(num_graphs={self.num_graphs}, " f"num_nodes={self.num_nodes}, " f"num_edges={self.num_edges}, " f"device={self.device})" ) def _get_attr(self, key: str) -> Tensor: """Look up *key* across all groups.""" for group in self._storage.groups.values(): if key in group: return group[key] raise KeyError(f"Attribute '{key}' not found in batch") def __getattr__(self, name: str) -> Any: """Delegate unknown attribute access to the storage groups.""" if name.startswith("_") or name in {"device", "keys"}: raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") try: return self._get_attr(name) except KeyError: raise AttributeError( f"'{type(self).__name__}' has no attribute '{name}'" ) from None def __delitem__(self, key: str) -> None: """Delete an attribute from the underlying storage.""" del self._storage[key] # ------------------------------------------------------------------ # Mutation # ------------------------------------------------------------------
[docs] def append(self, other: Batch) -> None: """Append another batch (in-place via concatenation). If *other* is missing a group that this batch has (e.g. system-level data), this batch's tensors in that group are extended with zeros so that the first dimension (num graphs) stays aligned. Parameters ---------- other : Batch Batch to append. """ atoms = self._atoms_group other_atoms = other._atoms_group if atoms is not None and other_atoms is not None: total_nodes = atoms.num_elements() other_edges = other._edges_group if other_edges is not None and "edge_index" in other_edges: other_edges._data["edge_index"] = ( other_edges["edge_index"] + total_nodes ) n_other = other.num_graphs for group_name, group in self._storage.groups.items(): other_group = other._storage.groups.get(group_name) if other_group is not None: group.concatenate(other_group) else: group.extend_for_appended_graphs(n_other)
[docs] def append_data( self, data_list: list[AtomicData], exclude_keys: list[str] | None = None, ) -> None: """Append individual :class:`AtomicData` objects to this batch. Parameters ---------- data_list : list[AtomicData] Data objects to append. exclude_keys : list[str], optional Keys to exclude. Raises ------ ValueError If *data_list* is empty. """ if not data_list: raise ValueError("No data provided to append.") other = Batch.from_data_list( data_list, device=self.device, exclude_keys=exclude_keys, ) self.append(other)
[docs] def add_key( self, key: str, values: list[Tensor], level: str = "node", overwrite: bool = False, ) -> None: """Add a new key-value pair to the batch. Parameters ---------- key : str Name of the new attribute. values : list[Tensor] One value per graph. level : str One of ``"node"``, ``"edge"``, ``"system"``. overwrite : bool If ``True``, overwrite existing keys. Raises ------ ValueError If key exists and *overwrite* is ``False``, or if the number of values does not match the batch size. """ if key in self._storage and not overwrite: raise ValueError( f"Key '{key}' already exists in batch. " "Set overwrite=True to replace existing values." ) if len(values) != self.num_graphs: raise ValueError( f"Number of values ({len(values)}) must match " f"number of graphs in batch ({self.num_graphs})" ) device = self.device values = [v.to(device) if isinstance(v, Tensor) else v for v in values] group_name = {"node": "atoms", "edge": "edges", "system": "system"}.get( level, "atoms" ) group = self._storage.groups.get(group_name) if group is None: raise ValueError(f"Group '{group_name}' not found in batch") if level == "system": # squeeze (1, *trailing) per-graph to (num_graphs, *trailing) squeezed = [ v.squeeze(0) if v.dim() >= 1 and v.shape[0] == 1 else v for v in values ] group._data[key] = torch.stack(squeezed, dim=0) else: group._data[key] = torch.cat(values, dim=0) if self.keys is not None: self.keys[level].add(key)
# ------------------------------------------------------------------ # DataMixin overrides (performance-critical) # ------------------------------------------------------------------
[docs] def to( self, device: torch.device | str, dtype: torch.dtype | None = None, non_blocking: bool = False, ) -> Batch: """Move all tensors to *device*. Overrides :meth:`DataMixin.to` for performance: delegates to :meth:`MultiLevelStorage.to_device` instead of the ``model_dump`` / ``map_structure`` / ``model_validate`` round-trip. Parameters ---------- device : torch.device | str Target device. dtype : torch.dtype, optional Ignored (present for API compatibility). non_blocking : bool Ignored (present for API compatibility). Returns ------- Batch """ new = self.clone() new._storage.to_device(device) new.device = torch.device(device) if isinstance(device, str) else device return new
[docs] def clone(self) -> Batch: """Return a deep copy. Overrides :meth:`DataMixin.clone` for performance. Returns ------- Batch """ return Batch._construct( device=self.device, keys={k: v.copy() for k, v in self.keys.items()} if self.keys else None, storage=self._storage.clone(), data_class=self._data_class, )
[docs] def cpu(self) -> Batch: """Return a copy on CPU.""" return self.to("cpu")
[docs] def cuda(self, device: int | None = None, non_blocking: bool = False) -> Batch: """Return a copy on CUDA.""" dev = f"cuda:{device}" if device is not None else "cuda" return self.to(dev)
[docs] def contiguous(self) -> Batch: """Ensure contiguous memory layout for all tensors. Returns ------- Self For method chaining. """ self._make_contiguous() return self
[docs] def pin_memory(self) -> Batch: """Pin all tensors to page-locked memory. Returns ------- Self For method chaining. """ for group in self._storage.groups.values(): for key, tensor in list(group.items()): group._data[key] = tensor.pin_memory() return self
def _make_contiguous(self) -> Batch: """Ensure all tensors are contiguous. Returns self for chaining.""" for group in self._storage.groups.values(): for key, tensor in list(group.items()): if not tensor.is_contiguous(): group._data[key] = tensor.contiguous() return self # ------------------------------------------------------------------ # Custom serialization # ------------------------------------------------------------------
[docs] def model_dump(self, **kwargs: Any) -> dict[str, Any]: """Serialize the batch into a flat dictionary. Collects all tensors from the underlying :class:`MultiLevelStorage` groups, plus metadata fields (``device``, ``keys``, ``batch``, ``ptr``, ``num_nodes_list``, ``num_edges_list``, ``num_graphs``). Returns ------- dict[str, Any] """ result: dict[str, Any] = { "device": self.device, "keys": self.keys, "batch": self.batch, "ptr": self.ptr, "num_graphs": self.num_graphs, "num_nodes_list": self.num_nodes_list, "num_edges_list": self.num_edges_list, } result.update( { key: tensor for group in self._storage.groups.values() for key, tensor in group.items() } ) exclude_none = kwargs.get("exclude_none", False) if exclude_none: result = {k: v for k, v in result.items() if v is not None} return result
# ------------------------------------------------------------------ # Distributed communication # ------------------------------------------------------------------
[docs] def isend( self, dst: int, *, tag: int = 0, group: ProcessGroup | None = None, ) -> _BatchSendHandle: """Non-blocking send of this batch to *dst*. Transmits a 3-int metadata header (``num_graphs``, ``num_nodes``, ``num_edges``), per-group segment lengths for segmented groups, and the bulk tensor data via ``TensorDict.isend()``. Parameters ---------- dst : int Destination rank. tag : int Base message tag. Incremented deterministically per group. group : ProcessGroup, optional Process group. ``None`` uses the default group. Returns ------- _BatchSendHandle Handle whose ``.wait()`` blocks until all sends complete. """ handles: list[Work | list[Work] | int | None] = [] meta = torch.tensor( [self.num_graphs, self.num_nodes, self.num_edges], dtype=torch.int64, device=self.device, ) handles.append(dist.isend(meta, dst=dst, tag=tag, group=group)) tag_offset = 1 if self.num_graphs == 0: return _BatchSendHandle(handles) for name in ("atoms", "edges"): grp = self._storage.groups.get(name) if grp is not None and isinstance(grp, SegmentedLevelStorage): seg_len = grp.segment_lengths[: self.num_graphs].contiguous() handles.append( dist.isend(seg_len, dst=dst, tag=tag + tag_offset, group=group) ) tag_offset += 1 for name in ("atoms", "edges", "system"): grp = self._storage.groups.get(name) if grp is None: tag_offset += 1 continue if isinstance(grp, SegmentedLevelStorage): n = grp.num_elements() else: n = self.num_graphs occupied_td = grp._data[:n] result = occupied_td.isend( dst=dst, init_tag=tag + tag_offset, group=group, return_early=True, ) if isinstance(result, list): handles.extend(result) else: handles.append(result) tag_offset += len(list(grp.keys())) + 1 return _BatchSendHandle(handles)
[docs] @classmethod def irecv( cls, src: int, device: torch.device | str, *, template: Batch | None = None, tag: int = 0, group: ProcessGroup | None = None, ) -> _BatchRecvHandle: """Non-blocking receive of a batch from *src*. Posts non-blocking receives for the metadata header, then returns a :class:`_BatchRecvHandle` whose ``.wait()`` blocks until all data arrives and reconstructs a :class:`Batch`. Parameters ---------- src : int Source rank. device : torch.device | str Device to receive tensors onto. template : Batch, optional Template batch providing attribute keys, dtypes, and group structure. Required for the first receive; may be cached by the caller for subsequent calls. tag : int Base message tag. group : ProcessGroup, optional Process group. Returns ------- _BatchRecvHandle Handle whose ``.wait()`` returns the received :class:`Batch`. """ device = torch.device(device) if isinstance(device, str) else device meta = torch.empty(3, dtype=torch.int64, device=device) meta_handle = dist.irecv(meta, src=src, tag=tag, group=group) return _BatchRecvHandle( meta=meta, meta_handle=meta_handle, src=src, device=device, template=template, base_tag=tag, group=group, )
[docs] def send( self, dst: int, *, tag: int = 0, group: ProcessGroup | None = None, ) -> None: """Blocking send to *dst*. Equivalent to ``self.isend(dst, tag=tag, group=group).wait()``. Parameters ---------- dst : int Destination rank. tag : int Base message tag. group : ProcessGroup, optional Process group. """ self.isend(dst=dst, tag=tag, group=group).wait()
[docs] @classmethod def recv( cls, src: int, device: torch.device | str, *, template: Batch | None = None, tag: int = 0, group: ProcessGroup | None = None, ) -> Batch: """Blocking receive from *src*. Equivalent to ``cls.irecv(src, device, ...).wait()``. Parameters ---------- src : int Source rank. device : torch.device | str Device to receive tensors onto. template : Batch, optional Template batch. tag : int Base message tag. group : ProcessGroup, optional Process group. Returns ------- Batch """ return cls.irecv( src=src, device=device, template=template, tag=tag, group=group, ).wait()
[docs] @classmethod def empty_like( cls, batch: Batch, *, device: torch.device | str | None = None, ) -> Batch: """Create an empty batch (0 graphs) with the same schema as *batch*. Parameters ---------- batch : Batch Template batch for attribute keys and dtypes. device : torch.device | str, optional Device for the new batch. Defaults to ``batch.device``. Returns ------- Batch A batch with ``num_graphs == 0``. """ dev = device if device is not None else batch.device return cls.empty( num_systems=0, num_nodes=0, num_edges=0, template=batch, device=dev, )
# ====================================================================== # Distributed communication handle classes # ====================================================================== class _BatchSendHandle: """Aggregates multiple async distributed send handles. Calling ``.wait()`` blocks until all underlying sends have completed. Parameters ---------- handles : list A list of ``torch.distributed.Work`` objects (or ``int`` / ``None`` values which are silently skipped). """ def __init__(self, handles: list) -> None: self._handles = handles def wait(self) -> None: """Block until all sends complete.""" for h in self._handles: if h is not None and hasattr(h, "wait"): h.wait() class _BatchRecvHandle: """Deferred receive that reconstructs a :class:`Batch` on ``.wait()``. Created by :meth:`Batch.irecv`. The metadata header receive is already posted; ``.wait()`` blocks on it, then posts and completes the segment-length and bulk-data receives. Parameters ---------- meta : Tensor Pre-allocated ``(3,)`` int64 tensor for the metadata header. meta_handle : Work Async receive handle for *meta*. src : int Source rank. device : torch.device Device to receive tensors onto. template : Batch | None Template batch for attribute keys and dtypes. base_tag : int Base message tag (must match sender's *tag*). group : ProcessGroup | None Process group. """ def __init__( self, *, meta: Tensor, meta_handle: Work, src: int, device: torch.device, template: Batch | None, base_tag: int, group: ProcessGroup | None, ) -> None: self._meta = meta self._meta_handle = meta_handle self._src = src self._device = device self._template = template self._base_tag = base_tag self._group = group def wait(self) -> Batch: """Block until all data arrives and return the received :class:`Batch`. Returns ------- Batch The reconstructed batch. If the sender sent a sentinel (0-graph batch), returns ``Batch.empty(...)`` with 0 capacity. """ self._meta_handle.wait() num_graphs, num_nodes, num_edges = self._meta.tolist() num_graphs = int(num_graphs) num_nodes = int(num_nodes) num_edges = int(num_edges) tag_offset = 1 if num_graphs == 0: if self._template is not None: return Batch.empty( num_systems=0, num_nodes=0, num_edges=0, template=self._template, device=self._device, ) return Batch(device=self._device) handles: list = [] atoms_seg: Tensor | None = None edges_seg: Tensor | None = None if self._template is not None: atoms_grp = self._template._storage.groups.get("atoms") if atoms_grp is not None and isinstance(atoms_grp, SegmentedLevelStorage): atoms_seg = torch.empty( num_graphs, dtype=torch.int32, device=self._device ) handles.append( dist.irecv( atoms_seg, src=self._src, tag=self._base_tag + tag_offset, group=self._group, ) ) tag_offset += 1 if self._template is not None: edges_grp = self._template._storage.groups.get("edges") if edges_grp is not None and isinstance(edges_grp, SegmentedLevelStorage): edges_seg = torch.empty( num_graphs, dtype=torch.int32, device=self._device ) handles.append( dist.irecv( edges_seg, src=self._src, tag=self._base_tag + tag_offset, group=self._group, ) ) tag_offset += 1 groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} attr_map = ( self._template._storage.attr_map if self._template is not None else LevelSchema() ) for name, capacity, seg_lens in [ ("atoms", num_nodes, atoms_seg), ("edges", num_edges, edges_seg), ("system", num_graphs, None), ]: template_grp = ( self._template._storage.groups.get(name) if self._template is not None else None ) if template_grp is None: tag_offset += ( (len(list(template_grp.keys())) + 1) if template_grp is not None else 1 ) continue keys = list(template_grp.keys()) if not keys: tag_offset += 1 continue recv_data = {} for k in keys: ref_tensor = template_grp[k] trailing_shape = ref_tensor.shape[1:] recv_data[k] = torch.empty( (capacity,) + trailing_shape, dtype=ref_tensor.dtype, device=self._device, ) recv_td = TensorDict(recv_data, batch_size=[capacity], device=self._device) td_handles = recv_td.irecv( src=self._src, init_tag=self._base_tag + tag_offset, group=self._group, return_premature=True, ) if isinstance(td_handles, list): handles.extend(td_handles) else: handles.append(td_handles) tag_offset += len(keys) + 1 if name == "system": storage = UniformLevelStorage( data={k: recv_td[k] for k in keys}, device=self._device, validate=False, attr_map=attr_map, ) groups[name] = storage else: if seg_lens is None: continue storage = SegmentedLevelStorage( data={k: recv_td[k] for k in keys}, segment_lengths=seg_lens, device=self._device, validate=False, attr_map=attr_map, ) groups[name] = storage for h in handles: if h is not None and hasattr(h, "wait"): h.wait() mls = MultiLevelStorage(groups=groups, attr_map=attr_map, validate=False) return Batch._construct( device=self._device, keys=( {k: v.copy() for k, v in self._template.keys.items()} if self._template is not None and self._template.keys is not None else None ), storage=mls, data_class=( self._template._data_class if self._template is not None else AtomicData ), )