Note
Go to the end to download the full example code.
Defensive MD: Safety Hooks and Performance Monitoring#
Numerical instabilities are a fact of life in machine-learning molecular dynamics. ML potentials are trained on a finite region of configuration space; geometries outside that region can produce enormous forces, NaN gradients, or energy drift that silently corrupts a long trajectory.
This example demonstrates four hooks that make simulations more robust:
NaNDetectorHook— raisesRuntimeErrorimmediately whenforcesorenergiescontain non-finite values (NaN or Inf). Prevents corrupted state from propagating.MaxForceClampHook— rescales atom force vectors whose L2 norm exceeds a threshold, preventing integration blow-ups from bad initial geometries or extrapolation errors.EnergyDriftMonitorHook— monitors the total energy drift in NVE simulations and warns or raises when it exceeds a threshold. Drift indicates timestep or potential issues.ProfilerHook— records wall-clock time per hook stage and writes a CSV timing log. Essential for identifying bottlenecks in the simulation loop.
Recommended registration order (when using multiple safety hooks):
MaxForceClampHook— clamp forces first so downstream checks see bounded values.NaNDetectorHook— detect any NaN that slipped through clamping.EnergyDriftMonitorHook— monitor cumulative drift (AFTER_STEP).ProfilerHook— spans all stages, so register last or usestages="all".
import logging
import tempfile
from pathlib import Path
import torch
from nvalchemi.data import AtomicData, Batch
from nvalchemi.dynamics import NVE, NVTLangevin
from nvalchemi.dynamics.hooks import (
EnergyDriftMonitorHook,
MaxForceClampHook,
NaNDetectorHook,
NeighborListHook,
ProfilerHook,
WrapPeriodicHook,
)
from nvalchemi.models.demo import DemoModelWrapper
from nvalchemi.models.lj import LennardJonesModelWrapper
logging.basicConfig(level=logging.INFO)
NaNDetectorHook — catching non-finite values early#
A pathological configuration (atoms at identical positions) causes the
LJ potential to diverge. NaNDetectorHook
detects this on the first step and raises RuntimeError with a
diagnostic message listing the affected fields and graph indices.
Without this hook, NaN forces would silently propagate through velocity updates, poisoning the entire trajectory.
lj_model_nan = LennardJonesModelWrapper(epsilon=0.0104, sigma=3.40, cutoff=8.5)
lj_model_nan.eval()
# Place two atoms at nearly identical positions — guaranteed LJ blow-up.
bad_data = AtomicData(
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.001]], dtype=torch.float32),
atomic_numbers=torch.tensor([18, 18], dtype=torch.long),
atomic_masses=torch.full((2,), 39.948),
forces=torch.zeros(2, 3),
energies=torch.zeros(1, 1),
cell=torch.eye(3).unsqueeze(0) * 20.0,
pbc=torch.tensor([[True, True, True]]),
)
bad_data.add_node_property("velocities", torch.zeros(2, 3))
bad_batch = Batch.from_data_list([bad_data])
nan_hook = NaNDetectorHook()
nl_hook_nan = NeighborListHook(lj_model_nan.model_card.neighbor_config)
nvt_nan = NVTLangevin(
model=lj_model_nan,
dt=1.0,
temperature=50.0,
friction=0.1,
random_seed=1,
n_steps=5,
hooks=[nl_hook_nan, nan_hook],
)
logging.info("Running NaNDetectorHook demo (expect RuntimeError)...")
try:
nvt_nan.run(bad_batch)
logging.info("No NaN detected (unexpected for this configuration).")
except RuntimeError as exc:
logging.info("NaNDetectorHook correctly caught: %s", str(exc)[:120])
MaxForceClampHook — surviving bad initial geometries#
MaxForceClampHook rescales any force
vector whose L2 norm exceeds max_force, preserving the direction but
bounding the magnitude. This lets the integrator take a step even when
the potential energy surface produces pathologically large gradients.
Register MaxForceClampHook before NaNDetectorHook so that forces
are bounded before the NaN check fires (both run at AFTER_COMPUTE).
lj_model_clamp = LennardJonesModelWrapper(epsilon=0.0104, sigma=3.40, cutoff=8.5)
lj_model_clamp.eval()
# Atoms randomly placed in a 5 Å box — some pairs will be very close.
clamp_data = _lj_system_bad(n_atoms=8, seed=77, box=5.0)
clamp_batch = Batch.from_data_list([clamp_data])
clamp_hook = MaxForceClampHook(max_force=10.0) # eV/Å
nl_hook_clamp = NeighborListHook(lj_model_clamp.model_card.neighbor_config)
wrap_hook_clamp = WrapPeriodicHook()
nvt_clamp = NVTLangevin(
model=lj_model_clamp,
dt=0.5,
temperature=50.0,
friction=0.1,
random_seed=2,
n_steps=20,
hooks=[nl_hook_clamp, clamp_hook, wrap_hook_clamp],
)
logging.info("Running MaxForceClampHook demo (20 steps with force clamping)...")
clamp_batch = nvt_clamp.run(clamp_batch)
final_fmax = clamp_batch.forces.norm(dim=-1).max().item()
logging.info(
"Completed %d steps. Final fmax=%.4f eV/Å (clamped at 10.0 eV/Å per step).",
nvt_clamp.step_count,
final_fmax,
)
EnergyDriftMonitorHook — NVE energy conservation#
In a well-integrated NVE (microcanonical) simulation the total energy (kinetic + potential) should be conserved. Drift exceeding a threshold indicates an overly large timestep or a non-smooth potential.
metric="per_atom_per_step" normalises the drift by atom count and step
number, making it comparable across systems of different sizes and lengths.
action="warn" emits a log warning rather than stopping the simulation.
demo_model = DemoModelWrapper()
demo_model.eval()
# Provide a system with non-zero initial velocities for kinetic energy.
nve_data = _demo_system(n_atoms=6, seed=10)
KB_EV = 8.617333e-5
kT = 200.0 * KB_EV
g = torch.Generator()
g.manual_seed(11)
nve_data["velocities"] = torch.randn(6, 3, generator=g) * (kT / 1.0) ** 0.5
nve_batch = Batch.from_data_list([nve_data])
drift_hook = EnergyDriftMonitorHook(
threshold=1e-3,
metric="per_atom_per_step",
action="warn",
frequency=10,
include_kinetic=True,
)
nve = NVE(
model=demo_model,
dt=0.5,
n_steps=100,
hooks=[drift_hook],
)
logging.info("Running EnergyDriftMonitorHook demo (100 NVE steps)...")
nve_batch = nve.run(nve_batch)
logging.info(
"NVE run complete. Reference energy captured; drift monitored every 10 steps."
)
ProfilerHook — timing the simulation loop#
ProfilerHook records wall-clock time at
each hook stage. stages="step" instruments only BEFORE_STEP and
AFTER_STEP, giving a clean per-step elapsed time without the overhead
of timing every sub-stage.
log_path writes a CSV with columns: rank, step, stage, t_since_init_s,
delta_s. show_console=True additionally prints a formatted table via
loguru at each console_frequency-th profiled step.
profiler_out = Path(tempfile.mkdtemp()) / "profile.csv"
profiler_hook = ProfilerHook(
stages="step",
timer_backend="auto",
log_path=str(profiler_out),
show_console=True,
console_frequency=10,
frequency=1,
)
prof_data = _demo_system(n_atoms=8, seed=20)
prof_batch = Batch.from_data_list([prof_data])
nvt_prof = NVTLangevin(
model=demo_model,
dt=0.5,
temperature=300.0,
friction=0.1,
random_seed=3,
n_steps=50,
hooks=[profiler_hook],
)
logging.info("Running ProfilerHook demo (50 NVT steps)...")
prof_batch = nvt_prof.run(prof_batch)
profiler_hook.close()
summary = profiler_hook.summary()
for transition, stats in summary.items():
logging.info(
" %s: mean=%.3f ms std=%.3f ms n=%d",
transition,
stats["mean_s"] * 1000,
stats["std_s"] * 1000,
int(stats["n_samples"]),
)
logging.info("Profile CSV written to: %s", profiler_out)
Defensive setup pattern — all four hooks together#
When running an unfamiliar potential on new structures, combine all four hooks. Registration order matters:
NeighborListHook(BEFORE_COMPUTE) — must run before model.MaxForceClampHook(AFTER_COMPUTE) — clamp before NaN check.NaNDetectorHook(AFTER_COMPUTE) — detect remaining bad values.EnergyDriftMonitorHook(AFTER_STEP) — cumulative drift check.ProfilerHook(all stages) — spans the full step loop.
This ordering ensures each hook sees the most up-to-date (and safest) state when it fires.
logging.info("=== Defensive setup pattern example ===")
demo_model2 = DemoModelWrapper()
demo_model2.eval()
safe_data = _demo_system(n_atoms=5, seed=99)
safe_batch = Batch.from_data_list([safe_data])
clamp = MaxForceClampHook(max_force=50.0)
nan_check = NaNDetectorHook(frequency=1)
drift_check = EnergyDriftMonitorHook(
threshold=1.0, # generous for demo purposes
metric="absolute",
action="warn",
frequency=5,
)
profiler = ProfilerHook(stages="step", timer_backend="auto", frequency=1)
safe_nvt = NVTLangevin(
model=demo_model2,
dt=0.5,
temperature=300.0,
friction=0.1,
random_seed=42,
n_steps=30,
# Hooks fire in registration order within the same stage.
hooks=[clamp, nan_check, drift_check, profiler],
)
safe_batch = safe_nvt.run(safe_batch)
profiler.close()
logging.info(
"Defensive run complete: %d steps, no exceptions raised.",
safe_nvt.step_count,
)
Total running time of the script: (0 minutes 0.245 seconds)