Source code for nvalchemiops.neighbors.cell_list

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Core warp kernels and launchers for cell list neighbor construction.

This module contains warp kernels for O(N) cell-based neighbor list computation.
See `nvalchemiops.torch.neighbors` for PyTorch bindings.
"""

from typing import Any

import warp as wp

from nvalchemiops.math import wpdivmod
from nvalchemiops.neighbors.neighbor_utils import (
    _update_neighbor_matrix_pbc,
    zero_array,
)

__all__ = [
    "build_cell_list",
    "query_cell_list",
]

###########################################################################################
########################### Cell List Construction ########################################
###########################################################################################


@wp.kernel(enable_backward=False)
def _estimate_cell_list_sizes(
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=Any),
    cell_size: Any,
    max_nbins: Any,
    number_of_cells: wp.array(dtype=Any),
    neighbor_search_radius: wp.array(dtype=Any),
) -> None:
    """Estimate allocation sizes for torch.compile-friendly cell list construction.

    Parameters
    ----------
    cell : wp.array(dtype=Any), shape (1, 3, 3)
        Unit cell matrix defining the simulation box.
    pbc : wp.array(dtype=Any), shape (3,), dtype=bool
        Flags indicating periodic boundary conditions in x, y, z directions.
        True enables PBC, False disables it for that dimension.
    cell_size : Any
        Size of the cells in the simulation box.
    max_nbins : Any
        Maximum number of cells to allocate.
    number_of_cells : wp.array(dtype=Any), shape (1,)
        Output: Number of cells in the simulation box.
    neighbor_search_radius : wp.array(dtype=Any), shape (3,)
        Output: Radius of neighboring cells to search in each dimension.
    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: number_of_cells
    - Handles: periodic boundaries by wrapping and clamping
    """
    # Convert cell matrix to inverse transpose for coordinate transformations
    inverse_cell_transpose = wp.transpose(wp.inverse(cell[0]))

    cells_per_dimension = wp.vec3i(0, 0, 0)
    # Calculate optimal number of cells in each dimension
    for i in range(3):
        # Distance between parallel faces in reciprocal space
        face_distance = type(cell_size)(1.0) / wp.length(inverse_cell_transpose[i])
        cells_per_dimension[i] = max(wp.int32(face_distance / cell_size), 1)

        if cells_per_dimension[i] == 1 and not pbc[i]:
            neighbor_search_radius[i] = 0
        else:
            neighbor_search_radius[i] = wp.int32(
                wp.ceil(
                    cell_size * type(cell_size)(cells_per_dimension[i]) / face_distance
                )
            )

    # Check if total cell count exceeds maximum allowed
    total_cells = int(
        cells_per_dimension[0] * cells_per_dimension[1] * cells_per_dimension[2]
    )

    # Reduce cell count if necessary by halving dimensions iteratively
    while total_cells > max_nbins:
        for i in range(3):
            cells_per_dimension[i] = max(cells_per_dimension[i] // 2, 1)
        total_cells = int(
            cells_per_dimension[0] * cells_per_dimension[1] * cells_per_dimension[2]
        )

    number_of_cells[0] = total_cells


@wp.kernel(enable_backward=False)
def _cell_list_construct_bin_size(
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=Any),
    cells_per_dimension: wp.array(dtype=Any),
    target_cell_size: Any,
    max_cells_allowed: Any,
) -> None:
    """Determine optimal spatial decomposition parameters for cell list construction.

    This kernel calculates the number of cells needed in each spatial dimension
    and the neighbor search radius based on the simulation cell geometry and
    target cell size. Assumes a single system (not batched).

    The algorithm:
    1. Computes optimal cell count per dimension based on cell geometry
    2. Reduces cell count if total exceeds maximum allowed
    3. Calculates neighbor search radius to ensure completeness

    Parameters
    ----------
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix defining simulation box geometry.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags for x, y, z directions.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        OUTPUT: Number of cells to create in x, y, z directions.
    target_cell_size : float
        Desired cell size, typically the neighbor cutoff distance.
    max_cells_allowed : int
        Maximum total number of cells allowed (nx * ny * nz ≤ max_cells_allowed).

    Notes
    -----
    - Modifies: cells_per_dimension, neighbor_search_radius
    - Thread launch: Single thread (dim=1)
    - For non-periodic directions with only 1 cell, search radius is set to 0
    """

    # Convert cell matrix to inverse transpose for coordinate transformations
    inverse_cell_transpose = wp.transpose(wp.inverse(cell[0]))

    # Calculate optimal number of cells in each dimension
    for i in range(3):
        # Distance between parallel faces in reciprocal space
        face_distance = type(target_cell_size)(1.0) / wp.length(
            inverse_cell_transpose[i]
        )
        cells_per_dimension[i] = max(wp.int32(face_distance / target_cell_size), 1)

    # Check if total cell count exceeds maximum allowed
    total_cells = int(
        cells_per_dimension[0] * cells_per_dimension[1] * cells_per_dimension[2]
    )

    # Reduce cell count if necessary by halving dimensions iteratively
    while total_cells > max_cells_allowed:
        for i in range(3):
            cells_per_dimension[i] = max(cells_per_dimension[i] // 2, 1)
        total_cells = int(
            cells_per_dimension[0] * cells_per_dimension[1] * cells_per_dimension[2]
        )


@wp.kernel(enable_backward=False)
def _cell_list_count_atoms_per_bin(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=Any),
    cells_per_dimension: wp.array(dtype=Any),
    atoms_per_cell_count: wp.array(dtype=Any),
    atom_periodic_shifts: wp.array(dtype=Any),
) -> None:
    """Count atoms in each spatial cell and compute periodic boundary shifts.

    This is the first pass of the two-pass cell list construction algorithm.
    Each thread processes one atom, determines which cell it belongs to,
    handles periodic boundary conditions, and atomically increments the
    atom count for that cell.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Atomic coordinates in Cartesian space.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix for coordinate transformations.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags for x, y, z directions.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        Number of spatial cells in x, y, z directions.
    atoms_per_cell_count : wp.array, shape (total_cells,), dtype=wp.int32
        OUTPUT: Number of atoms assigned to each cell (modified atomically).
    atom_periodic_shifts : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        OUTPUT: Periodic boundary crossings for each atom.

    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: atoms_per_cell_count, atom_periodic_shifts
    - Uses atomic operations for thread-safe counting
    - Handles periodic boundaries by wrapping coordinates and tracking shifts
    """
    atom_idx = wp.tid()

    # Transform to fractional coordinates
    inverse_cell_transpose = wp.transpose(wp.inverse(cell[0]))
    fractional_position = inverse_cell_transpose * positions[atom_idx]

    # Determine which cell this atom belongs to
    cell_coords = wp.vec3i(0, 0, 0)
    for dim in range(3):
        cell_coords[dim] = wp.int32(
            wp.floor(
                fractional_position[dim]
                * type(fractional_position[dim])(cells_per_dimension[dim])
            )
        )

        # Handle periodic boundary conditions
        if pbc[dim]:
            cell_before_wrap = cell_coords[dim]
            num_cells = cells_per_dimension[dim]
            quotient, remainder = wpdivmod(cell_before_wrap, num_cells)
            atom_periodic_shifts[atom_idx][dim] = quotient
            cell_coords[dim] = remainder
        else:
            # Clamp to valid cell range for non-periodic dimensions
            atom_periodic_shifts[atom_idx][dim] = 0
            cell_coords[dim] = wp.clamp(
                cell_coords[dim], 0, cells_per_dimension[dim] - 1
            )

    # Convert 3D cell coordinates to linear index
    linear_cell_index = cell_coords[0] + cells_per_dimension[0] * (
        cell_coords[1] + cells_per_dimension[1] * cell_coords[2]
    )

    # Atomically increment the count for this cell
    wp.atomic_add(atoms_per_cell_count, linear_cell_index, 1)


@wp.kernel(enable_backward=False)
def _cell_list_compute_cell_offsets(
    atoms_per_cell_count: wp.array(dtype=wp.int32),
    cell_atom_start_indices: wp.array(dtype=wp.int32),
    total_cells: int,
) -> None:
    """Compute exclusive prefix sum to determine starting indices for each cell.

    This kernel calculates where each cell's atom list begins in the flattened
    cell_atom_indices array. Uses an exclusive prefix sum so that cell i starts
    at index cell_atom_start_indices[i] and contains atoms_per_cell_count[i] atoms.

    Parameters
    ----------
    atoms_per_cell_count : wp.array, shape (total_cells,), dtype=wp.int32
        Number of atoms assigned to each cell.
    cell_atom_start_indices : wp.array, shape (total_cells,), dtype=wp.int32
        OUTPUT: Starting index in cell_atom_indices array for each cell.
    total_cells : int
        Total number of cells in the spatial decomposition.

    Notes
    -----
    - Thread launch: One thread per cell (dim=total_cells)
    - Modifies: cell_atom_start_indices
    - This is a simple O(n²) prefix sum implementation suitable for small arrays
    - For large arrays, a more efficient parallel prefix sum would be preferred
    """
    cell_idx = wp.tid()
    if cell_idx < total_cells:
        running_sum = wp.int32(0)
        for i in range(cell_idx):
            running_sum += atoms_per_cell_count[i]
        cell_atom_start_indices[cell_idx] = running_sum


@wp.kernel(enable_backward=False)
def _cell_list_bin_atoms(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=Any),
    cells_per_dimension: wp.array(dtype=Any),
    atom_to_cell_mapping: wp.array(dtype=Any),
    atoms_per_cell_count: wp.array(dtype=Any),
    cell_atom_start_indices: wp.array(dtype=Any),
    cell_atom_list: wp.array(dtype=Any),
) -> None:
    """Assign atoms to their spatial cells and build cell-to-atom mapping.

    This is the second pass of the two-pass cell list construction algorithm.
    Each thread processes one atom, determines its cell assignment, and adds
    it to that cell's atom list using atomic operations for thread safety.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Atomic coordinates in Cartesian space.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix for coordinate transformations.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags for x, y, z directions.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        Number of spatial cells in x, y, z directions.
    atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        OUTPUT: 3D cell coordinates for each atom.
    atoms_per_cell_count : wp.array, shape (total_cells,), dtype=wp.int32
        MODIFIED: Running count of atoms added to each cell (reset before use).
    cell_atom_start_indices : wp.array, shape (total_cells,), dtype=wp.int32
        Starting index in cell_atom_list for each cell's atoms.
    cell_atom_list : wp.array, shape (total_cells,), dtype=wp.int32
        OUTPUT: Flattened list of atom indices organized by cell.

    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: atom_to_cell_mapping, atoms_per_cell_count, cell_atom_list
    - atoms_per_cell_count must be zeroed before calling this kernel
    - Uses atomic operations for thread-safe list building
    """
    atom_idx = wp.tid()

    # Safety check for thread bounds
    if atom_idx >= positions.shape[0]:
        return

    # Transform to fractional coordinates
    inverse_cell_transpose = wp.transpose(wp.inverse(cell[0]))
    fractional_position = inverse_cell_transpose * positions[atom_idx]

    # Determine which cell this atom belongs to
    cell_coords = wp.vec3i(0, 0, 0)
    for dim in range(3):
        cell_coords[dim] = wp.int32(
            wp.floor(
                fractional_position[dim]
                * type(fractional_position[dim])(cells_per_dimension[dim])
            )
        )

        # Handle periodic boundary conditions
        if pbc[dim]:
            cell_before_wrap = cell_coords[dim]
            num_cells = cells_per_dimension[dim]
            _, remainder = wpdivmod(cell_before_wrap, num_cells)
            cell_coords[dim] = remainder
        else:
            # Clamp to valid cell range for non-periodic dimensions
            cell_coords[dim] = wp.clamp(
                cell_coords[dim], 0, cells_per_dimension[dim] - 1
            )

    # Store the cell assignment for this atom
    atom_to_cell_mapping[atom_idx] = cell_coords

    # Convert 3D cell coordinates to linear index
    linear_cell_index = cell_coords[0] + cells_per_dimension[0] * (
        cell_coords[1] + cells_per_dimension[1] * cell_coords[2]
    )

    # Atomically get position in this cell's atom list
    position_in_cell = wp.atomic_add(atoms_per_cell_count, linear_cell_index, 1)

    # Calculate final position in flattened atom list
    final_list_index = cell_atom_start_indices[linear_cell_index] + position_in_cell

    # Store this atom's index in the cell's atom list
    cell_atom_list[final_list_index] = atom_idx


@wp.func
def _cell_list_query_body(
    atom_idx: int,
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=wp.bool),
    cutoff: Any,
    cells_per_dimension: wp.array(dtype=wp.int32),
    neighbor_search_radius: wp.array(dtype=wp.int32),
    atom_periodic_shifts: wp.array(dtype=wp.vec3i),
    atom_to_cell_mapping: wp.array(dtype=wp.vec3i),
    atoms_per_cell_count: wp.array(dtype=wp.int32),
    cell_atom_start_indices: wp.array(dtype=wp.int32),
    cell_atom_list: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=Any, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: wp.bool,
):
    cutoff_distance_sq = cutoff * cutoff
    central_atom_position = positions[atom_idx]
    central_atom_cell = atom_to_cell_mapping[atom_idx]
    central_atom_shift = atom_periodic_shifts[atom_idx]
    max_neighbors = neighbor_matrix.shape[1]

    cell_mat = cell[0]
    cell_transpose = wp.transpose(cell_mat)

    cpd_x = cells_per_dimension[0]
    cpd_y = cells_per_dimension[1]
    cpd_z = cells_per_dimension[2]

    pbc_x = pbc[0]
    pbc_y = pbc[1]
    pbc_z = pbc[2]

    for dx in range(0, neighbor_search_radius[0] + 1):
        for dy in range(-neighbor_search_radius[1], neighbor_search_radius[1] + 1):
            for dz in range(-neighbor_search_radius[2], neighbor_search_radius[2] + 1):
                if not (
                    dx > 0 or (dx == 0 and dy > 0) or (dx == 0 and dy == 0 and dz >= 0)
                ):
                    continue
                target_x = central_atom_cell[0] + dx
                target_y = central_atom_cell[1] + dy
                target_z = central_atom_cell[2] + dz

                if not pbc_x and (target_x < 0 or target_x >= cpd_x):
                    continue
                if not pbc_y and (target_y < 0 or target_y >= cpd_y):
                    continue
                if not pbc_z and (target_z < 0 or target_z >= cpd_z):
                    continue

                cs_x, wc_x = wpdivmod(target_x, cpd_x)
                cs_y, wc_y = wpdivmod(target_y, cpd_y)
                cs_z, wc_z = wpdivmod(target_z, cpd_z)

                linear_cell_index = wc_x + cpd_x * (wc_y + cpd_y * wc_z)

                cell_start_index = cell_atom_start_indices[linear_cell_index]
                num_atoms_in_cell = atoms_per_cell_count[linear_cell_index]

                for cell_atom_idx in range(num_atoms_in_cell):
                    neighbor_atom_idx = cell_atom_list[cell_start_index + cell_atom_idx]

                    neighbor_atom_shift = atom_periodic_shifts[neighbor_atom_idx]

                    shift_x = cs_x
                    shift_y = cs_y
                    shift_z = cs_z

                    if pbc_x:
                        shift_x += central_atom_shift[0] - neighbor_atom_shift[0]
                    else:
                        shift_x = 0

                    if pbc_y:
                        shift_y += central_atom_shift[1] - neighbor_atom_shift[1]
                    else:
                        shift_y = 0

                    if pbc_z:
                        shift_z += central_atom_shift[2] - neighbor_atom_shift[2]
                    else:
                        shift_z = 0

                    if dx == 0 and dy == 0 and dz == 0:
                        if neighbor_atom_idx <= atom_idx:
                            continue

                    fractional_shift = type(central_atom_position)(
                        type(central_atom_position[0])(shift_x),
                        type(central_atom_position[0])(shift_y),
                        type(central_atom_position[0])(shift_z),
                    )
                    cartesian_shift = cell_transpose * fractional_shift

                    neighbor_pos = positions[neighbor_atom_idx]
                    dr = neighbor_pos - central_atom_position + cartesian_shift
                    distance_sq = wp.dot(dr, dr)

                    if distance_sq < cutoff_distance_sq:
                        _update_neighbor_matrix_pbc(
                            atom_idx,
                            neighbor_atom_idx,
                            neighbor_matrix,
                            neighbor_matrix_shifts,
                            num_neighbors,
                            wp.vec3i(shift_x, shift_y, shift_z),
                            max_neighbors,
                            half_fill,
                        )


@wp.kernel(enable_backward=False)
def _cell_list_build_neighbor_matrix(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=bool),
    cutoff: Any,
    cells_per_dimension: wp.array(dtype=wp.int32),
    neighbor_search_radius: wp.array(dtype=wp.int32),
    atom_periodic_shifts: wp.array(dtype=wp.vec3i),
    atom_to_cell_mapping: wp.array(dtype=wp.vec3i),
    atoms_per_cell_count: wp.array(dtype=wp.int32),
    cell_atom_start_indices: wp.array(dtype=wp.int32),
    cell_atom_list: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=Any, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: bool,
) -> None:
    """Build neighbor matrix with atom pairs and periodic shifts.

    For each atom, searches through neighboring cells and records all neighbor
    atoms within the cutoff distance into a fixed-size matrix format. Stores
    neighbor indices and their periodic shift vectors.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Atomic coordinates in Cartesian space.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix for periodic boundary coordinate shifts.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags.
    cutoff : float
        Maximum distance for considering atoms as neighbors.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        Number of spatial cells in x, y, z directions.
    neighbor_search_radius : wp.array, shape (3,), dtype=wp.int32
        Radius of neighboring cells to search in each dimension.
    atom_periodic_shifts : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        Periodic boundary crossings for each atom.
    atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        3D cell coordinates for each atom.
    atoms_per_cell_count : wp.array, shape (total_cells,), dtype=wp.int32
        Number of atoms in each cell.
    cell_atom_start_indices : wp.array, shape (total_cells,), dtype=wp.int32
        Starting index in cell_atom_list for each cell.
    cell_atom_list : wp.array, shape (total_atoms,), dtype=wp.int32
        Flattened list of atom indices organized by cell.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors, 3), dtype=wp.vec3i
        OUTPUT: Shift vectors for each neighbor relationship.
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Number of neighbors found for each atom.

    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Each thread loops over all neighbor cell shifts internally
    - Modifies: neighbor_matrix, neighbor_matrix_shifts, num_neighbors
    - If max_neighbors is exceeded for an atom, extra neighbors are ignored

    Performance Optimizations:
    - Uses cutoff squared to avoid expensive sqrt operations
    - Caches cells_per_dimension and pbc in registers to reduce memory access
    - Computes cell_transpose only when needed (late in the pipeline)
    - Uses scalar variables instead of vec3 where possible to reduce register pressure
    - Unrolls PBC boundary checks for better branch prediction
    - Explicitly computes distance components to enable vectorization
    """
    atom_idx = wp.tid()
    _cell_list_query_body(
        atom_idx,
        positions,
        cell,
        pbc,
        cutoff,
        cells_per_dimension,
        neighbor_search_radius,
        atom_periodic_shifts,
        atom_to_cell_mapping,
        atoms_per_cell_count,
        cell_atom_start_indices,
        cell_atom_list,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


@wp.kernel(enable_backward=False)
def _cell_list_build_neighbor_matrix_selective(
    positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    pbc: wp.array(dtype=wp.bool),
    cutoff: Any,
    cells_per_dimension: wp.array(dtype=wp.int32),
    neighbor_search_radius: wp.array(dtype=wp.int32),
    atom_periodic_shifts: wp.array(dtype=wp.vec3i),
    atom_to_cell_mapping: wp.array(dtype=wp.vec3i),
    atoms_per_cell_count: wp.array(dtype=wp.int32),
    cell_atom_start_indices: wp.array(dtype=wp.int32),
    cell_atom_list: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array(dtype=wp.int32, ndim=2),
    neighbor_matrix_shifts: wp.array(dtype=Any, ndim=2),
    num_neighbors: wp.array(dtype=wp.int32),
    half_fill: bool,
    rebuild_flags: wp.array(dtype=wp.bool),
) -> None:
    """Selective single-system cell list neighbor matrix kernel.

    Identical to ``_cell_list_build_neighbor_matrix`` but checks
    ``rebuild_flags[0]`` on the GPU and returns immediately when False.
    No CPU-GPU synchronisation occurs.

    Parameters
    ----------
    positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Cartesian coordinates.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix for periodic boundary coordinate shifts.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags.
    cutoff : float
        Maximum distance for considering atoms as neighbors.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        Number of spatial cells in x, y, z directions.
    neighbor_search_radius : wp.array, shape (3,), dtype=wp.int32
        Radius of neighboring cells to search in each dimension.
    atom_periodic_shifts : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        Periodic boundary crossings for each atom.
    atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        3D cell coordinates for each atom.
    atoms_per_cell_count : wp.array, shape (total_cells,), dtype=wp.int32
        Number of atoms in each cell.
    cell_atom_start_indices : wp.array, shape (total_cells,), dtype=wp.int32
        Starting index in cell_atom_list for each cell.
    cell_atom_list : wp.array, shape (total_atoms,), dtype=wp.int32
        Flattened list of atom indices organized by cell.
    neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32
        OUTPUT: Neighbor matrix (untouched when not rebuilding).
    neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors, 3), dtype=wp.vec3i
        OUTPUT: Shift vectors (untouched when not rebuilding).
    num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32
        OUTPUT: Neighbor counts (pre-zeroed by caller when rebuilding).
    half_fill : bool
        If True, only store half of the neighbor relationships.
    rebuild_flags : wp.array, shape (1,), dtype=wp.bool
        GPU-resident flag; False → kernel returns immediately.
    """
    atom_idx = wp.tid()
    if not rebuild_flags[0]:
        return
    _cell_list_query_body(
        atom_idx,
        positions,
        cell,
        pbc,
        cutoff,
        cells_per_dimension,
        neighbor_search_radius,
        atom_periodic_shifts,
        atom_to_cell_mapping,
        atoms_per_cell_count,
        cell_atom_start_indices,
        cell_atom_list,
        neighbor_matrix,
        neighbor_matrix_shifts,
        num_neighbors,
        half_fill,
    )


T = [wp.float32, wp.float64]
V = [wp.vec3f, wp.vec3d]
M = [wp.mat33f, wp.mat33d]
_estimate_cell_list_sizes_overload = {}
_cell_list_construct_bin_size_overload = {}
_cell_list_count_atoms_per_bin_overload = {}
_cell_list_bin_atoms_overload = {}
_cell_list_build_neighbor_matrix_overload = {}
_cell_list_build_neighbor_matrix_selective_overload = {}
for t, v, m in zip(T, V, M):
    _estimate_cell_list_sizes_overload[t] = wp.overload(
        _estimate_cell_list_sizes,
        [
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            t,
            wp.int32,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
        ],
    )
    _cell_list_construct_bin_size_overload[t] = wp.overload(
        _cell_list_construct_bin_size,
        [
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            wp.array(dtype=wp.int32),
            t,
            wp.int32,
        ],
    )
    _cell_list_count_atoms_per_bin_overload[t] = wp.overload(
        _cell_list_count_atoms_per_bin,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
        ],
    )
    _cell_list_bin_atoms_overload[t] = wp.overload(
        _cell_list_bin_atoms,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
        ],
    )
    _cell_list_build_neighbor_matrix_overload[t] = wp.overload(
        _cell_list_build_neighbor_matrix,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array2d(dtype=wp.int32),
            wp.array2d(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.bool,
        ],
    )
    _cell_list_build_neighbor_matrix_selective_overload[t] = wp.overload(
        _cell_list_build_neighbor_matrix_selective,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            t,
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array2d(dtype=wp.int32),
            wp.array2d(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.bool,
            wp.array(dtype=wp.bool),
        ],
    )


###########################################################################################
################################ Core Warp Launchers #######################################
###########################################################################################


[docs] def build_cell_list( positions: wp.array, cell: wp.array, pbc: wp.array, cutoff: float, cells_per_dimension: wp.array, atom_periodic_shifts: wp.array, atom_to_cell_mapping: wp.array, atoms_per_cell_count: wp.array, cell_atom_start_indices: wp.array, cell_atom_list: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for building spatial cell list. Constructs a spatial decomposition data structure for efficient neighbor searching using pure warp operations. This function launches warp kernels to organize atoms into spatial cells. Parameters ---------- positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Atomic coordinates in Cartesian space. cell : wp.array, shape (1, 3, 3), dtype=wp.mat33* Unit cell matrix defining the simulation box. pbc : wp.array, shape (3,), dtype=wp.bool Periodic boundary condition flags for x, y, z directions. cutoff : float Maximum distance for neighbor search. cells_per_dimension : wp.array, shape (3,), dtype=wp.int32 OUTPUT: Number of cells created in x, y, z directions. atom_periodic_shifts : wp.array, shape (total_atoms, 3), dtype=wp.vec3i OUTPUT: Periodic boundary crossings for each atom. atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i OUTPUT: 3D cell coordinates assigned to each atom. atoms_per_cell_count : wp.array, shape (max_total_cells,), dtype=wp.int32 OUTPUT: Number of atoms in each cell. Must be zeroed by caller before first use. cell_atom_start_indices : wp.array, shape (max_total_cells,), dtype=wp.int32 OUTPUT: Starting index in cell_atom_list for each cell's atoms. Must be filled with cumulative sums by caller between kernel launches. cell_atom_list : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Flattened list of atom indices organized by cell. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). Notes ----- - This is a low-level warp interface. Caller must ensure atoms_per_cell_count is zeroed. - atoms_per_cell_count must be zeroed before calling this function. - This function handles the cumsum internally using wp.utils.array_scan. - For framework bindings, use the torch/jax wrappers instead. See Also -------- query_cell_list : Query cell list to build neighbor matrix (call after this) wp.utils.array_scan : Warp utility for computing prefix sums _cell_list_construct_bin_size : Kernel for computing cell dimensions _cell_list_count_atoms_per_bin : Kernel for counting atoms per cell _cell_list_bin_atoms : Kernel for binning atoms into cells """ total_atoms = positions.shape[0] max_total_cells = atoms_per_cell_count.shape[0] wp_cutoff = wp_dtype(cutoff) # Construct cell dimensions wp.launch( _cell_list_construct_bin_size_overload[wp_dtype], dim=1, device=device, inputs=( cell, pbc, cells_per_dimension, wp_cutoff, max_total_cells, ), ) # Count atoms per bin (expects atoms_per_cell_count to be zeroed by caller) wp.launch( _cell_list_count_atoms_per_bin_overload[wp_dtype], dim=total_atoms, inputs=[ positions, cell, pbc, cells_per_dimension, atoms_per_cell_count, atom_periodic_shifts, ], device=device, ) # Compute exclusive scan to get starting indices for each cell # This converts [3, 5, 2, 0, 4, ...] -> [0, 3, 8, 10, 10, ...] wp.utils.array_scan(atoms_per_cell_count, cell_atom_start_indices, inclusive=False) # Zero counts before binning atoms (second pass needs fresh counts) zero_array(atoms_per_cell_count, device) # Bin atoms (expects atoms_per_cell_count to be zeroed) wp.launch( _cell_list_bin_atoms_overload[wp_dtype], dim=total_atoms, inputs=[ positions, cell, pbc, cells_per_dimension, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, ], device=device, )
[docs] def query_cell_list( positions: wp.array, cell: wp.array, pbc: wp.array, cutoff: float, cells_per_dimension: wp.array, neighbor_search_radius: wp.array, atom_periodic_shifts: wp.array, atom_to_cell_mapping: wp.array, atoms_per_cell_count: wp.array, cell_atom_start_indices: wp.array, cell_atom_list: wp.array, neighbor_matrix: wp.array, neighbor_matrix_shifts: wp.array, num_neighbors: wp.array, wp_dtype: type, device: str, half_fill: bool = False, rebuild_flags: wp.array | None = None, ) -> None: """Core warp launcher for querying spatial cell list to build neighbor matrix. Uses pre-built cell list data structures to efficiently find all atom pairs within the specified cutoff distance using pure warp operations. Parameters ---------- positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Atomic coordinates in Cartesian space. cell : wp.array, shape (1, 3, 3), dtype=wp.mat33* Unit cell matrix for periodic boundary coordinate shifts. pbc : wp.array, shape (3,), dtype=wp.bool Periodic boundary condition flags. cutoff : float Maximum distance for considering atoms as neighbors. cells_per_dimension : wp.array, shape (3,), dtype=wp.int32 Number of cells in x, y, z directions from build_cell_list_warp. neighbor_search_radius : wp.array, shape (3,), dtype=wp.int32 Radius of neighboring cells to search in each dimension. atom_periodic_shifts : wp.array, shape (total_atoms, 3), dtype=wp.vec3i Periodic boundary crossings for each atom from build_cell_list_warp. atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i 3D cell coordinates for each atom from build_cell_list_warp. atoms_per_cell_count : wp.array, shape (max_total_cells,), dtype=wp.int32 Number of atoms in each cell from build_cell_list_warp. cell_atom_start_indices : wp.array, shape (max_total_cells,), dtype=wp.int32 Starting index in cell_atom_list for each cell from build_cell_list_warp. cell_atom_list : wp.array, shape (total_atoms,), dtype=wp.int32 Flattened list of atom indices organized by cell from build_cell_list_warp. neighbor_matrix : wp.array, shape (total_atoms, max_neighbors), dtype=wp.int32 OUTPUT: Neighbor matrix to be filled with neighbor atom indices. neighbor_matrix_shifts : wp.array, shape (total_atoms, max_neighbors, 3), dtype=wp.vec3i OUTPUT: Matrix storing shift vectors for each neighbor relationship. num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Number of neighbors found for each atom. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). half_fill : bool, default=False If True, only store half of the neighbor relationships (i < j). rebuild_flags : wp.array, shape (1,), dtype=wp.bool GPU-resident flag; False → kernel returns immediately. Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - Output arrays (neighbor_matrix, neighbor_matrix_shifts, num_neighbors) should be pre-allocated by caller. See Also -------- build_cell_list : Build cell list data structures (call before this) _cell_list_build_neighbor_matrix : Kernel that performs the neighbor search """ total_atoms = positions.shape[0] wp_cutoff = wp_dtype(cutoff) if rebuild_flags is None: # Build neighbor matrix wp.launch( _cell_list_build_neighbor_matrix_overload[wp_dtype], dim=total_atoms, inputs=[ positions, cell, pbc, wp_cutoff, cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, ], device=device, ) else: wp.launch( _cell_list_build_neighbor_matrix_selective_overload[wp_dtype], dim=total_atoms, inputs=[ positions, cell, pbc, wp_cutoff, cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, rebuild_flags, ], device=device, )