Implementing Custom Dynamics#
This guide walks through the developer contract for creating a new
integrator, using DemoDynamics (Velocity
Verlet) as the running example.
The developer contract#
To implement a new integrator you must:
Subclass
BaseDynamics.Override
pre_update(batch)and/orpost_update(batch).Declare
__needs_keys__and__provides_keys__.
You should not override step(), compute(), or run() —
the base class orchestrates hook firing, model evaluation, output
validation, and convergence checking.
from nvalchemi.dynamics import BaseDynamics
from nvalchemi.data import Batch
class MyIntegrator(BaseDynamics):
"""Minimal skeleton for a custom integrator."""
# Declare what the model must produce
__needs_keys__: set[str] = {"forces"}
# Declare what this integrator writes to the batch
__provides_keys__: set[str] = {"velocities", "positions"}
def pre_update(self, batch: Batch) -> None:
"""First half-step: update positions using current state."""
...
def post_update(self, batch: Batch) -> None:
"""Second half-step: update velocities using new forces."""
...
__needs_keys__ and __provides_keys__#
These class-level sets drive automatic validation:
After every
compute()call,BaseDynamics._validate_model_outputschecks that every key in__needs_keys__is present and non-Nonein the model outputs. A clearRuntimeErroris raised otherwise.__provides_keys__documents which additional batch fields the integrator writes (beyond model outputs like forces and energies). The diagnostic helper_validate_batch_keyscan verify them.
When dynamics are composed into a FusedStage,
the fused stage computes the union of all sub-stage keys
automatically:
fused = relax + md # __needs_keys__ = relax.__needs_keys__ | md.__needs_keys__
Walkthrough: DemoDynamics (Velocity Verlet)#
The full implementation lives in nvalchemi/dynamics/demo.py. Let’s
break it down section by section.
Class declaration and keys#
class DemoDynamics(BaseDynamics):
"""Velocity Verlet integrator for molecular dynamics simulations."""
__needs_keys__: set[str] = {"forces"}
__provides_keys__: set[str] = {"velocities", "positions"}
_prev_accelerations: torch.Tensor | None
The integrator requires forces from the model and writes
velocities and positions back to the batch. A private
_prev_accelerations cache stores the previous step’s accelerations
for the half-step update.
Constructor#
def __init__(
self,
model: BaseModelMixin,
n_steps: int,
dt: float = 1.0,
hooks: list[Hook] | None = None,
convergence_hook: ConvergenceHook | dict | None = None,
**kwargs: Any,
) -> None:
super().__init__(
model=model,
hooks=hooks,
convergence_hook=convergence_hook,
n_steps=n_steps,
**kwargs, # ← forwards communication kwargs
)
self.dt = dt
self._prev_accelerations = None
The **kwargs forwarding is essential for cooperative MRO:
BaseDynamics.__init__ forwards to _CommunicationMixin.__init__,
which accepts prior_rank, next_rank, sinks,
max_batch_size, sampler, etc. By forwarding **kwargs, a
single constructor call configures both the integrator and the
communication layer.
Note that dt is not part of the base class — each subclass
that needs a timestep should accept it explicitly and store it as
self.dt:
# Works seamlessly in a pipeline context
dyn = DemoDynamics(
model=model,
dt=0.5,
max_batch_size=64,
comm_mode="async_recv",
)
pre_update: position half-step#
def pre_update(self, batch: Batch) -> None:
positions: NodePositions = batch.positions
velocities: NodeVelocities = batch.velocities
forces: Forces | None = batch.forces
masses = batch.atomic_masses.unsqueeze(-1)
dt = self.dt
with torch.no_grad():
if forces is not None and not torch.all(forces == 0):
accelerations = forces / masses
self._prev_accelerations = accelerations.clone()
# x(t+dt) = x(t) + v(t)*dt + 0.5*a(t)*dt²
positions.add_(velocities * dt + 0.5 * accelerations * dt * dt)
else:
# First step: Euler fallback
positions.add_(velocities * dt)
Key patterns:
In-place tensor ops (
add_,copy_) — the batch is modified in-place; never reassignbatch.positions = ....``torch.no_grad()`` context — avoids conflicts when the model uses conservative (autograd) forces.
Type annotations from
nvalchemi._typing—NodePositions,NodeVelocities,Forcesprovide jaxtyping shape documentation.
post_update: velocity half-step#
def post_update(self, batch: Batch) -> None:
velocities: NodeVelocities = batch.velocities
forces: Forces = batch.forces
masses = batch.atomic_masses.unsqueeze(-1)
dt = self.dt
with torch.no_grad():
new_accelerations = forces / masses
if self._prev_accelerations is not None:
# v(t+dt) = v(t) + 0.5*(a(t) + a(t+dt))*dt
velocities.add_(
0.5 * (self._prev_accelerations + new_accelerations) * dt,
)
else:
# First step: Euler fallback
velocities.add_(new_accelerations * dt)
At this point, forces are the new forces from compute(),
which ran between pre_update and post_update. The standard
Velocity Verlet averaging of old and new accelerations gives symplectic,
time-reversible integration.
How step() orchestrates everything#
You do not override step(). The base class runs this sequence on
every call:
1. BEFORE_STEP hooks
2. BEFORE_PRE_UPDATE hooks → pre_update() → AFTER_PRE_UPDATE hooks
3. BEFORE_COMPUTE hooks → compute() → AFTER_COMPUTE hooks
4. BEFORE_POST_UPDATE hooks → post_update() → AFTER_POST_UPDATE hooks
5. AFTER_STEP hooks
6. convergence check → ON_CONVERGE hooks (if any samples converged)
7. step_count += 1
compute() handles the full model pipeline: forward pass →
adapt_output() → _validate_model_outputs() → write
forces/energies to batch via copy_().
masked_update for FusedStage compatibility#
When your dynamics is composed via + into a
FusedStage, the fused stage calls
masked_update(batch, mask) instead of pre_update / post_update
directly. The default implementation in BaseDynamics is:
def masked_update(self, batch, mask):
# Expand graph-level mask → node-level via batch.batch
node_mask = mask[batch.batch]
# Snapshot unmasked state
original_positions = batch.positions.clone()
original_velocities = batch.velocities.clone() if ... else None
# Run full updates
self.pre_update(batch)
self.post_update(batch)
# Restore unmasked nodes
with torch.no_grad():
batch.positions[~node_mask] = original_positions[~node_mask]
if original_velocities is not None:
batch.velocities[~node_mask] = original_velocities[~node_mask]
This means your custom pre_update / post_update work correctly
inside a FusedStage without any modifications. The mask
selectively applies your updates only to samples at the corresponding
status code.
Checklist for a new integrator#
☐ Subclass BaseDynamics
☐ Set __needs_keys__ (e.g. {"forces"})
☐ Set __provides_keys__ (e.g. {"velocities", "positions"})
☐ Override pre_update(batch) — first half-step (positions)
☐ Override post_update(batch) — second half-step (velocities)
☐ Use in-place tensor ops (add_, copy_) — never reassign batch attrs
☐ Wrap updates in torch.no_grad() if model is conservative
☐ Forward **kwargs in __init__ for communication support
☐ Accept and store `dt` (or other integrator-specific params) directly
☐ Write tests using DemoModelWrapper fixtures