Source code for nvalchemi.dynamics.hooks.freeze
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Freeze atoms hook for constraining selected atoms during dynamics.
Provides :class:`FreezeAtomsHook`, which freezes atoms by category,
restoring their positions and zeroing velocities each step.
"""
from __future__ import annotations
from enum import Enum
import torch
from nvalchemi._typing import AtomCategory
from nvalchemi.data import Batch
from nvalchemi.dynamics.base import DynamicsStage
from nvalchemi.hooks._context import HookContext
__all__ = ["FreezeAtomsHook"]
[docs]
class FreezeAtomsHook:
"""Freeze selected atoms during molecular dynamics simulation.
During dynamics, certain atoms may need to remain fixed in place,
such as substrate atoms in surface simulations, boundary atoms in
slab models, or anchor atoms in constrained optimization. This
hook identifies atoms by their ``atom_categories`` field and
constrains them by:
1. **Snapshotting positions** — At ``BEFORE_PRE_UPDATE``, the hook
snapshots all atomic positions.
2. **Restoring positions** — At ``AFTER_POST_UPDATE``, the hook
restores positions of frozen atoms using ``torch.where``,
effectively undoing any displacement applied by the integrator.
3. **Zeroing velocities** — Velocities of frozen atoms are set to
zero to prevent momentum accumulation.
4. **Optionally zeroing forces** — By default, forces on frozen
atoms are also zeroed. This prevents force contributions from
propagating through the integrator and ensures clean energy
conservation diagnostics.
The hook fires at two stages: ``BEFORE_PRE_UPDATE`` (to snapshot
positions) and ``AFTER_POST_UPDATE`` (to restore frozen positions
and zero velocities/forces). This two-stage design enables
``torch.compile(fullgraph=True)`` compatibility by avoiding
data-dependent branching.
Parameters
----------
frequency : int, optional
Apply constraints every ``frequency`` steps. Default ``1``
(every step). Setting this higher than 1 is not recommended
as frozen atoms will drift between constraint applications.
freeze_category : int, optional
The ``atom_categories`` value that identifies frozen atoms.
Default is ``AtomCategory.SPECIAL.value`` (-1). Atoms with
``batch.atom_categories == freeze_category`` will be frozen.
zero_forces : bool, optional
Whether to zero forces on frozen atoms. Default ``True``.
Set to ``False`` if you need to measure forces on frozen
atoms for analysis purposes.
Attributes
----------
frequency : int
Constraint application frequency in steps.
freeze_category : int
Category value identifying frozen atoms.
zero_forces : bool
Whether forces are zeroed on frozen atoms.
stage : DynamicsStage
Primary stage, set to ``BEFORE_PRE_UPDATE`` for protocol compliance.
Examples
--------
Freeze atoms marked as SPECIAL (default):
>>> from nvalchemi.dynamics.hooks import FreezeAtomsHook
>>> hook = FreezeAtomsHook()
>>> dynamics = DemoDynamics(model=model, n_steps=1000, hooks=[hook])
>>> dynamics.run(batch)
Freeze bulk atoms instead:
>>> from nvalchemi._typing import AtomCategory
>>> hook = FreezeAtomsHook(freeze_category=AtomCategory.BULK.value)
Keep forces for analysis:
>>> hook = FreezeAtomsHook(zero_forces=False)
Notes
-----
* Fires at two stages: ``BEFORE_PRE_UPDATE`` (snapshot all positions)
and ``AFTER_POST_UPDATE`` (restore frozen positions via ``torch.where``).
* Uses ``torch.where`` for branchless GPU-vectorized restore, enabling
``torch.compile(fullgraph=True)`` compatibility.
* All positions are snapshotted each step (not just frozen ones) to
avoid shape-dependent logic.
* When using with :class:`WrapPeriodicHook`, both hooks fire at
``AFTER_POST_UPDATE``. Registration order determines execution order;
register this hook **before** the periodic wrapping hook to ensure
frozen positions are restored before wrapping is applied.
"""
[docs]
def __init__(
self,
frequency: int = 1,
freeze_category: int = AtomCategory.SPECIAL.value,
zero_forces: bool = True,
stage: Enum = DynamicsStage.BEFORE_PRE_UPDATE,
) -> None:
self.frequency = frequency
self.freeze_category = freeze_category
self.zero_forces = zero_forces
self.stage = stage
# Multi-stage hooks need both a primary stage and a list of active stages
self._active_stages = frozenset(
{DynamicsStage.BEFORE_PRE_UPDATE, DynamicsStage.AFTER_POST_UPDATE}
)
self._saved_positions: torch.Tensor | None = None
def _runs_on_stage(self, stage: Enum) -> bool:
"""Check if this hook should run on the given stage.
Parameters
----------
stage : Enum
The stage to check.
Returns
-------
bool
True if this hook runs on the given stage.
"""
return stage in self._active_stages
def _restore(self, batch: Batch) -> None:
"""Restore frozen atom positions and zero velocities/forces.
The restore logic runs under :func:`torch.no_grad` because
``positions`` may carry ``requires_grad=True`` from the model's
conservative-force computation.
Parameters
----------
batch : Batch
The current batch of atomic data. ``batch.positions``,
``batch.velocities``, and optionally ``batch.forces`` are
modified in-place.
"""
with torch.no_grad():
# mask shape: [V] -> [V, 1] for broadcasting with [V, 3]
mask = (batch.atom_categories == self.freeze_category).unsqueeze(-1)
zeros = torch.zeros_like(batch.positions)
batch.positions.copy_(
torch.where(mask, self._saved_positions, batch.positions)
)
batch.velocities.copy_(torch.where(mask, zeros, batch.velocities))
if self.zero_forces:
batch.forces.copy_(torch.where(mask, zeros, batch.forces))
def __call__(self, ctx: HookContext, stage: Enum) -> None:
"""Snapshot or restore frozen atom positions."""
if stage == DynamicsStage.BEFORE_PRE_UPDATE:
self._saved_positions = ctx.batch.positions.clone()
else:
self._restore(ctx.batch)