Note
Go to the end to download the full example code.
AtomicData and Batch: Graph-structured molecular data#
This example walks through the full API of AtomicData
and Batch: construction, properties, indexing,
mutation, device movement, and serialization.
import torch
from nvalchemi.data import AtomicData, Batch
AtomicData — Construction#
AtomicData requires positions (shape [n_nodes, 3])
and atomic_numbers (shape [n_nodes]). All other fields are optional.
positions = torch.randn(4, 3)
atomic_numbers = torch.tensor([1, 6, 6, 1], dtype=torch.long)
data = AtomicData(positions=positions, atomic_numbers=atomic_numbers)
# With edges (e.g. bonds or neighbor list): provide ``edge_index`` shape ``[2, n_edges]``.
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
data_with_edges = AtomicData(
positions=positions,
atomic_numbers=atomic_numbers,
edge_index=edge_index,
)
print(f"With edges: num_edges={data_with_edges.num_edges}")
# With system-level fields (e.g. energy, cell, pbc for periodicity):
data_with_system = AtomicData(
positions=positions,
atomic_numbers=atomic_numbers,
energies=torch.tensor([[0.5]]),
cell=torch.eye(3).unsqueeze(0),
pbc=torch.tensor([[True, True, False]]),
)
print(f"System energies shape: {data_with_system.energies.shape}")
With edges: num_edges=4
System energies shape: torch.Size([1, 1])
AtomicData — Properties#
Core properties: num_nodes,
num_edges, device,
dtype.
print(f"num_nodes={data.num_nodes}, num_edges={data.num_edges}")
print(f"device={data.device}, dtype={data.dtype}")
# Level-wise property views (dicts of set fields):
# :attr:`~nvalchemi.data.AtomicData.node_properties`,
# :attr:`~nvalchemi.data.AtomicData.edge_properties`,
# :attr:`~nvalchemi.data.AtomicData.system_properties`.
print("node_properties keys:", list(data.node_properties.keys()))
print("system_properties keys:", list(data_with_system.system_properties.keys()))
num_nodes=4, num_edges=0
device=cpu, dtype=torch.float32
node_properties keys: ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories', 'velocities']
system_properties keys: ['cell', 'pbc', 'energies']
AtomicData — Dict-like access and mutation#
Use __getitem__() / __setitem__()
for attribute access by name.
assert data["positions"] is data.positions
data["positions"] = torch.randn(4, 3)
assert data.positions.shape == (4, 3)
# Add custom node/edge/system properties with
# :meth:`~nvalchemi.data.AtomicData.add_node_property`,
# :meth:`~nvalchemi.data.AtomicData.add_edge_property`,
# :meth:`~nvalchemi.data.AtomicData.add_system_property`.
data.add_node_property("custom_node_feat", torch.randn(4, 2))
data_with_edges.add_edge_property("edge_weights", torch.ones(data_with_edges.num_edges))
data_with_system.add_system_property("temperature", torch.tensor([[300.0]]))
print(
"After add_*_property, 'custom_node_feat' in node_properties:",
"custom_node_feat" in data.node_properties,
)
After add_*_property, 'custom_node_feat' in node_properties: True
AtomicData — Chemical hash and equality#
chemical_hash gives a structure/composition hash;
__eq__() compares by chemical hash.
h = data.chemical_hash
print(f"chemical_hash length: {len(h)}")
data2 = AtomicData(
positions=data.positions.clone(), atomic_numbers=data.atomic_numbers.clone()
)
print(f"Same structure equal: {data == data2}")
chemical_hash length: 64
Same structure equal: True
AtomicData — Device and clone#
to() and clone()
(and .cpu() / .cuda() from the mixin) for device movement and copying.
on_cpu = data.to("cpu")
cloned = data.clone()
print(f"to('cpu').device: {on_cpu.device}, clone is new object: {cloned is not data}")
to('cpu').device: cpu, clone is new object: True
AtomicData — Serialization#
Pydantic serialization: model_dump() and
model_dump_json() (tensors become lists in JSON).
data_vanilla = AtomicData(
positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long)
)
d = data_vanilla.model_dump(exclude_none=True)
print("model_dump keys (sample):", list(d.keys())[:4])
json_str = data_vanilla.model_dump_json()
print(f"model_dump_json length: {len(json_str)}")
model_dump keys (sample): ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories']
model_dump_json length: 701
Batch — Construction#
Build a Batch with
from_data_list(). Optionally pass device or
exclude_keys to omit certain attributes.
data_list = [
AtomicData(
positions=torch.randn(2, 3),
atomic_numbers=torch.ones(2, dtype=torch.long),
energies=torch.tensor([[0.0]]),
),
AtomicData(
positions=torch.randn(3, 3),
atomic_numbers=torch.ones(3, dtype=torch.long),
energies=torch.tensor([[0.0]]),
),
AtomicData(
positions=torch.randn(1, 3),
atomic_numbers=torch.ones(1, dtype=torch.long),
energies=torch.tensor([[0.0]]),
),
]
batch = Batch.from_data_list(data_list)
# exclude_keys: e.g. skip a key when batching
data_with_extra = AtomicData(
positions=torch.randn(2, 3),
atomic_numbers=torch.ones(2, dtype=torch.long),
)
data_with_extra.add_node_property("skip_me", torch.zeros(2, 1))
batch_slim = Batch.from_data_list([data_with_extra], exclude_keys=["skip_me"])
print(f"Batch num_graphs={batch.num_graphs}, num_nodes={batch.num_nodes}")
Batch num_graphs=3, num_nodes=6
Batch — Size and shape properties#
print(f"num_graphs={batch.num_graphs}, batch_size={batch.batch_size}")
print(f"num_nodes_list={batch.num_nodes_list}, num_edges_list={batch.num_edges_list}")
print(
f"batch (graph index per node) shape: {batch.batch.shape}, ptr: {batch.ptr.tolist()}"
)
print(f"max_num_nodes={batch.max_num_nodes}")
num_graphs=3, batch_size=3
num_nodes_list=[2, 3, 1], num_edges_list=[]
batch (graph index per node) shape: torch.Size([6]), ptr: [0, 2, 5, 6]
max_num_nodes=3
Batch — Reconstructing graphs#
first = batch.get_data(0)
last = batch.get_data(-1)
all_graphs = batch.to_data_list()
print(
f"get_data(0).num_nodes={first.num_nodes}, get_data(-1).num_nodes={last.num_nodes}"
)
print(f"len(to_data_list())={len(all_graphs)}")
get_data(0).num_nodes=2, get_data(-1).num_nodes=1
len(to_data_list())=3
Batch — Indexing (single graph, sub-batch, attribute)#
one = batch[0]
sub = batch[1:3]
sub2 = batch[torch.tensor([0, 2])]
sub3 = batch[[0, 2]]
mask = torch.tensor([True, False, True])
sub4 = batch[mask]
positions_tensor = batch["positions"]
print(f"batch[0] num_nodes={one.num_nodes}, batch[1:3] num_graphs={len(sub)}")
print(f"batch[[0,2]] num_nodes_list={sub3.num_nodes_list}")
print(f"batch['positions'].shape={positions_tensor.shape}")
batch[0] num_nodes=2, batch[1:3] num_graphs=2
batch[[0,2]] num_nodes_list=[2, 1]
batch['positions'].shape=torch.Size([6, 3])
Batch — Containment, length, iteration#
print(f"'positions' in batch: {'positions' in batch}")
print(f"len(batch)={len(batch)}")
keys_from_iter = [k for k, _ in batch]
print(f"Keys from iteration (sample): {keys_from_iter[:3]}")
'positions' in batch: True
len(batch)=3
Keys from iteration (sample): ['atomic_numbers', 'velocities', 'atom_categories']
Batch — Setting attributes and adding keys#
batch.add_key(
"node_feat",
[torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)],
level="node",
)
batch.add_key(
"temperature",
[torch.tensor([[300.0]]), torch.tensor([[350.0]]), torch.tensor([[400.0]])],
level="system",
)
data_a = AtomicData(
positions=torch.randn(2, 3),
atomic_numbers=torch.ones(2, dtype=torch.long),
edge_index=torch.tensor([[0], [1]], dtype=torch.long),
)
data_b = AtomicData(
positions=torch.randn(3, 3),
atomic_numbers=torch.ones(3, dtype=torch.long),
edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long),
)
batch_with_edges = Batch.from_data_list([data_a, data_b])
batch_with_edges.add_key(
"edge_attr",
[torch.randn(1, 4), torch.randn(2, 4)],
level="edge",
)
print(
f"After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: {'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges}"
)
After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: (True, True, True)
Batch — Append and append_data#
extra = Batch.from_data_list(
[
AtomicData(
positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long)
),
]
)
batch.append(extra)
print(f"After append: num_graphs={batch.num_graphs}")
batch.append_data(
[
AtomicData(
positions=torch.randn(1, 3), atomic_numbers=torch.ones(1, dtype=torch.long)
),
]
)
print(
f"After append_data: num_graphs={batch.num_graphs}, num_nodes_list={batch.num_nodes_list}"
)
After append: num_graphs=4
After append_data: num_graphs=5, num_nodes_list=[2, 3, 1, 2, 1]
Batch — put and defrag#
def _tiny_graph(energy: float):
return AtomicData(
positions=torch.randn(2, 3),
atomic_numbers=torch.ones(2, dtype=torch.long),
energies=torch.tensor([[energy]]),
)
buffer = Batch.empty(
num_systems=40, num_nodes=80, num_edges=80, template=_tiny_graph(0.0)
)
print(
f"Empty buffer: num_graphs={buffer.num_graphs}, system_capacity={buffer.system_capacity}"
)
src_batch = Batch.from_data_list([_tiny_graph(1.0), _tiny_graph(2.0)])
mask = torch.tensor([True, False])
copied_mask = torch.zeros(2, dtype=torch.bool)
dest_mask = torch.zeros(buffer.system_capacity, dtype=torch.bool)
buffer.put(src_batch, mask, copied_mask=copied_mask, dest_mask=dest_mask)
print(
f"After put: buffer has {buffer.num_graphs} graphs; copied_mask={copied_mask.tolist()}"
)
src_batch.defrag(copied_mask=copied_mask)
print(f"After defrag: src_batch has {src_batch.num_graphs} graph(s)")
print(f"Remaining graph energy: {src_batch['energies']}")
Empty buffer: num_graphs=0, system_capacity=40
After put: buffer has 1 graphs; copied_mask=[True, False]
After defrag: src_batch has 1 graph(s)
Remaining graph energy: tensor([[2.],
[0.]], dtype=torch.float64)
Batch — Device, clone, contiguous, pin_memory#
batch_cpu = batch.to("cpu")
batch_cloned = batch.clone()
batch_contig = batch.contiguous()
batch_pinned = batch.pin_memory()
print(
f"to('cpu').device: {batch_cpu.device}, clone is new: {batch_cloned is not batch}"
)
print(f"pin_memory: {batch_pinned['positions'].is_pinned()}")
to('cpu').device: cpu, clone is new: True
pin_memory: True
Batch — Serialization#
flat = batch.model_dump()
print("model_dump keys (sample):", list(flat.keys())[:6])
flat_slim = batch.model_dump(exclude_none=True)
print(f"model_dump(exclude_none=True) has 'device': {'device' in flat_slim}")
model_dump keys (sample): ['device', 'keys', 'batch', 'ptr', 'num_graphs', 'num_nodes_list']
model_dump(exclude_none=True) has 'device': True
Round-trip summary#
reconstructed = batch.to_data_list()
batch_again = Batch.from_data_list(reconstructed)
print(
f"Round-trip: num_graphs {batch.num_graphs} -> {len(reconstructed)} -> {batch_again.num_graphs}"
)
print(
f"First graph has 'node_feat' after round-trip: {'node_feat' in reconstructed[0].model_dump()}"
)
Round-trip: num_graphs 5 -> 5 -> 5
First graph has 'node_feat' after round-trip: False
Total running time of the script: (0 minutes 0.170 seconds)