Source code for nvalchemi.hooks.periodic

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Periodic boundary condition hook for coordinate wrapping.

Provides :class:`WrapPeriodicHook`, which wraps atomic positions back
into the unit cell under periodic boundary conditions.
"""

from __future__ import annotations

from enum import Enum

import torch
import warp as wp
from jaxtyping import Float
from nvalchemiops.dynamics.utils import compute_cell_inverse, wrap_positions_to_cell

from nvalchemi.data import Batch
from nvalchemi.hooks._context import HookContext

__all__ = ["WrapPeriodicHook"]


# ---------------------------------------------------------------------------
# Custom op + helper for position wrapping (moved from dynamics/hooks/_utils)
# ---------------------------------------------------------------------------


@torch.library.custom_op("nvalchemi_hooks::wrap_positions", mutates_args=())
def _wrap_positions(
    positions: torch.Tensor,
    cell: torch.Tensor,
    batch_idx: torch.Tensor,
) -> torch.Tensor:
    vec_dtype = wp.vec3d if positions.dtype == torch.float64 else wp.vec3f
    mat_dtype = wp.mat33d if positions.dtype == torch.float64 else wp.mat33f

    # Transpose cell from row-convention (nvalchemi) to column-convention (nvalchemiops)
    cell_T = cell.transpose(-2, -1).contiguous()

    # Convert to warp arrays
    num_systems = cell_T.shape[0]
    wp_pos = wp.from_torch(positions.clone().contiguous(), dtype=vec_dtype)
    wp_cell = wp.from_torch(cell_T, dtype=mat_dtype)
    wp_cell_inv = wp.zeros(num_systems, dtype=mat_dtype, device=wp_pos.device)
    wp_batch_idx = wp.from_torch(batch_idx.to(torch.int32))

    compute_cell_inverse(wp_cell, wp_cell_inv)
    wrap_positions_to_cell(wp_pos, wp_cell, wp_cell_inv, wp_batch_idx)

    return wp.to_torch(wp_pos)


@_wrap_positions.register_fake
def _(
    positions: torch.Tensor,
    cell: torch.Tensor,
    batch_idx: torch.Tensor,
) -> torch.Tensor:
    return torch.empty_like(positions)


def wrap_positions_into_cell(
    positions: Float[torch.Tensor, "V 3"],
    cell: Float[torch.Tensor, "B 3 3"],
    pbc: torch.Tensor,
    batch_idx: torch.Tensor,
) -> Float[torch.Tensor, "V 3"]:
    """Wrap positions into the unit cell using fractional coordinates.

    Respects per-dimension periodicity: only periodic dimensions are
    wrapped.  Non-periodic dimensions are left unchanged.

    This function modifies ``positions`` **in-place** and returns the
    same tensor.  Delegates to ``nvalchemiops.dynamics.utils.wrap_positions_to_cell``
    for GPU-optimized wrapping, then applies per-dimension PBC masking
    in pure PyTorch.

    Parameters
    ----------
    positions : Float[Tensor, "V 3"]
        Per-atom Cartesian positions. Modified in-place.
    cell : Float[Tensor, "B 3 3"]
        Lattice vectors as rows, one ``(3, 3)`` matrix per graph.
    pbc : Tensor
        Per-dimension periodicity flags, shape ``(B, 3)``, boolean.
    batch_idx : Tensor
        Per-atom graph membership indices of shape ``(V,)``.

    Returns
    -------
    Float[Tensor, "V 3"]
        The same ``positions`` tensor (modified in-place).
    """
    original = positions.clone()
    wrapped = _wrap_positions(positions, cell, batch_idx)

    # Restore non-periodic dimensions
    per_atom_pbc = pbc[batch_idx]  # (V, 3)
    positions.copy_(torch.where(per_atom_pbc, wrapped, original))
    return positions


[docs] class WrapPeriodicHook: """Wrap atomic positions back into the simulation cell under PBC. During long molecular dynamics trajectories, atomic positions drift away from the unit cell as the integrator applies unbounded displacements. While physically valid (forces are invariant under lattice translations), large coordinates can cause problems: * **Neighbor list overflow** — distance calculations may exceed the numerical range of the cell-shift representation, leading to missed interactions or incorrect forces. * **Precision loss** — large coordinate magnitudes reduce the effective floating-point precision available for inter-atomic distances. * **Visualization artifacts** — trajectories with unwrapped coordinates are difficult to analyze and visualize. This hook wraps positions back into the unit cell by computing fractional coordinates, taking their modulo, and converting back to Cartesian:: frac = positions @ inv(cell) frac = frac % 1.0 positions = frac @ cell The wrapping is applied in-place to ``batch.positions`` and respects per-system periodicity flags in ``batch.pbc``: * If ``batch.pbc`` is ``[True, True, True]``, all three dimensions are wrapped. * If ``batch.pbc`` is ``[True, True, False]`` (e.g. a slab), only the *x* and *y* coordinates are wrapped; the *z* coordinate is left unwrapped to allow vacuum gaps. * If ``batch.pbc`` is ``[False, False, False]`` (non-periodic), the hook is a no-op for that system. The hook fires at :attr:`~DynamicsStage.AFTER_POST_UPDATE`, after velocities have been updated but before the next step begins. This ensures that the neighbor list built at the start of the next step sees wrapped coordinates. Parameters ---------- frequency : int, optional Wrap positions every ``frequency`` steps. Default ``1`` (every step). For simulations with moderate drift, wrapping every 10--100 steps is sufficient and reduces overhead. stage : Enum | None, optional The workflow stage at which this hook runs. Defaults to ``None`` (stage-agnostic until registered with a specific engine). Attributes ---------- frequency : int Wrapping frequency in steps. stage : Enum | None The stage at which this hook fires. Examples -------- >>> from nvalchemi.hooks import WrapPeriodicHook >>> from nvalchemi.dynamics.base import DynamicsStage >>> hook = WrapPeriodicHook(frequency=10, stage=DynamicsStage.AFTER_POST_UPDATE) >>> dynamics = DemoDynamics(model=model, n_steps=10_000, dt=0.5, hooks=[hook]) >>> dynamics.run(batch) Notes ----- * Wrapping does **not** modify velocities, momenta, or forces — only positions. This is correct because forces depend on relative distances (invariant under translation) and velocities are already in Cartesian space. * For triclinic (non-orthorhombic) cells, the fractional-coordinate approach naturally handles skewed lattice vectors. * This hook assumes ``batch.cell`` has shape ``(B, 3, 3)`` with lattice vectors as **rows** (consistent with ASE convention). * In batched simulations, wrapping is applied per-graph using ``batch.batch_idx`` to associate each atom with its cell. """
[docs] def __init__( self, frequency: int = 1, stage: Enum | None = None, ) -> None: self.frequency = frequency self.stage = stage
def _wrap_positions(self, batch: Batch) -> None: """Wrap positions into the unit cell in-place.""" cell = batch.cell pbc = batch.pbc # System-level tensors may have a leading singleton dim: (B, 1, 3, 3) -> (B, 3, 3) if cell.dim() == 4: cell = cell.squeeze(1) if pbc.dim() == 3: pbc = pbc.squeeze(1) wrap_positions_into_cell(batch.positions, cell, pbc, batch.batch_idx) def __call__(self, ctx: HookContext, stage: Enum) -> None: """Wrap positions into the unit cell in-place.""" self._wrap_positions(ctx.batch)