# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch neighbor list API.
This module provides the main entry point for PyTorch users of the neighbor list API.
"""
from __future__ import annotations
import torch
# Batch cell list functions
from nvalchemiops.torch.neighbors.batch_cell_list import (
batch_cell_list,
estimate_batch_cell_list_sizes,
)
# Batch naive functions
from nvalchemiops.torch.neighbors.batch_naive import (
batch_naive_neighbor_list,
)
# Batch naive dual cutoff functions
from nvalchemiops.torch.neighbors.batch_naive_dual_cutoff import (
batch_naive_neighbor_list_dual_cutoff,
)
# Unbatched cell list functions
from nvalchemiops.torch.neighbors.cell_list import (
cell_list,
estimate_cell_list_sizes,
)
# Unbatched naive functions
from nvalchemiops.torch.neighbors.naive import (
naive_neighbor_list,
)
# Unbatched naive dual cutoff functions
from nvalchemiops.torch.neighbors.naive_dual_cutoff import (
naive_neighbor_list_dual_cutoff,
)
# Utility functions
from nvalchemiops.torch.neighbors.neighbor_utils import prepare_batch_idx_ptr
[docs]
def neighbor_list(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
batch_idx: torch.Tensor | None = None,
batch_ptr: torch.Tensor | None = None,
cutoff2: float | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
method: str | None = None,
wrap_positions: bool = True,
**kwargs: dict,
):
"""Compute neighbor list using the appropriate method based on the provided parameters.
This is the main entry point for PyTorch users of the neighbor list API. It automatically
selects the most appropriate algorithm (naive O(N²) or cell list O(N)) based on system
size and parameters.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3)
Concatenated atomic coordinates for all systems in Cartesian space.
Each row represents one atom's (x, y, z) position.
Unwrapped (box-crossing) coordinates are supported when PBC is used;
the kernel wraps positions internally.
cutoff : float
Cutoff distance for neighbor detection in Cartesian units.
Must be positive. Atoms within this distance are considered neighbors.
cell : torch.Tensor, shape (3, 3) or (num_systems, 3, 3), optional
Cell matrix defining the simulation box.
pbc : torch.Tensor, shape (3,) or (num_systems, 3), dtype=torch.bool, optional
Periodic boundary condition flags for each dimension.
batch_idx : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional
System index for each atom.
batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32, optional
Cumulative atom counts defining system boundaries.
cutoff2 : float, optional
Second cutoff distance for neighbor detection in Cartesian units.
Must be positive. Atoms within this distance are considered neighbors.
half_fill : bool, optional
If True, only store half of the neighbor relationships to avoid double counting.
Another half could be reconstructed by swapping source and target indices and inverting unit shifts.
fill_value : int | None, optional
Value to fill the neighbor matrix with. Default is total_atoms.
return_neighbor_list : bool, optional - default = False
If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by
creating a mask over the fill_value, which can incur a performance penalty.
We recommend using the neighbor matrix format,
and only convert to a neighbor list format if absolutely necessary.
method : str | None, optional
Method to use for neighbor list computation.
Choices: "naive", "cell_list", "batch_naive", "batch_cell_list", "naive_dual_cutoff", "batch_naive_dual_cutoff".
If None, a default method will be chosen based on average atoms per
system (cell_list when >= 2000, naive otherwise). When only
``batch_idx`` is provided (no ``batch_ptr`` or 3-D ``cell``),
auto-selection reads ``batch_idx[-1]`` which triggers a
device-to-host synchronization. To avoid this, pass ``batch_ptr``,
a 3-D ``cell`` array, or specify ``method`` explicitly.
wrap_positions : bool, default=True
If True, wrap input positions into the primary cell before
neighbor search. Set to False when positions are already
wrapped (e.g. by a preceding integration step) to save two
GPU kernel launches per call. Only applies to naive methods; cell list
methods handle wrapping internally.
**kwargs : dict, optional
Additional keyword arguments to pass to the method.
max_neighbors : int, optional
Maximum number of neighbors per atom.
Can be provided to aid in allocation for both naive and cell list methods.
max_neighbors2 : int, optional
Maximum number of neighbors per atom within cutoff2.
Can be provided to aid in allocation for naive dual cutoff method.
neighbor_matrix : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, max_neighbors) for neighbor indices.
Can be provided to avoid reallocation for both naive and cell list methods.
neighbor_matrix_shifts : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, max_neighbors, 3) for shift vectors.
Can be provided to avoid reallocation for both naive and cell list methods.
num_neighbors : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms,) for neighbor counts.
Can be provided to avoid reallocation for both naive and cell list methods.
shift_range_per_dimension : torch.Tensor, optional
Pre-allocated tensor of shape (1, 3) for shift range in each dimension.
Can be provided to avoid reallocation for naive methods.
num_shifts_per_system : torch.Tensor, optional
Pre-computed tensor of shape (num_systems,) for the number of periodic
shifts per system. Can be provided to avoid recomputation for naive methods.
max_shifts_per_system : int, optional
Maximum per-system shift count.
Can be provided to avoid recomputation for naive methods.
cells_per_dimension : torch.Tensor, optional
Pre-allocated tensor of shape (3,) for number of cells in x, y, z directions.
Can be provided to avoid reallocation for cell list construction.
neighbor_search_radius : torch.Tensor, optional
Pre-allocated tensor of shape (3,) for radius of neighboring cells to search
in each dimension. Can be provided to avoid reallocation for cell list construction.
atom_periodic_shifts : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, 3) for periodic boundary crossings
for each atom. Can be provided to avoid reallocation for cell list construction.
atom_to_cell_mapping : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, 3) for cell coordinates for each atom.
Can be provided to avoid reallocation for cell list construction.
atoms_per_cell_count : torch.Tensor, optional
Pre-allocated tensor of shape (max_total_cells,) for number of atoms in each cell.
Can be provided to avoid reallocation for cell list construction.
cell_atom_start_indices : torch.Tensor, optional
Pre-allocated tensor of shape (max_total_cells,) for starting index in
cell_atom_list for each cell. Can be provided to avoid reallocation for
cell list construction.
cell_atom_list : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms,) for flattened list of atom
indices organized by cell. Can be provided to avoid reallocation for
cell list construction.
max_atoms_per_system : int, optional
Maximum number of atoms per system. Used in batch naive implementation
with PBC. If not provided, it will be computed automatically.
Can be provided to avoid CUDA synchronization.
Returns
-------
results : tuple of torch.Tensor
Variable-length tuple depending on input parameters. The return pattern follows:
**Single cutoff:**
- No PBC, matrix format: ``(neighbor_matrix, num_neighbors)``
- No PBC, list format: ``(neighbor_list, neighbor_ptr)``
- With PBC, matrix format: ``(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)``
- With PBC, list format: ``(neighbor_list, neighbor_ptr, neighbor_list_shifts)``
**Dual cutoff:**
- No PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)``
- No PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)``
- With PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)``
- With PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list_shifts1, neighbor_list2, neighbor_ptr2, neighbor_list_shifts2)``
**Components returned:**
- **neighbor_data** (tensor): Neighbor indices, format depends on ``return_neighbor_list``:
- If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix``
with shape (total_atoms, max_neighbors), dtype int32. Each row i contains
indices of atom i's neighbors.
- If ``return_neighbor_list=True``: Returns ``neighbor_list`` with shape
(2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms].
- **num_neighbor_data** (tensor): Information about the number of neighbors for each atom,
format depends on ``return_neighbor_list``:
- If ``return_neighbor_list=False`` (default): Returns ``num_neighbors`` with shape (total_atoms,), dtype int32.
Count of neighbors found for each atom.
- If ``return_neighbor_list=True``: Returns ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32.
CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of
neighbors for atom i in the flattened neighbor list.
- **neighbor_shift_data** (tensor, optional): Periodic shift vectors, only when ``pbc`` is provided:
format depends on ``return_neighbor_list``:
- If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix_shifts`` with
shape (total_atoms, max_neighbors, 3), dtype int32.
- If ``return_neighbor_list=True``: Returns ``unit_shifts`` with shape
(num_pairs, 3), dtype int32.
When ``cutoff2`` is provided, the pattern repeats for the second cutoff with interleaved
components (neighbor_data2, num_neighbor_data2, neighbor_shift_data2) appended to the tuple.
Examples
--------
Single cutoff, matrix format, with PBC::
>>> nm, num, shifts = neighbor_list(pos, 5.0, cell=cell, pbc=pbc)
Single cutoff, list format, no PBC::
>>> nlist, ptr = neighbor_list(pos, 5.0, return_neighbor_list=True)
Dual cutoff, matrix format, with PBC::
>>> nm1, num1, sh1, nm2, num2, sh2 = neighbor_list(
... pos, 2.5, cutoff2=5.0, cell=cell, pbc=pbc
... )
See Also
--------
naive_neighbor_list : Direct access to naive O(N²) algorithm
cell_list : Direct access to cell list O(N) algorithm
batch_naive_neighbor_list : Batched naive algorithm
batch_cell_list : Batched cell list algorithm
"""
if cell is not None and pbc is None:
raise ValueError(
"`pbc` is required when `cell` is provided. "
"Pass a boolean tensor of shape (3,) or (num_systems, 3), "
"e.g. pbc=torch.tensor([True, True, True])."
)
if method is None:
total_atoms = positions.shape[0]
# Compute average atoms per system for method selection.
num_systems = 1
if cell is not None and cell.ndim == 3:
# cell shape is (num_systems, 3, 3)
num_systems = cell.shape[0]
elif batch_ptr is not None:
# batch_ptr shape is (num_systems + 1,)
num_systems = batch_ptr.shape[0] - 1
elif batch_idx is not None:
# NOTE: reading batch_idx[-1] triggers a GPU-to-CPU sync
# assume sorted batch_idx
num_systems = max(1, batch_idx[-1].item() + 1)
avg_atoms = total_atoms // num_systems
if cutoff2 is not None:
method = "naive_dual_cutoff"
elif avg_atoms >= 2000:
method = "cell_list"
else:
method = "naive"
if batch_idx is not None or batch_ptr is not None:
method = "batch_" + method
batch_idx, batch_ptr = prepare_batch_idx_ptr(
batch_idx, batch_ptr, total_atoms, positions.device
)
match method:
case "naive":
return naive_neighbor_list(
positions,
cutoff,
pbc=pbc,
cell=cell,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
wrap_positions=wrap_positions,
**kwargs,
)
case "cell_list":
if cell is None:
pos_min = positions.min(dim=0).values
positions = positions - pos_min
pos_max = positions.max(dim=0).values
cell_lengths = pos_max + 0.1 * cutoff
cell = torch.diag(cell_lengths).reshape(1, 3, 3)
pbc = torch.tensor(
[False, False, False], dtype=torch.bool, device=positions.device
)
return cell_list(
positions,
cutoff,
cell,
pbc,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
**kwargs,
)
case "batch_naive":
return batch_naive_neighbor_list(
positions,
cutoff,
pbc=pbc,
cell=cell,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
wrap_positions=wrap_positions,
**kwargs,
)
case "batch_cell_list":
if batch_idx is None or batch_ptr is None:
batch_idx, batch_ptr = prepare_batch_idx_ptr(
batch_idx, batch_ptr, positions.shape[0], positions.device
)
if cell is None:
num_systems = batch_ptr.shape[0] - 1
expanded_idx = batch_idx.unsqueeze(1).expand_as(positions)
pos_min = torch.full(
(num_systems, 3),
float("inf"),
dtype=positions.dtype,
device=positions.device,
)
pos_min.scatter_reduce_(0, expanded_idx, positions, reduce="amin")
pos_max = torch.full(
(num_systems, 3),
float("-inf"),
dtype=positions.dtype,
device=positions.device,
)
pos_max.scatter_reduce_(0, expanded_idx, positions, reduce="amax")
# TODO: switch to segment_ops once #17 is merged
positions = positions - torch.index_select(pos_min, 0, batch_idx)
cell_lengths = pos_max - pos_min + 0.1 * cutoff
cell = torch.diag_embed(cell_lengths)
pbc = torch.zeros(
(num_systems, 3),
dtype=torch.bool,
device=positions.device,
)
return batch_cell_list(
positions,
cutoff,
cell,
pbc,
batch_idx,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
**kwargs,
)
case "naive_dual_cutoff":
return naive_neighbor_list_dual_cutoff(
positions,
cutoff,
cutoff2,
pbc=pbc,
cell=cell,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
wrap_positions=wrap_positions,
**kwargs,
)
case "batch_naive_dual_cutoff":
return batch_naive_neighbor_list_dual_cutoff(
positions,
cutoff,
cutoff2,
pbc=pbc,
cell=cell,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
half_fill=half_fill,
fill_value=fill_value,
return_neighbor_list=return_neighbor_list,
wrap_positions=wrap_positions,
**kwargs,
)
case _:
raise ValueError(f"Invalid method: {method}")
__all__ = [
# High-level API
"neighbor_list",
# Unbatched algorithms
"cell_list",
"naive_neighbor_list",
"naive_neighbor_list_dual_cutoff",
"estimate_cell_list_sizes",
# Batched algorithms
"batch_cell_list",
"batch_naive_neighbor_list",
"batch_naive_neighbor_list_dual_cutoff",
"estimate_batch_cell_list_sizes",
]