Source code for nvalchemiops.neighborlist.neighbor_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore

import math
from typing import Any

import torch
import warp as wp

from nvalchemiops.types import get_wp_dtype, get_wp_mat_dtype


@wp.kernel(enable_backward=False)
def _expand_naive_shifts(
    shift_range: wp.array(dtype=wp.vec3i),
    shift_offset: wp.array(dtype=int),
    shifts: wp.array(dtype=wp.vec3i),
    shift_system_idx: wp.array(dtype=int),
) -> None:
    """Expand shift ranges into actual shift vectors for all systems in the batch.

    Converts the compact shift range representation into a flattened array
    of explicit shift vectors, maintaining proper indexing to avoid double
    counting of periodic images.

    Parameters
    ----------
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        Array of shift ranges in each dimension for each system.
    shift_offset : wp.array, shape (num_systems+1,), dtype=wp.int32
        Cumulative sum of number of shifts for each system.
    shifts : wp.array, shape (total_shifts, 3), dtype=wp.vec3i
        OUTPUT: Flattened array to store the shift vectors.
    shift_system_idx : wp.array, shape (total_shifts,), dtype=wp.int32
        OUTPUT: System index mapping for each shift vector.

    Notes
    -----
    - Thread launch: One thread per system in the batch (dim=num_systems)
    - Modifies: shifts, shift_system_idx
    - total_shifts = shift_offset[-1]
    - Shift vectors generated in order k0, k1, k2 (increasing)
    - All shift vectors are integer lattice coordinates
    """
    tid = wp.tid()
    pos = shift_offset[tid]
    _shift_range = shift_range[tid]
    for k0 in range(0, _shift_range[0] + 1):
        for k1 in range(-_shift_range[1], _shift_range[1] + 1):
            for k2 in range(-_shift_range[2], _shift_range[2] + 1):
                if k0 > 0 or (k0 == 0 and k1 > 0) or (k0 == 0 and k1 == 0 and k2 >= 0):
                    shifts[pos] = wp.vec3i(k0, k1, k2)
                    shift_system_idx[pos] = tid
                    pos += 1


@wp.func
def _update_neighbor_matrix(
    i: int,
    j: int,
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    max_neighbors: int,
    half_fill: bool,
):
    """
    Update the neighbor matrix with the given atom indices.

    Parameters
    ----------
    i: int
        The index of the source atom.
    j: int
        The index of the target atom.
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2)
        OUTPUT: The neighbor matrix to be updated.
    num_neighbors: wp.array(dtype=wp.int32)
        OUTPUT: The number of neighbors for each atom.
    max_neighbors: int
        The maximum number of neighbors for each atom.
    half_fill: bool
        If True, only fill half of the neighbor matrix.
    """
    pos = wp.atomic_add(num_neighbors, i, 1)
    if pos < max_neighbors:
        neighbor_matrix[i, pos] = j
    if not half_fill and i < j:
        pos = wp.atomic_add(num_neighbors, j, 1)
        if pos < max_neighbors:
            neighbor_matrix[j, pos] = i


@wp.func
def _update_neighbor_matrix_pbc(
    i: int,
    j: int,
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    unit_shift: wp.vec3i,
    max_neighbors: int,
    half_fill: bool,
):
    """
    Update the neighbor matrix with the given atom indices and periodic shift.

    Parameters
    ----------
    i: int
        The index of the source atom.
    j: int
        The index of the target atom.
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2)
        OUTPUT: The neighbor matrix to be updated.
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2)
        OUTPUT: The neighbor matrix shifts to be updated.
    num_neighbors: wp.array(dtype=wp.int32)
        OUTPUT: The number of neighbors for each atom.
    unit_shift: wp.vec3i
        The unit shift vector for the periodic boundary.
    max_neighbors: int
        The maximum number of neighbors for each atom.
    half_fill: bool
        If True, only fill half of the neighbor matrix.
    """
    pos = wp.atomic_add(num_neighbors, i, 1)
    if pos < max_neighbors:
        neighbor_matrix[i, pos] = j
        neighbor_matrix_shifts[i, pos] = unit_shift
    if not half_fill:
        pos = wp.atomic_add(num_neighbors, j, 1)
        if pos < max_neighbors:
            neighbor_matrix[j, pos] = i
            neighbor_matrix_shifts[j, pos] = -unit_shift


@wp.kernel(enable_backward=False)
def _compute_naive_num_shifts(
    cell: wp.array(dtype=Any),
    cutoff: Any,
    pbc: wp.array2d(dtype=wp.bool),
    num_shifts: wp.array(dtype=int),
    shift_range: wp.array(dtype=wp.vec3i),
) -> None:
    """Compute periodic image shifts needed for neighbor searching.

    Calculates the number and range of periodic boundary shifts required
    to ensure all atoms within the cutoff distance are found, taking into
    account the geometry of the simulation cell and minimum image convention.

    Parameters
    ----------
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Cell matrices defining lattice vectors in Cartesian coordinates.
        Each 3x3 matrix represents one system's periodic cell.
    cutoff : float
        Cutoff distance for neighbor searching in Cartesian units.
        Must be positive and typically less than half the minimum cell dimension.
    pbc : wp.array, shape (num_systems, 3), dtype=wp.bool
        Periodic boundary condition flags for each dimension.
        True enables periodicity in that direction.
    num_shifts : wp.array, shape (num_systems,), dtype=int
        OUTPUT: Total number of periodic shifts needed for each system.
        Updated with calculated shift counts.
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        OUTPUT: Maximum shift indices in each dimension for each system.
        Updated with calculated shift ranges.

    Returns
    -------
    None
        This function modifies the input arrays in-place:

        - num_shifts : Updated with total shift counts per system
        - shift_range : Updated with shift ranges per dimension

    See Also
    --------
    _expand_naive_shifts : Expands shift ranges into explicit shift vectors
    """
    tid = wp.tid()

    _cell = cell[tid]
    _pbc = pbc[tid]

    _cell_inv = wp.transpose(wp.inverse(_cell))
    _d_inv_0 = wp.length(_cell_inv[0]) if _pbc[0] else type(_cell_inv[0, 0])(0.0)
    _d_inv_1 = wp.length(_cell_inv[1]) if _pbc[1] else type(_cell_inv[1, 0])(0.0)
    _d_inv_2 = wp.length(_cell_inv[2]) if _pbc[2] else type(_cell_inv[2, 0])(0.0)
    _s = wp.vec3i(
        wp.int32(wp.ceil(_d_inv_0 * type(_d_inv_0)(cutoff))),
        wp.int32(wp.ceil(_d_inv_1 * type(_d_inv_1)(cutoff))),
        wp.int32(wp.ceil(_d_inv_2 * type(_d_inv_2)(cutoff))),
    )
    k1 = 2 * _s[1] + 1
    k2 = 2 * _s[2] + 1
    shift_range[tid] = _s
    num_shifts[tid] = _s[0] * k1 * k2 + _s[1] * k2 + _s[2] + 1


## Generate overloads
T = [wp.float32, wp.float64, wp.float16]
V = [wp.vec3f, wp.vec3d, wp.vec3h]
M = [wp.mat33f, wp.mat33d, wp.mat33h]
_compute_naive_num_shifts_overload = {}
for t, v, m in zip(T, V, M):
    _compute_naive_num_shifts_overload[t] = wp.overload(
        _compute_naive_num_shifts,
        [
            wp.array(dtype=m),
            t,
            wp.array2d(dtype=wp.bool),
            wp.array(dtype=int),
            wp.array(dtype=wp.vec3i),
        ],
    )


# interface
def compute_naive_num_shifts(
    cell: torch.Tensor,
    cutoff: float,
    pbc: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, int]:
    """
    Compute periodic image shifts needed for neighbor searching.

    Parameters
    ----------
    cell: torch.Tensor
        Cell matrices defining lattice vectors in Cartesian coordinates.
        Each 3x3 matrix represents one system's periodic cell.
    cutoff: float
        Cutoff distance for neighbor searching in Cartesian units.
        Must be positive and typically less than half the minimum cell dimension.
    pbc: torch.Tensor
        Periodic boundary condition flags for each dimension.
        True enables periodicity in that direction.

    Returns
    -------
    shift_range: torch.Tensor
        Maximum shift indices in each dimension for each system.
    shift_offset: torch.Tensor
        Cumulative sum of number of shifts for each system.
    total_shifts: int
        Total number of periodic shifts needed for each system.
    """
    num_systems = cell.shape[0]
    device = cell.device

    num_shifts = torch.empty(num_systems, dtype=torch.int32, device=device)
    shift_range = torch.empty((num_systems, 3), dtype=torch.int32, device=device)

    wp_dtype = get_wp_dtype(cell.dtype)
    wp_mat_dtype = get_wp_mat_dtype(cell.dtype)
    wp_device = wp.device_from_torch(device)

    wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype)
    wp_pbc = wp.from_torch(pbc, dtype=wp.bool)
    wp_num_shifts = wp.from_torch(num_shifts)
    wp_shift_range = wp.from_torch(shift_range, dtype=wp.vec3i)

    wp.launch(
        kernel=_compute_naive_num_shifts,
        dim=num_systems,
        inputs=[
            wp_cell,
            wp_dtype(cutoff),
            wp_pbc,
            wp_num_shifts,
            wp_shift_range,
        ],
        device=wp_device,
    )

    shift_offset = torch.empty((num_systems + 1,), dtype=torch.int32, device=device)
    shift_offset[0] = 0
    torch.cumsum(num_shifts, dim=0, out=shift_offset[1:])
    return shift_range, shift_offset, shift_offset[-1].item()


[docs] def estimate_max_neighbors( cutoff: float, atomic_density: float = 0.35, safety_factor: float = 5.0, ) -> int: r"""Estimate maximum neighbors per atom based on volume calculations. Uses atomic density and cutoff volume to estimate a conservative upper bound on the number of neighbors any atom could have. This maintains torch.compile compatibility by using only tensor operations without dynamic control flow. Parameters ---------- cutoff : float Maximum distance for considering atoms as neighbors. atomic_density : float, optional Atomic density in atoms per unit volume. Default is 1.0. safety_factor : float Safety factor to multiply the estimated number of neighbors. Returns ------- max_neighbors_estimate : torch.Tensor Conservative estimate of maximum neighbors per atom. Returns 0 for empty systems, total atom count for degenerate cells. Notes ----- The estimation uses the formula: neighbors = safety_factor * density × cutoff_sphere_volume where density = N_atoms / cell_volume and cutoff_sphere_volume = (4/3)\pi r³ The result is rounded up to the multiple of 16 for memory alignment. """ if cutoff <= 0: return 0 # Calculate volume of cutoff sphere: V_sphere = (4/3) * \pi * r³ cutoff_sphere_volume = atomic_density * (4.0 / 3.0) * math.pi * (cutoff**3) # Estimate neighbors based on density and cutoff volume expected_neighbors = max(1, safety_factor * cutoff_sphere_volume) # Round up to next power of 2 for memory alignment and safety max_neighbors_estimate = int(math.ceil(expected_neighbors / 16)) * 16 return max_neighbors_estimate
class NeighborOverflowError(Exception): """Exception raised when the number of neighbors larger than the maximum allowed.""" def __init__(self, max_neighbors: int, num_neighbors: int): super().__init__( f"The number of neighbors is larger than the maximum allowed: {num_neighbors} > {max_neighbors}." ) def assert_max_neighbors(neighbor_matrix: torch.Tensor, num_neighbors: torch.Tensor): """Assert that the number of neighbors is not larger than size of the neighbor matrix.""" max_neighbors = 0 if num_neighbors.numel() == 0 else num_neighbors.max() if max_neighbors > neighbor_matrix.shape[1]: raise NeighborOverflowError( neighbor_matrix.shape[1], max_neighbors if isinstance(max_neighbors, int) else max_neighbors.item(), ) def get_neighbor_list_from_neighbor_matrix( neighbor_matrix: torch.Tensor, num_neighbors: torch.Tensor, neighbor_shift_matrix: torch.Tensor | None = None, fill_value: int = -1, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Get the neighbor list from the neighbor matrix. Parameters ---------- neighbor_matrix: torch.Tensor The neighbor matrix with shape (total_atoms, max_neighbors), dtype int32. num_neighbors: torch.Tensor The number of neighbors for each atom with shape (total_atoms,), dtype int32. neighbor_shift_matrix: torch.Tensor | None Optional neighbor shift matrix with shape (total_atoms, max_neighbors, 3), dtype int32. fill_value: int The fill value for the neighbor matrix. This is used to create a mask from the neighbor matrix. Returns ------- neighbor_list: torch.Tensor The neighbor list with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms]. neighbor_ptr: torch.Tensor The neighbor pointer with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of neighbors for atom i in the flattened neighbor list. neighbor_shift_matrix: torch.Tensor | None The neighbor shift matrix with shape (total_atoms, max_neighbors, 3), dtype int32. If input neighbor_shift_matrix is None, returns None. Raises ------ ValueError If the max number of neighbors is larger than the neighbor matrix. """ # Raise ValueError if the max number of neighbors is larger than the neighbor matrix if num_neighbors.shape[0] == 0: neighbor_list = torch.zeros( 2, 0, dtype=neighbor_matrix.dtype, device=neighbor_matrix.device ) neighbor_ptr = torch.zeros(1, dtype=torch.int32, device=neighbor_matrix.device) neighbor_shift_matrix = ( None if neighbor_shift_matrix is None else torch.empty( 0, 2, 3, dtype=neighbor_shift_matrix.dtype, device=neighbor_shift_matrix.device, ) ) returns = ( (neighbor_list, neighbor_ptr, neighbor_shift_matrix) if neighbor_shift_matrix is not None else (neighbor_list, neighbor_ptr) ) return returns # Raise NeighborOverflowError if the number of neighbors is larger than the neighbor matrix assert_max_neighbors(neighbor_matrix, num_neighbors) mask = neighbor_matrix != fill_value dtype = neighbor_matrix.dtype i_idx = torch.where(mask)[0].to(dtype) j_idx = neighbor_matrix[mask].to(dtype) neighbor_list = torch.stack([i_idx, j_idx], dim=0) neighbor_ptr = torch.zeros( num_neighbors.shape[0] + 1, dtype=torch.int32, device=neighbor_matrix.device ) torch.cumsum(num_neighbors, dim=0, out=neighbor_ptr[1:]) if neighbor_shift_matrix is not None: neighbor_list_shifts = neighbor_shift_matrix[mask] return neighbor_list, neighbor_ptr, neighbor_list_shifts else: return neighbor_list, neighbor_ptr @torch.compile def _prepare_batch_idx_ptr( batch_idx: torch.Tensor | None, batch_ptr: torch.Tensor | None, num_atoms: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """ Utility function to prepare batch index and pointer tensors. Parameters ---------- batch_idx: torch.Tensor | None Tensor indicating the batch index for each atom. batch_ptr: torch.Tensor | None Tensor indicating the start index of each batch in the atom list. num_atoms: int Total number of atoms. num_systems: int | None Total number of systems. device: torch.device Device on which to create tensors if needed. Returns ------- batch_idx: torch.Tensor Prepared batch index tensor. batch_ptr: torch.Tensor Prepared batch pointer tensor. """ if batch_idx is None and batch_ptr is None: raise ValueError("Either batch_idx or batch_ptr must be provided.") if batch_idx is None: num_systems = batch_ptr.shape[0] - 1 num_atoms_per_system = batch_ptr[1:] - batch_ptr[:-1] batch_idx = torch.repeat_interleave( torch.arange(num_systems, dtype=torch.int32, device=device), num_atoms_per_system, ) elif batch_ptr is None: num_systems = batch_idx.max() + 1 num_atoms_per_system = torch.bincount(batch_idx, minlength=num_systems) batch_ptr = torch.zeros(num_systems + 1, dtype=torch.int32, device=device) torch.cumsum(num_atoms_per_system, dim=0, out=batch_ptr[1:]) return batch_idx, batch_ptr
[docs] def allocate_cell_list( total_atoms: int, max_total_cells: int, neighbor_search_radius: torch.Tensor, device: torch.device, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """Allocate memory for the cell list.""" # detect number of systems from neighbor_search_radius cells_per_dimension = torch.zeros( (3,) if neighbor_search_radius.ndim == 1 else (neighbor_search_radius.shape[0], 3), dtype=torch.int32, device=device, ) atom_periodic_shifts = torch.zeros( (total_atoms, 3), dtype=torch.int32, device=device ) atom_to_cell_mapping = torch.zeros( (total_atoms, 3), dtype=torch.int32, device=device ) atoms_per_cell_count = torch.zeros( (max_total_cells,), dtype=torch.int32, device=device ) cell_atom_start_indices = torch.zeros( (max_total_cells,), dtype=torch.int32, device=device ) cell_atom_list = torch.zeros((total_atoms,), dtype=torch.int32, device=device) return ( cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, )