Source code for nvalchemiops.torch.neighbors.naive

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch bindings for unbatched naive neighbor list construction."""

from __future__ import annotations

import torch
import warp as wp

from nvalchemiops.neighbors.naive import (
    naive_neighbor_matrix,
    naive_neighbor_matrix_pbc,
)
from nvalchemiops.neighbors.neighbor_utils import (
    estimate_max_neighbors,
    selective_zero_num_neighbors_single,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
    compute_naive_num_shifts,
    get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

__all__ = ["naive_neighbor_list"]


@torch.library.custom_op(
    "nvalchemiops::_naive_neighbor_matrix_no_pbc",
    mutates_args=("neighbor_matrix", "num_neighbors"),
)
def _naive_neighbor_matrix_no_pbc(
    positions: torch.Tensor,
    cutoff: float,
    neighbor_matrix: torch.Tensor,
    num_neighbors: torch.Tensor,
    half_fill: bool = False,
    rebuild_flags: torch.Tensor | None = None,
) -> None:
    """Fill neighbor matrix for atoms using naive O(N^2) algorithm.

    Custom PyTorch operator that computes pairwise distances and fills
    the neighbor matrix with atom indices within the cutoff distance.
    No periodic boundary conditions are applied.

    This function does not allocate any tensors.

    This function is torch compilable.

    Parameters
    ----------
    positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
        Atomic coordinates in Cartesian space. Each row represents one atom's
        (x, y, z) position.
    cutoff : float
        Cutoff distance for neighbor detection in Cartesian units.
        Must be positive. Atoms within this distance are considered neighbors.
    neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Must be pre-allocated. Entries are filled with atom indices.
    num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32
        OUTPUT: Number of neighbors found for each atom.
        Must be pre-allocated. Updated in-place with actual neighbor counts.
    half_fill : bool
        If True, only store relationships where i < j to avoid double counting.
        If False, store all neighbor relationships symmetrically.
    rebuild_flags : torch.Tensor, shape (1,), dtype=torch.bool, optional
        Per-system rebuild flags. If provided, only systems where rebuild_flags[i]
        is True are processed; others are skipped on the GPU without CPU sync.
        Call selective_zero_num_neighbors before this launcher to reset counts.
    See Also
    --------
    nvalchemiops.neighbors.naive.naive_neighbor_matrix : Core warp launcher
    naive_neighbor_list : High-level wrapper function
    """
    device = positions.device
    wp_dtype = get_wp_dtype(positions.dtype)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_neighbor_matrix = wp.from_torch(
        neighbor_matrix, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True)
    if rebuild_flags is not None:
        wp_rebuild_flags = wp.from_torch(
            rebuild_flags, dtype=wp.bool, return_ctype=True
        )
        selective_zero_num_neighbors_single(
            wp_num_neighbors, wp_rebuild_flags, str(device)
        )
    else:
        wp_rebuild_flags = None

    naive_neighbor_matrix(
        positions=wp_positions,
        cutoff=cutoff,
        neighbor_matrix=wp_neighbor_matrix,
        num_neighbors=wp_num_neighbors,
        wp_dtype=wp_dtype,
        device=str(device),
        half_fill=half_fill,
        rebuild_flags=wp_rebuild_flags,
    )


@torch.library.custom_op(
    "nvalchemiops::_naive_neighbor_matrix_pbc",
    mutates_args=("neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"),
)
def _naive_neighbor_matrix_pbc(
    positions: torch.Tensor,
    cutoff: float,
    cell: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    num_neighbors: torch.Tensor,
    shift_range_per_dimension: torch.Tensor,
    num_shifts_per_system: torch.Tensor,
    max_shifts_per_system: int,
    half_fill: bool = False,
    rebuild_flags: torch.Tensor | None = None,
    wrap_positions: bool = True,
) -> None:
    """Compute neighbor matrix with periodic boundary conditions using naive O(N^2) algorithm.

    This function is torch compilable.

    Parameters
    ----------
    positions : torch.Tensor, shape (total_atoms, 3)
        Atomic coordinates in Cartesian space.
    cutoff : float
        Cutoff distance for neighbor detection.
    cell : torch.Tensor, shape (1, 3, 3)
        Cell matrix defining lattice vectors.
    neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32
        OUTPUT: Neighbor matrix to be filled.
    neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=torch.int32
        OUTPUT: Shift vectors for each neighbor relationship.
    num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32
        OUTPUT: Number of neighbors found for each atom.
    shift_range_per_dimension : torch.Tensor, shape (1, 3), dtype=torch.int32
        Shift range in each dimension.
    num_shifts_per_system : torch.Tensor, shape (1,), dtype=torch.int32
        Number of periodic shifts for the system.
    max_shifts_per_system : int
        Maximum shift count (used as launch dimension).
    half_fill : bool, optional
        If True, only store relationships where i < j. Default is False.
    rebuild_flags : torch.Tensor, shape (1,), dtype=torch.bool, optional
        When False the kernel skips work (no CPU-GPU sync).
    wrap_positions : bool, default=True
        If True, wrap positions into the primary cell before neighbor search.

    See Also
    --------
    nvalchemiops.neighbors.naive.naive_neighbor_matrix_pbc : Core warp launcher
    naive_neighbor_list : High-level wrapper function
    """
    device = positions.device
    wp_dtype = get_wp_dtype(positions.dtype)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_mat_dtype = get_wp_mat_dtype(cell.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
    wp_shift_range = wp.from_torch(
        shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
    )
    wp_neighbor_matrix = wp.from_torch(
        neighbor_matrix, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix_shifts = wp.from_torch(
        neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True)

    if rebuild_flags is not None:
        wp_rebuild_flags = wp.from_torch(
            rebuild_flags, dtype=wp.bool, return_ctype=True
        )
        selective_zero_num_neighbors_single(
            wp_num_neighbors, wp_rebuild_flags, str(device)
        )
    else:
        wp_rebuild_flags = None

    naive_neighbor_matrix_pbc(
        positions=wp_positions,
        cutoff=cutoff,
        cell=wp_cell,
        shift_range=wp_shift_range,
        num_shifts=max_shifts_per_system,
        neighbor_matrix=wp_neighbor_matrix,
        neighbor_matrix_shifts=wp_neighbor_matrix_shifts,
        num_neighbors=wp_num_neighbors,
        wp_dtype=wp_dtype,
        device=str(device),
        half_fill=half_fill,
        rebuild_flags=wp_rebuild_flags,
        wrap_positions=wrap_positions,
    )


[docs] def naive_neighbor_list( positions: torch.Tensor, cutoff: float, cell: torch.Tensor | None = None, pbc: torch.Tensor | None = None, max_neighbors: int | None = None, half_fill: bool = False, fill_value: int | None = None, return_neighbor_list: bool = False, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, num_neighbors: torch.Tensor | None = None, shift_range_per_dimension: torch.Tensor | None = None, num_shifts_per_system: torch.Tensor | None = None, max_shifts_per_system: int | None = None, rebuild_flags: torch.Tensor | None = None, wrap_positions: bool = True, ) -> ( tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor] ): """Compute neighbor list using naive O(N^2) algorithm. Identifies all atom pairs within a specified cutoff distance using a brute-force pairwise distance calculation. Supports both non-periodic and periodic boundary conditions. For non-pbc systems, this function is torch compilable. For pbc systems, precompute the shift metadata using compute_naive_num_shifts. Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64 Atomic coordinates in Cartesian space. Each row represents one atom's (x, y, z) position. cutoff : float Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors. pbc : torch.Tensor, shape (1, 3), dtype=torch.bool, optional Periodic boundary condition flags for each dimension. True enables periodicity in that direction. Default is None (no PBC). cell : torch.Tensor, shape (1, 3, 3), dtype=torch.float32 or torch.float64, optional Cell matrices defining lattice vectors in Cartesian coordinates. Required if pbc is provided. Default is None. max_neighbors : int, optional Maximum number of neighbors per atom. Must be positive. If exceeded, excess neighbors are ignored. Must be provided if neighbor_matrix is not provided. half_fill : bool, optional If True, only store relationships where i < j to avoid double counting. If False, store all neighbor relationships symmetrically. Default is False. fill_value : int, optional Value to fill the neighbor matrix with. Default is total_atoms. neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32, optional Neighbor matrix to be filled. Pass in a pre-allocated tensor to avoid reallocation. Must be provided if max_neighbors is not provided. neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=torch.int32, optional Shift vectors for each neighbor relationship. Pass in a pre-allocated tensor to avoid reallocation. Must be provided if max_neighbors is not provided. num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional Number of neighbors found for each atom. Pass in a pre-allocated tensor to avoid reallocation. Must be provided if max_neighbors is not provided. shift_range_per_dimension : torch.Tensor, shape (1, 3), dtype=torch.int32, optional Shift range in each dimension for each system. Pass in a pre-allocated tensor to avoid reallocation for pbc systems. num_shifts_per_system : torch.Tensor, shape (1,), dtype=torch.int32, optional Number of periodic shifts for the system. Pass in to avoid recomputation for pbc systems. max_shifts_per_system : int, optional Maximum shift count across all systems. Pass in to avoid recomputation for pbc systems. rebuild_flags : torch.Tensor, shape () or (1,), dtype=torch.bool, optional If provided, controls whether the neighbor list is recomputed. When the flag is False the existing ``neighbor_matrix``, ``num_neighbors``, and ``neighbor_matrix_shifts`` tensors are returned unchanged and all kernel launches are skipped. When the flag is True (or when this argument is None) the neighbor list is recomputed as normal. Note: providing this argument disables torch.compile compatibility. wrap_positions : bool, default=True If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call. return_neighbor_list : bool, optional - default = False If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty. We recommend using the neighbor matrix format, and only convert to a neighbor list format if absolutely necessary. Returns ------- results : tuple of torch.Tensor Variable-length tuple depending on input parameters. The return pattern follows: - No PBC, matrix format: ``(neighbor_matrix, num_neighbors)`` - No PBC, list format: ``(neighbor_list, neighbor_ptr)`` - With PBC, matrix format: ``(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)`` - With PBC, list format: ``(neighbor_list, neighbor_ptr, neighbor_list_shifts)`` **Components returned:** - **neighbor_data** (tensor): Neighbor indices, format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix`` with shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i's neighbors. * If ``return_neighbor_list=True``: Returns ``neighbor_list`` with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms]. - **num_neighbor_data** (tensor): Information about the number of neighbors for each atom, format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``num_neighbors`` with shape (total_atoms,), dtype int32. Count of neighbors found for each atom. Always returned. * If ``return_neighbor_list=True``: Returns ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of neighbors for atom i in the flattened neighbor list. - **neighbor_shift_data** (tensor, optional): Periodic shift vectors, only when ``pbc`` is provided: format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix_shifts`` with shape (total_atoms, max_neighbors, 3), dtype int32. * If ``return_neighbor_list=True``: Returns ``unit_shifts`` with shape (num_pairs, 3), dtype int32. Examples -------- Basic usage without periodic boundary conditions: >>> import torch >>> positions = torch.rand(100, 3) * 10.0 # 100 atoms in 10x10x10 box >>> cutoff = 2.5 >>> max_neighbors = 50 >>> neighbor_matrix, num_neighbors = naive_neighbor_list( ... positions, cutoff, max_neighbors ... ) >>> print(f"Found {num_neighbors.sum()} total neighbor pairs") With periodic boundary conditions: >>> cell = torch.eye(3).unsqueeze(0) * 10.0 # 10x10x10 cubic cell >>> pbc = torch.tensor([[True, True, True]]) # Periodic in all directions >>> neighbor_matrix, num_neighbors, shifts = naive_neighbor_list( ... positions, cutoff, max_neighbors, pbc=pbc, cell=cell ... ) Return as neighbor list instead of matrix: >>> neighbor_list, neighbor_ptr = naive_neighbor_list( ... positions, cutoff, max_neighbors, return_neighbor_list=True ... ) >>> source_atoms, target_atoms = neighbor_list[0], neighbor_list[1] See Also -------- nvalchemiops.neighbors.naive.naive_neighbor_matrix : Core warp launcher (no PBC) nvalchemiops.neighbors.naive.naive_neighbor_matrix_pbc : Core warp launcher (with PBC) cell_list : O(N) cell list method for larger systems """ if pbc is None and cell is not None: raise ValueError("If cell is provided, pbc must also be provided") if pbc is not None and cell is None: raise ValueError("If pbc is provided, cell must also be provided") if cell is not None: cell = cell if cell.ndim == 3 else cell.unsqueeze(0) if pbc is not None: pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0) if max_neighbors is None and ( neighbor_matrix is None or (neighbor_matrix_shifts is None and pbc is not None) or num_neighbors is None ): max_neighbors = estimate_max_neighbors(cutoff) if fill_value is None: fill_value = positions.shape[0] if neighbor_matrix is None: neighbor_matrix = torch.full( (positions.shape[0], max_neighbors), fill_value, dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix.fill_(fill_value) if num_neighbors is None: num_neighbors = torch.zeros( positions.shape[0], dtype=torch.int32, device=positions.device ) elif rebuild_flags is None: num_neighbors.zero_() if pbc is not None: if neighbor_matrix_shifts is None: neighbor_matrix_shifts = torch.zeros( (positions.shape[0], max_neighbors, 3), dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix_shifts.zero_() if ( max_shifts_per_system is None or num_shifts_per_system is None or shift_range_per_dimension is None ): shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system = ( compute_naive_num_shifts(cell, cutoff, pbc) ) if cutoff <= 0: if return_neighbor_list: if pbc is not None: return ( torch.zeros((2, 0), dtype=torch.int32, device=positions.device), torch.zeros( (positions.shape[0] + 1,), dtype=torch.int32, device=positions.device, ), torch.zeros((0, 3), dtype=torch.int32, device=positions.device), ) else: return ( torch.zeros((2, 0), dtype=torch.int32, device=positions.device), torch.zeros( (positions.shape[0] + 1,), dtype=torch.int32, device=positions.device, ), ) else: if pbc is not None: return neighbor_matrix, num_neighbors, neighbor_matrix_shifts else: return neighbor_matrix, num_neighbors if pbc is None: _naive_neighbor_matrix_no_pbc( positions=positions, cutoff=cutoff, neighbor_matrix=neighbor_matrix, num_neighbors=num_neighbors, half_fill=half_fill, rebuild_flags=rebuild_flags, ) if return_neighbor_list: neighbor_list, neighbor_ptr = get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, fill_value=fill_value, ) return neighbor_list, neighbor_ptr else: return neighbor_matrix, num_neighbors else: _naive_neighbor_matrix_pbc( positions=positions, cutoff=cutoff, cell=cell, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, num_neighbors=num_neighbors, shift_range_per_dimension=shift_range_per_dimension, num_shifts_per_system=num_shifts_per_system, max_shifts_per_system=max_shifts_per_system, half_fill=half_fill, rebuild_flags=rebuild_flags, wrap_positions=wrap_positions, ) if return_neighbor_list: neighbor_list, neighbor_ptr, neighbor_list_shifts = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, neighbor_shift_matrix=neighbor_matrix_shifts, fill_value=fill_value, ) ) return neighbor_list, neighbor_ptr, neighbor_list_shifts else: return neighbor_matrix, num_neighbors, neighbor_matrix_shifts