Source code for nvalchemi.dynamics.hooks.neighbor_list

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neighbor list hook for on-the-fly neighbor list construction.

This module provides :class:`NeighborListHook`, which runs at the
``BEFORE_COMPUTE`` stage to compute or refresh the neighbor list stored in
the batch before the model forward pass.  It supports an optional Verlet
skin buffer to avoid recomputing neighbors every step.

Both ``MATRIX`` and ``COO`` neighbor formats are supported for dynamic
updates (i.e. updates each dynamics step).  For ``COO`` format the hook
creates or replaces the edges group on the batch each step so that
``batch.edge_index`` (shape ``(E, 2)``) and ``batch.unit_shifts``
(shape ``(E, 3)``, PBC only) are always up to date.  The companion
``Batch.edge_ptr`` property derives the per-atom CSR pointer on demand.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from nvalchemiops.torch.neighbors import neighbor_list

from nvalchemi.dynamics.base import HookStageEnum
from nvalchemi.models.base import NeighborConfig, NeighborListFormat

if TYPE_CHECKING:
    from nvalchemi.data import Batch
    from nvalchemi.dynamics.base import BaseDynamics


[docs] class NeighborListHook: """Compute and cache neighbor lists before each model evaluation. This hook runs at :attr:`~HookStageEnum.BEFORE_COMPUTE` and writes neighbor data into the batch so that the model's ``adapt_input`` can read it. An optional Verlet skin buffer avoids rebuilding the list every step: the list is only recomputed when the maximum atomic displacement since the last build exceeds ``config.skin / 2``, or when the set of active systems changes (detected via ``system_id``). For ``MATRIX`` format the following tensors are written to the atoms group of the batch (and thus accessible as ``batch.neighbor_matrix`` etc.): * ``neighbor_matrix`` — shape ``(N, max_neighbors)``, int32 * ``num_neighbors`` — shape ``(N,)``, int32 * ``neighbor_shifts`` — shape ``(N, max_neighbors, 3)``, int32 (only written when PBC is active) For ``COO`` format the edges group of the batch is created or replaced on every rebuild, making the following accessible: * ``batch.edge_index`` — shape ``(E, 2)``, int32 (nvalchemi convention) * ``batch.unit_shifts`` — shape ``(E, 3)``, int32 (only when PBC active) * ``batch.edge_ptr`` — shape ``(N+1,)``, int32, derived on demand via the :attr:`~nvalchemi.data.Batch.edge_ptr` property Parameters ---------- config : NeighborConfig Neighbor list configuration read from the model card. The ``max_neighbors`` field must be set when ``format=MATRIX``. Raises ------ ValueError If ``format=MATRIX`` and ``config.max_neighbors`` is not set. """ stage: HookStageEnum = HookStageEnum.BEFORE_COMPUTE frequency: int = 1
[docs] def __init__(self, config: NeighborConfig) -> None: self.config = config self._neighbor_list_flag = config.format == NeighborListFormat.COO self._ref_positions: torch.Tensor | None = None self._ref_system_ids: torch.Tensor | None = None
# ------------------------------------------------------------------ # Main hook entry point # ------------------------------------------------------------------ def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None: """Recompute the neighbor list and write it to *batch*.""" self._rebuild(batch) # Update skin-buffer reference state. self._ref_positions = batch.positions.detach().clone() try: self._ref_system_ids = batch.system_id.detach().clone() except AttributeError: self._ref_system_ids = None # ------------------------------------------------------------------ # Neighbor list construction # ------------------------------------------------------------------ def _rebuild(self, batch: Batch) -> None: """Build the neighbor list and write results into the batch.""" positions = batch.positions # (N, 3) batch_ptr = batch.ptr.to(torch.int32) # (B+1,) int32 # Detect PBC. try: pbc = batch.pbc # (B, 3) bool cell = batch.cell # (B, 3, 3) float except AttributeError: pbc = None cell = None result = neighbor_list( positions=positions, cutoff=self.config.cutoff, cell=cell, pbc=pbc, max_neighbors=self.config.max_neighbors, half_fill=self.config.half_list, batch_ptr=batch_ptr, return_neighbor_list=self._neighbor_list_flag, ) if self._neighbor_list_flag: edge_index = result[0] # (2, E) int32 # result[1] is the per-atom edge_ptr — not stored explicitly; # batch.edge_ptr is a computed property that derives it from the # edges group on demand. unit_shifts = result[2] if len(result) > 2 else None # (E, 3) int32 # Build the edges group. Per-graph segment lengths are computed # from the source-atom indices in edge_index and the per-atom # graph assignment in batch.batch. from nvalchemi.data.level_storage import SegmentedLevelStorage E = edge_index.shape[1] B = batch.num_graphs src_atoms = edge_index[0].long() # (E,) graph_per_edge = batch.batch.long()[src_atoms] # (E,) seg_lengths = torch.zeros(B, dtype=torch.int32, device=positions.device) seg_lengths.scatter_add_( 0, graph_per_edge, torch.ones(E, dtype=torch.int32, device=positions.device), ) # Store edge_index in nvalchemi's (E, 2) convention so that # model adapt_input methods (e.g. MACEWrapper) can read it # directly with a .T transpose. data_dict: dict[str, torch.Tensor] = { "edge_index": edge_index.T.contiguous(), # (E, 2) } if unit_shifts is not None: data_dict["unit_shifts"] = unit_shifts # (E, 3) # Replace (or create) the edges group. validate=False is required # because the edge count changes between neighbor-list rebuilds. batch._storage.groups["edges"] = SegmentedLevelStorage( data=data_dict, device=positions.device, segment_lengths=seg_lengths, validate=False, ) else: neighbor_matrix = result[0] # (N, max_neighbors) int32 num_neighbors = result[1] # (N,) int32 neighbor_shifts = result[2] if len(result) > 2 else None # Write into the atoms group so that `batch.neighbor_matrix` etc. work. atoms_group = batch._atoms_group if atoms_group is None: raise RuntimeError( "NeighborListHook: batch has no atoms group — cannot store " "neighbor data." ) atoms_group["neighbor_matrix"] = neighbor_matrix atoms_group["num_neighbors"] = num_neighbors if neighbor_shifts is not None: atoms_group["neighbor_shifts"] = neighbor_shifts # Stamp the cutoff so that prepare_neighbors_for_model can detect when # filtering is needed for sub-models with a tighter cutoff. batch._neighbor_list_cutoff = self.config.cutoff