Implementing Custom Dynamics#

This guide walks through the developer contract for creating a new integrator, using DemoDynamics (Velocity Verlet) as the running example.

The developer contract#

To implement a new integrator you must:

  1. Subclass BaseDynamics.

  2. Override pre_update(batch) and/or post_update(batch).

  3. Declare __needs_keys__ and __provides_keys__.

You should not override step(), compute(), or run() — the base class orchestrates hook firing, model evaluation, output validation, and convergence checking.

from nvalchemi.dynamics import BaseDynamics
from nvalchemi.data import Batch

class MyIntegrator(BaseDynamics):
    """Minimal skeleton for a custom integrator."""

    # Declare what the model must produce
    __needs_keys__: set[str] = {"forces"}

    # Declare what this integrator writes to the batch
    __provides_keys__: set[str] = {"velocities", "positions"}

    def pre_update(self, batch: Batch) -> None:
        """First half-step: update positions using current state."""
        ...

    def post_update(self, batch: Batch) -> None:
        """Second half-step: update velocities using new forces."""
        ...

__needs_keys__ and __provides_keys__#

These class-level sets drive automatic validation:

  • After every compute() call, BaseDynamics._validate_model_outputs checks that every key in __needs_keys__ is present and non-None in the model outputs. A clear RuntimeError is raised otherwise.

  • __provides_keys__ documents which additional batch fields the integrator writes (beyond model outputs like forces and energies). The diagnostic helper _validate_batch_keys can verify them.

When dynamics are composed into a FusedStage, the fused stage computes the union of all sub-stage keys automatically:

fused = relax + md  # __needs_keys__ = relax.__needs_keys__ | md.__needs_keys__

Walkthrough: DemoDynamics (Velocity Verlet)#

The full implementation lives in nvalchemi/dynamics/demo.py. Let’s break it down section by section.

Class declaration and keys#

class DemoDynamics(BaseDynamics):
    """Velocity Verlet integrator for molecular dynamics simulations."""

    __needs_keys__: set[str] = {"forces"}
    __provides_keys__: set[str] = {"velocities", "positions"}

    _prev_accelerations: torch.Tensor | None

The integrator requires forces from the model and writes velocities and positions back to the batch. A private _prev_accelerations cache stores the previous step’s accelerations for the half-step update.

Constructor#

def __init__(
    self,
    model: BaseModelMixin,
    n_steps: int,
    dt: float = 1.0,
    hooks: list[Hook] | None = None,
    convergence_hook: ConvergenceHook | dict | None = None,
    **kwargs: Any,
) -> None:
    super().__init__(
        model=model,
        hooks=hooks,
        convergence_hook=convergence_hook,
        n_steps=n_steps,
        **kwargs,              # ← forwards communication kwargs
    )
    self.dt = dt
    self._prev_accelerations = None

The **kwargs forwarding is essential for cooperative MRO: BaseDynamics.__init__ forwards to _CommunicationMixin.__init__, which accepts prior_rank, next_rank, sinks, max_batch_size, sampler, etc. By forwarding **kwargs, a single constructor call configures both the integrator and the communication layer.

Note that dt is not part of the base class — each subclass that needs a timestep should accept it explicitly and store it as self.dt:

# Works seamlessly in a pipeline context
dyn = DemoDynamics(
    model=model,
    dt=0.5,
    max_batch_size=64,
    comm_mode="async_recv",
)

pre_update: position half-step#

def pre_update(self, batch: Batch) -> None:
    positions: NodePositions = batch.positions
    velocities: NodeVelocities = batch.velocities
    forces: Forces | None = batch.forces
    masses = batch.atomic_masses.unsqueeze(-1)

    dt = self.dt

    with torch.no_grad():
        if forces is not None and not torch.all(forces == 0):
            accelerations = forces / masses
            self._prev_accelerations = accelerations.clone()
            # x(t+dt) = x(t) + v(t)*dt + 0.5*a(t)*dt²
            positions.add_(velocities * dt + 0.5 * accelerations * dt * dt)
        else:
            # First step: Euler fallback
            positions.add_(velocities * dt)

Key patterns:

  • In-place tensor ops (add_, copy_) — the batch is modified in-place; never reassign batch.positions = ....

  • ``torch.no_grad()`` context — avoids conflicts when the model uses conservative (autograd) forces.

  • Type annotations from nvalchemi._typingNodePositions, NodeVelocities, Forces provide jaxtyping shape documentation.

post_update: velocity half-step#

def post_update(self, batch: Batch) -> None:
    velocities: NodeVelocities = batch.velocities
    forces: Forces = batch.forces
    masses = batch.atomic_masses.unsqueeze(-1)

    dt = self.dt

    with torch.no_grad():
        new_accelerations = forces / masses

        if self._prev_accelerations is not None:
            # v(t+dt) = v(t) + 0.5*(a(t) + a(t+dt))*dt
            velocities.add_(
                0.5 * (self._prev_accelerations + new_accelerations) * dt,
            )
        else:
            # First step: Euler fallback
            velocities.add_(new_accelerations * dt)

At this point, forces are the new forces from compute(), which ran between pre_update and post_update. The standard Velocity Verlet averaging of old and new accelerations gives symplectic, time-reversible integration.

How step() orchestrates everything#

You do not override step(). The base class runs this sequence on every call:

1.  BEFORE_STEP hooks
2.  BEFORE_PRE_UPDATE hooks  →  pre_update()  →  AFTER_PRE_UPDATE hooks
3.  BEFORE_COMPUTE hooks     →  compute()      →  AFTER_COMPUTE hooks
4.  BEFORE_POST_UPDATE hooks →  post_update()  →  AFTER_POST_UPDATE hooks
5.  AFTER_STEP hooks
6.  convergence check  →  ON_CONVERGE hooks (if any samples converged)
7.  step_count += 1

compute() handles the full model pipeline: forward pass → adapt_output()_validate_model_outputs() → write forces/energies to batch via copy_().

masked_update for FusedStage compatibility#

When your dynamics is composed via + into a FusedStage, the fused stage calls masked_update(batch, mask) instead of pre_update / post_update directly. The default implementation in BaseDynamics is:

def masked_update(self, batch, mask):
    # Expand graph-level mask → node-level via batch.batch
    node_mask = mask[batch.batch]

    # Snapshot unmasked state
    original_positions = batch.positions.clone()
    original_velocities = batch.velocities.clone() if ... else None

    # Run full updates
    self.pre_update(batch)
    self.post_update(batch)

    # Restore unmasked nodes
    with torch.no_grad():
        batch.positions[~node_mask] = original_positions[~node_mask]
        if original_velocities is not None:
            batch.velocities[~node_mask] = original_velocities[~node_mask]

This means your custom pre_update / post_update work correctly inside a FusedStage without any modifications. The mask selectively applies your updates only to samples at the corresponding status code.

Checklist for a new integrator#

☐  Subclass BaseDynamics
☐  Set __needs_keys__   (e.g. {"forces"})
☐  Set __provides_keys__ (e.g. {"velocities", "positions"})
☐  Override pre_update(batch)  — first half-step (positions)
☐  Override post_update(batch) — second half-step (velocities)
☐  Use in-place tensor ops (add_, copy_) — never reassign batch attrs
☐  Wrap updates in torch.no_grad() if model is conservative
☐  Forward **kwargs in __init__ for communication support
☐  Accept and store `dt` (or other integrator-specific params) directly
☐  Write tests using DemoModelWrapper fixtures