Hooks — Observe & Modify#
Hooks are the primary extensibility mechanism for dynamics simulations. They let you inject custom logic at any stage of the integration step without modifying the integrator itself.
The Hook protocol#
Any object matching the Hook protocol can
be registered:
from nvalchemi.dynamics import Hook, HookStageEnum
class MyHook:
"""A minimal custom hook — no inheritance required."""
frequency: int = 1
stage: HookStageEnum = HookStageEnum.AFTER_STEP
def __call__(self, batch, dynamics):
print(f"Step {dynamics.step_count}: energy = {batch.energies.mean():.4f}")
Because Hook is a runtime_checkable Protocol, you can also
use it as a type hint and check membership with isinstance:
assert isinstance(MyHook(), Hook) # True ✓
Tip
No subclassing required. The protocol approach means any
class—or even a frozen dataclass—that provides
frequency, stage, and __call__ works as a hook.
Hook stages#
HookStageEnum defines nine insertion
points that cover every phase of a dynamics step:
BEFORE_STEP ─────────────────────────────────────────────────┐
│ │
│ BEFORE_PRE_UPDATE → pre_update() → AFTER_PRE_UPDATE │
│ BEFORE_COMPUTE → compute() → AFTER_COMPUTE │
│ BEFORE_POST_UPDATE→ post_update()→ AFTER_POST_UPDATE │
│ │
AFTER_STEP ──────────────────────────────────────────────────┘
ON_CONVERGE (fires only when convergence is detected)
Stage |
Value |
When it fires |
|---|---|---|
|
0 |
Very start of each step, before any operations. |
|
1 |
Before the first integrator half-step (positions). |
|
2 |
After positions are updated, before the forward pass. |
|
3 |
Before the model forward pass. |
|
4 |
After forces/energies are written to the batch. |
|
5 |
Before the second integrator half-step (velocities). |
|
6 |
After velocities are updated. |
|
7 |
Very end of the step, after all operations. |
|
8 |
Only when the convergence hook detects converged samples. |
Registration and execution#
Hooks are registered either at construction or via register_hook():
from nvalchemi.dynamics import DemoDynamics
from nvalchemi.dynamics.hooks import LoggingHook, NaNDetectorHook
# At construction (recommended for most cases)
dynamics = DemoDynamics(
model=model,
dt=0.5,
hooks=[LoggingHook(frequency=100), NaNDetectorHook()],
)
# Or register later
dynamics.register_hook(MaxForceClampHook(max_force=50.0))
Hooks are dispatched by BaseDynamics._call_hooks(stage, batch).
At each stage, all registered hooks for that stage fire in
registration order, but only if step_count % hook.frequency == 0.
Note
At step_count == 0 all hooks fire (since 0 % n == 0 for
any n), making step 0 a good point for initialization logic.
Built-in hooks reference#
The nvalchemi.dynamics.hooks package ships eleven production-ready
hooks organized into four categories.
Pre-compute hooks (modify batch, fire at BEFORE_COMPUTE)#
These hooks prepare the batch before the model forward pass.
Hook |
Purpose |
|---|---|
Compute or refresh the neighbor list ( |
Observer hooks (read-only, fire at AFTER_STEP)#
These hooks do not modify the batch — they record, log, or monitor simulation state.
Hook |
Purpose |
|---|---|
Log scalar observables (energy, fmax, temperature) to
|
|
Write the full batch state to a
|
|
Write only newly converged samples to a
|
|
Track cumulative energy drift in NVE runs; warn or halt on excessive drift. |
|
Instrument steps with NVTX ranges and wall-clock timing for
Nsight Systems profiling. Fires at |
Post-compute hooks (modify batch, fire at AFTER_COMPUTE)#
These hooks modify the batch after the model forward pass and before the velocity update.
Hook |
Purpose |
|---|---|
Detect NaN/Inf in forces and energies; raise with diagnostic info (affected graph indices, step count). |
|
Clamp per-atom force magnitudes to a safe maximum, preserving force direction. Prevents numerical explosions. |
|
Add an external bias potential (energies + forces) for enhanced sampling: umbrella sampling, metadynamics, steered MD, harmonic restraints, wall potentials. |
|
Wrap atomic positions back into the unit cell under PBC.
Fires at |
Constraint hooks (modify batch, fire at BEFORE_PRE_UPDATE)#
These hooks enforce geometric constraints across integration steps.
Hook |
Purpose |
|---|---|
Freeze atoms by category (e.g. substrate, boundary). Snapshots
positions at |
Usage examples#
Logging to CSV every 100 steps#
from nvalchemi.dynamics.hooks import LoggingHook
hook = LoggingHook(frequency=100, backend="csv", log_path="md_log.csv")
dynamics = DemoDynamics(model=model, n_steps=10_000, dt=0.5, hooks=[hook])
dynamics.run(batch)
Recording trajectories to a data sink#
from nvalchemi.dynamics.hooks import SnapshotHook
from nvalchemi.dynamics import HostMemory
sink = HostMemory(capacity=10_000)
hook = SnapshotHook(sink=sink, frequency=10)
dynamics = DemoDynamics(model=model, n_steps=1_000, dt=0.5, hooks=[hook])
dynamics.run(batch) # 100 snapshots
trajectory = sink.read()
Safety: NaN detection + force clamping#
from nvalchemi.dynamics.hooks import MaxForceClampHook, NaNDetectorHook
dynamics = DemoDynamics(
model=model,
dt=0.5,
hooks=[
# Clamp first, then check — both fire at AFTER_COMPUTE
# in registration order.
MaxForceClampHook(max_force=50.0, log_clamps=True),
NaNDetectorHook(extra_keys=["stresses"]),
],
)
Enhanced sampling with a bias potential#
from nvalchemi.dynamics.hooks import BiasedPotentialHook
def harmonic_restraint(batch):
"""Restrain center of mass to the origin."""
k = 10.0 # eV/Ų
com = batch.positions.mean(dim=0, keepdim=True)
bias_energy = 0.5 * k * (com ** 2).sum().unsqueeze(0).unsqueeze(0)
bias_forces = -k * com.expand_as(batch.positions) / batch.num_nodes
return bias_energy, bias_forces
hook = BiasedPotentialHook(bias_fn=harmonic_restraint)
dynamics = DemoDynamics(model=model, dt=0.5, hooks=[hook])
Profiling with Nsight Systems#
from nvalchemi.dynamics.hooks import ProfilerHook
hook = ProfilerHook(enable_nvtx=True, enable_timer=True, frequency=10)
dynamics = DemoDynamics(model=model, n_steps=1_000, dt=0.5, hooks=[hook])
# Run under: nsys profile python my_script.py
dynamics.run(batch)
NVE energy drift monitoring#
from nvalchemi.dynamics.hooks import EnergyDriftMonitorHook
hook = EnergyDriftMonitorHook(
threshold=1e-5,
metric="per_atom_per_step",
action="raise", # or "warn" for production
frequency=100,
)
dynamics = DemoDynamics(model=model, dt=0.5, hooks=[hook])
Custom scalars via LoggingHook#
from nvalchemi.dynamics.hooks import LoggingHook
def pressure(batch, dynamics):
"""Compute instantaneous pressure from the virial."""
return compute_pressure(batch.stresses, batch.cell)
hook = LoggingHook(
frequency=50,
custom_scalars={"pressure": pressure},
)
Writing a custom hook from scratch#
from nvalchemi.dynamics import HookStageEnum
class VelocityRescaleHook:
"""Rescale velocities to a target temperature (Berendsen-like)."""
frequency: int
stage = HookStageEnum.AFTER_POST_UPDATE
def __init__(self, target_temp: float, tau: float, frequency: int = 1):
self.target_temp = target_temp
self.tau = tau
self.frequency = frequency
def __call__(self, batch, dynamics):
current_temp = compute_temperature(batch)
scale = (1.0 + (dynamics.dt / self.tau)
* (self.target_temp / current_temp - 1.0)) ** 0.5
batch.velocities.mul_(scale)
Hooks inside FusedStage#
When hooks are registered on sub-stage dynamics inside a
FusedStage, their firing semantics differ
slightly from standalone execution:
Fired on each sub-stage:
BEFORE_STEP,AFTER_COMPUTE,BEFORE_PRE_UPDATE,AFTER_POST_UPDATE,AFTER_STEP,ON_CONVERGE
Not fired on sub-stages (because the forward pass is shared):
BEFORE_COMPUTE,AFTER_PRE_UPDATE,BEFORE_POST_UPDATE
This means safety hooks (NaNDetectorHook, MaxForceClampHook)
and observer hooks (LoggingHook, SnapshotHook) work as expected
inside fused stages, since they fire at AFTER_COMPUTE or
AFTER_STEP.
Hook ordering inside a fused step:
for each sub-stage:
BEFORE_STEP hooks
── single compute() ──
for each sub-stage:
AFTER_COMPUTE hooks
for each sub-stage:
BEFORE_PRE_UPDATE hooks
masked_update() (if any samples match status)
AFTER_POST_UPDATE hooks
for each sub-stage:
AFTER_STEP hooks
for each sub-stage:
convergence check → ON_CONVERGE hooks