Source code for nvalchemi.dynamics.hooks.bias
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Biased potential hooks for enhanced sampling workflows.
Provides :class:`BiasedPotentialHook`, which adds external bias
potentials to the forces and energies computed by the ML model.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from nvalchemi.dynamics.hooks._base import _PostComputeHook
if TYPE_CHECKING:
from collections.abc import Callable
from nvalchemi._typing import Energy, Forces
from nvalchemi.data import Batch
from nvalchemi.dynamics.base import BaseDynamics
__all__ = ["BiasedPotentialHook"]
[docs]
class BiasedPotentialHook(_PostComputeHook):
"""Add an external bias potential to forces and energies after the forward pass.
This hook enables enhanced sampling techniques by composing an
arbitrary bias potential on top of the ML potential **without**
modifying the model itself. The bias is applied in-place to
``batch.forces`` and ``batch.energies`` at
:attr:`~HookStageEnum.AFTER_COMPUTE`, keeping the model output
pure and the bias fully composable.
The bias is specified via a callable ``bias_fn`` with signature::
bias_fn(batch: Batch) -> tuple[Tensor, Tensor]
Returns (bias_energy, bias_forces) where:
- bias_energy: Float[Tensor, "B 1"] — per-graph energy bias
- bias_forces: Float[Tensor, "V 3"] — per-atom force bias
The hook adds the bias terms to the existing batch values::
batch.energies += bias_energy
batch.forces += bias_forces
This design supports a wide range of enhanced sampling methods:
* **Harmonic restraints** — bias that penalizes deviation from a
reference geometry (e.g. collective variable restraints).
* **Umbrella sampling** — harmonic bias along a reaction
coordinate, with the umbrella window center parameterizing
``bias_fn``.
* **Metadynamics** — time-dependent Gaussian bias deposited along
collective variables. ``bias_fn`` would maintain internal state
(deposited Gaussians) and compute the accumulated bias.
* **Steered MD** — time-dependent bias that pulls the system along
a path. ``bias_fn`` can read ``dynamics.step_count`` to vary
the bias over time.
* **Wall potentials** — repulsive bias that confines atoms to a
region of space (e.g. preventing evaporation from a surface).
Parameters
----------
bias_fn : Callable[[Batch], tuple[Tensor, Tensor]]
A callable that computes the bias energy and forces given the
current batch. Must return a tuple of
``(bias_energy, bias_forces)`` with shapes ``(B, 1)`` and
``(V, 3)`` respectively, on the same device as the batch.
frequency : int, optional
Apply the bias every ``frequency`` steps. Default ``1``
(every step).
inplace : bool, optional
If True, will modify energies and forces in the batch in-place.
Otherwise, replaces the existing tensors.
Attributes
----------
bias_fn : Callable
The bias potential function.
frequency : int
Bias application frequency in steps.
stage : HookStageEnum
Fixed to ``AFTER_COMPUTE``.
Examples
--------
Harmonic restraint on center of mass:
>>> import torch
>>> from nvalchemi.dynamics.hooks import BiasedPotentialHook
>>> def harmonic_restraint(batch):
... # Restrain center of mass to origin with k=10 eV/A^2
... k = 10.0
... com = batch.positions.mean(dim=0, keepdim=True)
... bias_energy = 0.5 * k * (com ** 2).sum().unsqueeze(0).unsqueeze(0)
... bias_forces = -k * com.expand_as(batch.positions) / batch.num_nodes
... return bias_energy, bias_forces
>>> hook = BiasedPotentialHook(bias_fn=harmonic_restraint)
Notes
-----
* The ``bias_fn`` is called **after** the model forward pass, so
it has access to the model-computed forces and energies via the
batch if needed (e.g. for force-matching penalties).
* Because the bias modifies forces in-place, it interacts correctly
with :class:`MaxForceClampHook` — register the clamp hook
**after** the bias hook to clamp the total (model + bias) forces.
* For metadynamics, the ``bias_fn`` closure should hold a reference
to a mutable state object (e.g. a list of deposited Gaussians)
that is updated externally or within the callable.
* The bias does **not** contribute to the autograd graph. If the
model uses conservative forces (``forces_via_autograd=True``),
the bias forces are added after ``torch.autograd.grad`` has
already computed the model forces.
"""
[docs]
def __init__(
self,
bias_fn: Callable[[Batch], tuple[Energy, Forces]],
frequency: int = 1,
inplace: bool = True,
) -> None:
super().__init__(frequency=frequency)
self.bias_fn = bias_fn
self.inplace = inplace
def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None:
"""Compute and add the bias potential to forces and energies.
Parameters
----------
batch : Batch
The current batch of atomic data. ``batch.forces`` and
``batch.energies`` are modified in-place.
dynamics : BaseDynamics
The dynamics engine instance.
Raises
------
RuntimeError
If the bias tensors have incompatible shapes.
"""
bias_energy, bias_forces = self.bias_fn(batch)
if bias_energy.shape != batch.energies.shape:
raise RuntimeError(
f"bias_energy shape {bias_energy.shape} does not match "
f"batch.energies shape {batch.energies.shape}"
)
if bias_forces.shape != batch.forces.shape:
raise RuntimeError(
f"bias_forces shape {bias_forces.shape} does not match "
f"batch.forces shape {batch.forces.shape}"
)
if self.inplace:
batch.energies.add_(bias_energy)
batch.forces.add_(bias_forces)
else:
batch["energies"] = batch.energies + bias_energy
batch["forces"] = batch.forces + bias_forces