Source code for nvalchemiops.neighborlist.neighborlist

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nvalchemiops.neighborlist.batch_cell_list import (
    batch_cell_list,
)
from nvalchemiops.neighborlist.batch_naive import (
    batch_naive_neighbor_list,
)
from nvalchemiops.neighborlist.batch_naive_dual_cutoff import (
    batch_naive_neighbor_list_dual_cutoff,
)
from nvalchemiops.neighborlist.cell_list import (
    cell_list,
)
from nvalchemiops.neighborlist.naive import (
    naive_neighbor_list,
)
from nvalchemiops.neighborlist.naive_dual_cutoff import (
    naive_neighbor_list_dual_cutoff,
)
from nvalchemiops.neighborlist.neighbor_utils import (
    _prepare_batch_idx_ptr,
)


[docs] def neighbor_list( positions: torch.Tensor, cutoff: float, cell: torch.Tensor | None = None, pbc: torch.Tensor | None = None, batch_idx: torch.Tensor | None = None, batch_ptr: torch.Tensor | None = None, cutoff2: float | None = None, half_fill: bool = False, fill_value: int | None = None, return_neighbor_list: bool = False, method: str | None = None, **kwargs: dict, ): """Compute neighbor list using the appropriate method based on the provided parameters. Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3) Concatenated atomic coordinates for all systems in Cartesian space. Each row represents one atom's (x, y, z) position. Must be wrapped into the unit cell if PBC is used. cutoff : float Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors. cell : torch.Tensor, shape (3, 3), dtype=torch.float32 or torch.float64, optional Cell matrix defining the simulation box. pbc : torch.Tensor, shape (3,), dtype=torch.bool, optional Periodic boundary condition flags for each dimension. batch_idx : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional System index for each atom. batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32, optional Cumulative atom counts defining system boundaries. cutoff2 : float, optional Second cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors. half_fill : bool, optional If True, only store half of the neighbor relationships to avoid double counting. Another half could be reconstructed by swapping source and target indices and inverting unit shifts. fill_value : int | None, optional Value to fill the neighbor matrix with. Default is total_atoms. return_neighbor_list : bool, optional - default = False If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty. We recommend using the neighbor matrix format, and only convert to a neighbor list format if absolutely necessary. method : str | None, optional Method to use for neighbor list computation. Choices: "naive", "cell_list", "batch_naive", "batch_cell_list", "naive_dual_cutoff", "batch_naive_dual_cutoff". If None, a default method will be chosen based on the number of atoms. **kwargs : dict, optional Additional keyword arguments to pass to the method. - max_neighbors: int, optional Maximum number of neighbors per atom. Can be provided to aid in allocation for both naive and cell list methods. - max_neighbors2: int, optional Maximum number of neighbors per atom within cutoff2. Can be provided to aid in allocation for naive dual cutoff method. - neighbor_matrix: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms, max_neighbors) for neighbor indices. Can be provided to avoid reallocation for both naive and cell list methods. - neighbor_matrix_shifts: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms, max_neighbors, 3) for shift vectors. Can be provided to avoid reallocation for both naive and cell list methods. - num_neighbors: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms,) for neighbor counts. Can be provided to avoid reallocation for both naive and cell list methods. - shift_range_per_dimension: torch.Tensor, optional Pre-allocated tensor of shape (1, 3) for shift range in each dimension. Can be provided to avoid reallocation for naive methods. - shift_offset: torch.Tensor, optional Pre-allocated tensor of shape (2,) for cumulative sum of number of shifts for each system. Can be provided to avoid reallocation for naive methods. - total_shifts: int, optional Total number of shifts. Can be provided to avoid reallocation for naive methods. - cells_per_dimension: torch.Tensor, optional Pre-allocated tensor of shape (3,) for number of cells in x, y, z directions. Can be provided to avoid reallocation for cell list construction. - neighbor_search_radius: torch.Tensor, optional Pre-allocated tensor of shape (3,) for radius of neighboring cells to search in each dimension. Can be provided to avoid reallocation for cell list construction. - atom_periodic_shifts: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms, 3) for periodic boundary crossings for each atom. Can be provided to avoid reallocation for cell list construction. - atom_to_cell_mapping: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms, 3) for cell coordinates for each atom. Can be provided to avoid reallocation for cell list construction. - atoms_per_cell_count: torch.Tensor, optional Pre-allocated tensor of shape (max_total_cells,) for number of atoms in each cell. Can be provided to avoid reallocation for cell list construction. - cell_atom_start_indices: torch.Tensor, optional Pre-allocated tensor of shape (max_total_cells,) for starting index in cell_atom_list for each cell. Can be provided to avoid reallocation for cell list construction. - cell_atom_list: torch.Tensor, optional Pre-allocated tensor of shape (total_atoms,) for flattened list of atom indices organized by cell. Can be provided to avoid reallocation for cell list construction. - max_atoms_per_system: int, optional Maximum number of atoms per system. Used in batch naive implementation with PBC. If not provided, it will be computed automaticaly. Can be provided to avoid CUDA synchronization. Returns ------- results : tuple of torch.Tensor Variable-length tuple depending on input parameters. The return pattern follows: **Single cutoff:** - No PBC, matrix format: ``(neighbor_matrix, num_neighbors)`` - No PBC, list format: ``(neighbor_list, neighbor_ptr)`` - With PBC, matrix format: ``(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)`` - With PBC, list format: ``(neighbor_list, neighbor_ptr, neighbor_list_shifts)`` **Dual cutoff:** - No PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)`` - No PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)`` - With PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)`` - With PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list_shifts1, neighbor_list2, neighbor_ptr2, neighbor_list_shifts2)`` **Components returned:** - **neighbor_data** (tensor): Neighbor indices, format depends on ``return_neighbor_list``: - If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix`` with shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i's neighbors. - If ``return_neighbor_list=True``: Returns ``neighbor_list`` with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms]. - **num_neighbor_data** (tensor): Information about the number of neighbors for each atom, format depends on ``return_neighbor_list``: - If ``return_neighbor_list=False`` (default): Returns ``num_neighbors`` with shape (total_atoms,), dtype int32. Count of neighbors found for each atom. - If ``return_neighbor_list=True``: Returns ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of neighbors for atom i in the flattened neighbor list. - **neighbor_shift_data** (tensor, optional): Periodic shift vectors, only when ``pbc`` is provided: format depends on ``return_neighbor_list``: - If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix_shifts`` with shape (total_atoms, max_neighbors, 3), dtype int32. - If ``return_neighbor_list=True``: Returns ``unit_shifts`` with shape (num_pairs, 3), dtype int32. When ``cutoff2`` is provided, the pattern repeats for the second cutoff with interleaved components (neighbor_data2, num_neighbor_data2, neighbor_shift_data2) appended to the tuple. Examples -------- Single cutoff, matrix format, with PBC:: >>> nm, num, shifts = neighbor_list(pos, 5.0, cell=cell, pbc=pbc) Single cutoff, list format, no PBC:: >>> nlist, ptr = neighbor_list(pos, 5.0, return_neighbor_list=True) Dual cutoff, matrix format, with PBC:: >>> nm1, num1, sh1, nm2, num2, sh2 = neighbor_list( ... pos, 2.5, cutoff2=5.0, cell=cell, pbc=pbc ... ) See Also -------- naive_neighbor_list : Direct access to naive O(N²) algorithm cell_list : Direct access to cell list O(N) algorithm """ if method is None: total_atoms = positions.shape[0] if cutoff2 is not None: method = "naive_dual_cutoff" elif total_atoms >= 5000: method = "cell_list" if cell is None or pbc is None: cell = torch.eye( 3, dtype=positions.dtype, device=positions.device ).reshape(1, 3, 3) pbc = torch.tensor( [False, False, False], dtype=torch.bool, device=positions.device ) else: method = "naive" if batch_idx is not None or batch_ptr is not None: method = "batch_" + method batch_idx, batch_ptr = _prepare_batch_idx_ptr( batch_idx, batch_ptr, total_atoms, positions.device ) match method: case "naive": return naive_neighbor_list( positions, cutoff, pbc=pbc, cell=cell, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case "cell_list": return cell_list( positions, cutoff, cell, pbc, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case "batch_naive": return batch_naive_neighbor_list( positions, cutoff, pbc=pbc, cell=cell, batch_idx=batch_idx, batch_ptr=batch_ptr, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case "batch_cell_list": return batch_cell_list( positions, cutoff, cell, pbc, batch_idx, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case "naive_dual_cutoff": return naive_neighbor_list_dual_cutoff( positions, cutoff, cutoff2, pbc=pbc, cell=cell, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case "batch_naive_dual_cutoff": return batch_naive_neighbor_list_dual_cutoff( positions, cutoff, cutoff2, pbc=pbc, cell=cell, batch_idx=batch_idx, batch_ptr=batch_ptr, half_fill=half_fill, fill_value=fill_value, return_neighbor_list=return_neighbor_list, **kwargs, ) case _: raise ValueError(f"Invalid method: {method}")