# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neighbor list hook for on-the-fly neighbor list construction.
This module provides :class:`NeighborListHook`, which runs at the
``BEFORE_COMPUTE`` stage to compute or refresh the neighbor list stored in
the batch before the model forward pass. It supports an optional Verlet
skin buffer to avoid recomputing neighbors every step.
Both ``MATRIX`` and ``COO`` neighbor formats are supported for dynamic
updates (i.e. updates each dynamics step). For ``COO`` format the hook
creates or replaces the edges group on the batch each step so that
``batch.neighbor_list`` (shape ``(E, 2)``) and ``batch.neighbor_list_shifts``
(shape ``(E, 3)``, PBC only) are always up to date. The companion
``Batch.edge_ptr`` property derives the per-atom CSR pointer on demand.
Pre-allocation
--------------
The hook maintains *staging buffers* — persistent GPU tensors that are
refreshed each step via ``Tensor.copy_()`` — to avoid per-step dynamic
allocation inside the ``neighbor_list`` dispatcher.
``neighbor_list`` selects between ``batch_naive`` (avg < 2000 atoms/system)
and ``batch_cell_list`` (avg >= 2000), see https://nvidia.github.io/nvalchemi-toolkit-ops/userguide/components/neighborlist.html.
Both paths normally allocate auxiliary tensors on-demand with CPU-GPU syncs
(e.g. ``.item()`` calls). :meth:`NeighborListHook._alloc_nl_kwargs`
computes these **once** when the batch shape is first seen (or changes)
and caches them in ``NeighborListHook._buf_nl_kwargs``:
* *Naive, no PBC*: no extra kwargs needed.
* *Naive, PBC*: ``shift_range_per_dimension``, ``num_shifts_per_system``,
``max_shifts_per_system``, and ``max_atoms_per_system``.
* *Cell list*: seven cell-list scratch tensors via ``allocate_cell_list``.
**NPT note**: geometry-dependent kwargs (shift ranges, cell-list sizes) are
fixed when the staging buffers are first allocated for a given ``(N, B)``
shape. For NPT (variable-cell) simulations the pre-computed values may
become stale as the cell changes; accuracy is maintained by keeping the
cutoff + skin well below the shortest cell dimension throughout the run.
"""
from __future__ import annotations
from enum import Enum
import torch
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
from nvalchemiops.torch.neighbors import neighbor_list
try:
from nvalchemiops.torch.neighbors.batch_cell_list import (
estimate_batch_cell_list_sizes,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
allocate_cell_list,
compute_naive_num_shifts,
)
except ImportError:
allocate_cell_list = None
compute_naive_num_shifts = None
estimate_batch_cell_list_sizes = None
try:
from nvalchemiops.torch.neighbors.rebuild_detection import (
batch_neighbor_list_needs_rebuild as _batch_nl_needs_rebuild,
)
except ImportError:
_batch_nl_needs_rebuild = None
try:
from nvalchemi.dynamics._ops.neighbor_list_rebuild import (
batch_neighbor_list_rebuild_inplace as _batch_nl_rebuild_inplace,
)
except ImportError:
_batch_nl_rebuild_inplace = None
from nvalchemi.data import Batch
from nvalchemi.hooks._context import HookContext
from nvalchemi.models.base import NeighborConfig, NeighborListFormat
from nvalchemi.neighbors import _write_neighbor_data_to_batch
[docs]
class NeighborListHook:
"""Compute and cache neighbor lists before each model evaluation.
This hook runs at :attr:`~DynamicsStage.BEFORE_COMPUTE` and writes
neighbor data into the batch so that the model's ``adapt_input`` can
read it. An optional Verlet skin buffer avoids rebuilding the list
every step: the list is only recomputed when the maximum atomic
displacement since the last build exceeds ``config.skin / 2``, or when
the set of active systems changes (detected via ``system_id``).
For ``MATRIX`` format the following tensors are written to the atoms
group of the batch (and thus accessible as ``batch.neighbor_matrix``
etc.):
* ``neighbor_matrix`` — shape ``(N, max_neighbors)``, int32
* ``num_neighbors`` — shape ``(N,)``, int32
* ``neighbor_matrix_shifts`` — shape ``(N, max_neighbors, 3)``, int32
(only written when PBC is active)
For ``COO`` format the edges group of the batch is created or replaced
on every rebuild, making the following accessible:
* ``batch.neighbor_list`` — shape ``(E, 2)``, int32 (nvalchemi convention)
* ``batch.neighbor_list_shifts`` — shape ``(E, 3)``, int32 (only when PBC active)
* ``batch.edge_ptr`` — shape ``(N+1,)``, int32, derived on demand via
the :attr:`~nvalchemi.data.Batch.edge_ptr` property
Parameters
----------
config : NeighborConfig
Neighbor list configuration read from the model config.
skin : float, optional
Verlet skin distance in the same length units as positions.
The neighbor list is searched out to ``cutoff + skin`` so that
atoms crossing the skin boundary but not the bare cutoff are
already included. The list is only rebuilt when any atom has
moved more than ``skin / 2`` since the previous build (requires
``nvalchemiops >= 0.4``); set to ``0.0`` (default) to rebuild
every step.
max_neighbors : int | None, optional
Maximum number of neighbors per atom for MATRIX format. When
``None`` (default), auto-estimated from the cutoff via
``estimate_max_neighbors(cutoff)``. Ignored for COO format.
stage : Enum | None, optional
The workflow stage at which this hook runs. Defaults to
``DynamicsStage.BEFORE_COMPUTE``.
"""
[docs]
def __init__(
self,
config: NeighborConfig,
skin: float = 0.0,
max_neighbors: int | None = None,
stage: Enum | None = None,
) -> None:
self.config = config
self.skin = skin
self.stage = stage
self._max_neighbors_override = max_neighbors
self.frequency = 1
self._neighbor_list_flag = config.format == NeighborListFormat.COO
# Skin-buffer state: populated after the first build.
self._ref_positions: torch.Tensor | None = None
self._rebuild_flags: torch.Tensor | None = None
# Neighbor Matrix state: populated after the first build.
self._neighbor_matrix: torch.Tensor | None = None
self._col_range: torch.Tensor | None = None
self._num_neighbors: torch.Tensor | None = None
self._neighbor_matrix_shifts: torch.Tensor | None = None
# Shape the staging buffers were allocated for; used to detect when
# re-allocation is needed (e.g. inflight batching with variable load).
self._alloc_N: int | None = None
self._alloc_B: int | None = None
# Staging buffers — persistent GPU tensors refreshed each step via
# copy_() to avoid per-step dynamic allocation inside the dispatcher.
self._buf_positions: torch.Tensor | None = None
self._buf_batch_idx: torch.Tensor | None = None
self._buf_batch_ptr: torch.Tensor | None = None
self._buf_cell: torch.Tensor | None = None # PBC only
self._buf_pbc: torch.Tensor | None = None # PBC only
# Algorithm-specific pre-allocated kwargs forwarded to neighbor_list.
self._buf_nl_kwargs: dict[str, torch.Tensor] = {}
# Adaptive K-dimension state.
self._actual_max_k: torch.Tensor | None = None # GPU scalar from last build
self._first_build: bool = True # Force sync check after first kernel call
# ------------------------------------------------------------------
# Main hook entry point
# ------------------------------------------------------------------
@torch.compile(fullgraph=False, mode="max-autotune-no-cudagraphs")
def __call__(self, ctx: HookContext, stage: Enum) -> None:
"""Recompute the neighbor list if needed and write it to the batch.
When ``skin > 0`` and ``nvalchemiops`` provides
:func:`~nvalchemiops.torch.neighbors.rebuild_detection.batch_neighbor_list_needs_rebuild`,
the list is only rebuilt when at least one atom has moved more than
``skin / 2`` since the previous build. The reference positions are
updated in-place on the GPU (no clone) whenever a rebuild occurs.
"""
self._rebuild(ctx.batch)
# First build: initialise the skin-buffer reference (one-time clone).
if self.skin > 0.0 and self._ref_positions is None:
self._init_ref_positions(ctx.batch.positions)
@torch.compiler.disable
def _init_ref_positions(self, positions: torch.Tensor) -> None:
"""One-time clone of positions into the skin-buffer reference.
Marked ``@torch.compiler.disable`` because the attribute assignment
is a Python mutation that creates a graph break. Called only on the
first step for a given batch shape.
"""
self._ref_positions = positions.detach().clone()
# ------------------------------------------------------------------
# Neighbor list construction
# ------------------------------------------------------------------
def _rebuild(self, batch: Batch) -> None:
"""Build the neighbor list and write results into the batch."""
positions = batch.positions # (N, 3)
batch_ptr = batch.batch_ptr # (B+1,)
N = batch.num_nodes
B = batch.num_graphs
# Detect PBC. getattr avoids a try/except which is a graph break.
pbc = getattr(batch, "pbc", None) # (B, 3) bool or None
cell = getattr(batch, "cell", None) # (B, 3, 3) float or None
# ------------------------------------------------------------------
# Allocate (or reallocate) the output tensors when shape changes.
# Reallocation also resets the skin-buffer state so that the first
# subsequent step forces a full rebuild and re-initialises
# _ref_positions for the new atom count.
# ------------------------------------------------------------------
if self._neighbor_matrix is None or self._neighbor_matrix.shape[0] != N:
self._alloc_output_tensors(N, batch, pbc)
# ------------------------------------------------------------------
# (Re)allocate staging buffers and algorithm kwargs on shape change.
# ------------------------------------------------------------------
if self._alloc_N != N or self._alloc_B != B:
# Composition changed — check K before staging realloc.
self._check_and_resize_k(N, batch.device, pbc)
self._alloc_staging_buffers(
N, B, positions.dtype, batch.device, cell, pbc, batch_ptr
)
self._alloc_N = N
self._alloc_B = B
# Refresh staging buffers from the current batch.
self._copy_to_staging_buffers(positions, batch_ptr, batch.batch_idx, cell, pbc)
# ------------------------------------------------------------------
# Skin check: decide per-system whether the neighbor list needs
# rebuilding based on atomic displacement since the last build.
# Uses the in-place variant to avoid per-step allocation of the
# rebuild_flags tensor. Falls back to the upstream function if the
# in-place op is not available (nvalchemiops < 0.4 or custom op not
# loaded).
# ------------------------------------------------------------------
if self.skin > 0.0 and self._ref_positions is not None:
cell_inv = (
torch.linalg.inv_ex(self._buf_cell)[0].contiguous()
if self._buf_cell is not None
else None
)
if _batch_nl_rebuild_inplace is not None:
_batch_nl_rebuild_inplace(
reference_positions=self._ref_positions,
current_positions=self._buf_positions,
batch_idx=self._buf_batch_idx,
rebuild_flags=self._rebuild_flags,
skin_distance_threshold=self.skin / 2,
update_reference_positions=True,
cell=self._buf_cell,
cell_inv=cell_inv,
pbc=self._buf_pbc,
)
elif _batch_nl_needs_rebuild is not None:
self._rebuild_flags = _batch_nl_needs_rebuild(
reference_positions=self._ref_positions,
current_positions=self._buf_positions,
batch_idx=self._buf_batch_idx,
skin_distance_threshold=self.skin / 2,
update_reference_positions=True,
cell=self._buf_cell,
cell_inv=cell_inv,
pbc=self._buf_pbc,
)
# ------------------------------------------------------------------
# Build the neighbor list using pre-allocated buffers.
# ------------------------------------------------------------------
neighbor_list(
positions=self._buf_positions,
cutoff=self.config.cutoff + self.skin,
cell=self._buf_cell,
pbc=self._buf_pbc,
max_neighbors=self._max_neighbors,
half_fill=self.config.half_list,
batch_ptr=self._buf_batch_ptr,
batch_idx=self._buf_batch_idx,
neighbor_matrix=self._neighbor_matrix,
num_neighbors=self._num_neighbors,
neighbor_matrix_shifts=self._neighbor_matrix_shifts,
rebuild_flags=self._rebuild_flags,
**self._buf_nl_kwargs,
)
# ------------------------------------------------------------------
# Adaptive K: first-build check (runs once, then never again).
# This is the only per-step adaptive K code. After the first
# build, all checks are gated on structural events (N/B change)
# inside _alloc_output_tensors / _alloc_staging_buffers.
# ------------------------------------------------------------------
if self._first_build:
self._first_build = False
self._actual_max_k = self._num_neighbors.max()
grew = self._check_and_resize_k(N, batch.device, pbc)
if grew:
# K was too small — re-run kernel with larger buffers.
neighbor_list(
positions=self._buf_positions,
cutoff=self.config.cutoff + self.skin,
cell=self._buf_cell,
pbc=self._buf_pbc,
max_neighbors=self._max_neighbors,
half_fill=self.config.half_list,
batch_ptr=self._buf_batch_ptr,
batch_idx=self._buf_batch_idx,
neighbor_matrix=self._neighbor_matrix,
num_neighbors=self._num_neighbors,
neighbor_matrix_shifts=self._neighbor_matrix_shifts,
rebuild_flags=None, # Force full rebuild
**self._buf_nl_kwargs,
)
# ------------------------------------------------------------------
# Mark Stale Entries
# ------------------------------------------------------------------
stale = self._col_range.unsqueeze(0) >= self._num_neighbors.unsqueeze(1)
self._neighbor_matrix[stale] = batch.num_nodes
if self._neighbor_matrix_shifts is not None:
self._neighbor_matrix_shifts[stale] = 0
# ------------------------------------------------------------------
# Post-processing: write results to batch (shared with compute_neighbors)
# ------------------------------------------------------------------
_write_neighbor_data_to_batch(
batch=batch,
neighbor_matrix=self._neighbor_matrix,
num_neighbors=self._num_neighbors,
neighbor_matrix_shifts=self._neighbor_matrix_shifts,
format=NeighborListFormat.COO
if self._neighbor_list_flag
else NeighborListFormat.MATRIX,
cutoff=self.config.cutoff,
)
# ------------------------------------------------------------------
# Staging buffer management
# ------------------------------------------------------------------
@torch.compiler.disable
def _alloc_output_tensors(
self,
N: int,
batch: "Batch",
pbc: torch.Tensor | None,
) -> None:
"""Allocate neighbor-matrix output tensors for atom count *N*.
Marked ``@torch.compiler.disable`` because it calls
``estimate_max_neighbors`` (CPU work), allocates tensors with
dynamic shapes, and mutates Python attributes — all graph breaks.
Called only when the atom count changes.
"""
device = batch.device
max_nbrs = self._max_neighbors_override
if max_nbrs is None:
max_nbrs = estimate_max_neighbors(
cutoff=self.config.cutoff + self.skin,
)
# Non-PBC hard cap: an atom can see at most (N_system - 1)
# neighbors without periodic images. We use max_num_nodes
# (not max_num_nodes - 1) so that K has one sentinel slot
# to distinguish "all used" from "overflow" in the adaptive check.
# Round up to nearest 16 for memory-aligned kernel performance.
if pbc is None and batch.max_num_nodes > 0:
cap = ((batch.max_num_nodes + 15) // 16) * 16
max_nbrs = min(max_nbrs, cap)
self._max_neighbors = max_nbrs
self._neighbor_matrix = torch.full(
(N, max_nbrs), N, dtype=torch.int32, device=device
)
self._col_range = torch.arange(max_nbrs, device=device, dtype=torch.int32)
self._num_neighbors = torch.zeros(N, dtype=torch.int32, device=device)
if pbc is not None:
self._neighbor_matrix_shifts = torch.zeros(
N, max_nbrs, 3, dtype=torch.int32, device=device
)
# Reset skin-buffer state so __call__ re-initialises _ref_positions.
self._ref_positions = None
self._rebuild_flags = None
# Reset adaptive K state so first build triggers a sync check.
self._first_build = True
self._actual_max_k = None
@torch.compiler.disable
def _check_and_resize_k(
self,
N: int,
device: torch.device,
pbc: torch.Tensor | None,
) -> bool:
"""Sync on actual max K and grow/shrink the neighbor matrix if needed.
Called on structural events (first build, N/B change, cell volume
change). The sync cost is acceptable because these events are
infrequent and the calling code path is already off the compile graph.
Returns ``True`` if K was grown (caller must re-run the kernel).
Shrinking trims the existing buffers in-place — no re-run needed.
"""
if self._actual_max_k is None:
return False
actual = int(self._actual_max_k.item())
if actual >= self._max_neighbors:
# Overflow — grow with 1.5x headroom and round to nearest 16. Must re-run kernel.
self._max_neighbors = ((int(actual * 1.5) + 15) // 16) * 16
self._realloc_k(N, device, pbc)
return True
elif actual < (1 / 2) * self._max_neighbors and actual > 0:
# 2x+ overestimate — trim existing buffers in-place.
new_k = ((int(actual * 2) + 15) // 16) * 16
# Never shrink below the user-provided override — it serves as a
# hard floor. We may grow above it on overflow, but not below.
if self._max_neighbors_override is not None:
new_k = max(new_k, self._max_neighbors_override)
if new_k < self._max_neighbors:
self._max_neighbors = new_k
self._neighbor_matrix = self._neighbor_matrix[:, :new_k].contiguous()
self._col_range = self._col_range[:new_k]
if self._neighbor_matrix_shifts is not None:
self._neighbor_matrix_shifts = self._neighbor_matrix_shifts[
:, :new_k
].contiguous()
# num_neighbors unchanged — still valid.
return False
@torch.compiler.disable
def _realloc_k(
self,
N: int,
device: torch.device,
pbc: torch.Tensor | None,
) -> None:
"""Reallocate neighbor-matrix buffers at the current N with a new K.
Preserves N (no staging-buffer realloc needed) but resets the
skin state to force a full rebuild on the next step.
"""
max_nbrs = self._max_neighbors
self._neighbor_matrix = torch.full(
(N, max_nbrs), N, dtype=torch.int32, device=device
)
self._col_range = torch.arange(max_nbrs, device=device, dtype=torch.int32)
self._num_neighbors = torch.zeros(N, dtype=torch.int32, device=device)
if pbc is not None:
self._neighbor_matrix_shifts = torch.zeros(
N, max_nbrs, 3, dtype=torch.int32, device=device
)
else:
self._neighbor_matrix_shifts = None
# Reset skin state to force a full rebuild.
self._ref_positions = None
self._rebuild_flags = None
@torch.compiler.disable
def _alloc_staging_buffers(
self,
N: int,
B: int,
dtype: torch.dtype,
device: torch.device,
cell: torch.Tensor | None,
pbc: torch.Tensor | None,
batch_ptr: torch.Tensor | None = None,
) -> None:
"""Allocate persistent staging buffers for the current (N, B) shape."""
self._buf_positions = torch.zeros(N, 3, dtype=dtype, device=device)
self._buf_batch_idx = torch.zeros(N, dtype=torch.int32, device=device)
self._buf_batch_ptr = torch.zeros(B + 1, dtype=torch.int32, device=device)
if cell is not None:
self._buf_cell = torch.zeros(B, 3, 3, dtype=dtype, device=device)
self._buf_pbc = torch.zeros(B, 3, dtype=torch.bool, device=device)
else:
self._buf_cell = None
self._buf_pbc = None
# Pre-allocate rebuild_flags as all-True so that the very first step
# (before _ref_positions is set and the skin check runs) forces a full
# neighbor-list build for every system. The in-place op zeroes this
# buffer at the start of each subsequent call before writing fresh values.
self._rebuild_flags = torch.ones(B, dtype=torch.bool, device=device)
# Pre-allocate algorithm-specific kwargs to eliminate on-demand CPU syncs
# from the neighbor_list dispatcher. Use the actual batch_ptr (if provided)
# to compute max_atoms_per_system correctly — the staging buffer is still
# all-zeros at this point and would give max_atoms = 0.
ptr = batch_ptr if batch_ptr is not None else self._buf_batch_ptr
self._alloc_nl_kwargs(N, B, self._buf_positions, ptr, cell, pbc, device, dtype)
def _copy_to_staging_buffers(
self,
positions: torch.Tensor,
batch_ptr: torch.Tensor,
batch_idx: torch.Tensor,
cell: torch.Tensor | None,
pbc: torch.Tensor | None,
) -> None:
"""Refresh staging buffers from the current batch."""
self._buf_positions.copy_(positions)
self._buf_batch_ptr.copy_(batch_ptr)
self._buf_batch_idx.copy_(batch_idx)
if self._buf_cell is not None and cell is not None:
self._buf_cell.copy_(cell)
if self._buf_pbc is not None and pbc is not None:
self._buf_pbc.copy_(pbc)
# ------------------------------------------------------------------
# Algorithm-specific pre-allocation
# ------------------------------------------------------------------
def _alloc_nl_kwargs(
self,
N: int,
B: int,
positions: torch.Tensor,
batch_ptr: torch.Tensor,
cell: torch.Tensor | None,
pbc: torch.Tensor | None,
device: torch.device,
dtype: torch.dtype,
) -> None:
"""Pre-allocate algorithm-specific kwargs to remove CPU-GPU syncs.
The ``neighbor_list`` dispatcher normally infers geometry-dependent
values (shift ranges, cell-list sizes) at call time using ``.item()``
synchronisations. This method computes them **once** when the staging
buffers are allocated (or re-allocated after a shape change) and stores
the resulting tensors in ``_buf_nl_kwargs`` so they can be forwarded as
``**kwargs`` on every ``neighbor_list`` call.
Algorithm selection mirrors the dispatcher threshold:
* ``avg_atoms < 2000`` -> ``batch_naive``
* ``avg_atoms >= 2000`` -> ``batch_cell_list``
Parameters
----------
N, B : int
Total atom count and number of systems at alloc time.
positions : torch.Tensor
Staging buffer for positions (used to estimate bounding box for
non-PBC cell-list systems).
batch_ptr : torch.Tensor
Staging buffer for batch_ptr (used to get max_atoms_per_system).
cell, pbc : torch.Tensor or None
Cell and PBC flag tensors at alloc time.
device, dtype : torch.device, torch.dtype
Allocation target.
"""
self._buf_nl_kwargs = {}
avg_atoms = N // max(B, 1)
use_cell_list = avg_atoms >= 2000
if use_cell_list:
if estimate_batch_cell_list_sizes is None or allocate_cell_list is None:
return # nvalchemiops too old; fall back to dynamic allocation
if cell is not None and pbc is not None:
# PBC: use the actual cell geometry.
alloc_cell = cell.to(dtype).contiguous()
alloc_pbc = pbc
else:
# Non-PBC: synthesise a bounding-box cell from current positions
# with a 1.5x pad so that position drift during the simulation
# doesn't overflow the pre-allocated cell-list arrays.
expanded_idx = self._buf_batch_idx.unsqueeze(1).expand_as(positions)
pos_min = torch.full((B, 3), float("inf"), dtype=dtype, device=device)
pos_min.scatter_reduce_(0, expanded_idx, positions, reduce="amin")
pos_max = torch.full((B, 3), float("-inf"), dtype=dtype, device=device)
pos_max.scatter_reduce_(0, expanded_idx, positions, reduce="amax")
cell_lengths = (pos_max - pos_min) * 1.5 + 0.1 * (
self.config.cutoff + self.skin
)
alloc_cell = torch.diag_embed(cell_lengths) # (B, 3, 3)
alloc_pbc = torch.zeros(B, 3, dtype=torch.bool, device=device)
max_total_cells, neighbor_search_radius = estimate_batch_cell_list_sizes(
alloc_cell, alloc_pbc, self.config.cutoff + self.skin
)
(
cells_per_dimension,
neighbor_search_radius,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
) = allocate_cell_list(
N, int(max_total_cells), neighbor_search_radius, device
)
self._buf_nl_kwargs = {
"cells_per_dimension": cells_per_dimension,
"neighbor_search_radius": neighbor_search_radius,
"atom_periodic_shifts": atom_periodic_shifts,
"atom_to_cell_mapping": atom_to_cell_mapping,
"atoms_per_cell_count": atoms_per_cell_count,
"cell_atom_start_indices": cell_atom_start_indices,
"cell_atom_list": cell_atom_list,
}
else:
# Naive algorithm.
if cell is not None and pbc is not None:
# PBC naive: pre-compute shift-range tensors so the dispatcher
# does not call compute_naive_num_shifts (which has .item()) on
# the hot path.
if compute_naive_num_shifts is None:
return
shift_range, num_shifts, max_shifts = compute_naive_num_shifts(
cell.to(dtype).contiguous(),
self.config.cutoff + self.skin,
pbc,
)
max_atoms = int((batch_ptr[1:] - batch_ptr[:-1]).max().item())
self._buf_nl_kwargs = {
"shift_range_per_dimension": shift_range,
"num_shifts_per_system": num_shifts,
"max_shifts_per_system": max_shifts,
"max_atoms_per_system": max_atoms,
}
# No-PBC naive: no extra kwargs required — the kernel has no
# CPU-sync allocations in this branch.