.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/basic/01_data_structures.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_basic_01_data_structures.py: AtomicData and Batch: Graph-structured molecular data ===================================================== This example walks through the full API of :class:`~nvalchemi.data.AtomicData` and :class:`~nvalchemi.data.Batch`: construction, properties, indexing, mutation, device movement, and serialization. .. GENERATED FROM PYTHON SOURCE LINES 23-28 .. code-block:: Python import torch from nvalchemi.data import AtomicData, Batch .. GENERATED FROM PYTHON SOURCE LINES 29-33 AtomicData — Construction -------------------------- :class:`~nvalchemi.data.AtomicData` requires ``positions`` (shape ``[n_nodes, 3]``) and ``atomic_numbers`` (shape ``[n_nodes]``). All other fields are optional. .. GENERATED FROM PYTHON SOURCE LINES 33-57 .. code-block:: Python positions = torch.randn(4, 3) atomic_numbers = torch.tensor([1, 6, 6, 1], dtype=torch.long) data = AtomicData(positions=positions, atomic_numbers=atomic_numbers) # With edges (e.g. bonds or neighbor list): provide ``edge_index`` shape ``[2, n_edges]``. edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) data_with_edges = AtomicData( positions=positions, atomic_numbers=atomic_numbers, edge_index=edge_index, ) print(f"With edges: num_edges={data_with_edges.num_edges}") # With system-level fields (e.g. energy, cell, pbc for periodicity): data_with_system = AtomicData( positions=positions, atomic_numbers=atomic_numbers, energies=torch.tensor([[0.5]]), cell=torch.eye(3).unsqueeze(0), pbc=torch.tensor([[True, True, False]]), ) print(f"System energies shape: {data_with_system.energies.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none With edges: num_edges=4 System energies shape: torch.Size([1, 1]) .. GENERATED FROM PYTHON SOURCE LINES 58-63 AtomicData — Properties ----------------------- Core properties: :attr:`~nvalchemi.data.AtomicData.num_nodes`, :attr:`~nvalchemi.data.AtomicData.num_edges`, :attr:`~nvalchemi.data.AtomicData.device`, :attr:`~nvalchemi.data.AtomicData.dtype`. .. GENERATED FROM PYTHON SOURCE LINES 63-74 .. code-block:: Python print(f"num_nodes={data.num_nodes}, num_edges={data.num_edges}") print(f"device={data.device}, dtype={data.dtype}") # Level-wise property views (dicts of set fields): # :attr:`~nvalchemi.data.AtomicData.node_properties`, # :attr:`~nvalchemi.data.AtomicData.edge_properties`, # :attr:`~nvalchemi.data.AtomicData.system_properties`. print("node_properties keys:", list(data.node_properties.keys())) print("system_properties keys:", list(data_with_system.system_properties.keys())) .. rst-class:: sphx-glr-script-out .. code-block:: none num_nodes=4, num_edges=0 device=cpu, dtype=torch.float32 node_properties keys: ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories', 'velocities'] system_properties keys: ['cell', 'pbc', 'energies'] .. GENERATED FROM PYTHON SOURCE LINES 75-79 AtomicData — Dict-like access and mutation ------------------------------------------ Use :meth:`~nvalchemi.data.AtomicData.__getitem__` / :meth:`~nvalchemi.data.AtomicData.__setitem__` for attribute access by name. .. GENERATED FROM PYTHON SOURCE LINES 79-96 .. code-block:: Python assert data["positions"] is data.positions data["positions"] = torch.randn(4, 3) assert data.positions.shape == (4, 3) # Add custom node/edge/system properties with # :meth:`~nvalchemi.data.AtomicData.add_node_property`, # :meth:`~nvalchemi.data.AtomicData.add_edge_property`, # :meth:`~nvalchemi.data.AtomicData.add_system_property`. data.add_node_property("custom_node_feat", torch.randn(4, 2)) data_with_edges.add_edge_property("edge_weights", torch.ones(data_with_edges.num_edges)) data_with_system.add_system_property("temperature", torch.tensor([[300.0]])) print( "After add_*_property, 'custom_node_feat' in node_properties:", "custom_node_feat" in data.node_properties, ) .. rst-class:: sphx-glr-script-out .. code-block:: none After add_*_property, 'custom_node_feat' in node_properties: True .. GENERATED FROM PYTHON SOURCE LINES 97-101 AtomicData — Chemical hash and equality ---------------------------------------- :attr:`~nvalchemi.data.AtomicData.chemical_hash` gives a structure/composition hash; :meth:`~nvalchemi.data.AtomicData.__eq__` compares by chemical hash. .. GENERATED FROM PYTHON SOURCE LINES 101-109 .. code-block:: Python h = data.chemical_hash print(f"chemical_hash length: {len(h)}") data2 = AtomicData( positions=data.positions.clone(), atomic_numbers=data.atomic_numbers.clone() ) print(f"Same structure equal: {data == data2}") .. rst-class:: sphx-glr-script-out .. code-block:: none chemical_hash length: 64 Same structure equal: True .. GENERATED FROM PYTHON SOURCE LINES 110-114 AtomicData — Device and clone ------------------------------ :meth:`~nvalchemi.data.AtomicData.to` and :meth:`~nvalchemi.data.data.DataMixin.clone` (and ``.cpu()`` / ``.cuda()`` from the mixin) for device movement and copying. .. GENERATED FROM PYTHON SOURCE LINES 114-119 .. code-block:: Python on_cpu = data.to("cpu") cloned = data.clone() print(f"to('cpu').device: {on_cpu.device}, clone is new object: {cloned is not data}") .. rst-class:: sphx-glr-script-out .. code-block:: none to('cpu').device: cpu, clone is new object: True .. GENERATED FROM PYTHON SOURCE LINES 120-124 AtomicData — Serialization --------------------------- Pydantic serialization: :meth:`~pydantic.BaseModel.model_dump` and :meth:`~pydantic.BaseModel.model_dump_json` (tensors become lists in JSON). .. GENERATED FROM PYTHON SOURCE LINES 124-133 .. code-block:: Python data_vanilla = AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long) ) d = data_vanilla.model_dump(exclude_none=True) print("model_dump keys (sample):", list(d.keys())[:4]) json_str = data_vanilla.model_dump_json() print(f"model_dump_json length: {len(json_str)}") .. rst-class:: sphx-glr-script-out .. code-block:: none model_dump keys (sample): ['atomic_numbers', 'positions', 'atomic_masses', 'atom_categories'] model_dump_json length: 701 .. GENERATED FROM PYTHON SOURCE LINES 134-139 Batch — Construction --------------------- Build a :class:`~nvalchemi.data.Batch` with :meth:`~nvalchemi.data.Batch.from_data_list`. Optionally pass ``device`` or ``exclude_keys`` to omit certain attributes. .. GENERATED FROM PYTHON SOURCE LINES 139-168 .. code-block:: Python data_list = [ AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long), energies=torch.tensor([[0.0]]), ), AtomicData( positions=torch.randn(3, 3), atomic_numbers=torch.ones(3, dtype=torch.long), energies=torch.tensor([[0.0]]), ), AtomicData( positions=torch.randn(1, 3), atomic_numbers=torch.ones(1, dtype=torch.long), energies=torch.tensor([[0.0]]), ), ] batch = Batch.from_data_list(data_list) # exclude_keys: e.g. skip a key when batching data_with_extra = AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long), ) data_with_extra.add_node_property("skip_me", torch.zeros(2, 1)) batch_slim = Batch.from_data_list([data_with_extra], exclude_keys=["skip_me"]) print(f"Batch num_graphs={batch.num_graphs}, num_nodes={batch.num_nodes}") .. rst-class:: sphx-glr-script-out .. code-block:: none Batch num_graphs=3, num_nodes=6 .. GENERATED FROM PYTHON SOURCE LINES 169-171 Batch — Size and shape properties ---------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 171-179 .. code-block:: Python print(f"num_graphs={batch.num_graphs}, batch_size={batch.batch_size}") print(f"num_nodes_list={batch.num_nodes_list}, num_edges_list={batch.num_edges_list}") print( f"batch (graph index per node) shape: {batch.batch.shape}, ptr: {batch.ptr.tolist()}" ) print(f"max_num_nodes={batch.max_num_nodes}") .. rst-class:: sphx-glr-script-out .. code-block:: none num_graphs=3, batch_size=3 num_nodes_list=[2, 3, 1], num_edges_list=[] batch (graph index per node) shape: torch.Size([6]), ptr: [0, 2, 5, 6] max_num_nodes=3 .. GENERATED FROM PYTHON SOURCE LINES 180-182 Batch — Reconstructing graphs ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 182-191 .. code-block:: Python first = batch.get_data(0) last = batch.get_data(-1) all_graphs = batch.to_data_list() print( f"get_data(0).num_nodes={first.num_nodes}, get_data(-1).num_nodes={last.num_nodes}" ) print(f"len(to_data_list())={len(all_graphs)}") .. rst-class:: sphx-glr-script-out .. code-block:: none get_data(0).num_nodes=2, get_data(-1).num_nodes=1 len(to_data_list())=3 .. GENERATED FROM PYTHON SOURCE LINES 192-194 Batch — Indexing (single graph, sub-batch, attribute) ------------------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 194-206 .. code-block:: Python one = batch[0] sub = batch[1:3] sub2 = batch[torch.tensor([0, 2])] sub3 = batch[[0, 2]] mask = torch.tensor([True, False, True]) sub4 = batch[mask] positions_tensor = batch["positions"] print(f"batch[0] num_nodes={one.num_nodes}, batch[1:3] num_graphs={len(sub)}") print(f"batch[[0,2]] num_nodes_list={sub3.num_nodes_list}") print(f"batch['positions'].shape={positions_tensor.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none batch[0] num_nodes=2, batch[1:3] num_graphs=2 batch[[0,2]] num_nodes_list=[2, 1] batch['positions'].shape=torch.Size([6, 3]) .. GENERATED FROM PYTHON SOURCE LINES 207-209 Batch — Containment, length, iteration --------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 209-215 .. code-block:: Python print(f"'positions' in batch: {'positions' in batch}") print(f"len(batch)={len(batch)}") keys_from_iter = [k for k, _ in batch] print(f"Keys from iteration (sample): {keys_from_iter[:3]}") .. rst-class:: sphx-glr-script-out .. code-block:: none 'positions' in batch: True len(batch)=3 Keys from iteration (sample): ['atomic_numbers', 'velocities', 'atom_categories'] .. GENERATED FROM PYTHON SOURCE LINES 216-218 Batch — Setting attributes and adding keys ------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 218-249 .. code-block:: Python batch.add_key( "node_feat", [torch.randn(2, 4), torch.randn(3, 4), torch.randn(1, 4)], level="node", ) batch.add_key( "temperature", [torch.tensor([[300.0]]), torch.tensor([[350.0]]), torch.tensor([[400.0]])], level="system", ) data_a = AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long), edge_index=torch.tensor([[0], [1]], dtype=torch.long), ) data_b = AtomicData( positions=torch.randn(3, 3), atomic_numbers=torch.ones(3, dtype=torch.long), edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long), ) batch_with_edges = Batch.from_data_list([data_a, data_b]) batch_with_edges.add_key( "edge_attr", [torch.randn(1, 4), torch.randn(2, 4)], level="edge", ) print( f"After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: {'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none After add_key: 'node_feat' in batch, 'temperature' in batch, 'edge_attr' in batch_with_edges: (True, True, True) .. GENERATED FROM PYTHON SOURCE LINES 250-252 Batch — Append and append_data ------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 252-274 .. code-block:: Python extra = Batch.from_data_list( [ AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long) ), ] ) batch.append(extra) print(f"After append: num_graphs={batch.num_graphs}") batch.append_data( [ AtomicData( positions=torch.randn(1, 3), atomic_numbers=torch.ones(1, dtype=torch.long) ), ] ) print( f"After append_data: num_graphs={batch.num_graphs}, num_nodes_list={batch.num_nodes_list}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none After append: num_graphs=4 After append_data: num_graphs=5, num_nodes_list=[2, 3, 1, 2, 1] .. GENERATED FROM PYTHON SOURCE LINES 275-277 Batch — put and defrag ----------------------- .. GENERATED FROM PYTHON SOURCE LINES 277-307 .. code-block:: Python def _tiny_graph(energy: float): return AtomicData( positions=torch.randn(2, 3), atomic_numbers=torch.ones(2, dtype=torch.long), energies=torch.tensor([[energy]]), ) buffer = Batch.empty( num_systems=40, num_nodes=80, num_edges=80, template=_tiny_graph(0.0) ) print( f"Empty buffer: num_graphs={buffer.num_graphs}, system_capacity={buffer.system_capacity}" ) src_batch = Batch.from_data_list([_tiny_graph(1.0), _tiny_graph(2.0)]) mask = torch.tensor([True, False]) copied_mask = torch.zeros(2, dtype=torch.bool) dest_mask = torch.zeros(buffer.system_capacity, dtype=torch.bool) buffer.put(src_batch, mask, copied_mask=copied_mask, dest_mask=dest_mask) print( f"After put: buffer has {buffer.num_graphs} graphs; copied_mask={copied_mask.tolist()}" ) src_batch.defrag(copied_mask=copied_mask) print(f"After defrag: src_batch has {src_batch.num_graphs} graph(s)") print(f"Remaining graph energy: {src_batch['energies']}") .. rst-class:: sphx-glr-script-out .. code-block:: none Empty buffer: num_graphs=0, system_capacity=40 After put: buffer has 1 graphs; copied_mask=[True, False] After defrag: src_batch has 1 graph(s) Remaining graph energy: tensor([[2.], [0.]], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 308-310 Batch — Device, clone, contiguous, pin_memory ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 310-320 .. code-block:: Python batch_cpu = batch.to("cpu") batch_cloned = batch.clone() batch_contig = batch.contiguous() batch_pinned = batch.pin_memory() print( f"to('cpu').device: {batch_cpu.device}, clone is new: {batch_cloned is not batch}" ) print(f"pin_memory: {batch_pinned['positions'].is_pinned()}") .. rst-class:: sphx-glr-script-out .. code-block:: none to('cpu').device: cpu, clone is new: True pin_memory: True .. GENERATED FROM PYTHON SOURCE LINES 321-323 Batch — Serialization ---------------------- .. GENERATED FROM PYTHON SOURCE LINES 323-329 .. code-block:: Python flat = batch.model_dump() print("model_dump keys (sample):", list(flat.keys())[:6]) flat_slim = batch.model_dump(exclude_none=True) print(f"model_dump(exclude_none=True) has 'device': {'device' in flat_slim}") .. rst-class:: sphx-glr-script-out .. code-block:: none model_dump keys (sample): ['device', 'keys', 'batch', 'ptr', 'num_graphs', 'num_nodes_list'] model_dump(exclude_none=True) has 'device': True .. GENERATED FROM PYTHON SOURCE LINES 330-332 Round-trip summary ------------------ .. GENERATED FROM PYTHON SOURCE LINES 332-341 .. code-block:: Python reconstructed = batch.to_data_list() batch_again = Batch.from_data_list(reconstructed) print( f"Round-trip: num_graphs {batch.num_graphs} -> {len(reconstructed)} -> {batch_again.num_graphs}" ) print( f"First graph has 'node_feat' after round-trip: {'node_feat' in reconstructed[0].model_dump()}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Round-trip: num_graphs 5 -> 5 -> 5 First graph has 'node_feat' after round-trip: False .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.170 seconds) .. _sphx_glr_download_examples_basic_01_data_structures.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 01_data_structures.ipynb <01_data_structures.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 01_data_structures.py <01_data_structures.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 01_data_structures.zip <01_data_structures.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_