Source code for nvalchemi.dynamics.hooks.monitors

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Diagnostic monitor hooks for long-running simulations.

Provides :class:`EnergyDriftMonitorHook`, which tracks cumulative
energy drift over time and can warn or halt the simulation if the
drift exceeds a configurable threshold.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import torch
from loguru import logger

from nvalchemi.dynamics.hooks._base import _ObserverHook
from nvalchemi.dynamics.hooks._utils import kinetic_energy_per_graph

if TYPE_CHECKING:
    from nvalchemi.data import Batch
    from nvalchemi.dynamics.base import BaseDynamics

__all__ = ["EnergyDriftMonitorHook"]


[docs] class EnergyDriftMonitorHook(_ObserverHook): """Track energy drift and warn or stop if it exceeds a threshold. In a well-behaved NVE (microcanonical) simulation with a symplectic integrator, the total energy should be conserved to within numerical precision. Significant energy drift indicates problems with: * The integration timestep (too large for the force magnitudes). * The ML potential (non-smooth or discontinuous energy surface). * Numerical precision (single vs. double precision accumulation). * Force clamping or other hook-induced modifications breaking energy conservation. This hook monitors the **total energy** (potential + kinetic) over the simulation and computes drift metrics. It supports two modes: **Absolute drift mode** (``metric="absolute"``) Tracks ``|E(t) - E(0)|``, the absolute deviation from the initial total energy. Suitable for NVE validation runs. **Per-atom-per-step drift mode** (``metric="per_atom_per_step"``) Tracks ``|E(t) - E(0)| / (N_atoms * step_count)``, a normalized metric that allows comparison across systems of different size and simulation length. This is the standard metric reported in ML potential benchmarks. When the drift exceeds ``threshold``, the hook either emits a warning or raises a :class:`RuntimeError`, controlled by the ``action`` parameter. The hook records the reference energy on the first firing and computes drift on all subsequent firings. For NVT or NPT simulations, energy drift is expected (the thermostat/barostat injects or removes energy), so use this hook primarily for NVE validation. Parameters ---------- threshold : float Maximum acceptable drift before triggering the ``action``. Units depend on ``metric``: eV for ``"absolute"``, eV/atom/step for ``"per_atom_per_step"``. metric : {"absolute", "per_atom_per_step"}, optional Drift metric to use. Default ``"per_atom_per_step"``. action : {"warn", "raise"}, optional What to do when the threshold is exceeded. ``"warn"`` emits a :mod:`loguru` warning; ``"raise"`` raises a :class:`RuntimeError`. Default ``"warn"``. frequency : int, optional Evaluate drift every ``frequency`` steps. Default ``1``. include_kinetic : bool, optional Whether to include kinetic energy in the total energy calculation. Set to ``False`` if only monitoring potential energy drift (e.g. for optimizers). Default ``True``. Attributes ---------- threshold : float Drift threshold. metric : str Drift metric mode. action : str Threshold violation behavior. include_kinetic : bool Whether kinetic energy is included. frequency : int Evaluation frequency in steps. stage : HookStageEnum Fixed to ``AFTER_STEP``. Examples -------- NVE validation with strict drift tolerance: >>> from nvalchemi.dynamics.hooks import EnergyDriftMonitorHook >>> hook = EnergyDriftMonitorHook( ... threshold=1e-5, ... metric="per_atom_per_step", ... action="raise", ... frequency=100, ... ) >>> dynamics = DemoDynamics(model=model, n_steps=10_000, dt=0.5, hooks=[hook]) >>> dynamics.run(batch) Soft monitoring during production: >>> hook = EnergyDriftMonitorHook( ... threshold=1e-3, ... action="warn", ... frequency=1000, ... ) Notes ----- * The reference energy is captured on the **first** hook firing (step 0 by default), not at construction time. This allows the hook to be registered before the batch is available. * For batched simulations, drift is computed **per graph** and the maximum drift across all graphs is compared to the threshold. """
[docs] def __init__( self, threshold: float, metric: Literal["absolute", "per_atom_per_step"] = "per_atom_per_step", action: Literal["warn", "raise"] = "warn", frequency: int = 1, include_kinetic: bool = True, ) -> None: super().__init__(frequency=frequency) self.threshold = threshold self.metric = metric self.action = action self.include_kinetic = include_kinetic self._reference_total_energy: torch.Tensor | None = None
@torch.compiler.disable def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None: """Compute energy drift and compare against the threshold. On the first firing, this method captures the reference total energy and returns immediately. On all subsequent firings, it computes drift relative to that reference and compares against the configured threshold. Parameters ---------- batch : Batch The current batch of atomic data. Must have ``energies`` (and ``velocities`` if ``include_kinetic=True``). dynamics : BaseDynamics The dynamics engine instance. Raises ------ RuntimeError If ``action="raise"`` and drift exceeds the threshold. """ energy = batch.energies.squeeze(-1) # (B,) if self.include_kinetic and getattr(batch, "velocities", None) is not None: ke = kinetic_energy_per_graph( batch.velocities, batch.atomic_masses, batch.batch, batch.num_graphs, ).squeeze(-1) # (B,) total = energy + ke else: total = energy # Capture reference on first firing if self._reference_total_energy is None: self._reference_total_energy = total.clone() return drift = (total - self._reference_total_energy).abs() # (B,) if self.metric == "per_atom_per_step": step_count = max(dynamics.step_count, 1) drift = drift / (batch.num_nodes_per_graph * step_count) max_drift = drift.max().item() if max_drift > self.threshold: msg = ( f"Energy drift {max_drift:.2e} exceeds threshold " f"{self.threshold:.2e} at step {dynamics.step_count}" f" on rank {dynamics.global_rank}." ) if self.action == "raise": raise RuntimeError(msg) else: # TODO: use a distributed aware logger logger.warning(msg)