AtomicData and Batch: Graph-structured molecular data#

This example walks through the full API of AtomicData and Batch: construction, properties, indexing, mutation, device movement, and serialization.

import torch

from nvalchemi.data import AtomicData, Batch

AtomicData — Construction#

AtomicData requires positions (shape [n_nodes, 3]) and atomic_numbers (shape [n_nodes]). All other fields are optional.

positions = torch.randn(4, 3)
atomic_numbers = torch.tensor([1, 6, 6, 1], dtype=torch.long)
data = AtomicData(positions=positions, atomic_numbers=atomic_numbers)

# With edges (e.g. bonds or neighbor list): provide ``edge_index`` shape ``[2, n_edges]``.
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
data_with_edges = AtomicData(
    positions=positions,
    atomic_numbers=atomic_numbers,
    edge_index=edge_index,
)
print(f"With edges: num_edges={data_with_edges.num_edges}")

# With system-level fields (e.g. energy, cell, pbc for periodicity):
data_with_system = AtomicData(
    positions=positions,
    atomic_numbers=atomic_numbers,
    energies=torch.tensor([[0.5]]),
    cell=torch.eye(3).unsqueeze(0),
    pbc=torch.tensor([[True, True, False]]),
)
print(f"System energies shape: {data_with_system.energies.shape}")
With edges: num_edges=4
System energies shape: torch.Size([1, 1])

AtomicData — Properties#

Core properties: num_nodes, num_edges, device, dtype.

print(f"num_nodes={data.num_nodes}, num_edges={data.num_edges}")
print(f"device={data.device}, dtype={data.dtype}")

# Level-wise property views (dicts of set fields):
# :attr:`~nvalchemi.data.AtomicData.node_properties`,
# :attr:`~nvalchemi.data.AtomicData.edge_properties`,
# :attr:`~nvalchemi.data.AtomicData.system_properties`.
print("node_properties keys:", list(data.node_properties.keys()))
print("system_properties keys:", list(data_with_system.system_properties.keys()))
num_nodes=4, num_edges=0
device=cpu, dtype=torch.float32
node_properties keys: ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories', 'velocities']
system_properties keys: ['cell', 'pbc', 'energies']

AtomicData — Dict-like access and mutation#

Use __getitem__() / __setitem__() for attribute access by name.

assert data["positions"] is data.positions
data["positions"] = torch.randn(4, 3)
assert data.positions.shape == (4, 3)

# Add custom node/edge/system properties with
# :meth:`~nvalchemi.data.AtomicData.add_node_property`,
# :meth:`~nvalchemi.data.AtomicData.add_edge_property`,
# :meth:`~nvalchemi.data.AtomicData.add_system_property`.
data.add_node_property("custom_node_feat", torch.randn(4, 2))
data_with_edges.add_edge_property("edge_weights", torch.ones(data_with_edges.num_edges))
data_with_system.add_system_property("temperature", torch.tensor([[300.0]]))
print(
    "After add_*_property, 'custom_node_feat' in node_properties:",
    "custom_node_feat" in data.node_properties,
)
After add_*_property, 'custom_node_feat' in node_properties: True

AtomicData — Chemical hash and equality#

chemical_hash gives a structure/composition hash; __eq__() compares by chemical hash.

h = data.chemical_hash
print(f"chemical_hash length: {len(h)}")
data2 = AtomicData(
    positions=data.positions.clone(), atomic_numbers=data.atomic_numbers.clone()
)
print(f"Same structure equal: {data == data2}")
chemical_hash length: 64
Same structure equal: True

AtomicData — Device and clone#

to() and clone() (and .cpu() / .cuda() from the mixin) for device movement and copying.

on_cpu = data.to("cpu")
cloned = data.clone()
print(f"to('cpu').device: {on_cpu.device}, clone is new object: {cloned is not data}")
to('cpu').device: cpu, clone is new object: True

AtomicData — Serialization#

Pydantic serialization: model_dump() and model_dump_json() (tensors become lists in JSON).

data_vanilla = AtomicData(
    positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long)
)
d = data_vanilla.model_dump(exclude_none=True)
print("model_dump keys (sample):", list(d.keys())[:4])
json_str = data_vanilla.model_dump_json()
print(f"model_dump_json length: {len(json_str)}")
model_dump keys (sample): ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories']
model_dump_json length: 701

Batch — Construction#

Build a Batch with from_data_list(). Optionally pass device or exclude_keys to omit certain attributes.

data_list = [
    AtomicData(
        positions=torch.randn(2, 3),
        atomic_numbers=torch.ones(2, dtype=torch.long),
        energies=torch.tensor([[0.0]]),
    ),
    AtomicData(
        positions=torch.randn(3, 3),
        atomic_numbers=torch.ones(3, dtype=torch.long),
        energies=torch.tensor([[0.0]]),
    ),
    AtomicData(
        positions=torch.randn(1, 3),
        atomic_numbers=torch.ones(1, dtype=torch.long),
        energies=torch.tensor([[0.0]]),
    ),
]
batch = Batch.from_data_list(data_list)

# exclude_keys: e.g. skip a key when batching
data_with_extra = AtomicData(
    positions=torch.randn(2, 3),
    atomic_numbers=torch.ones(2, dtype=torch.long),
)
data_with_extra.add_node_property("skip_me", torch.zeros(2, 1))
batch_slim = Batch.from_data_list([data_with_extra], exclude_keys=["skip_me"])
print(f"Batch num_graphs={batch.num_graphs}, num_nodes={batch.num_nodes}")
Batch num_graphs=3, num_nodes=6

Batch — Size and shape properties#

print(f"num_graphs={batch.num_graphs}, batch_size={batch.batch_size}")
print(f"num_nodes_list={batch.num_nodes_list}, num_edges_list={batch.num_edges_list}")
print(
    f"batch (graph index per node) shape: {batch.batch.shape}, ptr: {batch.ptr.tolist()}"
)
print(f"max_num_nodes={batch.max_num_nodes}")
num_graphs=3, batch_size=3
num_nodes_list=[2, 3, 1], num_edges_list=[]
batch (graph index per node) shape: torch.Size([6]), ptr: [0, 2, 5, 6]
max_num_nodes=3

Batch — Reconstructing graphs#

first = batch.get_data(0)
last = batch.get_data(-1)
all_graphs = batch.to_data_list()
print(
    f"get_data(0).num_nodes={first.num_nodes}, get_data(-1).num_nodes={last.num_nodes}"
)
print(f"len(to_data_list())={len(all_graphs)}")
get_data(0).num_nodes=2, get_data(-1).num_nodes=1
len(to_data_list())=3

Batch — Indexing (single graph, sub-batch, attribute)#

one = batch[0]
sub = batch[1:3]
sub2 = batch[torch.tensor([0, 2])]
sub3 = batch[[0, 2]]
mask = torch.tensor([True, False, True])
sub4 = batch[mask]
positions_tensor = batch["positions"]
print(f"batch[0] num_nodes={one.num_nodes}, batch[1:3] num_graphs={len(sub)}")
print(f"batch[[0,2]] num_nodes_list={sub3.num_nodes_list}")
print(f"batch['positions'].shape={positions_tensor.shape}")
batch[0] num_nodes=2, batch[1:3] num_graphs=2
batch[[0,2]] num_nodes_list=[2, 1]
batch['positions'].shape=torch.Size([6, 3])

Batch — Containment, length, iteration#

print(f"'positions' in batch: {'positions' in batch}")
print(f"len(batch)={len(batch)}")
keys_from_iter = [k for k, _ in batch]
print(f"Keys from iteration (sample): {keys_from_iter[:3]}")
'positions' in batch: True
len(batch)=3
Keys from iteration (sample): ['atomic_numbers', 'velocities', 'atom_categories']

Batch — Setting attributes and adding keys#

batch.add_key(
    "node_feat",
    [torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)],
    level="node",
)
batch.add_key(
    "temperature",
    [torch.tensor([[300.0]]), torch.tensor([[350.0]]), torch.tensor([[400.0]])],
    level="system",
)
data_a = AtomicData(
    positions=torch.randn(2, 3),
    atomic_numbers=torch.ones(2, dtype=torch.long),
    edge_index=torch.tensor([[0], [1]], dtype=torch.long),
)
data_b = AtomicData(
    positions=torch.randn(3, 3),
    atomic_numbers=torch.ones(3, dtype=torch.long),
    edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long),
)
batch_with_edges = Batch.from_data_list([data_a, data_b])
batch_with_edges.add_key(
    "edge_attr",
    [torch.randn(1, 4), torch.randn(2, 4)],
    level="edge",
)
print(
    f"After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: {'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges}"
)
After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: (True, True, True)

Batch — Append and append_data#

extra = Batch.from_data_list(
    [
        AtomicData(
            positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long)
        ),
    ]
)
batch.append(extra)
print(f"After append: num_graphs={batch.num_graphs}")

batch.append_data(
    [
        AtomicData(
            positions=torch.randn(1, 3), atomic_numbers=torch.ones(1, dtype=torch.long)
        ),
    ]
)
print(
    f"After append_data: num_graphs={batch.num_graphs}, num_nodes_list={batch.num_nodes_list}"
)
After append: num_graphs=4
After append_data: num_graphs=5, num_nodes_list=[2, 3, 1, 2, 1]

Batch — put and defrag#

def _tiny_graph(energy: float):
    return AtomicData(
        positions=torch.randn(2, 3),
        atomic_numbers=torch.ones(2, dtype=torch.long),
        energies=torch.tensor([[energy]]),
    )


buffer = Batch.empty(
    num_systems=40, num_nodes=80, num_edges=80, template=_tiny_graph(0.0)
)
print(
    f"Empty buffer: num_graphs={buffer.num_graphs}, system_capacity={buffer.system_capacity}"
)

src_batch = Batch.from_data_list([_tiny_graph(1.0), _tiny_graph(2.0)])
mask = torch.tensor([True, False])
copied_mask = torch.zeros(2, dtype=torch.bool)
dest_mask = torch.zeros(buffer.system_capacity, dtype=torch.bool)
buffer.put(src_batch, mask, copied_mask=copied_mask, dest_mask=dest_mask)
print(
    f"After put: buffer has {buffer.num_graphs} graphs; copied_mask={copied_mask.tolist()}"
)

src_batch.defrag(copied_mask=copied_mask)
print(f"After defrag: src_batch has {src_batch.num_graphs} graph(s)")
print(f"Remaining graph energy: {src_batch['energies']}")
Empty buffer: num_graphs=0, system_capacity=40
After put: buffer has 1 graphs; copied_mask=[True, False]
After defrag: src_batch has 1 graph(s)
Remaining graph energy: tensor([[2.],
        [0.]], dtype=torch.float64)

Batch — Device, clone, contiguous, pin_memory#

batch_cpu = batch.to("cpu")
batch_cloned = batch.clone()
batch_contig = batch.contiguous()
batch_pinned = batch.pin_memory()
print(
    f"to('cpu').device: {batch_cpu.device}, clone is new: {batch_cloned is not batch}"
)
print(f"pin_memory: {batch_pinned['positions'].is_pinned()}")
to('cpu').device: cpu, clone is new: True
pin_memory: True

Batch — Serialization#

flat = batch.model_dump()
print("model_dump keys (sample):", list(flat.keys())[:6])
flat_slim = batch.model_dump(exclude_none=True)
print(f"model_dump(exclude_none=True) has 'device': {'device' in flat_slim}")
model_dump keys (sample): ['device', 'keys', 'batch', 'ptr', 'num_graphs', 'num_nodes_list']
model_dump(exclude_none=True) has 'device': True

Round-trip summary#

reconstructed = batch.to_data_list()
batch_again = Batch.from_data_list(reconstructed)
print(
    f"Round-trip: num_graphs {batch.num_graphs} -> {len(reconstructed)} -> {batch_again.num_graphs}"
)
print(
    f"First graph has 'node_feat' after round-trip: {'node_feat' in reconstructed[0].model_dump()}"
)
Round-trip: num_graphs 5 -> 5 -> 5
First graph has 'node_feat' after round-trip: False

Total running time of the script: (0 minutes 0.170 seconds)

Gallery generated by Sphinx-Gallery