# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base classes and protocols for molecular dynamics simulations.
This module provides the foundational abstractions for running dynamics
simulations, including hook protocols for extensibility, the base
dynamics class that coordinates model evaluation with integrator updates,
and the ``FusedStage`` class for fusing multiple dynamics stages on a
single GPU with shared batch and forward pass.
Inheritance structure::
object
└── _CommunicationMixin # inter-rank communication base
└── BaseDynamics(_CommunicationMixin)
└── FusedStage(BaseDynamics)
``BaseDynamics`` inherits from ``_CommunicationMixin``, so all dynamics
subclasses automatically have communication capabilities for pipeline
execution without needing explicit multiple inheritance.
"""
from __future__ import annotations
import sys
from collections import defaultdict
from collections.abc import Callable, Sequence
from enum import Enum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
Protocol,
TypeAlias,
runtime_checkable,
)
import torch
from jaxtyping import Bool
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
from torch import distributed as dist
from nvalchemi._typing import AtomsLike, ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import BaseModelMixin
if TYPE_CHECKING:
from nvalchemi.dynamics.sampler import SizeAwareSampler
from nvalchemi.dynamics.sinks import DataSink
__all__ = [
"Hook",
"HookStageEnum",
"ConvergenceHook",
"DistributedPipeline",
"BufferConfig",
]
class BufferConfig(BaseModel):
"""Buffer capacities for pipeline communication.
Required by :class:`_CommunicationMixin` whenever the stage
participates in inter-rank communication (i.e. ``prior_rank`` or
``next_rank`` is set). Buffers are lazily created via
``Batch.empty()`` on the first simulation step, once a concrete
batch is available as a template.
Attributes
----------
num_systems : int
Maximum number of graphs the buffer can hold.
num_nodes : int
Total node (atom) capacity across all graphs.
num_edges : int
Total edge capacity across all graphs.
"""
num_systems: Annotated[
int, Field(ge=0, description="Maximum number of graphs the buffer can hold.")
]
num_nodes: Annotated[
int, Field(ge=0, description="Total node (atom) capacity across all graphs.")
]
num_edges: Annotated[
int, Field(ge=0, description="Total edge capacity across all graphs.")
]
[docs]
class HookStageEnum(Enum):
"""
Enumeration of stages in the dynamics step where hooks can be executed.
Each stage corresponds to a specific point in the simulation step,
allowing hooks to be triggered before or after key operations.
Attributes
----------
BEFORE_STEP : int
Fired at the very beginning of a step, before any operations.
BEFORE_PRE_UPDATE : int
Fired before the pre_update (first half of integrator) is called.
AFTER_PRE_UPDATE : int
Fired after the pre_update completes.
BEFORE_COMPUTE : int
Fired before the model forward pass (force/energy computation).
AFTER_COMPUTE : int
Fired after the model forward pass completes.
BEFORE_POST_UPDATE : int
Fired before the post_update (second half of integrator) is called.
AFTER_POST_UPDATE : int
Fired after the post_update completes.
AFTER_STEP : int
Fired at the very end of a step, after all operations.
ON_CONVERGE : int
Fired when a convergence criterion is met (e.g., for optimizers).
"""
BEFORE_STEP = 0
BEFORE_PRE_UPDATE = 1
AFTER_PRE_UPDATE = 2
BEFORE_COMPUTE = 3
AFTER_COMPUTE = 4
BEFORE_POST_UPDATE = 5
AFTER_POST_UPDATE = 6
AFTER_STEP = 7
ON_CONVERGE = 8
[docs]
@runtime_checkable
class Hook(Protocol):
"""
Protocol defining the interface for dynamics hooks.
Hooks are callable objects that can be registered with a dynamics
engine to perform custom operations at specific stages of the
simulation. They are executed in-place and can modify the batch.
Users are expected to be able to develop their own hooks either
by subclassing the `Hook` protocol class, or simply by ensuring
that the class they intend to use as a hook provides the expected
signature.
Attributes
----------
frequency : int
Execute the hook every N steps. A frequency of 1 means every step,
2 means every other step, etc.
stage : HookStageEnum
The stage at which this hook should be fired.
"""
frequency: int
stage: HookStageEnum
def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None:
"""
Execute the hook operation.
Parameters
----------
batch : Batch
The current batch of atomic data, modified in-place.
dynamics : BaseDynamics
The dynamics engine instance, providing access to model,
step count, and other state.
"""
...
class _ConvergenceCriterion(BaseModel):
"""A single convergence criterion evaluated against a tensor key on ``Batch``.
This is an internal model and should not be instantiated directly by
users. Instead, pass ``dict`` mappings to :class:`ConvergenceHook`,
which will validate and construct ``_ConvergenceCriterion`` instances
automatically.
The evaluation pipeline is:
1. Retrieve ``getattr(batch, key)``; raise ``KeyError`` if absent.
2. If ``custom_op`` is provided, delegate entirely to it and return.
3. If the tensor is node-level (its first dimension matches
``batch.num_nodes``), scatter-reduce it to graph-level using
``batch.batch`` as the group index.
4. Otherwise the tensor is assumed to be graph-level and is squeezed
to 1-D ``(B,)`` if it has a trailing singleton dimension.
5. If ``reduce_op`` is not ``None``, apply the requested reduction
along ``reduce_dims`` **before** step 3/4 (i.e. within each
node / graph entry).
6. Compare the resulting ``(B,)`` tensor against ``threshold``.
Attributes
----------
key : str
Tensor key to measure convergence against (e.g. ``"fmax"``).
threshold : float
Convergence threshold; values ≤ this are considered converged.
reduce_dims : int | list[int]
Dimension(s) to reduce over when ``reduce_op`` is not ``None``.
Defaults to ``-1``.
reduce_op : {``"min"``, ``"max"``, ``"norm"``, ``"mean"``, ``"sum"``} or ``None``
Reduction applied to the raw tensor before the graph-level
aggregation. ``None`` (default) skips this step and expects
the key to already be at graph-level or to be a node-level
vector that will be scatter-reduced.
custom_op : Callable[[torch.Tensor], Bool[Tensor, " B"]] | None
Custom callable that receives the raw tensor and must return
a boolean ``(B,)`` mask. When provided, ``reduce_op``,
``reduce_dims``, and ``threshold`` are ignored.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
key: Annotated[str, Field(description="Tensor key to measure convergence against.")]
threshold: Annotated[
float,
Field(
description="Threshold for convergence; values"
" smaller than or equal to `threshold` are considered converged."
),
]
reduce_dims: Annotated[
int | list[int], Field(description="Dimension(s) to reduce over.")
] = -1
reduce_op: Annotated[
Literal["min", "max", "norm", "mean", "sum"] | None,
Field(
description="Operation used to reduce non-scalar criteria."
" None skips reductions and expects the key to already be"
" graph-level or a node-level vector suitable for"
" scatter-reduce."
),
] = None
custom_op: Annotated[
Callable[[torch.Tensor], Bool[torch.Tensor, " B"]] | None, # noqa: F722, F821
Field(
description="Custom operation that wraps the convergence"
" logic and returns a bool tensor indicating which samples"
" have converged."
),
] = None
def __repr__(self) -> str:
"""Return a human-readable summary of this criterion."""
if self.custom_op is not None:
op_name = getattr(self.custom_op, "__name__", repr(self.custom_op))
return f"_ConvergenceCriterion(key={self.key!r}, custom_op={op_name})"
parts = [f"key={self.key!r}", f"threshold={self.threshold}"]
if self.reduce_op is not None:
parts.append(f"reduce_op={self.reduce_op!r}")
parts.append(f"reduce_dims={self.reduce_dims!r}")
return f"_ConvergenceCriterion({', '.join(parts)})"
def _reduce_within_entry(self, target: torch.Tensor) -> torch.Tensor:
"""Apply ``reduce_op`` along ``reduce_dims`` within each entry.
Parameters
----------
target : torch.Tensor
The raw tensor retrieved from the batch.
Returns
-------
torch.Tensor
The reduced tensor. If ``reduce_op`` is ``None``, the
input is returned unchanged.
"""
if self.reduce_op is None:
return target
match self.reduce_op:
case "min":
return torch.amin(target, self.reduce_dims)
case "max":
return torch.amax(target, self.reduce_dims)
case "norm":
return torch.linalg.vector_norm(target, dim=self.reduce_dims)
case "mean":
return torch.mean(target, dim=self.reduce_dims)
case "sum":
return torch.sum(target, dim=self.reduce_dims)
# Unreachable because of the Literal type, but satisfies the
# type checker and gives a clear error for bad runtime values.
raise ValueError(f"Unknown reduce_op: {self.reduce_op!r}")
@staticmethod
def _scatter_reduce_to_graph(
values: torch.Tensor,
batch_idx: torch.Tensor,
num_graphs: int,
) -> torch.Tensor:
"""Scatter-reduce a 1-D node-level tensor to graph-level via max.
Parameters
----------
values : torch.Tensor
1-D tensor of shape ``(V,)`` with per-node values.
batch_idx : torch.Tensor
Integer tensor mapping each node to its graph index.
num_graphs : int
Number of graphs in the batch.
Returns
-------
torch.Tensor
1-D tensor of shape ``(B,)`` with per-graph reduced values
(using ``max`` as the scatter operation).
"""
out = torch.full(
(num_graphs,), float("-inf"), dtype=values.dtype, device=values.device
)
out.scatter_reduce_(0, batch_idx, values, reduce="amax", include_self=False)
return out
def __call__(self, batch: Batch) -> Bool[torch.Tensor, " B"]: # noqa: F722, F821
"""Evaluate this criterion against a batch.
Parameters
----------
batch : Batch
The current batch of atomic data.
Returns
-------
Bool[Tensor, " B"]
Per-sample boolean mask where ``True`` indicates that the
sample satisfies this convergence criterion.
Raises
------
KeyError
If ``self.key`` is not present on ``batch``.
"""
target: torch.Tensor | None = getattr(batch, self.key, None)
if target is None:
available = list(batch.model_dump(exclude_none=True))
raise KeyError(
"Key for convergence check not found;"
f" expected={self.key!r}, available={available}"
)
if self.custom_op is not None:
return self.custom_op(target)
target = self._reduce_within_entry(target)
is_node_level = (
target.shape[0] == batch.num_nodes and batch.num_nodes != batch.num_graphs
)
if is_node_level:
if target.dim() > 1:
target = target.view(target.shape[0], -1).amax(dim=-1)
reduced = self._scatter_reduce_to_graph(
target, batch.batch, batch.num_graphs
)
else:
reduced = target.squeeze(-1) if target.dim() == 2 else target
return reduced <= self.threshold
CommMode: TypeAlias = Literal["sync", "async_recv", "fully_async"]
class _CommunicationMixin:
"""Base class providing inter-rank communication and buffer management.
``BaseDynamics`` inherits from this class, so all dynamics subclasses
automatically have communication capabilities for pipeline execution.
This class manages active batch buffers, overflow sinks, and inter-rank
communication for distributed pipeline execution.
Parameters
----------
prior_rank : int | None
Rank to receive data from. ``None`` marks this stage as the
first in its sub-pipeline (no upstream). Defaults to ``-1``
(unset), which tells :meth:`DistributedPipeline.setup` to
auto-assign based on stage ordering. Set explicitly to
``None`` or a rank integer to prevent auto-assignment.
next_rank : int | None
Rank to send graduated samples to. ``None`` marks this stage
as the last in its sub-pipeline (no downstream). Defaults to
``-1`` (unset), with the same auto-assignment semantics as
``prior_rank``.
sinks : list[DataSink] | None
Priority-ordered overflow sinks.
active_batch : Batch | None
The currently active working batch.
max_batch_size : int
Maximum samples in the active batch.
done : bool
Whether this stage has no more work.
device_type : str
Device type string (e.g., ``"cuda"``, ``"cpu"``).
comm_mode : CommMode
Communication mode for inter-rank buffer synchronization.
One of ``"sync"``, ``"async_recv"``, or ``"fully_async"``.
Default ``"sync"``.
buffer_config : BufferConfig
Pre-allocation capacities for send/recv communication buffers.
**Required** when ``prior_rank`` or ``next_rank`` is set to a
valid rank. A ``ValueError`` is raised at construction if
omitted for a stage that has neighbors.
**kwargs
Forwarded to the next class in the MRO (cooperative init).
Attributes
----------
prior_rank : int | None
Rank of the previous pipeline stage. ``-1`` means unset
(will be auto-assigned by :meth:`DistributedPipeline.setup`),
``None`` means no upstream, and a non-negative integer is the
explicit source rank.
next_rank : int | None
Rank of the next pipeline stage, with the same conventions
as ``prior_rank``.
sinks : list[DataSink]
Overflow sinks in priority order.
active_batch : Batch | None
Current working batch.
max_batch_size : int
Maximum active batch capacity.
done : bool
Whether this stage is finished.
sampler : SizeAwareSampler | None
Size-aware sampler for inflight batching, or ``None`` for
external batch handling (i.e. the typical looping over
dataloader approach). Defaults to ``None``.
refill_frequency : int
How often to check and refill graduated samples from the
sampler; no-op if ``sampler`` is not provided.
device_type : str
Device type string (e.g., ``"cuda"``, ``"cpu"``).
comm_mode : CommMode
Communication mode for inter-rank buffer synchronization.
buffer_config : BufferConfig | None
Buffer capacities, or ``None`` for isolated stages.
_pending_recv_handle : Any
Stored ``irecv`` handle when receive is deferred (non-sync modes).
``None`` when no receive is pending.
_pending_send_handle : Any
Stored ``isend`` handle when send is deferred (``"fully_async"``).
``None`` when no send is pending.
_stream : torch.cuda.Stream | None
The CUDA stream created when entering the context manager.
``None`` when outside a ``with`` block or on non-CUDA devices.
_stream_ctx : torch.cuda.StreamContext | None
The active stream context wrapping ``_stream``.
``None`` when outside a ``with`` block or on non-CUDA devices.
Examples
--------
>>> from nvalchemi.dynamics.base import BaseDynamics, BufferConfig
>>> cfg = BufferConfig(num_systems=10, num_nodes=500, num_edges=2000)
>>> dyn = BaseDynamics(model=model, prior_rank=0, buffer_config=cfg, max_batch_size=50)
>>> dyn.is_first_stage
False
"""
def __init__(
self,
*,
prior_rank: int | None = -1,
next_rank: int | None = -1,
sinks: Sequence[DataSink] | None = None,
active_batch: Batch | None = None,
max_batch_size: int = 100,
done: bool = False,
sampler: SizeAwareSampler | None = None,
refill_frequency: int = 1,
device_type: str | None = None,
comm_mode: CommMode = "async_recv",
buffer_config: BufferConfig | None = None,
debug_mode: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the communication mixin.
Parameters
----------
prior_rank : int | None, optional
Rank to receive data from (previous stage). Default None.
next_rank : int | None, optional
Rank to send graduated samples to (next stage). Default None.
sinks : Sequence[DataSink] | None, optional
Priority-ordered overflow sinks. Default None (empty list).
active_batch : Batch | None, optional
The currently active working batch. Default None.
max_batch_size : int, optional
Maximum samples in the active batch. Default 100.
done : bool, optional
Whether this stage has no more work. Default False.
sampler : SizeAwareSampler | None, optional
Size-aware sampler for inflight batching. When provided,
enables inflight batching where graduated (converged/finished)
samples are automatically replaced with fresh ones from the dataset.
Default ``None``, which expects batches to come from dataloaders.
refill_frequency : int, optional
How often (in steps) to check for graduated samples and
request replacements from the sampler. Only used when
``sampler`` is not None. Default 1.
device_type : str, optional
Device type string (e.g., ``"cuda"``, ``"cpu"``). Defaults
to ``None``, which will perform auto placement.
comm_mode : CommMode, optional
Communication mode controlling blocking behavior of inter-rank
buffer synchronization. ``"sync"`` (default) blocks on receive
immediately. ``"async_recv"`` defers the receive wait until
``_complete_pending_recv`` is called. ``"fully_async"``
additionally stores the send handle and drains it at the start
of the next ``_prestep_sync_buffers`` call.
buffer_config : BufferConfig | None, optional
Pre-allocation capacities for send/recv buffers. Buffers
are created lazily via ``Batch.empty()`` on the first step
using the first concrete batch as a template. **Required**
when ``prior_rank`` or ``next_rank`` is a valid rank;
raises ``ValueError`` otherwise. Default ``None`` (only
valid for isolated stages with no neighbors).
debug_mode : bool, optional
When ``True``, emit detailed ``loguru.debug`` diagnostics
for inter-rank communication. Default ``False``.
**kwargs : Any
Forwarded to the next class in the MRO (cooperative init).
"""
super().__init__(**kwargs)
self.prior_rank = prior_rank
self.next_rank = next_rank
self.sinks: list[DataSink] = list(sinks) if sinks is not None else []
self.active_batch = active_batch
self.max_batch_size = max_batch_size
self.done = done
self.sampler = sampler
self.refill_frequency = refill_frequency
if not device_type:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
self.device_type = device_type
if comm_mode not in ("sync", "async_recv", "fully_async"):
raise ValueError(
f"Invalid comm_mode={comm_mode!r}. "
f"Expected one of: 'sync', 'async_recv', 'fully_async'."
)
self.comm_mode: CommMode = comm_mode
self._pending_recv_handle: Any = None
self._pending_send_handle: Any = None
self._stream: torch.cuda.Stream | None = None
self._stream_ctx: torch.cuda.StreamContext | None = None
if isinstance(buffer_config, dict):
buffer_config = BufferConfig(**buffer_config)
if buffer_config is not None and not isinstance(buffer_config, BufferConfig):
raise TypeError(
f"Buffer configuration invalid; got a {type(buffer_config)} object."
)
self.buffer_config = buffer_config
if self.has_neighbor and self.buffer_config is None:
raise ValueError(
"buffer_config is required when prior_rank or next_rank is set. "
"Pre-allocated buffers are mandatory for inter-rank communication."
)
self.send_buffer: Batch | None = None
self.recv_buffer: Batch | None = None
self._recv_template: Batch | None = None
self.debug_mode = debug_mode
@property
def has_neighbor(self) -> bool:
"""Convenient property to see if rank is isolated"""
next_rank = self.next_rank is not None and self.next_rank != -1
prior_rank = self.prior_rank is not None and self.prior_rank != -1
return next_rank or prior_rank
def _ensure_buffers(self, template: Batch) -> None:
"""Lazily create send/recv buffers from the first concrete batch.
Called automatically at the start of the first communication
step when ``buffer_config`` is set. Uses *template* (the first
real batch) to determine attribute keys, dtypes, and trailing
shapes for ``Batch.empty()``.
Parameters
----------
template : Batch
A concrete batch to use as a template for buffer creation.
"""
if self.buffer_config is None:
return
cfg = self.buffer_config
if self.send_buffer is None and self.next_rank is not None:
self.send_buffer = Batch.empty(
num_systems=cfg.num_systems,
num_nodes=cfg.num_nodes,
num_edges=cfg.num_edges,
template=template,
device=self.device,
)
if self.recv_buffer is None and self.prior_rank is not None:
self.recv_buffer = Batch.empty(
num_systems=cfg.num_systems,
num_nodes=cfg.num_nodes,
num_edges=cfg.num_edges,
template=template,
device=self.device,
)
@property
def is_final_stage(self) -> bool:
"""Return whether this is the last stage in the pipeline.
Returns
-------
bool
``True`` if ``next_rank`` is ``None``.
"""
return self.next_rank is None
@property
def is_first_stage(self) -> bool:
"""Return whether this is the first stage in the pipeline.
Returns
-------
bool
``True`` if ``prior_rank`` is ``None``.
"""
return self.prior_rank is None
@property
def inflight_mode(self) -> bool:
"""Return whether inflight batching is enabled.
Inflight batching is enabled when a sampler is configured.
In this mode, graduated samples are automatically replaced with
fresh ones from the dataset.
Returns
-------
bool
``True`` if ``sampler`` is not ``None``.
"""
return self.sampler is not None
@property
def local_rank(self) -> int:
"""Get the node-local rank for this process.
Returns ``0`` when ``torch.distributed`` is not initialized.
Returns
-------
int
The local rank on this node.
"""
if dist.is_initialized():
return dist.get_node_local_rank()
return 0
@property
def device(self) -> torch.device:
"""Compute the torch device for this rank.
For CUDA-like devices (``"cuda"``, ``"xpu"``, etc.), returns
``torch.device(f"{device_type}:{local_rank}")``. For CPU,
returns ``torch.device("cpu")`` because PyTorch CPU devices
do not use ordinal indices — all ranks share a single CPU
device.
Returns
-------
torch.device
Device for this rank.
"""
match self.device_type:
case "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
"Requested CUDA device type, but not available to PyTorch."
)
return torch.device(f"cuda:{self.local_rank}")
case "cpu":
return torch.device("cpu")
case _:
try:
device = torch.device(f"{self.device_type}:{self.local_rank}")
return device
except Exception as e:
raise RuntimeError(
f"Unable to create device={self.device_type}:{self.local_rank}"
f" with exception: {e}"
)
@property
def stream(self) -> torch.cuda.Stream | None:
"""Return the active CUDA stream, if any.
Returns ``None`` when outside a ``with`` block or on non-CUDA
devices.
Returns
-------
torch.cuda.Stream | None
The CUDA stream created by ``__enter__``, or ``None``.
"""
return self._stream
def __enter__(self) -> _CommunicationMixin:
"""Enter the stream context manager.
On CUDA devices, creates a new ``torch.cuda.Stream`` and enters
a ``torch.cuda.StreamContext`` so that all subsequent GPU
operations execute on the dedicated stream. On non-CUDA devices
this is a no-op.
Returns
-------
_CommunicationMixin
This instance.
"""
if self.device_type == "cuda" and torch.cuda.is_available():
self._stream = torch.cuda.Stream(device=self.device)
self._stream_ctx = torch.cuda.stream(self._stream)
self._stream_ctx.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Exit the stream context manager.
Exits the ``torch.cuda.StreamContext`` (if one was entered) and
clears the stored stream references.
Parameters
----------
exc_type : type[BaseException] | None
Exception type, if any.
exc_val : BaseException | None
Exception value, if any.
exc_tb : Any
Exception traceback, if any.
"""
if self._stream_ctx is not None:
self._stream_ctx.__exit__(exc_type, exc_val, exc_tb)
self._stream = None
self._stream_ctx = None
@property
def active_batch_size(self) -> int:
"""Return the number of samples currently in the active batch.
Returns
-------
int
Number of graphs in the active batch, or 0 if no batch.
"""
if self.active_batch is None:
return 0
return self.active_batch.num_graphs or 0
@property
def active_batch_has_room(self) -> bool:
"""Return whether the active batch can accept more samples.
Returns
-------
bool
``True`` if the active batch is below ``max_batch_size``.
"""
return self.active_batch_size < self.max_batch_size
@property
def room_in_active_batch(self) -> int:
"""Return the number of additional samples the active batch can hold.
Returns
-------
int
Remaining capacity.
"""
return max(0, self.max_batch_size - self.active_batch_size)
@property
def _send_buffer_capacity(self) -> int:
"""Return the number of additional graphs the send buffer can accept.
When ``send_buffer`` is ``None`` (no pre-allocated buffer), returns
``sys.maxsize`` to indicate no capacity constraint—the system sends
live batches directly without a fixed-size buffer.
Returns
-------
int
Remaining capacity in the send buffer, or ``sys.maxsize`` when
there is no pre-allocated send buffer.
"""
if self.send_buffer is None:
return sys.maxsize
return self.send_buffer.system_capacity - self.send_buffer.num_graphs
def _buffer_to_batch(self, incoming_batch: Batch) -> None:
"""Route received data into the active batch or overflow sinks.
If the active batch has room, samples are appended directly.
Otherwise, excess samples are written to overflow sinks in
priority order.
Parameters
----------
incoming_batch : Batch
Batch of samples received from the prior stage.
"""
if incoming_batch.num_graphs == 0:
return
if self.active_batch is None:
if incoming_batch.num_graphs <= self.max_batch_size:
# reform the batch without padding
self.active_batch = Batch.from_data_list(
incoming_batch.to_data_list(), device=incoming_batch.device
)
else:
# slice out samples that will fit in the active batch
# and move the rest to overflow
data_list = incoming_batch.to_data_list()
fit = data_list[: self.max_batch_size]
overflow = data_list[self.max_batch_size :]
self.active_batch = Batch.from_data_list(
fit, device=incoming_batch.device
)
self._overflow_to_sinks(
Batch.from_data_list(overflow, device=incoming_batch.device)
)
return
room = self.room_in_active_batch
if room <= 0:
self._overflow_to_sinks(incoming_batch)
return
data_list = incoming_batch.to_data_list()
if len(data_list) <= room:
existing = self.active_batch.to_data_list()
self.active_batch = Batch.from_data_list(
existing + data_list, device=incoming_batch.device
)
else:
fit = data_list[:room]
overflow = data_list[room:]
existing = self.active_batch.to_data_list()
self.active_batch = Batch.from_data_list(
existing + fit, device=incoming_batch.device
)
self._overflow_to_sinks(
Batch.from_data_list(overflow, device=incoming_batch.device)
)
def _recv_to_batch(self, incoming: Batch) -> None:
"""Stage incoming data through the recv buffer into the active batch.
When ``recv_buffer`` is available, copies *incoming* data into the
pre-allocated receive buffer via :meth:`Batch.put`, then routes the
buffer contents into the active batch via :meth:`_buffer_to_batch`.
When ``recv_buffer`` is ``None``, falls back to routing *incoming*
directly.
Parameters
----------
incoming : Batch
Batch received from the prior stage (via ``irecv`` / ``wait``).
"""
if incoming.num_graphs > 0 and self._recv_template is None:
self._recv_template = incoming
if self.recv_buffer is not None and incoming.num_graphs > 0:
mask = torch.ones(incoming.num_graphs, dtype=torch.bool, device=self.device)
self.recv_buffer.put(incoming, mask=mask)
self._buffer_to_batch(self.recv_buffer)
self.recv_buffer.zero()
else:
self._buffer_to_batch(incoming)
def _overflow_to_sinks(
self, batch: Batch, mask: torch.Tensor | None = None
) -> None:
"""Write overflow samples to the first sink with available capacity.
Parameters
----------
batch : Batch
Overflow samples to store.
mask : torch.Tensor | None, optional
Boolean mask for selective writing. Forwarded to sink.write().
Raises
------
RuntimeError
If no sink has capacity for the overflow.
"""
for sink in self.sinks:
if not sink.is_full:
sink.write(batch, mask=mask)
return
raise RuntimeError(
f"All sinks are full. Cannot store {batch.num_graphs} overflow samples."
)
def _batch_to_buffer(self, mask: torch.Tensor) -> None:
"""Move graduated samples from the active batch into the send buffer.
Uses ``send_buffer.put`` to copy samples where *mask* is ``True``
into the pre-allocated send buffer, then trims the active batch
to a new tight :class:`~nvalchemi.data.Batch` without the
graduated samples (or *None* if all were graduated).
Parameters
----------
mask : torch.Tensor
Boolean mask of shape ``(active_batch.num_graphs,)`` where
``True`` marks a graduated (converged) sample.
Raises
------
RuntimeError
If ``active_batch`` or ``send_buffer`` is ``None``.
"""
if self.active_batch is None:
raise RuntimeError("No active batch to extract from.")
if self.send_buffer is None:
raise RuntimeError("No send buffer to write to.")
self.send_buffer.put(self.active_batch, mask=mask)
self.active_batch = self.active_batch.trim(copied_mask=mask)
def _drain_sinks_to_batch(self) -> None:
"""Pull samples from overflow sinks into the active batch.
Iterates through sinks in priority order. For each non-empty
sink, drains its contents and routes them into the active batch
via :meth:`_buffer_to_batch`. Stops early when the active batch
has no more room.
If the drained batch is larger than the remaining room,
``_buffer_to_batch`` handles the partial-fit logic (accepts what
fits, overflows the rest back to sinks).
Notes
-----
Called by ``_prestep_sync_buffers`` after processing incoming
data from the prior rank, to backfill any remaining capacity
with previously overflowed samples.
"""
for sink in self.sinks:
if self.room_in_active_batch <= 0:
break
if len(sink) == 0:
continue
overflow = sink.drain()
self._buffer_to_batch(overflow)
def _prestep_sync_buffers(self) -> None:
"""Synchronize buffers before a dynamics step.
If this stage has a prior rank, zeros the send buffer and receive
buffer, then receives data from the prior stage via ``Batch.irecv``.
In ``"sync"`` mode the receive completes inline and incoming
data is staged through the receive buffer (if present) into the
active batch via :meth:`_recv_to_batch`. In ``"async_recv"`` and
``"fully_async"`` modes the receive handle is stored in
``_pending_recv_handle`` and the caller must invoke
``_complete_pending_recv`` before accessing ``active_batch``.
In ``"fully_async"`` mode, any pending send handle from the
previous iteration is drained (awaited) at the top of this
method before posting the new receive.
After processing incoming data (or when there is no prior rank),
any remaining capacity in the active batch is backfilled from
overflow sinks via :meth:`_drain_sinks_to_batch`.
Notes
-----
This method should be called before ``dynamics.step()`` in the
pipeline loop. When using a non-sync ``comm_mode``, call
``_complete_pending_recv`` between this method and ``step()``.
"""
if self.comm_mode == "fully_async" and self._pending_send_handle is not None:
if self.debug_mode:
logger.debug("[rank {}] draining pending async send", self.global_rank)
self._pending_send_handle.wait()
self._pending_send_handle = None
if self.prior_rank is not None:
if self.send_buffer is not None:
self.send_buffer.zero()
if self.recv_buffer is not None:
self.recv_buffer.zero()
template = self.recv_buffer or self._recv_template
if self.debug_mode:
logger.debug(
"[rank {}] posting irecv from rank {} (template={})",
self.global_rank,
self.prior_rank,
template is not None,
)
handle = Batch.irecv(
src=self.prior_rank, device=self.device, template=template
)
if self.comm_mode == "sync":
incoming = handle.wait()
if self.debug_mode:
logger.debug(
"[rank {}] sync recv complete, {} graphs from rank {}",
self.global_rank,
incoming.num_graphs,
self.prior_rank,
)
self._recv_to_batch(incoming)
else:
self._pending_recv_handle = handle
# In async modes, drain happens in _complete_pending_recv after recv.
if self.comm_mode == "sync" or self.prior_rank is None:
self._drain_sinks_to_batch()
def _complete_pending_recv(self) -> None:
"""Finalize any deferred receive before compute needs the data.
In ``"sync"`` mode this is a no-op because
``_prestep_sync_buffers`` already completed the receive inline.
In ``"async_recv"`` and ``"fully_async"`` modes, this method
calls ``wait()`` on the stored receive handle and stages the
incoming batch through the receive buffer (if present) into the
active batch via :meth:`_recv_to_batch`.
After processing incoming data, any remaining capacity in the
active batch is backfilled from overflow sinks via
:meth:`_drain_sinks_to_batch`.
Notes
-----
Must be called after ``_prestep_sync_buffers`` and before any
method that reads ``active_batch`` (e.g., ``step()``).
"""
if self._pending_recv_handle is not None:
if self.debug_mode:
logger.debug(
"[rank {}] waiting on pending async recv", self.global_rank
)
incoming = self._pending_recv_handle.wait()
if self.debug_mode:
logger.debug(
"[rank {}] async recv complete, {} graphs",
self.global_rank,
incoming.num_graphs,
)
self._recv_to_batch(incoming)
self._pending_recv_handle = None
self._drain_sinks_to_batch()
def _manage_send_handle(self, handle: Any) -> None:
"""Store or wait on a send handle based on communication mode.
Parameters
----------
handle
The send handle returned by ``Batch.isend``.
"""
if self.comm_mode == "fully_async":
self._pending_send_handle = handle
else:
handle.wait()
def _populate_send_buffer(self, converged_indices: torch.Tensor) -> None:
"""Populate the send buffer with converged graphs.
Creates a boolean mask from the converged indices and copies those
graphs into the send buffer. Does NOT send — the caller is responsible
for issuing the ``isend``.
Parameters
----------
converged_indices : torch.Tensor
Integer indices of converged samples (already truncated to capacity).
"""
mask = torch.zeros(
self.active_batch.num_graphs,
dtype=torch.bool,
device=self.device,
)
mask[converged_indices] = True
self._batch_to_buffer(mask)
if self.debug_mode:
logger.debug(
"[rank {}] populated send buffer with {} converged graphs",
self.global_rank,
converged_indices.numel(),
)
def _remove_converged_final_stage(self, converged_indices: torch.Tensor) -> None:
"""Remove converged graphs on the final stage and route to sinks.
Extracts converged graphs, removes them from the active batch,
and writes them to configured sinks if available.
Parameters
----------
converged_indices : torch.Tensor
Integer indices of converged samples.
"""
graduated = self.active_batch.index_select(converged_indices)
all_indices = set(range(self.active_batch.num_graphs))
remaining = sorted(all_indices - set(converged_indices.tolist()))
if remaining:
self.active_batch = self.active_batch.index_select(remaining)
else:
self.active_batch = None
if self.debug_mode:
logger.debug(
"[rank {}] final stage, {} converged graphs removed",
self.global_rank,
converged_indices.numel(),
)
if self.sinks:
self._overflow_to_sinks(graduated)
def _poststep_sync_buffers(
self, converged_indices: torch.Tensor | None = None
) -> None:
"""Synchronize buffers after a dynamics step.
If ``converged_indices`` is provided and a next rank exists with
available capacity, the converged samples are copied into
``send_buffer`` via :meth:`_populate_send_buffer`. The send
buffer is then unconditionally sent to the next rank — even if
empty (``num_graphs == 0`` after zeroing) — so the downstream
``irecv`` always completes without deadlock.
On the final stage, converged samples are extracted via
``index_select`` and written to the first available sink.
Back-pressure behavior
----------------------
Only as many converged samples as fit in the remaining buffer
capacity are copied and sent. Excess converged samples remain
in the active batch and become no-ops until the next step when
buffer capacity may be available.
Parameters
----------
converged_indices : torch.Tensor | None, optional
Integer indices of converged samples in the active batch.
Typically obtained from ``BaseDynamics._check_convergence()``.
If ``None``, no samples are graduated.
"""
has_converged = converged_indices is not None and converged_indices.numel() > 0
if has_converged:
if self.next_rank is not None:
send_capacity = self._send_buffer_capacity
if send_capacity > 0:
if converged_indices.numel() > send_capacity:
converged_indices = converged_indices[:send_capacity]
self._populate_send_buffer(converged_indices)
if self.is_final_stage:
self._remove_converged_final_stage(converged_indices)
if self.next_rank is not None:
handle = self.send_buffer.isend(dst=self.next_rank)
self._manage_send_handle(handle)
@property
def global_rank(self) -> int:
"""Get the global rank for this process.
Returns
-------
int
Global rank across all nodes, or 0 if distributed is not initialized.
"""
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
return rank
def __or__(self, other: BaseDynamics) -> DistributedPipeline:
"""Compose two stages into a ``DistributedPipeline`` via ``stage_a | stage_b``.
Chaining is supported::
a | b | c → (a | b) | c
The first ``|`` creates a two-stage pipeline via this method.
Subsequent ``|`` calls hit ``DistributedPipeline.__or__``,
which appends stages and re-wires source/sink dependencies.
Parameters
----------
other : BaseDynamics
The next stage to chain after this one.
Returns
-------
DistributedPipeline
A pipeline containing both stages mapped to sequential ranks.
Source/sink dependencies (``prior_rank`` / ``next_rank``) are
wired when ``DistributedPipeline.setup()`` is called (e.g.
via the context manager or ``run()``).
"""
return DistributedPipeline(stages={0: self, 1: other})
def __add__(self, other: BaseDynamics) -> FusedStage:
"""Fuse two dynamics into a ``FusedStage`` via ``dyn_a + dyn_b``.
Creates a ``FusedStage`` where both dynamics share a single batch
and forward pass. Each dynamics applies masked updates to samples
based on their status code.
Parameters
----------
other : BaseDynamics
The dynamics to fuse with this one.
Returns
-------
FusedStage
A fused stage containing both dynamics with status codes 0 and 1.
Raises
------
TypeError
If either ``self`` or ``other`` is not a ``BaseDynamics`` instance.
"""
# FusedStage is defined later in this file
if not isinstance(self, BaseDynamics):
raise TypeError(
"Both operands of + must be BaseDynamics instances. "
f"self is {type(self).__name__}, not BaseDynamics."
)
if not isinstance(other, BaseDynamics):
raise TypeError(
"Both operands of + must be BaseDynamics instances. "
f"other is {type(other).__name__}, not BaseDynamics."
)
return FusedStage(sub_stages=[(0, self), (1, other)])
[docs]
class BaseDynamics(_CommunicationMixin):
"""Base class for all dynamics simulations.
This class coordinates a ``BaseModelMixin`` model with a numerical
integrator to evolve a ``Batch`` of atomic systems over time. It manages
the step loop, hook execution at stage boundaries, and model evaluation.
``BaseDynamics`` inherits from ``_CommunicationMixin``, which provides
inter-rank communication and buffer management for pipeline execution.
All dynamics subclasses automatically have communication capabilities.
The public interface centers on three methods. ``run(batch)``
is the top-level entry point: it repeatedly calls ``step()`` for
``n_steps`` iterations and is the only method most users need.
``n_steps`` can be set at construction time or passed to ``run()``.
``step(batch)`` executes a single simulation step, orchestrating the
full hook-wrapped sequence ``pre_update → compute → post_update``, with
hooks fired at each stage boundary, followed by convergence checking.
Subclasses should generally NOT override ``step``.
``compute(batch)`` performs the model forward pass: it calls
``model(batch)`` which must return a fully adapted ``ModelOutputs`` dict,
validates outputs against ``__needs_keys__``, and writes results (forces,
energies, stresses) back to the batch in-place.
Subclasses should generally NOT override ``compute``.
Attributes
----------
model : BaseModelMixin
The neural network potential model.
step_count : int
The current step number, starting from 0.
hooks : dict[HookStageEnum, list[Hook]]
Dictionary mapping each stage to a list of registered hooks.
model_is_conservative : bool
Indicates that the model uses automatic differentiation
to obtain forces.
convergence_hook : ConvergenceHook
Hook that evaluates composable convergence criteria. Defaults
to a single ``fmax`` criterion with threshold ``0.05``.
n_steps : int | None
Total number of simulation steps for ``run()``. ``None``
means the step count must be supplied when calling ``run()``.
exit_status : int
Status code threshold for graduated samples. Samples with
``status >= exit_status`` are treated as no-ops during
``step()`` — their positions and velocities are preserved
through the integrator. Default is 1.
__needs_keys__ : set[str]
Set of output keys that this dynamics requires from the model.
Empty by default on ``BaseDynamics``. Subclasses declare their
own requirements (e.g., typically forces for optimization and MD).
Checked in ``_validate_model_outputs()`` after each forward pass.
__provides_keys__ : set[str]
Set of keys that this dynamics produces or updates on the batch
beyond model outputs. Empty by default. Subclasses declare what
additional state they provide (e.g., ``{"velocities", "positions"}``
for velocity verlet). Used for validation and buffer preallocation.
Notes
-----
Developers implementing a new integrator should override
``pre_update(batch)`` and ``post_update(batch)`` to implement the
integration scheme. These are called around ``compute()`` — ``pre_update``
before, ``post_update`` after. For example, Velocity Verlet updates
positions in ``pre_update`` and velocities in ``post_update``. The
class-level sets ``__needs_keys__`` and ``__provides_keys__`` declare
what outputs the dynamics requires from the model and what additional
state it produces; requirements are checked in ``_validate_model_outputs()``
after each forward pass.
``masked_update(batch, mask)`` is used by ``FusedStage`` to apply
``pre_update``/``post_update`` only to a subset of samples in a batched
setting. Models must be ``BaseModelMixin`` instances — plain
``nn.Module`` is not accepted.
Examples
--------
>>> model = MyPotentialModel()
>>> dynamics = BaseDynamics(model, n_steps=1000)
>>> dynamics.run(batch)
"""
__needs_keys__: set[str] = set()
__provides_keys__: set[str] = set()
_mutable_fields: tuple[str, ...] = ("positions", "velocities")
_bookkeeping_keys: dict[str, Callable[[int, torch.device], torch.Tensor]] = {
"status": lambda n, dev: torch.zeros(n, 1, dtype=torch.long, device=dev),
"fmax": lambda n, dev: torch.full(
(n, 1), float("inf"), dtype=torch.float32, device=dev
),
"system_id": lambda n, dev: torch.full(
(n, 1), -1, dtype=torch.long, device=dev
),
}
@classmethod
def register_bookkeeping_key(
cls,
key: str,
init_fn: Callable[[int, torch.device], torch.Tensor],
) -> None:
"""Register a graph-level bookkeeping field to survive refill_check.
Parameters
----------
key : str
Field name on Batch.
init_fn : Callable[[int, torch.device], torch.Tensor]
Factory that creates a zero-initialized tensor of shape (n, 1)
for n systems on the given device.
"""
cls._bookkeeping_keys = {**cls._bookkeeping_keys, key: init_fn}
[docs]
def __init__(
self,
model: BaseModelMixin,
hooks: list[Hook] | None = None,
convergence_hook: Any = None,
n_steps: int | None = None,
exit_status: int = 1,
**kwargs: Any,
) -> None:
"""
Initialize the dynamics engine.
Parameters
----------
model : BaseModelMixin
The neural network potential model.
hooks : list[Hook] | None, optional
Initial list of hooks to register. Each hook will be
organized by its `stage` attribute.
convergence_hook : ConvergenceHook | dict | None, optional
Hook that evaluates composable convergence criteria.
n_steps : int | None, optional
Total number of simulation steps. If provided, ``run()``
will use this value when called without an explicit
``n_steps`` argument. Default is ``None``.
If a dict is provided, it is unpacked as
``ConvergenceHook(**convergence_hook)``.
If ``None``, no convergence will be assessed.
exit_status : int, optional
Status code threshold for graduated samples. Samples with
``status >= exit_status`` are treated as no-ops during
``step()`` — their positions and velocities are preserved
through the integrator. Default is 1. Subclasses like
``FusedStage`` may compute this dynamically.
**kwargs : Any
Additional keyword arguments forwarded to the next class
in the MRO (for cooperative multiple inheritance).
"""
super().__init__(**kwargs)
if not isinstance(model, BaseModelMixin):
raise TypeError(
f"Expected a `BaseModelMixin` instance, got {type(model).__name__}."
" Please wrap your model with a `BaseModelMixin` subclass."
)
self.model = model
self.step_count: int = 0
if isinstance(convergence_hook, dict):
convergence_hook = ConvergenceHook(**convergence_hook)
self.convergence_hook = convergence_hook
self.n_steps = n_steps
self.exit_status = exit_status
self.model_card = model.model_card
self.hooks: dict[HookStageEnum, list[Hook]] = defaultdict(list)
self.current_hook_stage: HookStageEnum | None = None
if hooks is not None:
for hook in hooks:
self.register_hook(hook)
self._last_converged: torch.Tensor | None = None
@property
def model_is_conservative(self) -> bool:
"""Returns whether or not the model uses conservative forces"""
return self.model_card.forces_via_autograd
def __repr__(self) -> str:
"""Return a human-readable summary of the dynamics engine."""
cls = type(self).__name__
model_cls = type(self.model).__name__
conservative = self.model_is_conservative
n_hooks = sum(len(h) for h in self.hooks.values())
return (
f"{cls}("
f"model={model_cls}, "
f"n_steps={self.n_steps}, "
f"step_count={self.step_count}, "
f"conservative={conservative}, "
f"convergence_hook={self.convergence_hook!r}, "
f"hooks={n_hooks})"
)
def register_hook(self, hook: Hook) -> None:
"""
Register a hook to be executed at its designated stage(s).
If *hook* exposes a ``stages`` attribute (an iterable of
:class:`HookStageEnum`), the hook is registered at every
listed stage. Otherwise, it is registered at the single
``hook.stage``.
Parameters
----------
hook : Hook
The hook to register. Must have ``stage`` (or ``stages``)
and ``frequency`` attributes.
Raises
------
ValueError
If ``hook.frequency`` is not a positive integer (>= 1).
"""
if not isinstance(hook.frequency, int) or hook.frequency < 1:
raise ValueError(
f"Hook {hook!r} has frequency={hook.frequency!r}. "
"frequency must be a positive integer (>= 1)."
)
stages = getattr(hook, "stages", None)
if stages is not None:
for stage in stages:
self.hooks[stage].append(hook)
else:
self.hooks[hook.stage].append(hook)
def _call_hooks(self, stage: HookStageEnum, batch: Batch) -> None:
"""
Execute all hooks registered for a given stage.
Hooks are only executed if the current step count is divisible
by their frequency. At step_count == 0, all hooks fire since
0 % n == 0 for any n.
The current stage is stored on ``self.current_hook_stage`` so
that multi-stage hooks (registered at several stages via
``stages``) can determine which stage triggered the call.
Parameters
----------
stage : HookStageEnum
The stage for which to execute hooks.
batch : Batch
The current batch of atomic data.
"""
self.current_hook_stage = stage
for hook in self.hooks[stage]:
if self.step_count % hook.frequency == 0:
hook(batch, self)
def _open_hooks(self) -> None:
"""Enter context-manager hooks registered on this stage.
Calls ``__enter__`` on every hook that supports the context-manager
protocol. A ``seen`` set prevents double-entering hooks registered
at multiple stages.
Called automatically at the start of :meth:`run`.
"""
seen: set[int] = set()
for hooks_list in self.hooks.values():
for hook in hooks_list:
hook_id = id(hook)
if hook_id not in seen and hasattr(hook, "__enter__"):
seen.add(hook_id)
hook.__enter__()
def _close_hooks(self) -> None:
"""Exit context-manager hooks, falling back to ``close()`` otherwise.
For hooks that support the context-manager protocol, calls
``__exit__(None, None, None)``. For hooks that only expose a
``close()`` method (e.g. ``ProfilerHook``), calls ``close()``
directly. A ``seen`` set prevents double-closing hooks registered
at multiple stages.
Called automatically at the end of :meth:`run`.
"""
seen: set[int] = set()
for hooks_list in self.hooks.values():
for hook in hooks_list:
hook_id = id(hook)
if hook_id in seen:
continue
seen.add(hook_id)
if hasattr(hook, "__exit__"):
hook.__exit__(None, None, None)
elif hasattr(hook, "close"):
hook.close()
def _check_convergence(self, batch: Batch) -> torch.Tensor | None:
"""Return indices of converged samples, or None if none converged.
Delegates to ``self.convergence_hook.evaluate(batch)`` if a
convergence hook is configured.
Parameters
----------
batch : Batch
The current batch of atomic data.
Returns
-------
torch.Tensor | None
Integer tensor of converged sample indices, or ``None``.
"""
if self.convergence_hook is None:
return None
return self.convergence_hook.evaluate(batch)
def _validate_model_outputs(self, outputs: ModelOutputs) -> None:
"""Validate that model outputs satisfy the dynamics requirements.
Iterates over ``__needs_keys__`` and checks that each declared key
is present and not ``None`` in the model outputs.
Parameters
----------
outputs : ModelOutputs
The model outputs to validate.
Raises
------
RuntimeError
If a required output key is missing or ``None``.
"""
for key in self.__needs_keys__:
if outputs.get(key) is None:
raise RuntimeError(
f"{type(self).__name__} requires '{key}' "
f"(declared in __needs_keys__), but the model did not "
f"produce it. Check your model's ModelConfig and "
f"ModelCard to ensure '{key}' is supported and enabled,"
" or that no hooks are missing."
)
def _validate_batch_keys(self, batch: Batch) -> None:
"""Validate that the batch contains all keys declared in ``__provides_keys__``.
This is a diagnostic helper — not wired into the hot path. It can be
called after ``step()`` to verify that the dynamics produced the keys
it claims to provide.
Parameters
----------
batch : Batch
The batch to validate.
Raises
------
RuntimeError
If a declared provides-key is ``None`` on the batch.
"""
for key in self.__provides_keys__:
val = getattr(batch, key, None)
if val is None:
raise RuntimeError(
f"{type(self).__name__} declares '{key}' in "
f"__provides_keys__, but batch.{key} is None after "
f"compute. This may indicate a misconfigured model or "
f"dynamics."
)
# ------------------------------------------------------------------
# Per-system integrator state management
# ------------------------------------------------------------------
def _init_state(self, batch: Batch) -> None:
"""Allocate per-system integrator state from the first concrete batch.
No-op in the base class. Subclasses that require per-system state
(e.g. all warp-kernel integrators) override this to build a
system-only :class:`~nvalchemi.data.Batch` and assign it to
``self._state``.
Parameters
----------
batch : Batch
The first concrete batch; used to determine M, device, and dtype.
"""
def _make_new_state(self, n: int, template_batch: Batch) -> "Batch | None":
"""Return default state for *n* newly admitted systems.
No-op in the base class (returns ``None``). Subclasses override
to produce a system-only :class:`~nvalchemi.data.Batch` with
default/reset state for *n* replacement systems.
Parameters
----------
n : int
Number of new systems to create state for.
template_batch : Batch
The updated active batch; provides device and dtype.
Returns
-------
Batch | None
A system-only batch with *n* rows, or ``None`` if this
dynamics does not maintain per-system state.
"""
return None
def _ensure_state_initialized(self, batch: Batch) -> None:
"""Lazily initialize per-system integrator state on the first call.
Calls :meth:`_init_state` the first time this method is invoked
(i.e. when ``self._state`` does not yet exist). Subsequent calls
are no-ops. This is invoked automatically at the start of
:meth:`step` and :meth:`masked_update` so that concrete subclasses
never need to call it explicitly.
Parameters
----------
batch : Batch
The current batch; forwarded to ``_init_state`` if needed.
"""
if not hasattr(self, "_state"):
self._init_state(batch)
def _sync_state_to_batch(
self,
remaining_indices: "torch.Tensor",
n_new: int,
template_batch: Batch,
) -> None:
"""Synchronize ``self._state`` after an inflight batch refill.
Called by :meth:`_refill_check` after graduated systems have been
removed and replacement systems appended. Removes state rows for
graduated systems and appends fresh default state for the new ones.
If this dynamics has no ``_state`` (e.g. :class:`DemoDynamics`),
this method is a no-op.
Parameters
----------
remaining_indices : torch.Tensor
Integer indices of systems that remain in the new batch, in
order. Used to slice ``self._state`` via ``index_select``.
n_new : int
Number of newly admitted replacement systems appended after
the remaining ones. State for these is produced by
:meth:`_make_new_state`.
template_batch : Batch
The updated batch (remaining + replacements); provides device
and dtype for :meth:`_make_new_state`.
"""
if not hasattr(self, "_state"):
return
if remaining_indices.numel() > 0:
remaining_state: "Batch | None" = self._state.index_select(
remaining_indices
)
else:
remaining_state = None
new_state: "Batch | None" = None
if n_new > 0:
new_state = self._make_new_state(n_new, template_batch)
if remaining_state is not None and new_state is not None:
remaining_state.append(new_state)
self._state = remaining_state
elif remaining_state is not None:
self._state = remaining_state
elif new_state is not None:
self._state = new_state
else:
del self._state
def pre_update(self, batch: Batch) -> None:
"""
Perform the first half of the integration step.
This method is a no-op in the base class and should be
overridden by integrator subclasses (e.g., Velocity Verlet
would update positions here).
Parameters
----------
batch : Batch
The current batch of atomic data, modified in-place.
"""
pass
def post_update(self, batch: Batch) -> None:
"""
Perform the second half of the integration step.
This method is a no-op in the base class and should be
overridden by integrator subclasses (e.g., Velocity Verlet
would update velocities here).
Parameters
----------
batch : Batch
The current batch of atomic data, modified in-place.
"""
pass
def compute(self, batch: Batch | AtomsLike) -> ModelOutputs:
"""
Perform the model forward pass to compute forces and energies.
This method:
1. Runs the model forward pass, which should enable gradients
2. Adapts outputs to the standard format
3. Validates outputs against dynamics requirements
4. Writes forces/energies back to the batch in-place
Parameters
----------
batch : Batch
The current batch of atomic data. Will have forces and
energies updated in-place.
Returns
-------
ModelOutputs
OrderedDict containing the model outputs (energies, forces,
and any other computed properties).
Raises
------
RuntimeError
If the model outputs do not satisfy the dynamics requirements
specified by ``__needs_keys__``.
"""
# model.forward() is responsible for returning a fully adapted ModelOutputs dict.
# adapt_output() must NOT be called again here; each wrapper handles adaptation
# internally and returns canonical keys directly from forward().
outputs: ModelOutputs = self.model(batch)
self._validate_model_outputs(outputs)
# Use view() to handle shape mismatches (e.g. model [M,1] vs batch [M,1,1]).
if outputs.get("energies") is not None:
batch.energies.copy_(outputs["energies"].view(batch.energies.shape))
if outputs.get("forces") is not None:
batch.forces.copy_(outputs["forces"])
if outputs.get("stresses") is not None:
# batch.stress must be pre-allocated (e.g. AtomicData(stress=zeros(1,3,3))).
# NPT/NPH read this after each compute(); variable-cell optimizers also use it.
batch.stress.copy_(outputs["stresses"].view(batch.stress.shape))
return outputs
def step(self, batch: Batch) -> tuple[Batch, torch.Tensor | None]:
"""
Execute a single dynamics step with the full hook-wrapped sequence.
The step proceeds as follows:
1. BEFORE_STEP hooks
2. BEFORE_PRE_UPDATE hooks -> pre_update() -> AFTER_PRE_UPDATE hooks
3. BEFORE_COMPUTE hooks -> compute() -> AFTER_COMPUTE hooks
4. BEFORE_POST_UPDATE hooks -> post_update() -> AFTER_POST_UPDATE hooks
5. AFTER_STEP hooks
6. Check convergence and fire ON_CONVERGE hooks if any samples converged
7. Increment step_count
Samples with ``status >= exit_status`` are treated as no-ops for the
integrator (pre_update/post_update). Their positions and velocities
are preserved through the step. This enables back-pressure handling
in pipeline mode where converged samples may remain in the active
batch when the send buffer is full.
Parameters
----------
batch : Batch
The current batch of atomic data.
Returns
-------
tuple[Batch, torch.Tensor | None]
The updated batch after the step, and a 1-D integer tensor
of converged sample indices (or ``None`` if nothing converged).
"""
self._ensure_state_initialized(batch)
self._call_hooks(HookStageEnum.BEFORE_STEP, batch)
active_mask: torch.Tensor | None = None
if hasattr(batch, "status") and batch.status is not None:
status = (
batch.status.squeeze(-1) if batch.status.dim() == 2 else batch.status
)
active_mask = status[: batch.num_graphs] < self.exit_status
saved: dict[str, torch.Tensor] = {}
if active_mask is not None:
node_mask_occupied = torch.repeat_interleave(
active_mask, batch.num_nodes_per_graph
)
node_mask = torch.zeros(
batch.num_nodes, dtype=torch.bool, device=batch.device
)
node_mask[: len(node_mask_occupied)] = node_mask_occupied
sys_mask = ~active_mask
for field in self._mutable_fields:
val = getattr(batch, field, None)
if val is None:
continue
if val.shape[0] == batch.num_nodes:
saved[field] = val[~node_mask].clone()
elif val.shape[0] == batch.num_graphs:
saved[field] = val[sys_mask].clone()
self._call_hooks(HookStageEnum.BEFORE_PRE_UPDATE, batch)
self.pre_update(batch)
self._call_hooks(HookStageEnum.AFTER_PRE_UPDATE, batch)
self._call_hooks(HookStageEnum.BEFORE_COMPUTE, batch)
self.compute(batch)
self._call_hooks(HookStageEnum.AFTER_COMPUTE, batch)
self._call_hooks(HookStageEnum.BEFORE_POST_UPDATE, batch)
self.post_update(batch)
self._call_hooks(HookStageEnum.AFTER_POST_UPDATE, batch)
if active_mask is not None:
with torch.no_grad():
for field, sv in saved.items():
val = getattr(batch, field)
if val.shape[0] == batch.num_nodes:
val[~node_mask] = sv
else:
val[sys_mask] = sv
self._call_hooks(HookStageEnum.AFTER_STEP, batch)
converged = self._check_convergence(batch)
self._last_converged = converged
if converged is not None:
self._call_hooks(HookStageEnum.ON_CONVERGE, batch)
self.step_count += 1
return batch, converged
def run(self, batch: Batch, n_steps: int | None = None) -> Batch:
"""
Run the dynamics simulation for a specified number of steps.
This is a convenience method that repeatedly calls ``step()``.
The step count can be set at construction time via the
``n_steps`` parameter, or passed directly to this method.
A value passed here takes precedence over the instance
attribute.
Parameters
----------
batch : Batch
The initial batch of atomic data.
n_steps : int | None, optional
The number of steps to run. If ``None``, falls back to
``self.n_steps``. If both are ``None``, raises
``ValueError``.
Returns
-------
Batch
The batch after all steps have been executed.
Raises
------
ValueError
If no step count is available (both the argument and
``self.n_steps`` are ``None``).
"""
resolved = n_steps if n_steps is not None else self.n_steps
if resolved is None:
raise ValueError(
"No step count provided. Either pass `n_steps` to run() "
"or set it at construction time via "
f"`{type(self).__name__}(..., n_steps=N)`."
)
self._open_hooks()
try:
for _ in range(resolved):
batch, _converged = self.step(batch)
# Early exit when every system has satisfied the convergence
# criteria (sampler-free / Mode 1 only).
if (
self.sampler is None
and _converged is not None
and _converged.numel() == batch.num_graphs
):
break
finally:
self._close_hooks()
return batch
def refill_check(self, batch: Batch, exit_status: int) -> Batch | None:
"""Replace graduated samples via index-select and append.
Graduated graphs (``status >= exit_status``) are written to sinks,
then removed via :meth:`Batch.index_select` on the remaining indices.
Replacement samples from the sampler are appended via
:meth:`Batch.append`. Dynamics-specific bookkeeping fields are
written into the result batch via the ``_bookkeeping_keys`` registry.
Parameters
----------
batch : Batch
The current batch with a ``status`` field.
exit_status : int
Status code indicating graduation.
Returns
-------
Batch | None
A new batch with graduated graphs replaced by fresh samples,
or ``None`` if no active samples remain (sampler exhausted
and all graduated) — in which case ``self.done`` is set to
``True``.
Raises
------
RuntimeError
If ``self.sampler`` is ``None``.
"""
if self.sampler is None:
raise RuntimeError("refill_check requires a sampler to be configured.")
status = batch.status
if status.dim() == 2:
status = status.squeeze(-1)
graduated_mask = status >= exit_status
if not graduated_mask.any():
return batch
graduated_indices = torch.where(graduated_mask)[0]
remaining_indices = torch.where(~graduated_mask)[0]
if self.sinks and graduated_mask.any():
self._overflow_to_sinks(batch, mask=graduated_mask)
grad_node_counts = batch.num_nodes_per_graph[graduated_indices].tolist()
edges_per_graph = batch.num_edges_per_graph
if edges_per_graph.numel() > 0:
grad_edge_counts = edges_per_graph[graduated_indices].tolist()
else:
grad_edge_counts = [0] * len(grad_node_counts)
n_remaining = remaining_indices.numel()
if remaining_indices.numel() > 0:
result = batch.index_select(remaining_indices)
else:
result = None
replacements: list[AtomicData] = []
for n_atoms, n_edges in zip(grad_node_counts, grad_edge_counts):
repl = self.sampler.request_replacement(n_atoms, n_edges)
if repl is not None:
replacements.append(repl)
if result is not None and replacements:
repl_batch = Batch.from_data_list(replacements, device=batch.device)
result.append(repl_batch)
elif result is None and replacements:
result = Batch.from_data_list(replacements, device=batch.device)
if result is not None:
n_total = result.num_graphs
device = result.device
for key, default_fn in self._bookkeeping_keys.items():
new_tensor = default_fn(n_total, device)
remaining_vals = getattr(batch, key, None)
if remaining_vals is not None and n_remaining > 0:
src = remaining_vals[remaining_indices]
src = src.unsqueeze(-1) if src.dim() == 1 else src
new_tensor[:n_remaining] = src
result[key] = new_tensor
self._sync_state_to_batch(remaining_indices, len(replacements), result)
return result
if self.sampler.exhausted:
self.done = True
self._sync_state_to_batch(remaining_indices, 0, batch)
return None
def masked_update(
self,
batch: Batch,
mask: Bool[torch.Tensor, "B"], # noqa: F722, F821
) -> None:
"""
Apply pre_update and post_update only to selected samples in the batch.
This method allows selective updates where only some graphs in the
batch are modified. Unmasked samples retain their original positions
and velocities.
The mask is a boolean tensor of shape (B,) where B is the number of
graphs. True values indicate samples that should be updated.
Parameters
----------
batch : Batch
The current batch of atomic data, modified in-place.
mask : Bool[Tensor, "B"]
Boolean mask selecting which graphs to update. Shape (B,)
where B is the number of graphs in the batch.
Notes
-----
This method expands the graph-level mask to node-level using
`batch.batch` to correctly index per-node tensors like positions
and velocities.
"""
# lazy init — FusedStage sub-stages never have step() called on them directly
self._ensure_state_initialized(batch)
node_mask_occupied = torch.repeat_interleave(mask, batch.num_nodes_per_graph)
node_mask = torch.zeros(batch.num_nodes, dtype=torch.bool, device=batch.device)
node_mask[: len(node_mask_occupied)] = node_mask_occupied
sys_mask = ~mask
saved: dict[str, torch.Tensor] = {}
for field in self._mutable_fields:
val = getattr(batch, field, None)
if val is None:
continue
if val.shape[0] == batch.num_nodes:
saved[field] = val[~node_mask].clone()
elif val.shape[0] == batch.num_graphs:
saved[field] = val[sys_mask].clone()
self.pre_update(batch)
self.post_update(batch)
with torch.no_grad():
for field, sv in saved.items():
val = getattr(batch, field)
if val.shape[0] == batch.num_nodes:
val[~node_mask] = sv
else:
val[sys_mask] = sv
def _masked_pre_update(
self,
batch: Batch,
mask: Bool[torch.Tensor, "B"], # noqa: F722, F821
) -> None:
"""Run only pre_update for masked samples, restoring non-masked state.
Used by :class:`FusedStage` to interleave pre_update across all
sub-stages before the shared compute, so that forces are evaluated at
the post-pre_update positions (required for BAOAB Langevin and
velocity-Verlet-based integrators).
"""
self._ensure_state_initialized(batch)
node_mask = mask[batch.batch]
sys_mask = ~mask
saved: dict[str, torch.Tensor] = {}
for field in self._mutable_fields:
val = getattr(batch, field, None)
if val is None:
continue
if val.shape[0] == batch.num_nodes:
saved[field] = val[~node_mask].clone()
elif val.shape[0] == batch.num_graphs:
saved[field] = val[sys_mask].clone()
self.pre_update(batch)
with torch.no_grad():
for field, sv in saved.items():
val = getattr(batch, field)
if val.shape[0] == batch.num_nodes:
val[~node_mask] = sv
else:
val[sys_mask] = sv
def _masked_post_update(
self,
batch: Batch,
mask: Bool[torch.Tensor, "B"], # noqa: F722, F821
) -> None:
"""Run only post_update for masked samples, restoring non-masked state.
Called by :class:`FusedStage` after the shared compute so that
post_update (e.g. the final BAOAB velocity half-kick) uses forces at
the new positions.
"""
node_mask = mask[batch.batch]
sys_mask = ~mask
saved: dict[str, torch.Tensor] = {}
for field in self._mutable_fields:
val = getattr(batch, field, None)
if val is None:
continue
if val.shape[0] == batch.num_nodes:
saved[field] = val[~node_mask].clone()
elif val.shape[0] == batch.num_graphs:
saved[field] = val[sys_mask].clone()
self.post_update(batch)
with torch.no_grad():
for field, sv in saved.items():
val = getattr(batch, field)
if val.shape[0] == batch.num_nodes:
val[~node_mask] = sv
else:
val[sys_mask] = sv
[docs]
class ConvergenceHook:
"""Hook that evaluates composable convergence criteria and optionally
migrates converged samples between pipeline stages.
Wraps one or more :class:`_ConvergenceCriterion` instances and combines
their results with AND semantics: a sample is converged only when
**every** criterion is satisfied.
When ``source_status`` and ``target_status`` are both provided, the
hook also performs status migration — updating ``batch.status`` for
converged samples that match ``source_status``. This enables the
single-loop execution strategy used by :class:`FusedStage`.
When used as a standalone convergence detector (both ``source_status``
and ``target_status`` are ``None``), call :meth:`evaluate` directly
or let :class:`BaseDynamics` use it via ``_check_convergence``.
Attributes
----------
criteria : list[_ConvergenceCriterion]
The individual convergence criteria.
frequency : int
Execute every N steps.
stage : HookStageEnum
The stage at which this hook fires (``AFTER_STEP``).
source_status : int | None
Status code of samples to check for convergence. ``None``
disables status migration.
target_status : int | None
Status code to assign to converged samples. ``None``
disables status migration.
Examples
--------
>>> # Backward-compatible fmax-only hook
>>> hook = ConvergenceHook.from_fmax(0.05)
>>> converged = hook.evaluate(batch)
>>> # Multi-criteria hook for FusedStage with status migration
>>> hook = ConvergenceHook(
... criteria=[
... {"key": "fmax", "threshold": 0.05},
... {"key": "energy_change", "threshold": 1e-6},
... ],
... source_status=0,
... target_status=1,
... )
"""
[docs]
def __init__(
self,
criteria: (
_ConvergenceCriterion
| list[_ConvergenceCriterion]
| dict
| list[dict]
| None
) = None,
source_status: int | None = None,
target_status: int | None = None,
frequency: int = 1,
) -> None:
"""Initialize the convergence hook.
Parameters
----------
criteria : _ConvergenceCriterion | list[...] | dict | list[dict] | None
Convergence criterion specification(s). Dicts are validated
and converted to ``_ConvergenceCriterion`` instances. If
``None``, defaults to a single fmax criterion with threshold
``0.05``.
source_status : int | None, optional
Status code to check. ``None`` disables status migration.
target_status : int | None, optional
Status code to assign on convergence. ``None`` disables
status migration.
frequency : int, optional
Execute every N steps. Default 1.
"""
self.frequency = frequency
self.stage = HookStageEnum.AFTER_STEP
self.source_status = source_status
self.target_status = target_status
if criteria is None:
self.criteria: list[_ConvergenceCriterion] = [
_ConvergenceCriterion(key="fmax", threshold=0.05)
]
elif isinstance(criteria, _ConvergenceCriterion):
self.criteria = [criteria]
elif isinstance(criteria, dict):
self.criteria = [_ConvergenceCriterion(**criteria)]
elif isinstance(criteria, (list, tuple)):
normalized: list[_ConvergenceCriterion] = []
for item in criteria:
if isinstance(item, dict):
normalized.append(_ConvergenceCriterion(**item))
elif isinstance(item, _ConvergenceCriterion):
normalized.append(item)
else:
raise TypeError(
"Each criterion must be a dict or"
f" _ConvergenceCriterion, got {type(item).__name__}"
)
self.criteria = normalized
else:
raise TypeError(
"criteria must be a dict, _ConvergenceCriterion, or list"
f" thereof, got {type(criteria).__name__}"
)
def __repr__(self) -> str:
"""Return a human-readable summary of the convergence hook."""
inner = ", ".join(repr(c) for c in self.criteria)
parts = [f"criteria=[{inner}]"]
if self.source_status is not None:
parts.append(f"source_status={self.source_status}")
if self.target_status is not None:
parts.append(f"target_status={self.target_status}")
parts.append(f"frequency={self.frequency}")
return f"ConvergenceHook({', '.join(parts)})"
@classmethod
def from_fmax(
cls,
threshold: float = 0.05,
source_status: int | None = None,
target_status: int | None = None,
frequency: int = 1,
) -> ConvergenceHook:
"""Create a single fmax-based convergence hook.
This is a convenience constructor for backward compatibility
with the original ``convergence_fmax`` parameter.
Parameters
----------
threshold : float, optional
Maximum force threshold. Default ``0.05``.
source_status : int | None, optional
Status code to check. ``None`` disables status migration.
target_status : int | None, optional
Status code to assign on convergence. ``None`` disables
status migration.
frequency : int, optional
Execute every N steps. Default 1.
Returns
-------
ConvergenceHook
Hook with a single ``fmax`` criterion.
"""
return cls(
criteria=[_ConvergenceCriterion(key="fmax", threshold=threshold)],
source_status=source_status,
target_status=target_status,
frequency=frequency,
)
@classmethod
def from_forces(
cls,
threshold: float,
frequency: int = 1,
source_status: int | None = None,
target_status: int | None = None,
) -> ConvergenceHook:
"""Construct from force-norm threshold (reads 'forces' key, norm reduction).
Parameters
----------
threshold : float
fmax threshold; systems with max force norm <= threshold are converged.
frequency : int, optional
Evaluate every N steps. Default 1.
source_status : int | None, optional
Status code that eligible systems must have. Default None (any status).
target_status : int | None, optional
Status code to assign to converged systems. Default None (no status change).
Returns
-------
ConvergenceHook
Hook that evaluates max per-atom force norm against ``threshold``.
"""
return cls(
criteria=[
{
"key": "forces",
"threshold": threshold,
"reduce_op": "norm",
"reduce_dims": -1,
}
],
frequency=frequency,
source_status=source_status,
target_status=target_status,
)
@property
def num_criteria(self) -> int:
"""Return the number of individual criteria."""
return len(self.criteria)
def evaluate(self, batch: Batch) -> torch.Tensor | None:
"""Evaluate all criteria and return indices of converged samples.
Pre-allocates a ``(N_criteria, B)`` boolean tensor, evaluates
each criterion to fill one row, then AND-reduces across
criteria. Returns the integer indices of converged samples,
or ``None`` if no samples have converged.
Parameters
----------
batch : Batch
The current batch of atomic data.
Returns
-------
torch.Tensor | None
1-D integer tensor of converged sample indices, or ``None``
if no samples satisfy all criteria.
"""
n_criteria = len(self.criteria)
n_graphs = batch.num_graphs
results = torch.ones(
n_criteria,
n_graphs,
dtype=torch.bool,
device=batch.positions.device,
)
for i, criterion in enumerate(self.criteria):
results[i] = criterion(batch)
converged_mask = torch.all(results, dim=0)
if not converged_mask.any():
return None
return torch.where(converged_mask)[0]
def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None:
"""Evaluate convergence and optionally migrate sample status.
When ``source_status`` and ``target_status`` are both set,
converged samples whose ``batch.status`` matches
``source_status`` are migrated to ``target_status``.
If ``batch`` lacks ``status`` when status migration is
configured, the migration step is silently skipped.
Parameters
----------
batch : Batch
The current batch, modified in-place.
dynamics : BaseDynamics
The dynamics engine (unused).
"""
converged = self.evaluate(batch)
if converged is None:
return
if self.source_status is not None and self.target_status is not None:
if not hasattr(batch, "status") or batch.status is None:
return
status = batch.status
if status.dim() == 2:
status = status.squeeze(-1)
converged_mask = torch.zeros(
batch.num_graphs, dtype=torch.bool, device=status.device
)
converged_mask[converged] = True
status_mask = status == self.source_status
migrate = converged_mask & status_mask
if migrate.any():
flat_status = (
batch.status.view(-1) if batch.status.dim() == 2 else batch.status
)
flat_status[migrate] = self.target_status
[docs]
class FusedStage(BaseDynamics):
"""Composite dynamics engine fusing multiple sub-stages on a single GPU.
``FusedStage`` composes multiple ``BaseDynamics`` sub-stages to share one
``Batch`` and one model forward pass per step, avoiding redundant forward
passes when multiple simulation phases (e.g., relaxation then MD) operate
on the same batch.
Unlike ``BaseDynamics``, **``step(batch)``** is overridden. Instead of the
standard ``pre_update → compute → post_update`` loop, ``FusedStage``
performs: (1) a single ``compute()`` call on the full batch, then (2)
iterates over sub-stages, applying ``masked_update(batch, mask)`` on each
sub-stage's dynamics for samples whose ``batch.status`` matches that
sub-stage's status code. Only ONE forward pass happens per step regardless
of the number of sub-stages. **``run(batch)``** is also overridden —
the ``n_steps`` attribute (inherited from ``BaseDynamics``) and any
``n_steps`` argument passed to ``run()`` are both the **maximum** number
of steps; the loop runs until all samples have migrated to the
``exit_status``, the sampler is exhausted, or ``n_steps`` is reached. Convergence-driven migration is handled
by ``ConvergenceHook`` instances auto-registered between adjacent
sub-stages: when samples converge in sub-stage *i*, their ``batch.status``
is updated to sub-stage *i+1*'s code, causing them to be processed by the
next dynamics on the following step. The ``+`` operator composes
sub-stages: ``dyn_a + dyn_b`` creates a ``FusedStage``, and
``fused + dyn_c`` appends a third sub-stage. The ``|`` operator (inherited
from ``BaseDynamics`` via ``_CommunicationMixin``) creates a ``DistributedPipeline``
for multi-rank execution instead.
Developers generally do NOT subclass ``FusedStage``. Instead, create
``BaseDynamics`` subclasses (integrators) and compose them using ``+``.
``FusedStage`` handles orchestration automatically. The key requirement is
that sub-stage dynamics must implement ``masked_update`` correctly
(inherited from ``BaseDynamics``) and that the batch must have a
``status`` tensor.
Hook Firing Semantics
~~~~~~~~~~~~~~~~~~~~~
Because ``FusedStage`` shares a single forward pass across all sub-stages,
hook firing differs from standalone ``BaseDynamics`` execution. The
following hooks fire **on each sub-stage** during ``_step_impl``:
**Fired on sub-stages (in order):**
- ``BEFORE_STEP`` — at the start of each fused step, before any work.
- ``AFTER_COMPUTE`` — after the shared model forward pass completes.
- ``BEFORE_PRE_UPDATE`` — before each sub-stage's ``masked_update``
(fires even when no samples match the sub-stage's status code).
- ``AFTER_POST_UPDATE`` — after each sub-stage's ``masked_update``
(fires even when no samples match the sub-stage's status code).
- ``AFTER_STEP`` — after all masked updates are complete.
- ``ON_CONVERGE`` — when a sub-stage's ``_check_convergence`` detects
converged samples.
**NOT fired on sub-stages:**
- ``BEFORE_COMPUTE`` — the forward pass is shared across all sub-stages,
not executed per-sub-stage; there is no meaningful "before compute"
point for individual sub-stages.
- ``AFTER_PRE_UPDATE`` — ``masked_update`` combines ``pre_update`` and
``post_update`` atomically; there is no intermediate hook point.
- ``BEFORE_POST_UPDATE`` — same reason as ``AFTER_PRE_UPDATE``.
**Step count semantics:** Each sub-stage's ``step_count`` is incremented
alongside the ``FusedStage``'s own ``step_count`` after every fused step,
ensuring that hook frequency (e.g., ``every_n_steps``) is respected
correctly across all sub-stages.
Parameters
----------
sub_stages : list[tuple[int, BaseDynamics]]
Ordered ``(status_code, dynamics)`` pairs. Status codes are
auto-assigned starting from 0 when using the ``+`` operator.
entry_status : int
Status code assigned to incoming samples (default: 0).
exit_status : int
Status code that triggers graduation to the next pipeline stage.
Auto-set to ``len(sub_stages)`` (one past the last sub-stage code).
compile_step : bool
If ``True``, replace ``self.step`` with
``torch.compile(self.step, **compile_kwargs)``.
compile_kwargs : dict
Keyword arguments forwarded to ``torch.compile``.
**kwargs
Additional keyword arguments forwarded to ``BaseDynamics``.
Attributes
----------
sub_stages : list[tuple[int, BaseDynamics]]
Ordered ``(status_code, dynamics)`` pairs.
entry_status : int
Status code for incoming samples.
exit_status : int
Status code that triggers graduation.
compile_step : bool
Whether the step method is compiled.
compile_kwargs : dict
Arguments passed to ``torch.compile``.
__needs_keys__ : set[str]
Union of all sub-stage ``__needs_keys__`` sets. Populated
automatically during ``__init__``.
__provides_keys__ : set[str]
Union of all sub-stage ``__provides_keys__`` sets. Populated
automatically during ``__init__``.
Examples
--------
>>> from nvalchemi.dynamics import FusedStage, BaseDynamics
>>> dynamics0 = BaseDynamics(model=model)
>>> dynamics1 = BaseDynamics(model=model)
>>> fused = FusedStage(sub_stages=[(0, dynamics0), (1, dynamics1)])
>>> fused.exit_status
2
"""
[docs]
def __init__(
self,
sub_stages: list[tuple[int, BaseDynamics]],
*,
entry_status: int = 0,
exit_status: int = -1,
compile_step: bool = False,
compile_kwargs: dict[str, Any] | None = None,
init_fn: Callable[[Batch], None] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the fused stage.
Parameters
----------
sub_stages : list[tuple[int, BaseDynamics]]
Ordered ``(status_code, dynamics)`` pairs.
entry_status : int, optional
Status code assigned to incoming samples. Default 0.
exit_status : int, optional
Status code that triggers graduation. Auto-set to ``len(sub_stages)``
if -1. Default -1.
compile_step : bool, optional
If ``True``, compile the step method with ``torch.compile``.
Default ``False``.
compile_kwargs : dict[str, Any] | None, optional
Keyword arguments for ``torch.compile``. Default ``None``.
init_fn : Callable[[Batch], None] | None, optional
Optional callback invoked on the initial batch immediately after
``sampler.build_initial_batch()`` returns, before the first step.
Use this to populate fields that the sampler does not set, such
as ``velocities`` or ``forces``. Only called in Mode 2 (inflight
batching with ``batch=None``). Default ``None``.
**kwargs : Any
Additional keyword arguments forwarded to ``BaseDynamics``.
Raises
------
ValueError
If sub-stages have different ``device_type`` values.
"""
first_dynamics = sub_stages[0][1]
model = first_dynamics.model
device_types = {dyn.device_type for _, dyn in sub_stages}
if len(device_types) > 1:
per_stage = {code: dyn.device_type for code, dyn in sub_stages}
raise ValueError(
f"All sub-stages in a FusedStage must share the same "
f"device_type, but got: {per_stage}. A FusedStage runs "
f"on a single device with a shared batch and forward pass."
)
super().__init__(model=model, **kwargs)
self.sub_stages = sub_stages
self.__needs_keys__ = set().union(
*(dyn.__needs_keys__ for _, dyn in sub_stages)
)
self.__provides_keys__ = set().union(
*(dyn.__provides_keys__ for _, dyn in sub_stages)
)
self.entry_status = entry_status
self.compile_kwargs: dict[str, Any] = (
compile_kwargs if compile_kwargs is not None else {}
)
self.compile_step = compile_step
self._compiled_step: (
Callable[[Batch], tuple[Batch, torch.Tensor | None]] | None
) = None
if exit_status == -1:
self.exit_status = len(self.sub_stages)
else:
self.exit_status = exit_status
self.convergence_check_frequency: int = 1
self.init_fn = init_fn
self.fused_hooks: dict[HookStageEnum, list[Hook]] = defaultdict(list)
for i in range(len(self.sub_stages) - 1):
source_code, source_dynamics = self.sub_stages[i]
target_code, _ = self.sub_stages[i + 1]
# Remove duplicate migration hooks with the same (source_status, target_status)
# to prevent double-fire after __add__ reconstruction.
existing = source_dynamics.hooks[HookStageEnum.AFTER_STEP]
source_dynamics.hooks[HookStageEnum.AFTER_STEP] = [
h
for h in existing
if not (
isinstance(h, ConvergenceHook)
and hasattr(h, "source_status")
and h.source_status == source_code
and hasattr(h, "target_status")
and h.target_status == target_code
)
]
criteria = None
if source_dynamics.convergence_hook is not None:
criteria = source_dynamics.convergence_hook.criteria
hook = ConvergenceHook(
criteria=criteria,
source_status=source_code,
target_status=target_code,
)
source_dynamics.register_hook(hook)
for status_code, dynamics in self.sub_stages:
if dynamics.n_steps is not None:
counter_key = f"n_steps_counter_{status_code}"
BaseDynamics.register_bookkeeping_key(
counter_key,
lambda n, dev: torch.zeros(n, 1, dtype=torch.long, device=dev),
)
if self.compile_step:
self.compile()
def __repr__(self) -> str:
"""Return a human-readable summary of the fused stage."""
stages_repr = ", ".join(
f"{code}:{type(dyn).__name__}" for code, dyn in self.sub_stages
)
compiled = self._compiled_step is not None
return (
f"FusedStage("
f"sub_stages=[{stages_repr}], "
f"entry_status={self.entry_status}, "
f"exit_status={self.exit_status}, "
f"compiled={compiled}, "
f"step_count={self.step_count})"
)
def compile(self, **kwargs: Any) -> FusedStage:
"""Compile the fused step with ``torch.compile``.
Merges *kwargs* with any ``compile_kwargs`` stored at construction
time (values passed here take precedence), then wraps
``_step_impl`` with ``torch.compile``. Calling this method also
sets ``compile_step = True`` so that the ``step`` dispatch path
uses the compiled callable.
This method is idempotent in intent but **will** re-compile if
called again (e.g. with different kwargs).
Parameters
----------
**kwargs : Any
Keyword arguments forwarded to ``torch.compile``. Merged
with ``compile_kwargs`` from ``__init__``; values here win.
Returns
-------
FusedStage
This instance, enabling fluent chaining such as
``fused.compile(fullgraph=True).run(batch)``.
"""
merged = {**self.compile_kwargs, **kwargs}
self.compile_kwargs = merged
self.compile_step = True
self._compiled_step = torch.compile(self._step_impl, **merged)
return self
def __enter__(self) -> FusedStage:
"""Enter the stream context and propagate to all sub-stages.
Calls the parent ``__enter__`` to create and enter a CUDA stream
context, then sets every sub-stage's ``_stream`` reference to the
same stream so that all computation runs on a single dedicated
stream.
If ``compile_step`` is ``True`` but compilation has not yet been
performed (e.g. because the stage was created via ``__add__``),
compilation is triggered here automatically.
Returns
-------
FusedStage
This instance.
"""
super().__enter__()
for _, dynamics in self.sub_stages:
dynamics._stream = self._stream
if self.compile_step and self._compiled_step is None:
self.compile()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Exit the stream context and clear sub-stage stream references.
Clears every sub-stage's ``_stream`` reference, then delegates
to the parent ``__exit__`` to exit the ``StreamContext`` and
clean up.
Parameters
----------
exc_type : type[BaseException] | None
Exception type, if any.
exc_val : BaseException | None
Exception value, if any.
exc_tb : Any
Exception traceback, if any.
"""
for _, dynamics in self.sub_stages:
dynamics._stream = None
super().__exit__(exc_type, exc_val, exc_tb)
def _sync_state_to_batch(
self,
remaining_indices: "torch.Tensor",
n_new: int,
template_batch: Batch,
) -> None:
"""Fan out state sync to all sub-stages.
``FusedStage`` itself holds no ``_state``; each sub-stage does.
This override delegates to every sub-stage so that inflight
batch refills (via :meth:`~BaseDynamics._refill_check`) keep
each sub-stage's ``_state`` aligned with the new batch
composition.
Parameters
----------
remaining_indices : torch.Tensor
Integer indices of systems that remain after graduation.
n_new : int
Number of newly admitted replacement systems.
template_batch : Batch
The updated batch; provides device/dtype for new-state init.
"""
for _, sub_stage in self.sub_stages:
sub_stage._sync_state_to_batch(remaining_indices, n_new, template_batch)
def _ensure_bookkeeping_fields(self, batch: Batch) -> None:
"""Auto-initialize status and registered bookkeeping fields if absent.
Parameters
----------
batch : Batch
The batch to check and initialize fields on.
"""
for key, default_fn in self._bookkeeping_keys.items():
if getattr(batch, key, None) is None:
batch[key] = default_fn(batch.num_graphs, batch.device)
def register_fused_hook(self, hook: Hook) -> None:
"""Register a hook that fires at the FusedStage level on the full batch.
Unlike hooks registered on individual sub-stages (which only receive
the sub-batched view), fused hooks observe the complete batch at the
``BEFORE_STEP``, ``AFTER_STEP``, ``BEFORE_COMPUTE``, and
``AFTER_COMPUTE`` stages of every fused step.
Parameters
----------
hook : Hook
The hook to register. Only ``BEFORE_STEP`` and ``AFTER_STEP``
stages are meaningful at the fused level; other stages are
silently accepted but will not fire during normal execution.
Raises
------
ValueError
If ``hook.frequency`` is not a positive integer.
"""
if not isinstance(hook.frequency, int) or hook.frequency < 1:
raise ValueError(
f"Hook {hook!r} has frequency={hook.frequency!r}. "
"frequency must be a positive integer (>= 1)."
)
self.fused_hooks[hook.stage].append(hook)
def _call_fused_hooks(self, stage: HookStageEnum, batch: Batch) -> None:
"""Invoke all fused hooks registered for the given stage.
Parameters
----------
stage : HookStageEnum
The hook stage to fire.
batch : Batch
The current full batch.
"""
for hook in self.fused_hooks[stage]:
if self.step_count % hook.frequency == 0:
hook(batch, self)
def _step_impl(self, batch: Batch) -> tuple[Batch, torch.Tensor | None]:
"""Internal step implementation (may be compiled).
Performs the following sequence:
1. Fire BEFORE_STEP hooks (fused, self, sub-stages).
2. For each sub-stage: fire BEFORE_PRE_UPDATE, run pre_update
(positions advance to r(t+dt)).
3. Fire BEFORE_COMPUTE hooks (fused / self, e.g. NeighborListHook
with new positions) → single shared forward pass → AFTER_COMPUTE.
4. For each sub-stage: run post_update (final velocity kick at
r(t+dt) forces), fire AFTER_POST_UPDATE.
5. Fire AFTER_STEP hooks on each sub-stage.
6. Snapshot status, check convergence per sub-stage and fire
ON_CONVERGE if triggered.
7. Increment step_count for FusedStage and all sub-stages.
8. Identify samples that newly graduated during this step.
Parameters
----------
batch : Batch
The batch with a ``status`` field.
Returns
-------
tuple[Batch, torch.Tensor | None]
The updated batch, and a 1-D integer tensor of sample indices
that newly graduated (reached ``exit_status``) during this step,
or ``None`` if no samples graduated.
"""
self._ensure_bookkeeping_fields(batch)
self._call_fused_hooks(HookStageEnum.BEFORE_STEP, batch)
self._call_hooks(HookStageEnum.BEFORE_STEP, batch)
for _, dynamics in self.sub_stages:
dynamics._call_hooks(HookStageEnum.BEFORE_STEP, batch)
# Phase 1 — pre_update for each sub-stage.
# This moves positions to r(t+dt) so that the shared compute can
# evaluate forces at the correct (updated) positions.
status = batch.status
if status.dim() == 2:
status = status.squeeze(-1)
for status_code, dynamics in self.sub_stages:
mask = status == status_code
dynamics._call_hooks(HookStageEnum.BEFORE_PRE_UPDATE, batch)
if mask.any():
dynamics._masked_pre_update(batch, mask)
# Phase 2 — shared forward pass at the updated positions.
self._call_hooks(HookStageEnum.BEFORE_COMPUTE, batch)
outputs: ModelOutputs = self.compute(batch)
# TODO: update this when `batch` structure is done
for key, tensor in outputs.items():
if key not in ("forces", "energies"):
batch[key] = tensor
self._call_hooks(HookStageEnum.AFTER_COMPUTE, batch)
for _, dynamics in self.sub_stages:
dynamics._call_hooks(HookStageEnum.AFTER_COMPUTE, batch)
# Phase 3 — post_update for each sub-stage, now with forces at r(t+dt).
for status_code, dynamics in self.sub_stages:
mask = status == status_code
if mask.any():
dynamics._masked_post_update(batch, mask)
dynamics._call_hooks(HookStageEnum.AFTER_POST_UPDATE, batch)
for _, dynamics in self.sub_stages:
dynamics._call_hooks(HookStageEnum.AFTER_STEP, batch)
self._call_hooks(HookStageEnum.AFTER_STEP, batch)
self._call_fused_hooks(HookStageEnum.AFTER_STEP, batch)
for i, (status_code, dynamics) in enumerate(self.sub_stages):
if dynamics.n_steps is None:
continue
counter_key = f"n_steps_counter_{status_code}"
if getattr(batch, counter_key, None) is None:
batch[counter_key] = torch.zeros(
batch.num_graphs, 1, dtype=torch.long, device=batch.device
)
counter = getattr(batch, counter_key)
cur_status = (
batch.status.squeeze(-1) if batch.status.dim() == 2 else batch.status
)
active = cur_status == status_code
counter[active] += 1
next_status = (
self.sub_stages[i + 1][0]
if i + 1 < len(self.sub_stages)
else self.exit_status
)
migrate = active & (counter.squeeze(-1) >= dynamics.n_steps)
if migrate.any():
batch.status.view(-1)[migrate] = next_status
counter[migrate] = 0 # Reset for next system in this slot
pre_converge_status = batch.status.clone()
if pre_converge_status.dim() == 2:
pre_converge_status = pre_converge_status.squeeze(-1)
for _, dynamics in self.sub_stages:
converged = dynamics._check_convergence(batch)
dynamics._last_converged = converged
if converged is not None:
dynamics._call_hooks(HookStageEnum.ON_CONVERGE, batch)
self.step_count += 1
for _, dynamics in self.sub_stages:
dynamics.step_count += 1
post_status = batch.status
if post_status.dim() == 2:
post_status = post_status.squeeze(-1)
newly_graduated = (pre_converge_status < self.exit_status) & (
post_status >= self.exit_status
)
exit_converged: torch.Tensor | None = (
torch.where(newly_graduated)[0] if newly_graduated.any() else None
)
return batch, exit_converged
def step(self, batch: Batch) -> tuple[Batch, torch.Tensor | None]:
"""Execute one fused step: single forward pass + masked updates.
If ``compile_step=True`` was set, this delegates to the compiled
step implementation.
Parameters
----------
batch : Batch
The batch with a ``status`` field.
Returns
-------
tuple[Batch, torch.Tensor | None]
The updated batch, and a 1-D integer tensor of sample indices
that newly graduated (reached ``exit_status``) during this step,
or ``None`` if no samples graduated.
"""
if self._compiled_step is not None:
return self._compiled_step(batch)
return self._step_impl(batch)
def __call__(self, batch: Batch) -> tuple[Batch, torch.Tensor | None]:
"""Call the ``step`` method on a batch."""
return self.step(batch)
@staticmethod
def all_complete(batch: Batch, exit_status: int) -> bool:
"""Check if all samples have reached the exit status.
Parameters
----------
batch : Batch
The current batch.
exit_status : int
The status code that indicates completion.
Returns
-------
bool
``True`` if every sample has ``status == exit_status``.
"""
if not hasattr(batch, "status") or batch.status is None:
return False
status = batch.status
if status.dim() == 2:
status = status.squeeze(-1)
return bool((status == exit_status).all())
def _open_hooks(self) -> None:
"""Enter context-manager hooks on this stage, fused hooks, and sub-stages."""
super()._open_hooks()
seen: set[int] = set()
for hooks_list in self.fused_hooks.values():
for hook in hooks_list:
hook_id = id(hook)
if hook_id not in seen and hasattr(hook, "__enter__"):
seen.add(hook_id)
hook.__enter__()
for _, dynamics in self.sub_stages:
dynamics._open_hooks()
def _close_hooks(self) -> None:
"""Exit context-manager hooks on this stage, fused hooks, and sub-stages."""
super()._close_hooks()
seen: set[int] = set()
for hooks_list in self.fused_hooks.values():
for hook in hooks_list:
hook_id = id(hook)
if hook_id in seen:
continue
seen.add(hook_id)
if hasattr(hook, "__exit__"):
hook.__exit__(None, None, None)
elif hasattr(hook, "close"):
hook.close()
for _, dynamics in self.sub_stages:
dynamics._close_hooks()
def run(
self, batch: Batch | None = None, n_steps: int | None = None
) -> Batch | None:
"""Run the fused stage until all samples converge or the sampler is exhausted.
Supports two modes of execution:
**Mode 1 (external batch loop):** When ``batch`` is provided,
runs the dynamics until ``all_complete`` or until ``n_steps``
have been executed (whichever comes first).
**Mode 2 (inflight batching):** When ``batch is None`` and a
sampler is configured, builds the initial batch from the
sampler and replaces graduated samples every
``refill_frequency`` steps.
.. note::
In Mode 2, ``refill_check`` replaces graduated samples by
extracting remaining graphs via :meth:`Batch.index_select`,
requesting replacements from the sampler, and appending them
via :meth:`Batch.append`. This produces a **new** ``Batch``
object; the ``batch = result`` reassignment in the loop body
updates the local reference. ``None`` is returned when the
sampler is exhausted and no active samples remain, which
triggers termination.
Parameters
----------
batch : Batch | None, optional
The initial batch. If ``None``, uses the sampler to build one.
n_steps : int | None, optional
Maximum number of steps to run. When ``None``, falls back to
``self.n_steps``. When both are ``None``, the loop runs until
``all_complete`` (Mode 1) or sampler exhaustion (Mode 2).
Sub-stages that have no exit criterion (e.g. a plain MD stage)
will loop forever without a step limit, so always pass
``n_steps`` when such a stage is the final sub-stage.
Note: sub-stages with ``n_steps`` set use that value as a per-system
step budget for automatic migration to the next stage.
Returns
-------
Batch | None
The batch after all steps, or ``None`` if the sampler was
exhausted and all samples graduated.
Raises
------
ValueError
If ``batch is None`` and no sampler is configured.
"""
if batch is None:
if self.sampler is None:
raise ValueError("No batch provided and no sampler configured.")
batch = self.sampler.build_initial_batch()
if self.init_fn is not None:
self.init_fn(batch)
self.active_batch = batch
# Ensure bookkeeping fields are present before the loop begins.
self._ensure_bookkeeping_fields(batch)
resolved_steps = n_steps if n_steps is not None else self.n_steps
self._open_hooks()
try:
# Prime forces before the first step so that pre_update can use
# them. _step_impl now runs pre_update BEFORE compute, so without
# this initial forward pass the first step would integrate with
# zero (uninitialised) forces.
self._call_hooks(HookStageEnum.BEFORE_COMPUTE, batch)
self.compute(batch)
self._call_hooks(HookStageEnum.AFTER_COMPUTE, batch)
step_num = 0
while True:
batch, _converged = self.step(batch)
if (
self.sampler is not None
and (step_num + 1) % self.refill_frequency == 0
):
result = self.refill_check(batch, self.exit_status)
if result is None:
self.active_batch = None
return None
batch = result
self.active_batch = batch
elif (
self.sampler is None
and (step_num + 1) % self.convergence_check_frequency == 0
and self.all_complete(batch, self.exit_status)
):
break
step_num += 1
if resolved_steps is not None and step_num >= resolved_steps:
break
return batch
finally:
self._close_hooks()
def __add__(self, other: BaseDynamics) -> FusedStage:
"""Append a sub-stage to this fused stage via ``fused + dyn``.
Parameters
----------
other : BaseDynamics
The dynamics to append to this fused stage.
Returns
-------
FusedStage
A new fused stage with the additional sub-stage appended.
Raises
------
TypeError
If ``other`` is not a ``BaseDynamics`` instance.
Notes
-----
Compilation is deferred when composing via ``+``. If the source
``FusedStage`` had ``compile_step=True``, the returned stage
preserves that intent but does **not** compile eagerly. Call
``.compile()`` explicitly or enter the context manager to trigger
compilation.
"""
if not isinstance(other, BaseDynamics):
raise TypeError(
"Cannot append stage: other must be a BaseDynamics instance. "
f"Got {type(other).__name__} instead."
)
next_code = len(self.sub_stages)
new_sub_stages = list(self.sub_stages) + [(next_code, other)]
new_fused = FusedStage(
sub_stages=new_sub_stages,
entry_status=self.entry_status,
compile_step=False,
compile_kwargs=self.compile_kwargs,
)
# Defer compilation to __enter__ or an explicit .compile() call.
new_fused.compile_step = self.compile_step
return new_fused
[docs]
class DistributedPipeline:
"""Orchestrates multi-rank pipeline execution.
Maps GPU ranks to pipeline stages and coordinates the distributed
step loop. Each rank executes only its assigned stage.
Parameters
----------
stages : dict[int, BaseDynamics]
Mapping from rank to its assigned pipeline stage.
synchronized : bool
If ``True``, insert a global ``dist.barrier()`` across **all**
pipeline ranks after every ``step()`` call, forcing every rank to
complete its current step before any rank proceeds to the next one.
This is primarily useful for debugging ordering or deadlock issues
because it eliminates all inter-rank timing skew.
.. note::
This is distinct from the per-stage ``comm_mode`` parameter on
``_CommunicationMixin``, which controls the blocking behavior
of *pairwise* ``isend``/``irecv`` between adjacent stages.
``synchronized`` enforces a *global* synchronization point
across the entire pipeline and will significantly reduce
throughput; it should be disabled (``False``) in production.
Attributes
----------
stages : dict[int, BaseDynamics]
Rank-to-stage mapping.
synchronized : bool
Whether a global ``dist.barrier()`` is inserted after every step.
_dist_initialized : bool
Whether this DistributedPipeline instance initialized the distributed
process group (used to determine cleanup responsibility).
Examples
--------
>>> # Context manager usage (recommended):
>>> pipeline = DistributedPipeline(stages={0: opt_stage, 1: md_stage})
>>> with pipeline:
... pipeline.run()
...
>>> # Manual usage:
>>> pipeline = DistributedPipeline(stages={0: opt_stage, 1: md_stage})
>>> pipeline.init_distributed()
>>> pipeline.setup()
>>> pipeline.run()
>>> pipeline.cleanup()
>>> # Composing multiple pipelines together
>>> full_pipeline = pipe1 | pipe2 | pipe3
>>> with full_pipeline:
... pipeline.run()
...
"""
[docs]
def __init__(
self,
stages: dict[int, BaseDynamics],
synchronized: bool = False,
debug_mode: bool = False,
**dist_kwargs: Any,
) -> None:
"""Initialize the pipeline.
Parameters
----------
stages : dict[int, BaseDynamics]
Mapping from global rank to pipeline stage.
synchronized : bool, optional
If ``True``, insert a global ``dist.barrier()`` across all
pipeline ranks after every step, preventing any rank from
advancing until all ranks have completed the current step.
Useful for debugging but significantly reduces throughput.
See the class-level docstring for how this differs from the
per-stage ``comm_mode``. Default ``False``.
debug_mode : bool, optional
When ``True``, emit detailed ``loguru.debug`` diagnostics
for inter-rank communication and pipeline orchestration.
Propagated to all stages during ``setup()``. Default ``False``.
**dist_kwargs : Any
Additional keyword arguments for ``torch.distributed.init_process_group``.
"""
dist_kwargs.setdefault(
"backend", "nccl" if dist.is_nccl_available() else "gloo"
)
self.stages = stages
self.synchronized = synchronized
self._dist_initialized: bool = False
self._dist_kwargs = dist_kwargs
self._done_tensor: torch.Tensor | None = None
self.debug_mode = debug_mode
self._templates_shared: bool = False
def __or__(self, other: BaseDynamics | DistributedPipeline) -> DistributedPipeline:
"""Append a stage or merge another pipeline via the ``|`` operator.
Supports three composition patterns::
pipeline | dynamics # append one stage
pipeline | pipeline # merge two pipelines
dyn1 | dyn2 | dyn3 | ... # left-associative chaining
Ranks in the resulting pipeline are renumbered to form a
contiguous sequence. Source/sink dependencies (``prior_rank`` /
``next_rank``) are wired when ``setup()`` is called (e.g. via the
context manager or ``run()``).
Parameters
----------
other : BaseDynamics | DistributedPipeline
A single dynamics stage to append, or another pipeline
whose stages will be absorbed (renumbered) after this
pipeline's stages.
Returns
-------
DistributedPipeline
A new pipeline containing all stages with stages mapped to
contiguous ranks.
Raises
------
TypeError
If ``other`` is not a ``BaseDynamics`` or
``DistributedPipeline`` instance.
RuntimeError
If ``torch.distributed`` is already initialized. Pipeline
composition must occur before entering the pipeline context
or calling ``run()``.
"""
if dist.is_initialized():
raise RuntimeError(
"Cannot compose pipelines after torch.distributed has been "
"initialized. Build the full pipeline topology before "
"entering the pipeline context or calling run()."
)
if isinstance(other, DistributedPipeline):
base_rank = max(self.stages.keys()) + 1
new_stages = {**self.stages}
for i, rank in enumerate(sorted(other.stages.keys())):
new_stages[base_rank + i] = other.stages[rank]
pipeline = DistributedPipeline(
stages=new_stages,
synchronized=self.synchronized or other.synchronized,
**self._dist_kwargs,
)
elif isinstance(other, BaseDynamics):
next_rank = max(self.stages.keys()) + 1
new_stages = {**self.stages, next_rank: other}
pipeline = DistributedPipeline(
stages=new_stages,
synchronized=self.synchronized,
**self._dist_kwargs,
)
else:
raise TypeError(
f"Right operand of | must be a BaseDynamics or "
f"DistributedPipeline instance, got {type(other).__name__}."
)
return pipeline
def _validate_world_size(self) -> None:
"""Validate that the distributed world size matches the expected ranks.
Compares ``torch.distributed.get_world_size()`` against the number
of configured pipeline stages. A mismatch indicates that the
``torchrun`` launch configuration does not match the pipeline
topology.
This method is a no-op if ``torch.distributed`` is not initialized
(e.g., during local testing).
Raises
------
RuntimeError
If the world size does not match the number of configured
pipeline stages.
"""
if not dist.is_initialized():
return
world_size = dist.get_world_size()
expected = len(self.stages)
if world_size != expected:
raise RuntimeError(
f"DistributedPipeline expects {expected} ranks (stages configured "
f"for ranks {sorted(self.stages.keys())}), but "
f"torch.distributed world_size is {world_size}. "
)
def setup(self) -> None:
"""Wire up ``prior_rank`` / ``next_rank`` between adjacent stages.
Sorts stages by rank and connects each stage to its
predecessor and successor.
Raises
------
ValueError
If fewer than 2 stages are provided, or if adjacent stages
have mismatched buffer configurations.
RuntimeError
If the world size does not match the number of configured
pipeline stages.
"""
sorted_ranks = sorted(self.stages.keys())
if len(sorted_ranks) < 2:
raise ValueError("Pipeline requires at least 2 stages.")
self._validate_world_size()
for i, rank in enumerate(sorted_ranks):
stage = self.stages[rank]
if stage.prior_rank == -1:
stage.prior_rank = sorted_ranks[i - 1] if i > 0 else None
if stage.next_rank == -1:
stage.next_rank = (
sorted_ranks[i + 1] if i < len(sorted_ranks) - 1 else None
)
for i in range(len(sorted_ranks) - 1):
rank = sorted_ranks[i]
next_rank = sorted_ranks[i + 1]
sender = self.stages[rank]
receiver = self.stages[next_rank]
s_cfg = getattr(sender, "buffer_config", None)
r_cfg = getattr(receiver, "buffer_config", None)
if s_cfg is None or r_cfg is None:
raise ValueError(
"All stages in a DistributedPipeline must have buffer_config set. "
f"Stage on rank {rank} has buffer_config={s_cfg}, "
f"stage on rank {next_rank} has buffer_config={r_cfg}."
)
if s_cfg != r_cfg:
raise ValueError(
f"Buffer configuration mismatch between rank {rank} "
f"and rank {next_rank}: sender has "
f"BufferConfig(num_systems={s_cfg.num_systems}, "
f"num_nodes={s_cfg.num_nodes}, num_edges={s_cfg.num_edges}), "
f"receiver has "
f"BufferConfig(num_systems={r_cfg.num_systems}, "
f"num_nodes={r_cfg.num_nodes}, num_edges={r_cfg.num_edges}). "
f"Adjacent stages must use identical buffer configurations."
)
n_stages = len(sorted_ranks)
device = self.local_stage.device
self._done_tensor = torch.zeros(n_stages, dtype=torch.int32, device=device)
# move model to device if it isn't there already
model = self.local_stage.model
if not callable(getattr(model, "to", None)):
raise RuntimeError(
"Model expected to possess `to()` method for device"
f" and casting behavior. Passed model is type {type(model)}"
" so ensure class contains this method."
)
else:
self.local_stage.model = model.to(device)
for stage in self.stages.values():
stage.debug_mode = self.debug_mode
def _share_templates(self) -> None:
"""Compute batch schema templates for all stages via local iteration.
Since ``DistributedPipeline`` has all stages in ``self.stages`` on
every rank, templates can be computed locally without inter-rank
communication. Inflight (first) stages build their initial batch
from the sampler and cache an ``empty_like`` template. Downstream
stages derive their template from the upstream stage's cached
template.
This method is idempotent; repeated calls are no-ops once
templates have been computed.
"""
if self._templates_shared:
return
self._templates_shared = True
for rank in sorted(self.stages.keys()):
stage = self.stages[rank]
if stage.is_first_stage and stage.inflight_mode:
if stage.active_batch is None:
stage.active_batch = stage.sampler.build_initial_batch()
if stage.active_batch is not None:
if stage.active_batch.device != stage.device:
stage.active_batch = stage.active_batch.to(stage.device)
stage._recv_template = Batch.empty_like(
stage.active_batch, device=stage.device
)
if self.debug_mode:
logger.debug(
"[rank {}] computed template from inflight sampler",
rank,
)
elif stage.prior_rank is not None:
upstream = self.stages[stage.prior_rank]
if upstream._recv_template is not None:
stage._recv_template = Batch.empty_like(
upstream._recv_template, device=stage.device
)
if self.debug_mode:
logger.debug(
"[rank {}] computed template from upstream rank {}",
rank,
stage.prior_rank,
)
@property
def local_rank(self) -> int:
"""Get the local rank for this process."""
rank = 0
if dist.is_initialized():
rank = dist.get_node_local_rank()
return rank
@property
def global_rank(self) -> int:
"""Get the global rank for this process."""
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
return rank
@property
def local_stage(self) -> BaseDynamics:
"""Get the stage associated with the rank this is executed on."""
return self.stages[self.global_rank]
def step(self) -> None:
"""Execute one timestep for the local rank's stage.
The stage (a ``BaseDynamics`` subclass) handles both the dynamics
step and buffer synchronization.
Supports two modes for the first stage:
**Mode 1 (external batch loop):** Standard flow where the first
stage receives from ``_prestep_sync_buffers`` like other stages.
**Mode 2 (inflight batching):** When the first stage has
``inflight_mode=True`` (i.e., a sampler is configured), it builds
the initial batch from the sampler and refills graduated samples
instead of receiving from a prior stage.
When ``self.synchronized`` is ``True``, a global
``dist.barrier()`` is issued at the end of each step so that no
rank advances until every rank in the pipeline has finished the
current step.
Raises
------
RuntimeError
If ``torch.distributed`` is not initialized, or if the world
size does not match the number of configured pipeline stages.
KeyError
If the current rank is not in the global rank stage mapping.
"""
if not dist.is_initialized():
raise RuntimeError(
"torch.distributed is not initialized. "
"Call torch.distributed.init_process_group() first."
)
rank = self.global_rank
if rank not in self.stages:
raise KeyError(f"Rank {rank} is not assigned to any pipeline stage.")
stage = self.stages[rank]
stage_type = type(stage).__name__
if stage.is_first_stage and stage.inflight_mode:
n_graphs = stage.active_batch.num_graphs if stage.active_batch else 0
if self.debug_mode:
logger.debug(
"[rank {}] inflight step begin | stage={} batch_size={}",
rank,
stage_type,
n_graphs,
)
if stage.active_batch is None and not stage.done:
try:
stage.active_batch = stage.sampler.build_initial_batch()
except RuntimeError:
stage.active_batch = None
if stage.active_batch is not None:
if stage.active_batch.device != stage.device:
stage.active_batch = stage.active_batch.to(stage.device)
if self.debug_mode:
logger.debug(
"[rank {}] built initial batch, {} graphs",
rank,
stage.active_batch.num_graphs,
)
else:
if self.debug_mode:
logger.debug(
"[rank {}] sampler exhausted at build, marking done", rank
)
stage.done = True
if stage.active_batch is not None:
stage._ensure_buffers(stage.active_batch)
elif stage._recv_template is not None:
stage._ensure_buffers(stage._recv_template)
if stage.active_batch is not None:
stage.active_batch, converged_indices = stage.step(stage.active_batch)
n_conv = (
converged_indices.numel() if converged_indices is not None else 0
)
if self.debug_mode:
logger.debug(
"[rank {}] step done | converged={} remaining={}",
rank,
n_conv,
stage.active_batch.num_graphs if stage.active_batch else 0,
)
stage._poststep_sync_buffers(converged_indices)
if hasattr(stage, "exit_status"):
exit_status = stage.exit_status
else:
exit_status = 1
if stage.active_batch is not None:
result = stage.refill_check(stage.active_batch, exit_status)
stage.active_batch = result
if result is None:
if self.debug_mode:
logger.debug(
"[rank {}] sampler exhausted, marking done", rank
)
stage.done = True
else:
if self.debug_mode:
logger.debug(
"[rank {}] active_batch is None after poststep, "
"rebuilding from sampler",
rank,
)
try:
stage.active_batch = stage.sampler.build_initial_batch()
except RuntimeError:
stage.active_batch = None
if stage.active_batch is not None:
if stage.active_batch.device != stage.device:
stage.active_batch = stage.active_batch.to(stage.device)
else:
if self.debug_mode:
logger.debug(
"[rank {}] sampler exhausted, marking done", rank
)
stage.done = True
elif stage.next_rank is not None:
if self.debug_mode:
logger.debug(
"[rank {}] done, sending empty buffer to rank {}",
rank,
stage.next_rank,
)
stage.send_buffer.isend(dst=stage.next_rank).wait()
else:
n_graphs = stage.active_batch.num_graphs if stage.active_batch else 0
if self.debug_mode:
logger.debug(
"[rank {}] downstream step begin | stage={} batch_size={}",
rank,
stage_type,
n_graphs,
)
if stage.active_batch is not None:
stage._ensure_buffers(stage.active_batch)
stage._prestep_sync_buffers()
stage._complete_pending_recv()
converged_indices = None
if stage.active_batch is not None and stage.active_batch.num_graphs > 0:
stage.active_batch, converged_indices = stage.step(stage.active_batch)
n_conv = (
converged_indices.numel() if converged_indices is not None else 0
)
if self.debug_mode:
logger.debug(
"[rank {}] step done | converged={} remaining={}",
rank,
n_conv,
stage.active_batch.num_graphs if stage.active_batch else 0,
)
elif stage.active_batch is not None:
if self.debug_mode:
logger.debug(
"[rank {}] skipping step, active_batch has 0 graphs", rank
)
stage._poststep_sync_buffers(converged_indices)
# Auto-terminate downstream stages: if the upstream is done
# and this stage has no remaining work, mark it as done.
n_active = (
stage.active_batch.num_graphs if stage.active_batch is not None else 0
)
upstream_done = (
self._done_tensor is not None
and stage.prior_rank is not None
and bool(self._done_tensor[stage.prior_rank])
)
if upstream_done and n_active == 0 and not stage.done:
if self.debug_mode:
logger.debug(
"[rank {}] upstream rank {} done and no active work, marking done",
rank,
stage.prior_rank,
)
stage.done = True
if self.synchronized:
if self.debug_mode:
logger.debug("[rank {}] waiting at barrier", rank)
dist.barrier()
def _sync_done_flags(self) -> bool:
"""Synchronize ``done`` flags across all ranks via ``all_reduce``.
Each rank writes its local stage's ``done`` status into the
shared ``_done_tensor`` at its position, then an ``all_reduce``
(``MAX``) broadcasts the flags so every rank sees the global
state.
Returns
-------
bool
``True`` if **all** stages report ``done``.
"""
if self._done_tensor is None:
raise RuntimeError("_done_tensor is not initialized. Call setup() first.")
stage = self.local_stage
self._done_tensor[self.global_rank] = int(stage.done)
if dist.is_initialized():
dist.all_reduce(self._done_tensor, op=dist.ReduceOp.MAX)
all_done = bool(self._done_tensor.all())
if self.debug_mode:
logger.debug(
"[rank {}] done_flags={} all_done={}",
self.global_rank,
self._done_tensor.tolist(),
all_done,
)
return all_done
def run(self) -> None:
"""Run the pipeline loop until all stages report done.
After each ``step()``, an ``all_reduce`` synchronizes the
``done`` flags across all ranks so that every process can
observe the global termination state.
"""
self.setup()
self._share_templates()
iteration = 0
while True:
if self.debug_mode:
logger.debug(
"[rank {}] === pipeline iteration {} ===",
self.global_rank,
iteration,
)
self.step()
if self._sync_done_flags():
if self.debug_mode:
logger.debug("[rank {}] all stages done, exiting", self.global_rank)
break
iteration += 1
def init_distributed(self) -> None:
"""Initialize the ``torch.distributed`` process group.
If ``torch.distributed`` is already initialized, this method is
a no-op. Otherwise, it calls
``torch.distributed.init_process_group(**self._dist_kwargs)``.
The backend and other distributed options are configured via the
constructor's ``**dist_kwargs`` parameter.
Notes
-----
When launching with ``torchrun``, the process group is typically
already initialized. This method provides a convenient fallback
for scripts that do not use ``torchrun``.
"""
if dist.is_initialized():
return
dist.init_process_group(**self._dist_kwargs)
self._dist_initialized = True
def cleanup(self) -> None:
"""Destroy the ``torch.distributed`` process group.
Only destroys the process group if it was initialized by this
``DistributedPipeline`` instance (via :meth:`init_distributed`). If the
process group was externally initialized (e.g., by ``torchrun``),
this method is a no-op.
"""
if self._dist_initialized and dist.is_initialized():
dist.destroy_process_group()
self._dist_initialized = False
def __enter__(self) -> DistributedPipeline:
"""Enter the pipeline context manager.
Calls :meth:`init_distributed` and :meth:`setup` in sequence.
The ``setup()`` call also initializes the distributed
``_done_tensor`` used for coordinated termination.
Returns
-------
DistributedPipeline
This pipeline instance.
"""
self.init_distributed()
self.setup()
self._share_templates()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Exit the pipeline context manager.
Calls :meth:`cleanup` to destroy the process group if it was
initialized by this pipeline.
Parameters
----------
exc_type : type[BaseException] | None
Exception type, if any.
exc_val : BaseException | None
Exception value, if any.
exc_tb : Any
Exception traceback, if any.
"""
self.cleanup()