nvalchemi.data.Batch#

class nvalchemi.data.Batch(*, device, storage=None, keys=None)[source]#

Graph-aware batch built on MultiLevelStorage.

Internally stores three attribute groups via an MultiLevelStorage:

  • "atoms" (SegmentedLevelStorage) – node-level tensors

  • "edges" (SegmentedLevelStorage) – edge-level tensors

  • "system" (UniformLevelStorage) – graph-level tensors

batch, ptr, num_nodes_list, and num_edges_list are derived lazily from the segmented groups.

Parameters:
  • device (torch.device | str)

  • storage (MultiLevelStorage | None)

  • keys (dict[str, set[str]] | None)

device#

Device of the underlying storage.

Type:

torch.device

keys#

Level categorisation: {"node": ..., "edge": ..., "system": ...}.

Type:

dict[str, set[str]] | None

add_key(key, values, level='node', overwrite=False)[source]#

Add a new key-value pair to the batch.

Parameters:
  • key (str) – Name of the new attribute.

  • values (list[Tensor]) – One value per graph.

  • level (str) – One of "node", "edge", "system".

  • overwrite (bool) – If True, overwrite existing keys.

Raises:

ValueError – If key exists and overwrite is False, or if the number of values does not match the batch size.

Return type:

None

append(other)[source]#

Append another batch (in-place via concatenation).

If other is missing a group that this batch has (e.g. system-level data), this batch’s tensors in that group are extended with zeros so that the first dimension (num graphs) stays aligned.

Parameters:

other (Batch) – Batch to append.

Return type:

None

append_data(data_list, exclude_keys=None)[source]#

Append individual AtomicData objects to this batch.

Parameters:
  • data_list (list[AtomicData]) – Data objects to append.

  • exclude_keys (list[str], optional) – Keys to exclude.

Raises:

ValueError – If data_list is empty.

Return type:

None

property batch: Tensor#

Per-node graph assignment tensor (lazily computed).

property batch_size: int#

Alias for num_graphs.

clone()[source]#

Return a deep copy.

Overrides DataMixin.clone() for performance.

Return type:

Batch

contiguous()[source]#

Ensure contiguous memory layout for all tensors.

Returns:

For method chaining.

Return type:

Self

cpu()[source]#

Return a copy on CPU.

Return type:

Batch

cuda(device=None, non_blocking=False)[source]#

Return a copy on CUDA.

Parameters:
  • device (int | None)

  • non_blocking (bool)

Return type:

Batch

defrag(copied_mask=None)[source]#

Defrag this batch in-place by removing graphs that were put.

Drops graphs where copied_mask[i] is True (e.g. from a prior put()). Uses Warp buffer kernels; one host sync per group to trim. Only float32 attributes are compacted.

Parameters:

copied_mask (Tensor, optional) – (num_graphs,) bool; if None, uses stored value from last put().

Returns:

For method chaining.

Return type:

Self

property edge_ptr: Tensor#

Per-atom CSR pointer into the edge list (N+1,), int32.

Returns a tensor where edge_ptr[i] : edge_ptr[i+1] is the slice of edge rows in edge_index that belong to atom i (i.e. where atom i is the sender). Valid only after a COO-format NeighborListHook has populated the edges group.

An all-zeros pointer of length num_nodes + 1 is returned when the edges group is absent or empty.

classmethod empty(*, num_systems, num_nodes, num_edges, template=None, device='cpu', attr_map=None)[source]#

Construct an empty batch with pre-allocated capacity (zero graphs, fixed storage).

Storage tensors are allocated with the given capacities; no graphs are stored initially (num_graphs == 0). Use put() to copy graphs into the buffer; pass dest_mask of shape (num_systems,) with False for empty slots.

Parameters:
  • num_systems (int) – Maximum number of systems (graphs) the buffer can hold.

  • num_nodes (int) – Total node (atom) capacity across all graphs.

  • num_edges (int) – Total edge capacity across all graphs.

  • template (AtomicData or Batch, optional) – Template for attribute keys and per-key shapes/dtypes. If None, a minimal AtomicData with positions, atomic_numbers, and energies is used.

  • device (torch.device or str, optional) – Device for allocated tensors.

  • attr_map (LevelSchema, optional) – Attribute registry; used when template is provided.

Returns:

Batch with num_graphs == 0 and capacity for the given sizes.

Return type:

Batch

classmethod empty_like(batch, *, device=None)[source]#

Create an empty batch (0 graphs) with the same schema as batch.

Parameters:
  • batch (Batch) – Template batch for attribute keys and dtypes.

  • device (torch.device | str, optional) – Device for the new batch. Defaults to batch.device.

Returns:

A batch with num_graphs == 0.

Return type:

Batch

classmethod from_data_list(data_list, device=None, skip_validation=False, attr_map=None, exclude_keys=None)[source]#

Construct a batch from a list of AtomicData objects.

Parameters:
  • data_list (list[AtomicData]) – Individual graphs to batch.

  • device (torch.device | str, optional) – Target device. Inferred from data_list if None.

  • skip_validation (bool) – If True, skip shape validation for speed.

  • attr_map (LevelSchema, optional) – Attribute registry. Defaults to LevelSchema().

  • exclude_keys (list[str], optional) – Keys to exclude from batching.

Return type:

Batch

get_data(idx)[source]#

Reconstruct the AtomicData object at position idx.

Edge-index offsets applied during batching are automatically undone.

Parameters:

idx (int) – Graph index (supports negative indexing).

Return type:

AtomicData

index_select(idx)[source]#

Select a subset of graphs by index.

Operates directly on concatenated tensors via segment selection – no per-graph AtomicData reconstruction.

Parameters:

idx (int, slice, Tensor, list[int], np.ndarray, or Sequence[int]) – Graph-level index specification.

Return type:

Batch

classmethod irecv(src, device, *, template=None, tag=0, group=None)[source]#

Non-blocking receive of a batch from src.

Posts non-blocking receives for the metadata header, then returns a _BatchRecvHandle whose .wait() blocks until all data arrives and reconstructs a Batch.

Parameters:
  • src (int) – Source rank.

  • device (torch.device | str) – Device to receive tensors onto.

  • template (Batch, optional) – Template batch providing attribute keys, dtypes, and group structure. Required for the first receive; may be cached by the caller for subsequent calls.

  • tag (int) – Base message tag.

  • group (ProcessGroup, optional) – Process group.

Returns:

Handle whose .wait() returns the received Batch.

Return type:

_BatchRecvHandle

isend(dst, *, tag=0, group=None)[source]#

Non-blocking send of this batch to dst.

Transmits a 3-int metadata header (num_graphs, num_nodes, num_edges), per-group segment lengths for segmented groups, and the bulk tensor data via TensorDict.isend().

Parameters:
  • dst (int) – Destination rank.

  • tag (int) – Base message tag. Incremented deterministically per group.

  • group (ProcessGroup, optional) – Process group. None uses the default group.

Returns:

Handle whose .wait() blocks until all sends complete.

Return type:

_BatchSendHandle

property max_num_nodes: int#

Maximum node count in any graph.

model_dump(**kwargs)[source]#

Serialize the batch into a flat dictionary.

Collects all tensors from the underlying MultiLevelStorage groups, plus metadata fields (device, keys, batch, ptr, num_nodes_list, num_edges_list, num_graphs).

Return type:

dict[str, Any]

Parameters:

kwargs (Any)

property num_edges: int#

Total number of edges across all graphs.

property num_edges_list: list[int]#

Per-graph edge counts as a Python list.

property num_edges_per_graph: Tensor#

Per-graph edge counts as a tensor.

property num_graphs: int#

Number of graphs in the batch.

property num_nodes: int#

Total number of nodes across all graphs.

property num_nodes_list: list[int]#

Per-graph node counts as a Python list.

property num_nodes_per_graph: Tensor#

Per-graph node counts as a tensor.

pin_memory()[source]#

Pin all tensors to page-locked memory.

Returns:

For method chaining.

Return type:

Self

property ptr: Tensor#

Cumulative node count per graph (lazily computed).

put(src_batch, mask, *, copied_mask=None, dest_mask=None)[source]#

Put graphs where mask[i] is True from src_batch into this batch (buffer).

Computes per-level fit masks (system/atoms/edges), takes their logical_and as the copy mask, then puts with that mask so all levels only copy systems that fit in every level. Uses Warp buffer kernels; only float32 attributes copied. If copied_mask is provided, it is updated with the copy mask for defrag().

Parameters:
  • src_batch (Batch) – Source batch; must have same groups (atoms/edges/system).

  • mask (Tensor) – (num_graphs,) bool, True = consider copying this graph.

  • copied_mask (Tensor, optional) – (num_graphs,) bool; if provided, modified in place with the actual copy mask (fit in all levels). If None, stored on src_batch.

  • dest_mask (Tensor, optional) – For uniform (system) level: (len(self),) bool, True = slot occupied. If None, system level treats all slots as empty.

Return type:

None

classmethod recv(src, device, *, template=None, tag=0, group=None)[source]#

Blocking receive from src.

Equivalent to cls.irecv(src, device, ...).wait().

Parameters:
  • src (int) – Source rank.

  • device (torch.device | str) – Device to receive tensors onto.

  • template (Batch, optional) – Template batch.

  • tag (int) – Base message tag.

  • group (ProcessGroup, optional) – Process group.

Return type:

Batch

send(dst, *, tag=0, group=None)[source]#

Blocking send to dst.

Equivalent to self.isend(dst, tag=tag, group=group).wait().

Parameters:
  • dst (int) – Destination rank.

  • tag (int) – Base message tag.

  • group (ProcessGroup, optional) – Process group.

Return type:

None

property system_capacity: int#

Maximum number of systems (graphs) this buffer can hold (e.g. from empty()).

to(device, dtype=None, non_blocking=False)[source]#

Move all tensors to device.

Overrides DataMixin.to() for performance: delegates to MultiLevelStorage.to_device() instead of the model_dump / map_structure / model_validate round-trip.

Parameters:
  • device (torch.device | str) – Target device.

  • dtype (torch.dtype, optional) – Ignored (present for API compatibility).

  • non_blocking (bool) – Ignored (present for API compatibility).

Return type:

Batch

to_data_list()[source]#

Reconstruct all individual AtomicData objects.

Return type:

list[AtomicData]

trim(copied_mask=None)[source]#

Remove marked graphs and return a new Batch with tight storage.

Unlike defrag(), which compacts data to the front of pre-allocated buffers while preserving their capacity (ideal for fixed-size GPU buffers that will be reused with put()), trim produces a brand-new Batch whose underlying storage tensors are sized to exactly fit the remaining graphs — no padding, no unused trailing slots.

Use defrag() when you need to keep the buffer alive for further put() / defrag() cycles (e.g. communication buffers). Use trim when the batch will be consumed directly by a model or integrator and must have self-consistent tensor shapes across all storage groups.

Parameters:

copied_mask (Tensor, optional) – (num_graphs,) boolean tensor where True marks graphs to remove. If None, uses the _copied_mask stored by the most recent put().

Returns:

A new Batch containing only the kept graphs with all tensors sized to exactly fit, or None if every graph was removed.

Return type:

Batch or None

Raises:

ValueError – If no copied_mask is provided and no prior put() has stored one.

See also

defrag

In-place compaction that preserves buffer capacity.

zero()[source]#

Reset this batch to an empty-but-allocated state.

Zeros all leaf data tensors while preserving the allocated storage capacity. After calling zero(), num_graphs returns 0 but system_capacity remains unchanged.

This method is used to reset pre-allocated communication buffers (created via empty()) between pipeline steps without reallocating memory.

Notes

Modeled after GPUBuffer.zero() in nvalchemi.dynamics.sinks. Resets bookkeeping for both UniformLevelStorage (_num_kept) and SegmentedLevelStorage (segment_lengths, _batch_ptr).

Examples

>>> batch = Batch.empty(num_systems=10, num_nodes=100, num_edges=200)
>>> batch.zero()
>>> batch.num_graphs
0
>>> batch.system_capacity
10
Return type:

None