Note
Go to the end to download the full example code.
Distributed Multi-GPU Pipeline: Parallel FIRE → Langevin#
This example orchestrates two independent FIRE → NVTLangevin pipelines
running in parallel across 4 GPUs using
DistributedPipeline.
Topology
Rank 0 (FIRE, sampler_a) ─→ Rank 1 (NVTLangevin, sink_a)
Rank 2 (FIRE, sampler_b) ─→ Rank 3 (NVTLangevin, sink_b)
Each FIRE rank draws molecules from a dataset, optimises them until
convergence, and sends them to the paired Langevin rank for short MD
production. A ConvergedSnapshotHook
on the Langevin ranks writes completed trajectories to a
HostMemory sink.
Note
This example requires 4 GPUs. Run with:
torchrun --nproc_per_node=4 examples/distributed/01_distributed_pipeline.py
For CPU-only testing, change backend="nccl" to backend="gloo".
from __future__ import annotations
import logging
import os
import torch
import torch.distributed as dist
from ase.build import molecule
from loguru import logger
from nvalchemi.data import AtomicData, Batch
from nvalchemi.dynamics import (
FIRE,
ConvergenceHook,
DistributedPipeline,
HostMemory,
NVTLangevin,
SizeAwareSampler,
)
from nvalchemi.dynamics.base import BufferConfig, HookStageEnum
from nvalchemi.dynamics.hooks import ConvergedSnapshotHook
from nvalchemi.models.demo import DemoModelWrapper
logging.basicConfig(level=logging.INFO)
# When run outside ``torchrun`` (e.g. during a Sphinx docs build), the
# distributed environment variables ``RANK`` and ``WORLD_SIZE`` are absent.
# We detect this and skip the pipeline launch so the example renders in
# the gallery without requiring multiple GPUs.
_DISTRIBUTED_ENV = "RANK" in os.environ and "WORLD_SIZE" in os.environ
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def atoms_to_data(atoms) -> AtomicData:
"""Convert an ASE Atoms object to AtomicData with dynamics fields."""
data = AtomicData.from_atoms(atoms)
n = data.num_nodes
data.forces = torch.zeros(n, 3)
data.energies = torch.zeros(1, 1)
data.add_node_property("velocities", torch.zeros(n, 3))
return data
class InMemoryDataset:
"""Minimal dataset wrapper for ``SizeAwareSampler``."""
def __init__(self, data_list: list[AtomicData]) -> None:
self._data = data_list
def __len__(self) -> int:
return len(self._data)
def __getitem__(self, idx: int) -> tuple[AtomicData, dict]:
d = self._data[idx]
return d, {"num_atoms": d.num_nodes, "num_edges": d.num_edges}
def get_metadata(self, idx: int) -> tuple[int, int]:
"""Get the metadata associated with idx"""
d = self._data[idx]
return d.num_nodes, d.num_edges
class DownstreamDoneHook:
"""Set ``stage.done = True`` after *patience* consecutive idle steps.
Downstream (non-inflight) stages never set ``done`` on their own.
This hook counts consecutive steps where the batch is empty (no
graphs to integrate) and marks the stage as finished once the
patience limit is reached.
"""
stage = HookStageEnum.AFTER_STEP
frequency = 1
def __init__(self, patience: int = 5) -> None:
self.patience = patience
self._idle_steps = 0
def __call__(self, batch: Batch, dynamics) -> None:
if batch.num_graphs == 0:
self._idle_steps += 1
else:
self._idle_steps = 0
if self._idle_steps >= self.patience:
dynamics.done = True
# ---------------------------------------------------------------------------
# Build molecules
# ---------------------------------------------------------------------------
def build_dataset() -> list[AtomicData]:
"""Create a handful of small rattled molecules."""
names = ["H2O", "CH4", "NH3", "H2O", "CH4", "NH3", "H2O", "CH4"]
data_list = []
for name in names:
atoms = molecule(name)
atoms.rattle(stdev=0.15)
data_list.append(atoms_to_data(atoms))
return data_list
DistributedPipeline topology#
DistributedPipeline maps integer GPU ranks
to dynamics stage instances. When pipeline.run() is called, each
process (launched by torchrun) looks up its own rank and executes
only the corresponding stage. Inter-rank communication is handled
transparently via NCCL isend/irecv calls.
Here we create two independent sub-pipelines: ranks 0→1 and 2→3.
Each sub-pipeline is a FIRE optimiser feeding into a Langevin MD stage.
The stages dict is built on every rank but only the local stage is
ever executed; constructing all stages on every rank keeps the code
identical across processes, which simplifies debugging.
Dataset and samplers#
SizeAwareSampler draws molecules from a
dataset and packs them into variable-size batches that fit within atom
and edge count budgets. In a distributed setting, each upstream rank
owns its own sampler and dataset partition so that work is distributed
evenly. Downstream ranks (the Langevin stages) do not need a sampler —
they receive systems directly from the paired FIRE rank via NCCL.
BufferConfig: fixed-size communication#
NCCL requires that every isend/irecv pair transfers an identical
number of bytes. BufferConfig specifies
the fixed sizes (num_systems, num_nodes, num_edges) used to
pre-allocate communication buffers on both sender and receiver. Choose
values that are at least as large as the largest batch you expect to send
in a single step; excess capacity is padded with zeros and stripped on
receipt.
Stage construction#
Each stage is wired to its neighbours via prior_rank and next_rank.
An upstream stage has prior_rank=None and next_rank=<downstream>;
a downstream stage has prior_rank=<upstream> and next_rank=None.
The buffer_config must be identical for all stages that communicate
with each other.
def make_fire(model: DemoModelWrapper, rank: int, **kwargs) -> FIRE:
"""Create a FIRE optimiser stage."""
return FIRE(
model=model,
dt=1.0,
n_steps=50,
convergence_hook=ConvergenceHook(
criteria=[
{
"key": "forces",
"threshold": 0.05,
"reduce_op": "norm",
"reduce_dims": -1,
}
],
),
**kwargs,
)
def make_langevin(
model: DemoModelWrapper,
sink: HostMemory,
rank: int,
**kwargs,
) -> NVTLangevin:
"""Create an NVTLangevin MD stage with a snapshot hook."""
return NVTLangevin(
model=model,
dt=0.5,
temperature=300.0,
friction=0.01,
n_steps=20,
hooks=[
ConvergedSnapshotHook(sink=sink, frequency=1),
DownstreamDoneHook(patience=10),
],
convergence_hook=ConvergenceHook(
criteria=[
{
"key": "forces",
"threshold": 0.01,
"reduce_op": "norm",
"reduce_dims": -1,
}
],
),
**kwargs,
)
Running the pipeline#
DistributedPipeline is used as a context
manager. On __enter__ it initialises the PyTorch process group
(dist.init_process_group) and assigns each rank its GPU device.
On __exit__ it tears down the process group gracefully.
pipeline.run() blocks until the local stage signals completion
(dynamics.done = True). Upstream ranks finish when their sampler
is exhausted; downstream ranks finish via DownstreamDoneHook after
a configurable number of idle steps with no incoming systems.
def main() -> None:
"""Launch two parallel FIRE -> Langevin pipelines on 4 GPUs."""
model = DemoModelWrapper()
# Sinks (only used by ranks 1 and 3, but created on all for simplicity)
sink_a = HostMemory(capacity=100)
sink_b = HostMemory(capacity=100)
# Dataset (only used by ranks 0 and 2)
all_data = build_dataset()
mid = len(all_data) // 2
dataset_a = InMemoryDataset(all_data[:mid])
dataset_b = InMemoryDataset(all_data[mid:])
sampler_a = SizeAwareSampler(
dataset=dataset_a,
max_atoms=50,
max_edges=0,
max_batch_size=4,
)
sampler_b = SizeAwareSampler(
dataset=dataset_b,
max_atoms=50,
max_edges=0,
max_batch_size=4,
)
# Buffer config — matches sampler capacities and ensures fixed-size comm
# buffers for NCCL compatibility (identical message count every step).
buffer_cfg = BufferConfig(num_systems=4, num_nodes=50, num_edges=0)
# Stages — one per rank.
# By default prior_rank / next_rank are -1 (unset) and
# DistributedPipeline.setup() would auto-wire a linear chain
# 0 -> 1 -> 2 -> 3. Setting them explicitly here creates two
# independent sub-pipelines: 0 -> 1 and 2 -> 3.
stages = {
0: make_fire(
model,
rank=0,
sampler=sampler_a,
refill_frequency=1,
prior_rank=None,
next_rank=1,
buffer_config=buffer_cfg,
),
1: make_langevin(
model,
sink=sink_a,
rank=1,
prior_rank=0,
next_rank=None,
buffer_config=buffer_cfg,
),
2: make_fire(
model,
rank=2,
sampler=sampler_b,
refill_frequency=1,
prior_rank=None,
next_rank=3,
buffer_config=buffer_cfg,
),
3: make_langevin(
model,
sink=sink_b,
rank=3,
prior_rank=2,
next_rank=None,
buffer_config=buffer_cfg,
),
}
if not _DISTRIBUTED_ENV:
logger.info(
"Not running under torchrun — skipping pipeline launch. "
"Run with: torchrun --nproc_per_node=4 "
"examples/distributed/01_distributed_pipeline.py",
)
return
backend = "nccl" # use gloo when testing with CPUs
# debug mode will provide insight into what rank is doing what
pipeline = DistributedPipeline(stages=stages, backend=backend, debug_mode=True)
with pipeline:
pipeline.run()
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 1:
logger.info(f"Rank 1 sink collected {len(sink_a)} samples")
elif rank == 3:
logger.info(f"Rank 3 sink collected {len(sink_b)} samples")
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.012 seconds)