Multi-Criteria Convergence with Custom Operators#

ConvergenceHook detects when geometry optimisation or MD has reached a desired stopping condition and removes converged systems from the active batch so subsequent steps spend compute only on unconverged systems.

A convergence hook holds a list of criteria, each of which evaluates a tensor attribute of the batch and returns a per-system boolean. A system is declared converged when every criterion is satisfied simultaneously. This AND-logic prevents false positives: a system with near-zero forces but still large energy fluctuations is not yet converged.

Key concepts demonstrated#

  • ConvergenceHook.from_forces(threshold) — the simplest one-liner.

  • Multi-criteria convergence combining force norm and a second criterion.

  • custom_op — a callable that receives the raw tensor and returns a [B] bool mask, used here to implement an energy-change criterion.

  • Combining force-norm and energy-change criteria in a FIRE optimisation.

from __future__ import annotations

import logging

import torch

from nvalchemi.data import AtomicData, Batch
from nvalchemi.dynamics import FIRE
from nvalchemi.dynamics.base import ConvergenceHook, HookStageEnum
from nvalchemi.dynamics.hooks import NeighborListHook
from nvalchemi.models.lj import LennardJonesModelWrapper

logging.basicConfig(level=logging.INFO)

LJ model and helper#

LJ_EPSILON = 0.0104
LJ_SIGMA = 3.40
LJ_CUTOFF = 8.5
_R_MIN = 2 ** (1 / 6) * LJ_SIGMA

model = LennardJonesModelWrapper(
    epsilon=LJ_EPSILON,
    sigma=LJ_SIGMA,
    cutoff=LJ_CUTOFF,
    max_neighbors=32,
)


def _make_cluster(
    n_per_side: int = 2, spacing_factor: float = 1.05, seed: int = 0
) -> AtomicData:
    n = n_per_side**3
    spacing = _R_MIN * spacing_factor
    coords = torch.arange(n_per_side, dtype=torch.float32) * spacing
    gx, gy, gz = torch.meshgrid(coords, coords, coords, indexing="ij")
    positions = torch.stack([gx.flatten(), gy.flatten(), gz.flatten()], dim=-1)
    torch.manual_seed(seed)
    positions = positions + 0.05 * torch.randn_like(positions)
    return AtomicData(
        positions=positions,
        atomic_numbers=torch.full((n,), 18, dtype=torch.long),
        forces=torch.zeros(n, 3),
        energies=torch.zeros(1, 1),
        velocities=torch.zeros(n, 3),
    )

Built-in convergence: from_forces factory#

ConvergenceHook.from_forces(threshold) is the standard one-liner for geometry optimisation. It reads the "forces" key from the batch, computes the per-atom Euclidean norm, takes the max over atoms within each graph (scatter reduce), and declares convergence when that max norm ≤ threshold.

simple_hook = ConvergenceHook.from_forces(threshold=0.05)
print("Simple hook:", simple_hook)
Simple hook: ConvergenceHook(criteria=[_ConvergenceCriterion(key='forces', threshold=0.05, reduce_op='norm', reduce_dims=-1)], frequency=1)

Multi-criteria convergence#

Pass a list of criterion dicts. Each dict maps to a _ConvergenceCriterion.

reduce_op="norm" computes the per-atom vector norm along reduce_dims=-1 (last axis = Cartesian), yielding a [N] scalar per atom. The criterion then scatter-reduces to graph level via max before comparing to threshold.

You can add as many criteria as you like; all must be satisfied.

dual_hook = ConvergenceHook(
    criteria=[
        # Force criterion: max per-atom force norm ≤ 0.05 eV/Å.
        {
            "key": "forces",
            "threshold": 0.05,
            "reduce_op": "norm",
            "reduce_dims": -1,
        },
        # Energy criterion: per-system energy ≤ -0.1 eV (cluster is bound).
        # This guards against spurious convergence at high energy.
        {
            "key": "energies",
            "threshold": -0.1,
        },
    ]
)
print("Dual hook:", dual_hook)
Dual hook: ConvergenceHook(criteria=[_ConvergenceCriterion(key='forces', threshold=0.05, reduce_op='norm', reduce_dims=-1), _ConvergenceCriterion(key='energies', threshold=-0.1)], frequency=1)

Custom operator: energy-change criterion#

Built-in reduce_op values cover many cases, but sometimes you need arbitrary logic. custom_op receives the raw tensor for that key (whatever shape the batch stores it in) and must return a [B] bool tensor.

Here we implement a relative energy-change criterion: a system is converged when \(|\Delta E / E| < \varepsilon\) between consecutive steps. This requires state — we track the previous energies in a closure.

The custom_op callable is called with the full batch.energies tensor of shape [B, 1].

prev_energies: dict[str, torch.Tensor] = {}  # mutable state accessible via closure


def energy_change_criterion(energies: torch.Tensor) -> torch.Tensor:
    """Return True for systems whose relative energy change is < 1e-4.

    Parameters
    ----------
    energies : torch.Tensor
        Shape ``[B, 1]`` — per-system total energies in eV.

    Returns
    -------
    torch.Tensor
        Shape ``[B]`` boolean — True where |ΔE/E| < 1e-4.
    """
    e = energies.squeeze(-1)  # [B]
    if "last" not in prev_energies:
        # First call: cannot compute delta, treat all as unconverged.
        prev_energies["last"] = e.detach().clone()
        return torch.zeros(e.shape[0], dtype=torch.bool, device=e.device)

    delta = (e - prev_energies["last"]).abs()
    denom = prev_energies["last"].abs().clamp(min=1e-12)
    rel_change = delta / denom
    prev_energies["last"] = e.detach().clone()
    return rel_change < 1e-4


custom_hook = ConvergenceHook(
    criteria=[
        {
            "key": "energies",
            "threshold": 0.0,  # threshold is ignored when custom_op is set
            "custom_op": energy_change_criterion,
        }
    ]
)
print("Custom hook:", custom_hook)
Custom hook: ConvergenceHook(criteria=[_ConvergenceCriterion(key='energies', custom_op=energy_change_criterion)], frequency=1)

Practical example: dual force + energy-change convergence#

Combine a force-norm criterion with the energy-change criterion so that FIRE stops only when the optimizer has truly converged — both forces are small AND the energy is stable.

The energy-change guard prevents early exit when the optimizer happens to take a near-zero force step during a large momentum phase.

print("\n=== FIRE with dual force+energy-change convergence ===")

# Reset the shared closure state for a clean run.
prev_energies.clear()

data_list = [
    _make_cluster(2, spacing_factor=1.05, seed=0),
    _make_cluster(2, spacing_factor=1.20, seed=1),
]
batch = Batch.from_data_list(data_list)
print(f"Batch: {batch.num_graphs} systems, {batch.num_nodes} atoms total\n")

dual_custom_hook = ConvergenceHook(
    criteria=[
        {
            "key": "forces",
            "threshold": 0.01,
            "reduce_op": "norm",
            "reduce_dims": -1,
        },
        {
            "key": "energies",
            "threshold": 0.0,
            "custom_op": energy_change_criterion,
        },
    ]
)

fire = FIRE(
    model=model,
    dt=0.5,
    n_steps=500,
    convergence_hook=dual_custom_hook,
)
fire.register_hook(NeighborListHook(model.model_card.neighbor_config))


class _LogHook:
    """Log energy and fmax every 50 steps."""

    stage = HookStageEnum.AFTER_STEP
    frequency = 50

    def __call__(self, batch: Batch, dynamics) -> None:
        step = dynamics.step_count + 1
        energies = batch.energies.squeeze(-1)
        fmax_per_sys = torch.zeros(batch.num_graphs, device=batch.device)
        fmax_per_sys.scatter_reduce_(
            0, batch.batch, batch.forces.norm(dim=-1), reduce="amax", include_self=True
        )
        rows = [
            f"  sys{i}: E={energies[i].item():+.5f} eV  fmax={fmax_per_sys[i].item():.5f} eV/Å"
            for i in range(batch.num_graphs)
        ]
        print(f"[step {step:4d}]\n" + "\n".join(rows))


fire.register_hook(_LogHook())
batch = fire.run(batch)
print(f"\nCompleted {fire.step_count} FIRE steps (dual convergence).")

final_energies = batch.energies.squeeze(-1)
for i in range(batch.num_graphs):
    print(f"  sys{i}: final E = {final_energies[i].item():+.6f} eV")
=== FIRE with dual force+energy-change convergence ===
Batch: 2 systems, 16 atoms total

[step    1]
  sys0: E=-0.14072 eV  fmax=0.01609 eV/Å
  sys1: E=-0.07711 eV  fmax=0.01232 eV/Å
[step   51]
  sys0: E=-0.15918 eV  fmax=0.00185 eV/Å
  sys1: E=-0.15530 eV  fmax=0.01766 eV/Å
[step  101]
  sys0: E=-0.19725 eV  fmax=0.00136 eV/Å
  sys1: E=-0.18039 eV  fmax=0.01052 eV/Å

Completed 143 FIRE steps (dual convergence).
  sys0: final E = -0.197351 eV
  sys1: final E = -0.196677 eV

Total running time of the script: (0 minutes 0.202 seconds)

Gallery generated by Sphinx-Gallery