Note
Go to the end to download the full example code.
Processing Large Datasets with Inflight Batching#
The problem: a production relaxation campaign may require running thousands of structures through FIRE → NVT stages. Loading all of them into GPU memory at once is impossible; processing one at a time wastes GPU throughput.
The solution: inflight batching (Mode 2 in FusedStage).
A fixed-size live batch occupies the GPU at all times. Whenever a system
graduates (converges or exhausts its step budget), it is evicted and a new
sample from the dataset takes its slot. The
SizeAwareSampler handles bin-packing so the
replacement fits within the memory envelope of the slot it replaces.
Lifecycle of a system:
Dataset → SizeAwareSampler.request_replacement()
│
▼
Live batch (GPU) — stage 0 (FIRE relaxation)
│ converges
▼
Live batch (GPU) — stage 1 (NVT equilibration)
│ n_steps exhausted
▼
ConvergedSnapshotHook → HostMemory sink
│
▼
Slot freed → next sample loaded from dataset
Key concept: system_id. Each sample loaded from the dataset receives a
monotonically-increasing integer system_id stamped by the sampler. This
lets downstream code track individual trajectories across refill events.
This example uses DemoModelWrapper (a small
neural network) so no neighbor list is needed. For a real LJ or MACE model,
add NeighborListHook to each sub-stage.
import logging
import torch
from nvalchemi.data import AtomicData
from nvalchemi.dynamics import FIRE, NVTLangevin, SizeAwareSampler
from nvalchemi.dynamics.base import ConvergenceHook, FusedStage
from nvalchemi.dynamics.hooks import ConvergedSnapshotHook
from nvalchemi.dynamics.sinks import HostMemory
from nvalchemi.models.demo import DemoModelWrapper
logging.basicConfig(level=logging.INFO)
The dataset interface#
SizeAwareSampler requires a dataset that implements exactly three
methods:
__len__()— total number of samples.__getitem__(idx) -> (AtomicData, dict)— load one sample by index. Thedictcarries arbitrary per-sample metadata (empty is fine).get_metadata(idx) -> (num_atoms, num_edges)— return atom/edge counts without constructing the fullAtomicDataobject.
The get_metadata method exists for efficiency: the sampler pre-scans the
entire dataset at construction time to build size-aware bins. If loading
each sample were necessary just for bin-packing, large datasets would be
prohibitively slow to initialise.
class MixedSizeDataset:
"""A dataset of random molecular systems with varying atom counts.
Systems have between ``min_atoms`` and ``max_atoms`` atoms, assigned
by cycling through a range. This produces a realistic distribution of
sizes that exercises the bin-packing logic.
Parameters
----------
n_samples : int
Total number of systems.
min_atoms : int
Minimum atom count.
max_atoms : int
Maximum atom count.
seed : int
Base RNG seed.
"""
def __init__(
self,
n_samples: int,
min_atoms: int = 4,
max_atoms: int = 6,
seed: int = 0,
) -> None:
self.n_samples = n_samples
self.min_atoms = min_atoms
self.max_atoms = max_atoms
self.base_seed = seed
# Pre-assign atom counts to each sample index.
span = max_atoms - min_atoms + 1
self._atom_counts = [min_atoms + (i % span) for i in range(n_samples)]
def __len__(self) -> int:
return self.n_samples
def get_metadata(self, idx: int) -> tuple[int, int]:
"""Return ``(num_atoms, num_edges)`` without loading the sample.
The sampler calls this for **every** sample at construction time.
It must be cheap — no I/O, no tensor allocation.
Parameters
----------
idx : int
Sample index.
Returns
-------
tuple[int, int]
``(num_atoms, num_edges)``; num_edges=0 for models without edge
lists.
"""
return self._atom_counts[idx], 0
def __getitem__(self, idx: int) -> tuple[AtomicData, dict]:
"""Load one sample.
Parameters
----------
idx : int
Sample index.
Returns
-------
tuple[AtomicData, dict]
The AtomicData and an (empty) metadata dict.
"""
n = self._atom_counts[idx]
g = torch.Generator()
g.manual_seed(self.base_seed + idx)
data = AtomicData(
positions=torch.randn(n, 3, generator=g),
atomic_numbers=torch.randint(1, 10, (n,), dtype=torch.long, generator=g),
atomic_masses=torch.ones(n),
forces=torch.zeros(n, 3),
energies=torch.zeros(1, 1),
)
data.add_node_property("velocities", torch.zeros(n, 3))
return data, {}
SizeAwareSampler#
The sampler consumes the dataset lazily. It builds atom-count bins at
construction (calling get_metadata on every sample) and then serves
replacements via request_replacement(num_atoms, num_edges), which
finds an unconsumed sample small enough to fit in the vacated slot.
max_atoms=24 allows at most 4–6 systems of 4–6 atoms each in the
live batch. max_batch_size=6 caps the graph count independently.
max_edges=None disables the edge constraint (DemoModelWrapper does
not use a neighbor list).
dataset = MixedSizeDataset(n_samples=30, min_atoms=4, max_atoms=6, seed=200)
sampler = SizeAwareSampler(
dataset,
max_atoms=24,
max_edges=None,
max_batch_size=6,
)
logging.info(
"Sampler created: %d samples, bins: %s",
len(sampler),
sorted(sampler._bins.keys()),
)
Building the pipeline#
Two sub-stages:
Stage 0 — FIRE geometry relaxation (convergence at fmax < 0.5 eV/Å; deliberately loose so most systems converge quickly in this demo).
Stage 1 — NVT equilibration for 20 steps.
ConvergedSnapshotHook fires at
ON_CONVERGE. When a system exits stage 1 (its step budget is
exhausted and it graduates), the hook writes only that system’s data to
the HostMemory sink. This is the
recommended pattern for collecting results in inflight runs without
writing intermediate states.
torch.manual_seed(42)
model = DemoModelWrapper()
model.eval()
results_sink = HostMemory(capacity=30)
converged_hook = ConvergedSnapshotHook(sink=results_sink)
fire_stage = FIRE(
model=model,
dt=0.1,
convergence_hook=ConvergenceHook.from_forces(threshold=0.5),
)
nvt_stage = NVTLangevin(
model=model,
dt=0.5,
temperature=300.0,
friction=0.1,
random_seed=11,
n_steps=20,
hooks=[converged_hook],
)
fused = FusedStage(
sub_stages=[(0, fire_stage), (1, nvt_stage)],
sampler=sampler,
sinks=[results_sink],
refill_frequency=5,
)
Running in inflight mode (batch=None)#
Passing batch=None tells FusedStage to call
sampler.build_initial_batch() to create the first live batch, then
run until the sampler is exhausted and all remaining systems have
graduated through all stages.
n_steps=500 is the maximum total step budget. If the dataset is fully
processed before this limit, the run terminates early and returns None.
logging.info("Starting inflight batching run...")
result = fused.run(batch=None, n_steps=500)
if result is None:
logging.info("All %d systems processed — sampler exhausted.", dataset.n_samples)
else:
logging.info(
"Step budget exhausted with %d systems still active.", result.num_graphs
)
/home/kinlongkelvi/Repos/nvalchemi-toolkit/.venv/lib/python3.13/site-packages/warp/_src/torch.py:280: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:492.)
if t.grad is not None:
Inspecting results#
The HostMemory sink accumulates
graduated systems on CPU. Drain it to retrieve a single Batch
containing all collected structures.
n_collected = len(results_sink)
logging.info("Results sink contains %d systems.", n_collected)
if n_collected > 0:
results_batch = results_sink.drain()
system_ids = results_batch.system_id.squeeze(-1).tolist()
logging.info("Collected system_ids (first 10): %s", system_ids[:10])
logging.info(
"Results batch: num_graphs=%d, num_nodes=%d",
results_batch.num_graphs,
results_batch.num_nodes,
)
else:
logging.info("No results collected (step budget may have been too small).")
Total running time of the script: (0 minutes 0.167 seconds)