.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/intermediate/04_inflight_batching.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_intermediate_04_inflight_batching.py: Processing Large Datasets with Inflight Batching ================================================= **The problem**: a production relaxation campaign may require running thousands of structures through FIRE → NVT stages. Loading all of them into GPU memory at once is impossible; processing one at a time wastes GPU throughput. **The solution**: *inflight batching* (Mode 2 in :class:`~nvalchemi.dynamics.FusedStage`). A fixed-size *live batch* occupies the GPU at all times. Whenever a system graduates (converges or exhausts its step budget), it is evicted and a new sample from the dataset takes its slot. The :class:`~nvalchemi.dynamics.SizeAwareSampler` handles bin-packing so the replacement fits within the memory envelope of the slot it replaces. Lifecycle of a system: :: Dataset → SizeAwareSampler.request_replacement() │ ▼ Live batch (GPU) — stage 0 (FIRE relaxation) │ converges ▼ Live batch (GPU) — stage 1 (NVT equilibration) │ n_steps exhausted ▼ ConvergedSnapshotHook → HostMemory sink │ ▼ Slot freed → next sample loaded from dataset Key concept: **system_id**. Each sample loaded from the dataset receives a monotonically-increasing integer ``system_id`` stamped by the sampler. This lets downstream code track individual trajectories across refill events. This example uses :class:`~nvalchemi.models.demo.DemoModelWrapper` (a small neural network) so no neighbor list is needed. For a real LJ or MACE model, add :class:`~nvalchemi.dynamics.hooks.NeighborListHook` to each sub-stage. .. GENERATED FROM PYTHON SOURCE LINES 57-71 .. code-block:: Python import logging import torch from nvalchemi.data import AtomicData from nvalchemi.dynamics import FIRE, NVTLangevin, SizeAwareSampler from nvalchemi.dynamics.base import ConvergenceHook, FusedStage from nvalchemi.dynamics.hooks import ConvergedSnapshotHook from nvalchemi.dynamics.sinks import HostMemory from nvalchemi.models.demo import DemoModelWrapper logging.basicConfig(level=logging.INFO) .. GENERATED FROM PYTHON SOURCE LINES 72-87 The dataset interface ---------------------- :class:`SizeAwareSampler` requires a dataset that implements exactly three methods: * ``__len__()`` — total number of samples. * ``__getitem__(idx) -> (AtomicData, dict)`` — load one sample by index. The ``dict`` carries arbitrary per-sample metadata (empty is fine). * ``get_metadata(idx) -> (num_atoms, num_edges)`` — return atom/edge counts **without** constructing the full ``AtomicData`` object. The ``get_metadata`` method exists for efficiency: the sampler pre-scans the entire dataset at construction time to build size-aware bins. If loading each sample were necessary just for bin-packing, large datasets would be prohibitively slow to initialise. .. GENERATED FROM PYTHON SOURCE LINES 87-172 .. code-block:: Python class MixedSizeDataset: """A dataset of random molecular systems with varying atom counts. Systems have between ``min_atoms`` and ``max_atoms`` atoms, assigned by cycling through a range. This produces a realistic distribution of sizes that exercises the bin-packing logic. Parameters ---------- n_samples : int Total number of systems. min_atoms : int Minimum atom count. max_atoms : int Maximum atom count. seed : int Base RNG seed. """ def __init__( self, n_samples: int, min_atoms: int = 4, max_atoms: int = 6, seed: int = 0, ) -> None: self.n_samples = n_samples self.min_atoms = min_atoms self.max_atoms = max_atoms self.base_seed = seed # Pre-assign atom counts to each sample index. span = max_atoms - min_atoms + 1 self._atom_counts = [min_atoms + (i % span) for i in range(n_samples)] def __len__(self) -> int: return self.n_samples def get_metadata(self, idx: int) -> tuple[int, int]: """Return ``(num_atoms, num_edges)`` without loading the sample. The sampler calls this for **every** sample at construction time. It must be cheap — no I/O, no tensor allocation. Parameters ---------- idx : int Sample index. Returns ------- tuple[int, int] ``(num_atoms, num_edges)``; num_edges=0 for models without edge lists. """ return self._atom_counts[idx], 0 def __getitem__(self, idx: int) -> tuple[AtomicData, dict]: """Load one sample. Parameters ---------- idx : int Sample index. Returns ------- tuple[AtomicData, dict] The AtomicData and an (empty) metadata dict. """ n = self._atom_counts[idx] g = torch.Generator() g.manual_seed(self.base_seed + idx) data = AtomicData( positions=torch.randn(n, 3, generator=g), atomic_numbers=torch.randint(1, 10, (n,), dtype=torch.long, generator=g), atomic_masses=torch.ones(n), forces=torch.zeros(n, 3), energies=torch.zeros(1, 1), ) data.add_node_property("velocities", torch.zeros(n, 3)) return data, {} .. GENERATED FROM PYTHON SOURCE LINES 173-184 SizeAwareSampler ----------------- The sampler consumes the dataset lazily. It builds atom-count bins at construction (calling ``get_metadata`` on every sample) and then serves replacements via ``request_replacement(num_atoms, num_edges)``, which finds an unconsumed sample small enough to fit in the vacated slot. ``max_atoms=24`` allows at most 4–6 systems of 4–6 atoms each in the live batch. ``max_batch_size=6`` caps the graph count independently. ``max_edges=None`` disables the edge constraint (DemoModelWrapper does not use a neighbor list). .. GENERATED FROM PYTHON SOURCE LINES 184-199 .. code-block:: Python dataset = MixedSizeDataset(n_samples=30, min_atoms=4, max_atoms=6, seed=200) sampler = SizeAwareSampler( dataset, max_atoms=24, max_edges=None, max_batch_size=6, ) logging.info( "Sampler created: %d samples, bins: %s", len(sampler), sorted(sampler._bins.keys()), ) .. GENERATED FROM PYTHON SOURCE LINES 200-214 Building the pipeline ---------------------- Two sub-stages: * Stage 0 — FIRE geometry relaxation (convergence at fmax < 0.5 eV/Å; deliberately loose so most systems converge quickly in this demo). * Stage 1 — NVT equilibration for 20 steps. :class:`~nvalchemi.dynamics.hooks.ConvergedSnapshotHook` fires at ``ON_CONVERGE``. When a system exits stage 1 (its step budget is exhausted and it graduates), the hook writes only that system's data to the :class:`~nvalchemi.dynamics.sinks.HostMemory` sink. This is the recommended pattern for collecting results in inflight runs without writing intermediate states. .. GENERATED FROM PYTHON SOURCE LINES 214-244 .. code-block:: Python torch.manual_seed(42) model = DemoModelWrapper() model.eval() results_sink = HostMemory(capacity=30) converged_hook = ConvergedSnapshotHook(sink=results_sink) fire_stage = FIRE( model=model, dt=0.1, convergence_hook=ConvergenceHook.from_forces(threshold=0.5), ) nvt_stage = NVTLangevin( model=model, dt=0.5, temperature=300.0, friction=0.1, random_seed=11, n_steps=20, hooks=[converged_hook], ) fused = FusedStage( sub_stages=[(0, fire_stage), (1, nvt_stage)], sampler=sampler, sinks=[results_sink], refill_frequency=5, ) .. GENERATED FROM PYTHON SOURCE LINES 245-254 Running in inflight mode (batch=None) --------------------------------------- Passing ``batch=None`` tells :class:`FusedStage` to call ``sampler.build_initial_batch()`` to create the first live batch, then run until the sampler is exhausted **and** all remaining systems have graduated through all stages. ``n_steps=500`` is the maximum total step budget. If the dataset is fully processed before this limit, the run terminates early and returns ``None``. .. GENERATED FROM PYTHON SOURCE LINES 254-265 .. code-block:: Python logging.info("Starting inflight batching run...") result = fused.run(batch=None, n_steps=500) if result is None: logging.info("All %d systems processed — sampler exhausted.", dataset.n_samples) else: logging.info( "Step budget exhausted with %d systems still active.", result.num_graphs ) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/kinlongkelvi/Repos/nvalchemi-toolkit/.venv/lib/python3.13/site-packages/warp/_src/torch.py:280: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:492.) if t.grad is not None: .. GENERATED FROM PYTHON SOURCE LINES 266-271 Inspecting results ------------------- The :class:`~nvalchemi.dynamics.sinks.HostMemory` sink accumulates graduated systems on CPU. Drain it to retrieve a single ``Batch`` containing all collected structures. .. GENERATED FROM PYTHON SOURCE LINES 271-286 .. code-block:: Python n_collected = len(results_sink) logging.info("Results sink contains %d systems.", n_collected) if n_collected > 0: results_batch = results_sink.drain() system_ids = results_batch.system_id.squeeze(-1).tolist() logging.info("Collected system_ids (first 10): %s", system_ids[:10]) logging.info( "Results batch: num_graphs=%d, num_nodes=%d", results_batch.num_graphs, results_batch.num_nodes, ) else: logging.info("No results collected (step budget may have been too small).") .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.167 seconds) .. _sphx_glr_download_examples_intermediate_04_inflight_batching.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 04_inflight_batching.ipynb <04_inflight_batching.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 04_inflight_batching.py <04_inflight_batching.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 04_inflight_batching.zip <04_inflight_batching.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_