Source code for nvalchemiops.neighborlist.batch_naive

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import warp as wp

from nvalchemiops.neighborlist.neighbor_utils import (
    _expand_naive_shifts,
    _prepare_batch_idx_ptr,
    _update_neighbor_matrix,
    _update_neighbor_matrix_pbc,
    compute_naive_num_shifts,
    estimate_max_neighbors,
    get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

###########################################################################################
########################### Naive Neighbor List Kernels ################################
###########################################################################################


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix(
    positions: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_idx: wp.array(dtype=wp.int32),
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
) -> None:
    """Calculate batch neighbor matrix using naive O(N^2) algorithm.

    Computes pairwise distances between atoms within each system in a batch
    and identifies neighbors within the specified cutoff distance. Atoms from
    different systems do not interact. No periodic boundary conditions are applied.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Concatenated atomic coordinates for all systems in Cartesian space.
        Each row represents one atom's (x, y, z) position.
    cutoff_sq : float
        Squared cutoff distance for neighbor detection in Cartesian units.
        Atoms within this distance are considered neighbors.
    batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32
        System index for each atom. Atoms with the same index belong to
        the same system and can be neighbors.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts defining system boundaries.
        System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Entries are filled with atom indices, remaining entries stay as initialized.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors found for each atom.
        Updated in-place with actual neighbor counts.
    half_fill : wp.bool
        If True, only store relationships where i < j to avoid double counting.
        If False, store all neighbor relationships symmetrically.

    Returns
    -------
    None
        This function modifies the input arrays in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - num_neighbors : Updated with neighbor counts per atom

    See Also
    --------
    _fill_naive_neighbor_matrix : Single system version
    _fill_batch_naive_neighbor_matrix_pbc : Version with periodic boundary conditions
    """
    tid = wp.tid()
    isys = batch_idx[tid]
    j_end = batch_ptr[isys + 1]

    positions_i = positions[tid]
    max_neighbors = neighbor_matrix.shape[1]
    for j in range(tid + 1, j_end):
        diff = positions_i - positions[j]
        dist_sq = wp.length_sq(diff)
        if dist_sq < cutoff_sq:
            _update_neighbor_matrix(
                tid, j, neighbor_matrix, num_neighbors, max_neighbors, half_fill
            )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_pbc(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    shifts: wp.array(dtype=wp.vec3i),
    shift_system_idx: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
) -> None:
    """Calculate batch neighbor matrix with periodic boundary conditions using naive O(N^2) algorithm.

    Computes neighbor relationships between atoms across periodic boundaries by
    considering all periodic images within the cutoff distance. Processes multiple
    systems in a batch, where each system can have different periodic cells.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Concatenated atomic coordinates for all systems in Cartesian space.
        Each row represents one atom's (x, y, z) position.
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Array of cell matrices for each system in the batch. Each matrix
        defines the lattice vectors in Cartesian coordinates.
    cutoff_sq : float
        Squared cutoff distance for neighbor detection in Cartesian units.
        Must be positive. Atoms within this distance are considered neighbors.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative sum of number of atoms per system in the batch.
        System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1.
    shifts : wp.array, shape (total_shifts, 3), dtype=wp.vec3i
        Array of integer shift vectors for periodic images.
    shift_system_idx : wp.array, shape (total_shifts,), dtype=wp.int32
        Array mapping each shift to its system index in the batch.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Entries are filled with atom indices, remaining entries stay as initialized.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors), dtype=wp.vec3i
        OUTPUT: Matrix storing shift vectors for each neighbor relationship.
        Each entry corresponds to the shift used for the neighbor in neighbor_matrix.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors found for each atom.
        Updated in-place with actual neighbor counts.
    half_fill : wp.bool
        If True, only store half of the neighbor relationships.
        The other half can be reconstructed by swapping indices and inverting shifts.
        If False, store all neighbor relationships symmetrically.

    Returns
    -------
    None
        This function modifies the input arrays in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - neighbor_matrix_shifts : Filled with corresponding shift vectors
        - num_neighbors : Updated with neighbor counts per atom

    See Also
    --------
    _fill_batch_naive_neighbor_matrix : Version without periodic boundary conditions
    _fill_naive_neighbor_matrix_pbc : Single system version
    """
    ishift, iatom = wp.tid()

    isys = shift_system_idx[ishift]

    _natom = batch_ptr[isys + 1] - batch_ptr[isys]

    if iatom >= _natom:
        return

    start = batch_ptr[isys]
    iatom = iatom + start
    jatom_start = start
    jatom_end = batch_ptr[isys + 1]

    maxnb = neighbor_matrix.shape[1]
    _positions = positions[iatom]
    _shift = shifts[ishift]
    _cell = cell[isys]

    positions_shifted = type(_cell[0])(_shift) * _cell + _positions

    _zero_shift = _shift[0] == 0 and _shift[1] == 0 and _shift[2] == 0
    if _zero_shift:
        jatom_end = iatom
    for jatom in range(jatom_start, jatom_end):
        diff = positions_shifted - positions[jatom]
        dist_sq = wp.length_sq(diff)
        if dist_sq < cutoff_sq:
            _update_neighbor_matrix_pbc(
                jatom,
                iatom,
                neighbor_matrix,
                neighbor_matrix_shifts,
                num_neighbors,
                _shift,
                maxnb,
                half_fill,
            )


T = [wp.float32, wp.float64, wp.float16]
V = [wp.vec3f, wp.vec3d, wp.vec3h]
M = [wp.mat33f, wp.mat33d, wp.mat33h]
_fill_batch_naive_neighbor_matrix_overload = {}
_fill_batch_naive_neighbor_matrix_pbc_overload = {}
for t, v, m in zip(T, V, M):
    _fill_batch_naive_neighbor_matrix_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix,
        [
            wp.array(dtype=v),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )
    _fill_batch_naive_neighbor_matrix_pbc_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix_pbc,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.vec3i, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )

###########################################################################################
###################### Naive Neighbor List Python Wrapper ##############################
###########################################################################################


@torch.library.custom_op(
    "nvalchemiops::_naive_batch_neighbor_matrix_no_pbc",
    mutates_args=("neighbor_matrix", "num_neighbors"),
)
def _batch_naive_neighbor_matrix_no_pbc(
    positions: torch.Tensor,
    cutoff: float,
    batch_idx: torch.Tensor,
    batch_ptr: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    num_neighbors: torch.Tensor,
    half_fill: bool,
) -> None:
    """Fill neighbor matrix for batch of atoms using naive O(N^2) algorithm.

    Custom PyTorch operator that computes pairwise distances and fills
    the neighbor matrix with atom indices within the cutoff distance.
    Processes multiple systems in a batch where atoms from different systems
    do not interact. No periodic boundary conditions are applied.

    This function does not allocate any tensors.

    This function is torch compilable.

    Parameters
    ----------
    positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
        Concatenated atomic coordinates for all systems in Cartesian space.
        Each row represents one atom's (x, y, z) position.
    cutoff : float
        Cutoff distance for neighbor detection in Cartesian units.
        Must be positive. Atoms within this distance are considered neighbors.
    batch_idx : torch.Tensor, shape (total_atoms,), dtype=torch.int32
        System index for each atom. Atoms with the same index belong to
        the same system and can be neighbors.
    batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32
        Cumulative atom counts defining system boundaries.
        System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1.
    neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Must be pre-allocated. Entries are filled with atom indices.
    num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32
        OUTPUT: Number of neighbors found for each atom.
        Must be pre-allocated. Updated in-place with actual neighbor counts.
    half_fill : bool
        If True, only store relationships where i < j to avoid double counting.
        If False, store all neighbor relationships symmetrically.

    Returns
    -------
    None
        This function modifies the input tensors in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - num_neighbors : Updated with neighbor counts per atom

    See Also
    --------
    batch_naive_neighbor_list : Higher-level wrapper function
    _naive_neighbor_matrix_no_pbc : Single system version
    """
    device = positions.device
    wp_dtype = get_wp_dtype(positions.dtype)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix = wp.from_torch(
        neighbor_matrix, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True)

    wp.launch(
        kernel=_fill_batch_naive_neighbor_matrix_overload[wp_dtype],
        dim=positions.shape[0],
        inputs=[
            wp_positions,
            wp_dtype(cutoff * cutoff),
            wp_batch_idx,
            wp_batch_ptr,
            wp_neighbor_matrix,
            wp_num_neighbors,
            half_fill,
        ],
        device=wp.device_from_torch(device),
    )


@torch.library.custom_op(
    "nvalchemiops::_batch_naive_neighbor_matrix_pbc",
    mutates_args=("neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"),
)
def _batch_naive_neighbor_matrix_pbc(
    positions: torch.Tensor,
    cell: torch.Tensor,
    cutoff: float,
    batch_ptr: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    num_neighbors: torch.Tensor,
    shift_range_per_dimension: torch.Tensor,
    shift_offset: torch.Tensor,
    total_shifts: int,
    half_fill: bool = False,
    max_atoms_per_system: int | None = None,
) -> None:
    """Compute batch neighbor matrix with periodic boundary conditions using naive O(N^2) algorithm.

    Custom PyTorch operator that computes neighbor relationships between atoms
    across periodic boundaries for multiple systems in a batch. Uses pre-computed
    shift vectors for efficiency. Each system can have
    different periodic cells and boundary conditions.

    This function does not allocate any tensors.

    This function is torch compilable.

    Parameters
    ----------
    positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
        Concatenated atomic coordinates for all systems in Cartesian space.
        Each row represents one atom's (x, y, z) position.
        Must be wrapped into the unit cell.
    cell : torch.Tensor, shape (num_systems, 3, 3), dtype=torch.float32 or torch.float64
        Cell matrices defining lattice vectors in Cartesian coordinates.
        Each 3x3 matrix represents one system's periodic cell.
    cutoff : float
        Cutoff distance for neighbor detection in Cartesian units.
        Must be positive. Atoms within this distance are considered neighbors.
    batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32
        Cumulative atom counts defining system boundaries.
        System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1.
    neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Must be pre-allocated. Entries are filled with atom indices.
    neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=torch.int32
        OUTPUT: Matrix storing shift vectors for each neighbor relationship.
        Must be pre-allocated. Each entry corresponds to the shift used for the neighbor in neighbor_matrix.
    num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32
        OUTPUT: Number of neighbors found for each atom.
        Must be pre-allocated. Updated in-place with actual neighbor counts.
    shift_range_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=torch.int32
        Shift range in each dimension for each system.
    shift_offset : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32
        Cumulative sum of shift counts, defining shift boundaries for each system.
        System i uses shifts from shift_offset[i] to shift_offset[i+1]-1.
    total_shifts : int
        Total number of periodic shifts across all systems.
        Must match the sum of shifts for all systems.
    half_fill : bool, optional
        If True, only store relationships where i < j to avoid double counting.
        If False, store all neighbor relationships symmetrically. Default is False.
    max_atoms_per_system : int, optional
        Maximum number of atoms per system.
        If not provided, it will be computed automaticaly.
        Can be provided to avoid CUDA synchronization.

    Returns
    -------
    None
        This function modifies the input tensors in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - neighbor_matrix_shifts : Filled with corresponding shift vectors
        - num_neighbors : Updated with neighbor counts per atom

    See Also
    --------
    batch_naive_neighbor_list : Higher-level wrapper function
    _batch_compute_total_shifts : Computes the required shift vectors
    _naive_neighbor_matrix_pbc : Single system version
    """
    num_systems = cell.shape[0]
    device = positions.device
    wp_device = wp.device_from_torch(device)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
    wp_dtype = get_wp_dtype(positions.dtype)
    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)

    shifts = torch.empty((total_shifts, 3), dtype=torch.int32, device=device)
    shift_system_idx = torch.empty((total_shifts,), dtype=torch.int32, device=device)
    wp_shifts = wp.from_torch(shifts, dtype=wp.vec3i, return_ctype=True)
    wp_shift_system_idx = wp.from_torch(
        shift_system_idx, dtype=wp.int32, return_ctype=True
    )

    wp.launch(
        kernel=_expand_naive_shifts,
        dim=num_systems,
        inputs=[
            wp.from_torch(shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True),
            wp.from_torch(shift_offset, dtype=wp.int32, return_ctype=True),
            wp_shifts,
            wp_shift_system_idx,
        ],
        device=wp_device,
    )

    wp_neighbor_matrix = wp.from_torch(
        neighbor_matrix, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix_shifts = wp.from_torch(
        neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)

    if max_atoms_per_system is None:
        max_atoms_per_system = (batch_ptr[1:] - batch_ptr[:-1]).max().item()

    wp.launch(
        kernel=_fill_batch_naive_neighbor_matrix_pbc_overload[wp_dtype],
        dim=(total_shifts, max_atoms_per_system),
        inputs=[
            wp_positions,
            wp_cell,
            wp_dtype(cutoff * cutoff),
            wp_batch_ptr,
            wp_shifts,
            wp_shift_system_idx,
            wp_neighbor_matrix,
            wp_neighbor_matrix_shifts,
            wp_num_neighbors,
            half_fill,
        ],
        device=wp_device,
    )


[docs] def batch_naive_neighbor_list( positions: torch.Tensor, cutoff: float, batch_idx: torch.Tensor | None = None, batch_ptr: torch.Tensor | None = None, pbc: torch.Tensor | None = None, cell: torch.Tensor | None = None, max_neighbors: int | None = None, half_fill: bool = False, fill_value: int | None = None, return_neighbor_list: bool = False, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, num_neighbors: torch.Tensor | None = None, shift_range_per_dimension: torch.Tensor | None = None, shift_offset: torch.Tensor | None = None, total_shifts: int | None = None, max_atoms_per_system: int | None = None, ) -> ( tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor] ): """Compute batch neighbor matrix using naive O(N^2) algorithm. Identifies all atom pairs within a specified cutoff distance for multiple systems processed in a batch. Each system is processed independently, supporting both non-periodic and periodic boundary conditions. For efficiency, this function supports in-place modification of the pre-allocated tensors. If not provided, the resulting tensors will be allocated. This function does not introduce CUDA graph breaks for non-PBC systems. For PBC systems, pre-compute unit shifts to avoid CUDA graph breaks: .. code-block:: python >> from nvalchemiops.neighborlist import compute_naive_num_shifts >> shift_range_per_dimension, shift_offset, total_shifts = compute_naive_num_shifts( ... cell, cutoff, pbc ... ) Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64 Concatenated atomic coordinates for all systems in Cartesian space. Each row represents one atom's (x, y, z) position. Must be wrapped into the unit cell if PBC is used. cutoff : float Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors. batch_idx : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional System index for each atom. Atoms with the same index belong to the same system and can be neighbors. Must be in sorted order. If not provided, assumes all atoms belong to a single system. batch_ptr : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32, optional Cumulative atom counts defining system boundaries. System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1. If not provided and batch_idx is provided, it will be computed automatically. pbc : torch.Tensor, shape (num_systems, 3), dtype=torch.bool, optional Periodic boundary condition flags for each dimension of each system. True enables periodicity in that direction. Default is None (no PBC). cell : torch.Tensor, shape (num_systems, 3, 3), dtype=torch.float32 or torch.float64, optional Cell matrices defining lattice vectors in Cartesian coordinates. Required if pbc is provided. Default is None. max_neighbors : int, optional Maximum number of neighbors per atom. Must be positive. If exceeded, excess neighbors are ignored. Must be provided if neighbor_matrix is not provided. half_fill : bool, optional If True, only store half of the neighbor relationships to avoid double counting. Another half could be reconstructed by swapping source and target indices and inverting unit shifts. If False, store all neighbor relationships. Default is False. fill_value : int | None, optional Value to fill the neighbor matrix with. Default is total_atoms. return_neighbor_list : bool, optional - default = False If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty. We recommend using the neighbor matrix format, and only convert to a neighbor list format if absolutely necessary. neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=torch.int32, optional Optional pre-allocated tensor for the neighbor matrix. Must be provided if max_neighbors is not provided. If provided, return_neighbor_list must be False. neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=torch.int32, optional Optional pre-allocated tensor for the shift vectors of the neighbor matrix. Must be provided if max_neighbors is not provided and pbc is not None. If provided, return_neighbor_list must be False. num_neighbors : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional Optional pre-allocated tensor for the number of neighbors in the neighbor matrix. Must be provided if max_neighbors is not provided. shift_range_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=torch.int32, optional Optional pre-allocated tensor for the shift range in each dimension for each system. shift_offset : torch.Tensor, shape (num_systems + 1,), dtype=torch.int32, optional Optional pre-allocated tensor for the cumulative sum of number of shifts for each system. total_shifts : int, optional Total number of shifts. Pass in to avoid reallocation for pbc systems. max_atoms_per_system : int, optional Maximum number of atoms per system. If not provided, it will be computed automaticaly. Can be provided to avoid CUDA synchronization. Returns ------- results : tuple of torch.Tensor Variable-length tuple depending on input parameters. The return pattern follows: - No PBC, matrix format: ``(neighbor_matrix, num_neighbors)`` - No PBC, list format: ``(neighbor_list, neighbor_ptr)`` - With PBC, matrix format: ``(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)`` - With PBC, list format: ``(neighbor_list, neighbor_ptr, neighbor_list_shifts)`` **Components returned:** - **neighbor_data** (tensor): Neighbor indices, format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix`` with shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i's neighbors. * If ``return_neighbor_list=True``: Returns ``neighbor_list`` with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms]. - **num_neighbor_data** (tensor): Information about the number of neighbors for each atom, format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``num_neighbors`` with shape (total_atoms,), dtype int32. Count of neighbors found for each atom. * If ``return_neighbor_list=True``: Returns ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of neighbors for atom i in the flattened neighbor list. - **neighbor_shift_data** (tensor, optional): Periodic shift vectors, only when ``pbc`` is provided: format depends on ``return_neighbor_list``: * If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix_shifts`` with shape (total_atoms, max_neighbors, 3), dtype int32. * If ``return_neighbor_list=True``: Returns ``unit_shifts`` with shape (num_pairs, 3), dtype int32. Examples -------- Basic batch processing without periodic boundary conditions: >>> import torch >>> # Create batch with 2 systems: 50 and 30 atoms >>> coord1 = torch.rand(50, 3) * 5.0 >>> coord2 = torch.rand(30, 3) * 8.0 >>> positions = torch.cat([coord1, coord2], dim=0) >>> batch_idx = torch.cat([torch.zeros(50), torch.ones(30)]).int() >>> batch_ptr = torch.tensor([0, 50, 80], dtype=torch.int32) >>> >>> cutoff = 2.0 >>> max_neighbors = 40 >>> neighbor_matrix, num_neighbors, _ = batch_naive_neighbor_list( ... positions, cutoff, batch_idx, batch_ptr, max_neighbors=max_neighbors ... ) >>> # neighbor_matrix_shifts will be empty tensor for non-PBC systems With periodic boundary conditions: >>> # Different cells for each system >>> cell = torch.stack([ ... torch.eye(3) * 5.0, # System 0: 5x5x5 cubic cell ... torch.eye(3) * 8.0 # System 1: 8x8x8 cubic cell ... ]) >>> pbc = torch.tensor([[True, True, True], [True, True, False]]) >>> neighbor_matrix, num_neighbors, neighbor_matrix_shifts = batch_naive_neighbor_list( ... positions, cutoff, batch_idx, batch_ptr, ... pbc=pbc, cell=cell, max_neighbors=max_neighbors ... ) See Also -------- naive_neighbor_list : Single system version batch_naive_neighbor_list_dual_cutoff : Version with two cutoff distances """ if pbc is None and cell is not None: raise ValueError("If cell is provided, pbc must also be provided") if pbc is not None and cell is None: raise ValueError("If pbc is provided, cell must also be provided") if cell is not None: cell = cell if cell.ndim == 3 else cell.unsqueeze(0) if pbc is not None: pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0) if max_neighbors is None and ( neighbor_matrix is None or (neighbor_matrix_shifts is None and pbc is not None) or num_neighbors is None ): max_neighbors = estimate_max_neighbors(cutoff) total_atoms = positions.shape[0] if fill_value is None: fill_value = total_atoms if neighbor_matrix is None: neighbor_matrix = torch.full( (positions.shape[0], max_neighbors), fill_value, dtype=torch.int32, device=positions.device, ) else: neighbor_matrix.fill_(fill_value) if num_neighbors is None: num_neighbors = torch.zeros( positions.shape[0], dtype=torch.int32, device=positions.device ) else: num_neighbors.zero_() if pbc is not None: if neighbor_matrix_shifts is None: neighbor_matrix_shifts = torch.zeros( (positions.shape[0], max_neighbors, 3), dtype=torch.int32, device=positions.device, ) else: neighbor_matrix_shifts.zero_() if ( total_shifts is None or shift_offset is None or shift_range_per_dimension is None ): shift_range_per_dimension, shift_offset, total_shifts = ( compute_naive_num_shifts(cell, cutoff, pbc) ) batch_idx, batch_ptr = _prepare_batch_idx_ptr( batch_idx=batch_idx, batch_ptr=batch_ptr, num_atoms=total_atoms, device=positions.device, ) if pbc is None: _batch_naive_neighbor_matrix_no_pbc( positions=positions, cutoff=cutoff, batch_idx=batch_idx, batch_ptr=batch_ptr, neighbor_matrix=neighbor_matrix, num_neighbors=num_neighbors, half_fill=half_fill, ) if return_neighbor_list: neighbor_list, neighbor_ptr = get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, fill_value=fill_value, ) return neighbor_list, neighbor_ptr else: return neighbor_matrix, num_neighbors else: _batch_naive_neighbor_matrix_pbc( positions=positions, cell=cell, cutoff=cutoff, batch_ptr=batch_ptr, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, num_neighbors=num_neighbors, shift_range_per_dimension=shift_range_per_dimension, shift_offset=shift_offset, total_shifts=total_shifts, half_fill=half_fill, max_atoms_per_system=max_atoms_per_system, ) if return_neighbor_list: neighbor_list, neighbor_ptr, neighbor_list_shifts = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, neighbor_shift_matrix=neighbor_matrix_shifts, fill_value=fill_value, ) ) return neighbor_list, neighbor_ptr, neighbor_list_shifts else: return neighbor_matrix, num_neighbors, neighbor_matrix_shifts