Source code for nvalchemiops.neighbors.batch_naive

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Core warp kernels and launchers for batched naive neighbor list construction.

This module contains warp kernels for batched O(N²) neighbor list computation.
See `nvalchemiops.torch.neighbors` for PyTorch bindings.
"""

from typing import Any

import warp as wp

from nvalchemiops.neighbors.neighbor_utils import (
    _decode_shift_index,
    _update_neighbor_matrix,
    _update_neighbor_matrix_pbc,
    compute_inv_cells,
    selective_zero_num_neighbors,
    wrap_positions_batch,
)

__all__ = [
    "batch_naive_neighbor_matrix",
    "batch_naive_neighbor_matrix_pbc",
]

###########################################################################################
########################### Batch Naive Neighbor List Kernels ############################
###########################################################################################


@wp.func
def _batch_naive_neighbor_body(
    tid: int,
    positions: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_idx: wp.array(dtype=wp.int32),
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
):
    isys = batch_idx[tid]
    j_end = batch_ptr[isys + 1]
    positions_i = positions[tid]
    max_neighbors = neighbor_matrix.shape[1]
    for j in range(tid + 1, j_end):
        diff = positions_i - positions[j]
        dist_sq = wp.length_sq(diff)
        if dist_sq < cutoff_sq:
            _update_neighbor_matrix(
                tid, j, neighbor_matrix, num_neighbors, max_neighbors, half_fill
            )


@wp.func
def _batch_naive_neighbor_pbc_body(
    shift: wp.vec3i,
    iatom_global: int,
    isys: int,
    positions: wp.array(dtype=Any),
    per_atom_cell_offsets: wp.array(dtype=wp.vec3i),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
):
    jatom_start = batch_ptr[isys]
    jatom_end = batch_ptr[isys + 1]
    maxnb = neighbor_matrix.shape[1]
    _cell = cell[isys]
    _pos_i = positions[iatom_global]
    _int_i = per_atom_cell_offsets[iatom_global]
    positions_shifted = type(_cell[0])(shift) * _cell + _pos_i
    _zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
    if _zero_shift:
        jatom_end = iatom_global
    for jatom in range(jatom_start, jatom_end):
        _pos_j = positions[jatom]
        diff = positions_shifted - _pos_j
        dist_sq = wp.length_sq(diff)
        if dist_sq < cutoff_sq:
            # Correct the stored shift so that dist = pos_i - pos_j - shift*cell
            # holds for the original (potentially unwrapped) positions.
            _int_j = per_atom_cell_offsets[jatom]
            _corrected_shift = wp.vec3i(
                shift[0] - _int_i[0] + _int_j[0],
                shift[1] - _int_i[1] + _int_j[1],
                shift[2] - _int_i[2] + _int_j[2],
            )
            _update_neighbor_matrix_pbc(
                jatom,
                iatom_global,
                neighbor_matrix,
                neighbor_matrix_shifts,
                num_neighbors,
                _corrected_shift,
                maxnb,
                half_fill,
            )


@wp.func
def _batch_naive_neighbor_pbc_body_prewrapped(
    shift: wp.vec3i,
    iatom_global: int,
    isys: int,
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
):
    jatom_start = batch_ptr[isys]
    jatom_end = batch_ptr[isys + 1]
    maxnb = neighbor_matrix.shape[1]
    _cell = cell[isys]
    _pos_i = positions[iatom_global]
    positions_shifted = type(_cell[0])(shift) * _cell + _pos_i
    _zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
    if _zero_shift:
        jatom_end = iatom_global
    for jatom in range(jatom_start, jatom_end):
        _pos_j = positions[jatom]
        diff = positions_shifted - _pos_j
        dist_sq = wp.length_sq(diff)
        if dist_sq < cutoff_sq:
            _update_neighbor_matrix_pbc(
                jatom,
                iatom_global,
                neighbor_matrix,
                neighbor_matrix_shifts,
                num_neighbors,
                shift,
                maxnb,
                half_fill,
            )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix(
    positions: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_idx: wp.array(dtype=wp.int32),
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
) -> None:
    """Calculate batch neighbor matrix using naive O(N^2) algorithm.

    Computes pairwise distances between atoms within each system in a batch
    and identifies neighbors within the specified cutoff distance. Atoms from
    different systems do not interact. No periodic boundary conditions are applied.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Concatenated Cartesian coordinates for all systems.
        Each row represents one atom's (x, y, z) position.
    cutoff_sq : float
        Squared cutoff distance for neighbor detection in Cartesian units.
        Atoms within this distance are considered neighbors.
    batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32
        System index for each atom. Atoms with the same index belong to
        the same system and can be neighbors.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts defining system boundaries.
        System i contains atoms from batch_ptr[i] to batch_ptr[i+1]-1.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
        Entries are filled with atom indices, remaining entries stay as initialized.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors found for each atom.
        Updated in-place with actual neighbor counts.
    half_fill : wp.bool
        If True, only store relationships where i < j to avoid double counting.
        If False, store all neighbor relationships symmetrically.

    Returns
    -------
    None
        This function modifies the input arrays in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - num_neighbors : Updated with neighbor counts per atom

    See Also
    --------
    _fill_naive_neighbor_matrix : Single system version
    _fill_batch_naive_neighbor_matrix_pbc : Version with periodic boundary conditions
    """
    tid = wp.tid()
    _batch_naive_neighbor_body(
        tid,
        positions,
        cutoff_sq,
        batch_idx,
        batch_ptr,
        neighbor_matrix,
        num_neighbors,
        half_fill,
    )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_pbc(
    positions: wp.array(dtype=Any),
    per_atom_cell_offsets: wp.array(dtype=wp.vec3i),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    shift_range: wp.array(dtype=wp.vec3i),
    num_shifts_arr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
) -> None:
    """Calculate batch neighbor matrix with PBC using naive O(N^2) algorithm.

    Computes neighbor relationships between atoms across periodic boundaries by
    considering all periodic images within the cutoff distance. Processes multiple
    systems in a batch, where each system can have different periodic cells.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Assumed to be wrapped into the primary cell.
    per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i
        Integer cell offsets for each atom.
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Cell matrices for each system.
    cutoff_sq : float
        Squared cutoff distance.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts defining system boundaries.
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        Shift range per dimension per system.
    num_shifts_arr : wp.array, shape (num_systems,), dtype=wp.int32
        Number of shifts per system (for bounds checking).
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors), dtype=wp.vec3i
        OUTPUT: Shift vectors for each neighbor.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors per atom.
    half_fill : wp.bool
        If True, only store half of the neighbor relationships.

    Returns
    -------
    None
        This function modifies the input arrays in-place:

        - neighbor_matrix : Filled with neighbor atom indices
        - neighbor_matrix_shifts : Filled with corresponding shift vectors
        - num_neighbors : Updated with neighbor counts per atom

    Notes
    -----
    - Thread launch: 3D (num_systems, max_shifts_per_system, max_atoms_per_system)

    See Also
    --------
    _fill_batch_naive_neighbor_matrix : Version without periodic boundary conditions
    _fill_naive_neighbor_matrix_pbc : Single system version
    """
    isys, ishift_local, iatom = wp.tid()

    if ishift_local >= num_shifts_arr[isys]:
        return

    _natom = batch_ptr[isys + 1] - batch_ptr[isys]

    if iatom >= _natom:
        return

    iatom_global = iatom + batch_ptr[isys]
    shift = _decode_shift_index(ishift_local, shift_range[isys])
    _batch_naive_neighbor_pbc_body(
        shift,
        iatom_global,
        isys,
        positions,
        per_atom_cell_offsets,
        cell,
        cutoff_sq,
        batch_ptr,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_pbc_prewrapped(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    shift_range: wp.array(dtype=wp.vec3i),
    num_shifts_arr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
) -> None:
    """Batch PBC neighbor matrix for pre-wrapped positions (no cell-offset correction).

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Positions already wrapped into the primary cell.
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Cell matrices for each system.
    cutoff_sq : float
        Squared cutoff distance.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts defining system boundaries.
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        Shift range per dimension per system.
    num_shifts_arr : wp.array, shape (num_systems,), dtype=wp.int32
        Number of shifts per system.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors), dtype=wp.vec3i
        OUTPUT: Shift vectors for each neighbor.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors per atom.
    half_fill : wp.bool
        If True, only store half of the neighbor relationships.

    Notes
    -----
    - Thread launch: 3D (num_systems, max_shifts_per_system, max_atoms_per_system)
    """
    isys, ishift_local, iatom = wp.tid()

    if ishift_local >= num_shifts_arr[isys]:
        return

    _natom = batch_ptr[isys + 1] - batch_ptr[isys]

    if iatom >= _natom:
        return

    iatom_global = iatom + batch_ptr[isys]
    shift = _decode_shift_index(ishift_local, shift_range[isys])
    _batch_naive_neighbor_pbc_body_prewrapped(
        shift,
        iatom_global,
        isys,
        positions,
        cell,
        cutoff_sq,
        batch_ptr,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


T = [wp.float32, wp.float64, wp.float16]
V = [wp.vec3f, wp.vec3d, wp.vec3h]
M = [wp.mat33f, wp.mat33d, wp.mat33h]
_fill_batch_naive_neighbor_matrix_overload = {}
_fill_batch_naive_neighbor_matrix_pbc_overload = {}
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload = {}
for t, v, m in zip(T, V, M):
    _fill_batch_naive_neighbor_matrix_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix,
        [
            wp.array(dtype=v),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )
    _fill_batch_naive_neighbor_matrix_pbc_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix_pbc,
        [
            wp.array(dtype=v),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=m),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.vec3i, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )
    _fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix_pbc_prewrapped,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.vec3i, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )

###########################################################################################
########################### Selective Skip Kernels #######################################
###########################################################################################


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_selective(
    positions: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_idx: wp.array(dtype=wp.int32),
    batch_ptr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
    rebuild_flags: wp.array(dtype=wp.bool),
) -> None:
    """Selective batch naive neighbor matrix kernel - skips non-rebuilt systems.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Concatenated Cartesian coordinates for all systems.
    cutoff_sq : float
        Squared cutoff distance for neighbor detection.
    batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32
        System index for each atom.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts defining system boundaries.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors found for each atom.
    half_fill : wp.bool
        If True, only store relationships where i < j.
    rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool
        Per-system flags. Only systems with True are processed.

    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Atoms in systems where rebuild_flags[isys] is False are skipped
    """
    tid = wp.tid()
    isys = batch_idx[tid]
    if not rebuild_flags[isys]:
        return
    _batch_naive_neighbor_body(
        tid,
        positions,
        cutoff_sq,
        batch_idx,
        batch_ptr,
        neighbor_matrix,
        num_neighbors,
        half_fill,
    )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_pbc_selective(
    positions: wp.array(dtype=Any),
    per_atom_cell_offsets: wp.array(dtype=wp.vec3i),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    shift_range: wp.array(dtype=wp.vec3i),
    num_shifts_arr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
    rebuild_flags: wp.array(dtype=wp.bool),
) -> None:
    """Selective batch naive PBC neighbor matrix kernel - skips non-rebuilt systems.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Assumed to be wrapped into the primary cell.
    per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i
        Integer cell offsets for each atom.
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Cell matrices for each system.
    cutoff_sq : float
        Squared cutoff distance.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts.
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        Shift range per dimension per system.
    num_shifts_arr : wp.array, shape (num_systems,), dtype=wp.int32
        Number of shifts per system.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors), dtype=wp.vec3i
        OUTPUT: Shift vectors for each neighbor.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors per atom.
    half_fill : wp.bool
        If True, only store half of the neighbor relationships.
    rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool
        Per-system flags. Only systems with True are processed.

    Notes
    -----
    - Thread launch: 3D (num_systems, max_shifts_per_system, max_atoms_per_system)
    """
    isys, ishift_local, iatom = wp.tid()

    if not rebuild_flags[isys]:
        return

    if ishift_local >= num_shifts_arr[isys]:
        return

    _natom = batch_ptr[isys + 1] - batch_ptr[isys]

    if iatom >= _natom:
        return

    iatom_global = iatom + batch_ptr[isys]
    shift = _decode_shift_index(ishift_local, shift_range[isys])
    _batch_naive_neighbor_pbc_body(
        shift,
        iatom_global,
        isys,
        positions,
        per_atom_cell_offsets,
        cell,
        cutoff_sq,
        batch_ptr,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


@wp.kernel(enable_backward=False)
def _fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    cutoff_sq: Any,
    batch_ptr: wp.array(dtype=wp.int32),
    shift_range: wp.array(dtype=wp.vec3i),
    num_shifts_arr: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
    rebuild_flags: wp.array(dtype=wp.bool),
) -> None:
    """Selective batch PBC kernel for pre-wrapped positions - skips non-rebuilt systems.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Positions already wrapped into the primary cell.
    cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33*
        Cell matrices for each system.
    cutoff_sq : float
        Squared cutoff distance.
    batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32
        Cumulative atom counts.
    shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i
        Shift range per dimension per system.
    num_shifts_arr : wp.array, shape (num_systems,), dtype=wp.int32
        Number of shifts per system.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors), dtype=wp.vec3i
        OUTPUT: Shift vectors for each neighbor.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors per atom.
    half_fill : wp.bool
        If True, only store half of the neighbor relationships.
    rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool
        Per-system flags. Only systems with True are processed.

    Notes
    -----
    - Thread launch: 3D (num_systems, max_shifts_per_system, max_atoms_per_system)
    """
    isys, ishift_local, iatom = wp.tid()

    if not rebuild_flags[isys]:
        return

    if ishift_local >= num_shifts_arr[isys]:
        return

    _natom = batch_ptr[isys + 1] - batch_ptr[isys]

    if iatom >= _natom:
        return

    iatom_global = iatom + batch_ptr[isys]
    shift = _decode_shift_index(ishift_local, shift_range[isys])
    _batch_naive_neighbor_pbc_body_prewrapped(
        shift,
        iatom_global,
        isys,
        positions,
        cell,
        cutoff_sq,
        batch_ptr,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


_fill_batch_naive_neighbor_matrix_selective_overload = {}
_fill_batch_naive_neighbor_matrix_pbc_selective_overload = {}
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload = {}
for t, v, m in zip(T, V, M):
    _fill_batch_naive_neighbor_matrix_selective_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix_selective,
        [
            wp.array(dtype=v),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
            wp.array(dtype=wp.bool),
        ],
    )
    _fill_batch_naive_neighbor_matrix_pbc_selective_overload[t] = wp.overload(
        _fill_batch_naive_neighbor_matrix_pbc_selective,
        [
            wp.array(dtype=v),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=m),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32, ndim=2),
            wp.array(dtype=wp.vec3i, ndim=2),
            wp.array(dtype=wp.int32),
            wp.bool,
            wp.array(dtype=wp.bool),
        ],
    )
    _fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload[t] = (
        wp.overload(
            _fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective,
            [
                wp.array(dtype=v),
                wp.array(dtype=m),
                t,
                wp.array(dtype=wp.int32),
                wp.array(dtype=wp.vec3i),
                wp.array(dtype=wp.int32),
                wp.array(dtype=wp.int32, ndim=2),
                wp.array(dtype=wp.vec3i, ndim=2),
                wp.array(dtype=wp.int32),
                wp.bool,
                wp.array(dtype=wp.bool),
            ],
        )
    )

###########################################################################################
########################### Warp Launchers ###############################################
###########################################################################################


[docs] def batch_naive_neighbor_matrix( positions: wp.array, cutoff: float, batch_idx: wp.array, batch_ptr: wp.array, neighbor_matrix: wp.array, num_neighbors: wp.array, wp_dtype: type, device: str, half_fill: bool = False, rebuild_flags: wp.array | None = None, ) -> None: """Core warp launcher for batched naive neighbor matrix construction (no PBC). Computes pairwise distances and fills the neighbor matrix for multiple systems in a batch using pure warp operations. No periodic boundary conditions are applied. Parameters ---------- positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Concatenated Cartesian coordinates for all systems. cutoff : float Cutoff distance for neighbor detection in Cartesian units. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32 Cumulative atom counts defining system boundaries. neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32 OUTPUT: Neighbor matrix to be filled with neighbor atom indices. num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Number of neighbors found for each atom. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). half_fill : bool, default=False If True, only store relationships where i < j to avoid double counting. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool, optional Per-system rebuild flags. If provided, only systems where rebuild_flags[i] is True are processed; others are skipped on the GPU without CPU sync. Call selective_zero_num_neighbors before this launcher to reset counts. Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - Output arrays must be pre-allocated by caller. See Also -------- batch_naive_neighbor_matrix_pbc : Version with periodic boundary conditions _fill_batch_naive_neighbor_matrix : Kernel that performs the computation _fill_batch_naive_neighbor_matrix_selective : Selective-skip kernel variant """ total_atoms = positions.shape[0] if rebuild_flags is not None: selective_zero_num_neighbors(num_neighbors, batch_idx, rebuild_flags, device) wp.launch( kernel=_fill_batch_naive_neighbor_matrix_selective_overload[wp_dtype], dim=total_atoms, inputs=[ positions, wp_dtype(cutoff * cutoff), batch_idx, batch_ptr, neighbor_matrix, num_neighbors, half_fill, rebuild_flags, ], device=device, ) else: wp.launch( kernel=_fill_batch_naive_neighbor_matrix_overload[wp_dtype], dim=total_atoms, inputs=[ positions, wp_dtype(cutoff * cutoff), batch_idx, batch_ptr, neighbor_matrix, num_neighbors, half_fill, ], device=device, )
[docs] def batch_naive_neighbor_matrix_pbc( positions: wp.array, cell: wp.array, cutoff: float, batch_ptr: wp.array, batch_idx: wp.array, shift_range: wp.array, num_shifts_arr: wp.array, max_shifts_per_system: int, neighbor_matrix: wp.array, neighbor_matrix_shifts: wp.array, num_neighbors: wp.array, wp_dtype: type, device: str, max_atoms_per_system: int, half_fill: bool = False, rebuild_flags: wp.array | None = None, wrap_positions: bool = True, ) -> None: """Core warp launcher for batched naive neighbor matrix construction with PBC. Computes neighbor relationships between atoms across periodic boundaries for multiple systems in a batch using pure warp operations. Parameters ---------- positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Concatenated Cartesian coordinates for all systems. cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33* Cell matrices for each system. cutoff : float Cutoff distance for neighbor detection. batch_ptr : wp.array, shape (num_systems + 1,), dtype=wp.int32 Cumulative atom counts defining system boundaries. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i Shift range per dimension per system. num_shifts_arr : wp.array, shape (num_systems,), dtype=wp.int32 Number of shifts per system. max_shifts_per_system : int Maximum per-system shift count (launch dimension). neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32 OUTPUT: Neighbor matrix. neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors, 3), dtype=wp.vec3i OUTPUT: Shift vectors for each neighbor. num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Number of neighbors per atom. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). max_atoms_per_system : int Maximum number of atoms in any single system. half_fill : bool, default=False If True, only store half of the neighbor relationships. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool, optional Per-system rebuild flags. wrap_positions : bool, default=True If True, wrap input positions into the primary cell. Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - Output arrays must be pre-allocated by caller. - When ``wrap_positions`` is True, positions are wrapped into the primary cell in a preprocessing step before the neighbor search kernel. See Also -------- batch_naive_neighbor_matrix : Version without periodic boundary conditions _fill_batch_naive_neighbor_matrix_pbc : Kernel that performs the computation _fill_batch_naive_neighbor_matrix_pbc_selective : Selective-skip kernel variant wrap_positions_batch : Preprocessing step that wraps positions """ total_atoms = positions.shape[0] num_systems = cell.shape[0] if wrap_positions: wp_mat_dtype = ( wp.mat33f if wp_dtype == wp.float32 else wp.mat33d if wp_dtype == wp.float64 else wp.mat33h if wp_dtype == wp.float16 else None ) wp_vec_dtype = ( wp.vec3f if wp_dtype == wp.float32 else wp.vec3d if wp_dtype == wp.float64 else wp.vec3h if wp_dtype == wp.float16 else None ) inv_cell = wp.empty((cell.shape[0],), dtype=wp_mat_dtype, device=device) compute_inv_cells(cell, inv_cell, wp_dtype, device) positions_wrapped = wp.empty((total_atoms,), dtype=wp_vec_dtype, device=device) per_atom_cell_offsets = wp.empty(total_atoms, dtype=wp.vec3i, device=device) wrap_positions_batch( positions, cell, inv_cell, batch_idx, positions_wrapped, per_atom_cell_offsets, wp_dtype, device, ) if rebuild_flags is not None: selective_zero_num_neighbors( num_neighbors, batch_idx, rebuild_flags, device ) wp.launch( kernel=_fill_batch_naive_neighbor_matrix_pbc_selective_overload[ wp_dtype ], dim=(num_systems, max_shifts_per_system, max_atoms_per_system), inputs=[ positions_wrapped, per_atom_cell_offsets, cell, wp_dtype(cutoff * cutoff), batch_ptr, shift_range, num_shifts_arr, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, rebuild_flags, ], device=device, ) else: wp.launch( kernel=_fill_batch_naive_neighbor_matrix_pbc_overload[wp_dtype], dim=(num_systems, max_shifts_per_system, max_atoms_per_system), inputs=[ positions_wrapped, per_atom_cell_offsets, cell, wp_dtype(cutoff * cutoff), batch_ptr, shift_range, num_shifts_arr, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, ], device=device, ) else: if rebuild_flags is not None: selective_zero_num_neighbors( num_neighbors, batch_idx, rebuild_flags, device ) wp.launch( kernel=_fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload[ wp_dtype ], dim=(num_systems, max_shifts_per_system, max_atoms_per_system), inputs=[ positions, cell, wp_dtype(cutoff * cutoff), batch_ptr, shift_range, num_shifts_arr, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, rebuild_flags, ], device=device, ) else: wp.launch( kernel=_fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload[ wp_dtype ], dim=(num_systems, max_shifts_per_system, max_atoms_per_system), inputs=[ positions, cell, wp_dtype(cutoff * cutoff), batch_ptr, shift_range, num_shifts_arr, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, ], device=device, )