Source code for nvalchemi.dynamics.hooks.freeze
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Freeze atoms hook for constraining selected atoms during dynamics.
Provides :class:`FreezeAtomsHook`, which freezes atoms by category,
restoring their positions and zeroing velocities each step.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from nvalchemi._typing import AtomCategory
from nvalchemi.dynamics.base import HookStageEnum
if TYPE_CHECKING:
from nvalchemi.data import Batch
from nvalchemi.dynamics.base import BaseDynamics
__all__ = ["FreezeAtomsHook"]
[docs]
class FreezeAtomsHook:
"""Freeze selected atoms during molecular dynamics simulation.
During dynamics, certain atoms may need to remain fixed in place,
such as substrate atoms in surface simulations, boundary atoms in
slab models, or anchor atoms in constrained optimization. This
hook identifies atoms by their ``atom_categories`` field and
constrains them by:
1. **Snapshotting positions** — At ``BEFORE_PRE_UPDATE``, the hook
snapshots all atomic positions.
2. **Restoring positions** — At ``AFTER_POST_UPDATE``, the hook
restores positions of frozen atoms using ``torch.where``,
effectively undoing any displacement applied by the integrator.
3. **Zeroing velocities** — Velocities of frozen atoms are set to
zero to prevent momentum accumulation.
4. **Optionally zeroing forces** — By default, forces on frozen
atoms are also zeroed. This prevents force contributions from
propagating through the integrator and ensures clean energy
conservation diagnostics.
The hook fires at two stages: ``BEFORE_PRE_UPDATE`` (to snapshot
positions) and ``AFTER_POST_UPDATE`` (to restore frozen positions
and zero velocities/forces). This two-stage design enables
``torch.compile(fullgraph=True)`` compatibility by avoiding
data-dependent branching.
Parameters
----------
frequency : int, optional
Apply constraints every ``frequency`` steps. Default ``1``
(every step). Setting this higher than 1 is not recommended
as frozen atoms will drift between constraint applications.
freeze_category : int, optional
The ``atom_categories`` value that identifies frozen atoms.
Default is ``AtomCategory.SPECIAL.value`` (-1). Atoms with
``batch.atom_categories == freeze_category`` will be frozen.
zero_forces : bool, optional
Whether to zero forces on frozen atoms. Default ``True``.
Set to ``False`` if you need to measure forces on frozen
atoms for analysis purposes.
Attributes
----------
frequency : int
Constraint application frequency in steps.
freeze_category : int
Category value identifying frozen atoms.
zero_forces : bool
Whether forces are zeroed on frozen atoms.
stage : HookStageEnum
Primary stage, set to ``BEFORE_PRE_UPDATE`` for protocol compliance.
stages : tuple[HookStageEnum, ...]
Tuple of stages at which this hook fires: ``BEFORE_PRE_UPDATE``
and ``AFTER_POST_UPDATE``.
Examples
--------
Freeze atoms marked as SPECIAL (default):
>>> from nvalchemi.dynamics.hooks import FreezeAtomsHook
>>> hook = FreezeAtomsHook()
>>> dynamics = DemoDynamics(model=model, n_steps=1000, hooks=[hook])
>>> dynamics.run(batch)
Freeze bulk atoms instead:
>>> from nvalchemi._typing import AtomCategory
>>> hook = FreezeAtomsHook(freeze_category=AtomCategory.BULK.value)
Keep forces for analysis:
>>> hook = FreezeAtomsHook(zero_forces=False)
Notes
-----
* Fires at two stages: ``BEFORE_PRE_UPDATE`` (snapshot all positions)
and ``AFTER_POST_UPDATE`` (restore frozen positions via ``torch.where``).
* Uses ``torch.where`` for branchless GPU-vectorized restore, enabling
``torch.compile(fullgraph=True)`` compatibility.
* All positions are snapshotted each step (not just frozen ones) to
avoid shape-dependent logic.
* When using with :class:`WrapPeriodicHook`, both hooks fire at
``AFTER_POST_UPDATE``. Registration order determines execution order;
register this hook **before** the periodic wrapping hook to ensure
frozen positions are restored before wrapping is applied.
"""
stage: HookStageEnum = HookStageEnum.BEFORE_PRE_UPDATE
stages: tuple[HookStageEnum, ...] = (
HookStageEnum.BEFORE_PRE_UPDATE,
HookStageEnum.AFTER_POST_UPDATE,
)
[docs]
def __init__(
self,
frequency: int = 1,
freeze_category: int = AtomCategory.SPECIAL.value,
zero_forces: bool = True,
) -> None:
self.frequency = frequency
self.freeze_category = freeze_category
self.zero_forces = zero_forces
self._saved_positions: torch.Tensor | None = None
def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None:
"""Apply freeze constraints to the batch in-place.
At ``BEFORE_PRE_UPDATE``, snapshots all positions. At
``AFTER_POST_UPDATE``, restores frozen atom positions and
zeros their velocities (and optionally forces).
The restore stage runs under :func:`torch.no_grad` because
``positions`` may carry ``requires_grad=True`` from the model's
conservative-force computation. This mirrors the pattern used by
:meth:`BaseDynamics.step` when restoring graduated-sample state.
Parameters
----------
batch : Batch
The current batch of atomic data. ``batch.positions``,
``batch.velocities``, and optionally ``batch.forces`` are
modified in-place during the restore stage.
dynamics : BaseDynamics
The dynamics engine instance. Uses ``dynamics.current_hook_stage``
to determine which stage is being executed.
"""
if dynamics.current_hook_stage == HookStageEnum.BEFORE_PRE_UPDATE:
# Snapshot ALL positions (no shape-dependent logic)
self._saved_positions = batch.positions.clone()
else:
# AFTER_POST_UPDATE: restore frozen positions via torch.where.
# torch.no_grad() is required because positions may have
# requires_grad=True from the model forward pass.
with torch.no_grad():
# mask shape: [V] -> [V, 1] for broadcasting with [V, 3]
mask = (batch.atom_categories == self.freeze_category).unsqueeze(-1)
zeros = torch.zeros_like(batch.positions)
batch.positions.copy_(
torch.where(mask, self._saved_positions, batch.positions)
)
batch.velocities.copy_(torch.where(mask, zeros, batch.velocities))
if self.zero_forces:
batch.forces.copy_(torch.where(mask, zeros, batch.forces))