Source code for nvalchemiops.neighbors.rebuild_detection

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Core warp kernels and launchers for rebuild detection.

This module provides warp kernels to determine when cell lists and neighbor lists
need to be rebuilt based on atomic positions, cell changes, and skin distance criteria.
See `nvalchemiops.torch.neighbors` for PyTorch bindings.
"""

from typing import Any

import warp as wp

from nvalchemiops.neighbors.neighbor_utils import (
    update_ref_positions,
    update_ref_positions_batch,
)

__all__ = [
    "check_cell_list_rebuild",
    "check_neighbor_list_rebuild",
    "check_batch_neighbor_list_rebuild",
    "check_batch_cell_list_rebuild",
]

###########################################################################################
########################### Cell List Rebuild Detection ###################################
###########################################################################################


@wp.kernel(enable_backward=False)
def _check_atoms_changed_cells(
    current_positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    atom_to_cell_mapping: wp.array(dtype=Any),
    cells_per_dimension: wp.array(dtype=Any),
    pbc: wp.array(dtype=Any),
    rebuild_flag: wp.array(dtype=wp.bool),
) -> None:
    """Detect if atoms have moved between spatial cells requiring cell list rebuild.

    This kernel computes current cell assignments for each atom and compares them
    with the stored cell assignments from the existing cell list to determine if
    any atoms have crossed cell boundaries. Uses early termination for efficiency.

    Parameters
    ----------
    current_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Current atomic coordinates in Cartesian space.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
        Unit cell matrix for coordinate transformations.
    atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i
        Previously computed cell coordinates for each atom from existing cell list.
        This is an output from build_cell_list.
    cells_per_dimension : wp.array, shape (3,), dtype=wp.int32
        Number of cells in x, y, z directions.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags for x, y, z directions.
    rebuild_flag : wp.array, shape (1,), dtype=bool
        OUTPUT: Flag set to True if any atom changed cells (modified atomically).

    Notes
    -----
    - Currently only supports single system.
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: rebuild_flag (atomic write)
    - Early termination: Threads exit if rebuild already flagged
    - Handles periodic boundaries with proper wrapping
    """
    atom_idx = wp.tid()

    if atom_idx >= current_positions.shape[0]:
        return

    # Skip computation if rebuild already flagged by another thread
    if rebuild_flag[0]:
        return

    # Transform current position to fractional coordinates
    inverse_cell_transpose = wp.transpose(wp.inverse(cell[0]))
    fractional_position = inverse_cell_transpose * current_positions[atom_idx]
    current_cell_coords = wp.vec3i(0, 0, 0)

    # Compute current cell coordinates for each dimension
    for dim in range(3):
        current_cell_coords[dim] = wp.int32(
            wp.floor(
                fractional_position[dim]
                * type(fractional_position[dim])(cells_per_dimension[dim])
            )
        )

        # Handle periodic boundary conditions
        if pbc[dim]:
            current_cell_coords[dim] = (
                current_cell_coords[dim] % cells_per_dimension[dim]
            )
            if current_cell_coords[dim] < 0:
                current_cell_coords[dim] += cells_per_dimension[dim]
        else:
            # Clamp to valid cell range for non-periodic dimensions
            current_cell_coords[dim] = wp.clamp(
                current_cell_coords[dim], 0, cells_per_dimension[dim] - 1
            )

    # Compare with stored cell coordinates from existing cell list
    stored_cell_coords = atom_to_cell_mapping[atom_idx]

    # Check if atom has moved to a different cell
    if (
        current_cell_coords[0] != stored_cell_coords[0]
        or current_cell_coords[1] != stored_cell_coords[1]
        or current_cell_coords[2] != stored_cell_coords[2]
    ):
        # Atom crossed cell boundary - flag for rebuild
        rebuild_flag[0] = True


# Generate overload dictionary for cell list rebuild kernel
_T = [wp.float32, wp.float64, wp.float16]
_V = [wp.vec3f, wp.vec3d, wp.vec3h]
_M = [wp.mat33f, wp.mat33d, wp.mat33h]
_check_atoms_changed_cells_overload = {}
for t, v, m in zip(_T, _V, _M):
    _check_atoms_changed_cells_overload[t] = wp.overload(
        _check_atoms_changed_cells,
        [
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=wp.vec3i),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.bool),
            wp.array(dtype=wp.bool),
        ],
    )


###########################################################################################
########################### Neighbor List Rebuild Detection #############################
###########################################################################################


@wp.kernel(enable_backward=False)
def _check_atoms_moved_beyond_skin(
    reference_positions: wp.array(dtype=Any),
    current_positions: wp.array(dtype=Any),
    skin_distance_threshold: Any,
    rebuild_flag: wp.array(dtype=wp.bool),
) -> None:
    """Detect if atoms have moved beyond skin distance requiring neighbor list rebuild.

    This kernel computes the displacement of each atom from its reference position
    and checks if any atom has moved farther than the skin distance threshold.
    Uses early termination for computational efficiency when rebuild is already flagged.

    Parameters
    ----------
    reference_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Atomic positions when the neighbor list was last built.
    current_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
        Current atomic positions to compare against reference.
    skin_distance_threshold : float*/int*
        Maximum allowed displacement before neighbor list becomes invalid.
        Typically set to (cutoff_radius - cutoff) / 2.
    rebuild_flag : wp.array, shape (1,), dtype=bool
        OUTPUT: Flag set to True if any atom moved beyond skin distance (modified atomically).

    Notes
    -----
    - Currently only supports single system.
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: rebuild_flag (atomic write)
    - Early termination: Threads exit if rebuild already flagged
    - Displacement calculation uses Euclidean distance
    """
    atom_idx = wp.tid()

    if atom_idx >= reference_positions.shape[0]:
        return

    # Skip computation if rebuild already flagged by another thread
    if rebuild_flag[0]:
        return

    # Calculate displacement vector from reference to current position
    displacement_vector = current_positions[atom_idx] - reference_positions[atom_idx]
    displacement_magnitude = wp.length(displacement_vector)

    # Check if atom has moved beyond the skin distance threshold
    if displacement_magnitude > skin_distance_threshold:
        # Neighbor list is no longer valid - flag for rebuild
        rebuild_flag[0] = True


# Generate overload dictionary for neighbor list rebuild kernel
_check_atoms_moved_beyond_skin_overload = {}
for t, v in zip(_T, _V):
    _check_atoms_moved_beyond_skin_overload[t] = wp.overload(
        _check_atoms_moved_beyond_skin,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            t,
            wp.array(dtype=wp.bool),
        ],
    )


###########################################################################################
############### PBC Neighbor List Rebuild Detection (Periodic-Aware) #####################
###########################################################################################


@wp.kernel(enable_backward=False)
def _check_atoms_moved_beyond_skin_pbc(
    reference_positions: wp.array(dtype=Any),
    current_positions: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    cell_inv: wp.array(dtype=Any),
    pbc: wp.array(dtype=wp.bool),
    skin_distance_threshold: Any,
    rebuild_flag: wp.array(dtype=wp.bool),
) -> None:
    """Detect if atoms moved beyond skin distance using minimum-image convention.

    Unlike ``_check_atoms_moved_beyond_skin`` which uses raw Euclidean
    displacement, this kernel applies the minimum-image convention (MIC)
    so that atoms crossing periodic boundaries are not spuriously flagged.

    Parameters
    ----------
    reference_positions : wp.array, shape (total_atoms,), dtype=wp.vec3*
        Atomic positions when the neighbor list was last built.
    current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3*
        Current atomic positions to compare against reference.
    cell : wp.array, shape (1,), dtype=wp.mat33*
        Unit cell matrix (basis vectors as rows).
    cell_inv : wp.array, shape (1,), dtype=wp.mat33*
        Precomputed inverse of the cell matrix.
    pbc : wp.array, shape (3,), dtype=bool
        Periodic boundary condition flags for x, y, z directions.
    skin_distance_threshold : float*/int*
        Maximum allowed displacement before neighbor list becomes invalid.
    rebuild_flag : wp.array, shape (1,), dtype=bool
        OUTPUT: Flag set to True if any atom moved beyond skin distance.

    Notes
    -----
    - Thread launch: One thread per atom (dim=total_atoms)
    - Modifies: rebuild_flag (atomic write)
    - Correct for triclinic cells; avoids per-thread matrix inversion
    """
    atom_idx = wp.tid()

    if atom_idx >= reference_positions.shape[0]:
        return

    if rebuild_flag[0]:
        return

    delta = current_positions[atom_idx] - reference_positions[atom_idx]

    # Convert displacement to fractional coordinates (row-vector convention)
    delta_frac = delta * cell_inv[0]

    # Apply minimum-image convention on periodic dimensions
    for dim in range(3):
        if pbc[dim]:
            delta_frac[dim] -= wp.floor(delta_frac[dim] + type(delta_frac[dim])(0.5))

    # Convert back to Cartesian
    delta_cart = delta_frac * cell[0]
    displacement_magnitude = wp.length(delta_cart)

    if displacement_magnitude > skin_distance_threshold:
        rebuild_flag[0] = True


_check_atoms_moved_beyond_skin_pbc_overload = {}
for t, v, m in zip(_T, _V, _M):
    _check_atoms_moved_beyond_skin_pbc_overload[t] = wp.overload(
        _check_atoms_moved_beyond_skin_pbc,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=m),
            wp.array(dtype=m),
            wp.array(dtype=wp.bool),
            t,
            wp.array(dtype=wp.bool),
        ],
    )


###########################################################################################
########################### Warp Launchers ###############################################
###########################################################################################


[docs] def check_cell_list_rebuild( current_positions: wp.array, atom_to_cell_mapping: wp.array, cells_per_dimension: wp.array, cell: wp.array, pbc: wp.array, rebuild_flag: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for detecting if cell list needs rebuilding. Checks if any atoms have moved between spatial cells since the cell list was built. Parameters ---------- current_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Current atomic coordinates in Cartesian space. atom_to_cell_mapping : wp.array, shape (total_atoms, 3), dtype=wp.vec3i Previously computed cell coordinates for each atom. cells_per_dimension : wp.array, shape (3,), dtype=wp.int32 Number of cells in x, y, z directions. cell : wp.array, shape (1, 3, 3), dtype=wp.mat33* Unit cell matrix for coordinate transformations. pbc : wp.array, shape (3,), dtype=wp.bool Periodic boundary condition flags. rebuild_flag : wp.array, shape (1,), dtype=wp.bool OUTPUT: Flag set to True if rebuild is needed. Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - rebuild_flag must be pre-allocated and initialized to False by caller. See Also -------- _check_atoms_changed_cells : Kernel that performs the check """ total_atoms = current_positions.shape[0] wp.launch( kernel=_check_atoms_changed_cells_overload[wp_dtype], dim=total_atoms, inputs=[ current_positions, cell, atom_to_cell_mapping, cells_per_dimension, pbc, rebuild_flag, ], device=device, )
[docs] def check_neighbor_list_rebuild( reference_positions: wp.array, current_positions: wp.array, skin_distance_threshold: float, rebuild_flag: wp.array, wp_dtype: type, device: str, update_reference_positions: bool = False, cell: wp.array | None = None, cell_inv: wp.array | None = None, pbc: wp.array | None = None, ) -> None: """Core warp launcher for detecting if neighbor list needs rebuilding. Checks if any atoms have moved beyond the skin distance since the neighbor list was built. When ``cell``, ``cell_inv`` and ``pbc`` are all provided the check uses minimum-image convention (MIC) so that atoms crossing periodic boundaries are not spuriously flagged. Parameters ---------- reference_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Atomic positions when the neighbor list was last built. current_positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3* Current atomic positions to compare against reference. skin_distance_threshold : float Maximum allowed displacement before neighbor list becomes invalid. rebuild_flag : wp.array, shape (1,), dtype=wp.bool OUTPUT: Flag set to True if rebuild is needed. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). update_reference_positions : bool, optional If True, overwrite ``reference_positions`` with ``current_positions`` for all atoms when a rebuild is detected. The update runs in a second kernel launch after the detection kernel, so every atom is guaranteed to be updated with no race conditions. Default False. cell : wp.array or None, optional Unit cell matrix, shape (1,), dtype=wp.mat33*. Required together with ``cell_inv`` and ``pbc`` to enable MIC displacement. cell_inv : wp.array or None, optional Precomputed inverse of the cell matrix, same shape/dtype as ``cell``. pbc : wp.array or None, optional Periodic boundary condition flags, shape (3,), dtype=wp.bool. Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - rebuild_flag must be pre-allocated and initialized to False by caller. Raises ------ ValueError If only a subset of ``cell``, ``cell_inv``, and ``pbc`` are provided. All three must be supplied together to enable MIC displacement. See Also -------- _check_atoms_moved_beyond_skin : Euclidean kernel _check_atoms_moved_beyond_skin_pbc : PBC kernel for periodic systems update_ref_positions : Standalone reference-position update launcher """ pbc_params = (cell, cell_inv, pbc) if any(p is not None for p in pbc_params) and not all( p is not None for p in pbc_params ): raise ValueError( "cell, cell_inv, and pbc must all be provided together to enable MIC " "displacement checking. Received a partial set." ) total_atoms = reference_positions.shape[0] use_pbc = cell is not None if use_pbc: wp.launch( kernel=_check_atoms_moved_beyond_skin_pbc_overload[wp_dtype], dim=total_atoms, inputs=[ reference_positions, current_positions, cell, cell_inv, pbc, wp_dtype(skin_distance_threshold), rebuild_flag, ], device=device, ) else: wp.launch( kernel=_check_atoms_moved_beyond_skin_overload[wp_dtype], dim=total_atoms, inputs=[ reference_positions, current_positions, wp_dtype(skin_distance_threshold), rebuild_flag, ], device=device, ) if update_reference_positions: update_ref_positions( current_positions, rebuild_flag, reference_positions, wp_dtype, device )
########################################################################################### ########################### Batch Neighbor List Rebuild Detection ######################## ########################################################################################### @wp.kernel(enable_backward=False) def _check_batch_atoms_moved_beyond_skin( reference_positions: wp.array(dtype=Any), current_positions: wp.array(dtype=Any), batch_idx: wp.array(dtype=wp.int32), skin_distance_threshold: Any, rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Detect per-system if atoms moved beyond skin distance requiring neighbor list rebuild. Checks each atom's displacement from its reference position against the skin distance threshold. When any atom in a system exceeds this threshold, the system's rebuild flag is set to True. Uses early termination per system for efficiency. Parameters ---------- reference_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Atomic positions when each system's neighbor list was last built. current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic positions to compare against reference. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. skin_distance_threshold : float* Maximum allowed displacement before neighbor list becomes invalid. Typically set to (cutoff_radius - cutoff) / 2. rebuild_flags : wp.array, shape (num_systems,), dtype=bool OUTPUT: Per-system flags set to True if any atom in that system moved beyond skin distance (modified per system). Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: rebuild_flags - Early termination: Threads exit if their system's rebuild flag is already set - Displacement calculation uses Euclidean distance - No CPU-GPU synchronization required; flags are set entirely on GPU """ atom_idx = wp.tid() if atom_idx >= reference_positions.shape[0]: return isys = batch_idx[atom_idx] # Skip computation if rebuild already flagged for this system if rebuild_flags[isys]: return displacement_vector = current_positions[atom_idx] - reference_positions[atom_idx] displacement_magnitude = wp.length(displacement_vector) if displacement_magnitude > skin_distance_threshold: rebuild_flags[isys] = True # Generate overload dictionary for batch neighbor list rebuild kernel _check_batch_atoms_moved_beyond_skin_overload = {} for t, v in zip(_T, _V): _check_batch_atoms_moved_beyond_skin_overload[t] = wp.overload( _check_batch_atoms_moved_beyond_skin, [ wp.array(dtype=v), wp.array(dtype=v), wp.array(dtype=wp.int32), t, wp.array(dtype=wp.bool), ], ) ########################################################################################### ############ PBC Batch Neighbor List Rebuild Detection (Periodic-Aware) ################## ########################################################################################### @wp.kernel(enable_backward=False) def _check_batch_atoms_moved_beyond_skin_pbc( reference_positions: wp.array(dtype=Any), current_positions: wp.array(dtype=Any), batch_idx: wp.array(dtype=wp.int32), cell: wp.array(dtype=Any), cell_inv: wp.array(dtype=Any), pbc: wp.array2d(dtype=wp.bool), skin_distance_threshold: Any, rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Per-system PBC-aware skin-distance rebuild detection. Like ``_check_batch_atoms_moved_beyond_skin`` but applies minimum-image convention per system so atoms wrapping across periodic boundaries are not spuriously flagged. Parameters ---------- reference_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Atomic positions when each system's neighbor list was last built. current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic positions to compare against reference. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. cell : wp.array, shape (num_systems,), dtype=wp.mat33* Per-system unit cell matrices (basis vectors as rows). cell_inv : wp.array, shape (num_systems,), dtype=wp.mat33* Precomputed per-system inverse cell matrices. pbc : wp.array2d, shape (num_systems, 3), dtype=bool Per-system periodic boundary condition flags. skin_distance_threshold : float* Maximum allowed displacement before neighbor list becomes invalid. rebuild_flags : wp.array, shape (num_systems,), dtype=bool OUTPUT: Per-system flags set to True if any atom moved beyond skin distance. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: rebuild_flags - Correct for triclinic cells; avoids per-thread matrix inversion """ atom_idx = wp.tid() if atom_idx >= reference_positions.shape[0]: return isys = batch_idx[atom_idx] if rebuild_flags[isys]: return delta = current_positions[atom_idx] - reference_positions[atom_idx] # Convert displacement to fractional coordinates (row-vector convention) delta_frac = delta * cell_inv[isys] # Apply minimum-image convention on periodic dimensions for dim in range(3): if pbc[isys, dim]: delta_frac[dim] -= wp.floor(delta_frac[dim] + type(delta_frac[dim])(0.5)) # Convert back to Cartesian delta_cart = delta_frac * cell[isys] displacement_magnitude = wp.length(delta_cart) if displacement_magnitude > skin_distance_threshold: rebuild_flags[isys] = True _check_batch_atoms_moved_beyond_skin_pbc_overload = {} for t, v, m in zip(_T, _V, _M): _check_batch_atoms_moved_beyond_skin_pbc_overload[t] = wp.overload( _check_batch_atoms_moved_beyond_skin_pbc, [ wp.array(dtype=v), wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array(dtype=m), wp.array(dtype=m), wp.array2d(dtype=wp.bool), t, wp.array(dtype=wp.bool), ], ) ########################################################################################### ########################### Batch Cell List Rebuild Detection ############################ ########################################################################################### @wp.kernel(enable_backward=False) def _check_batch_atoms_changed_cells( current_positions: wp.array(dtype=Any), cell: wp.array(dtype=Any), atom_to_cell_mapping: wp.array(dtype=wp.vec3i), batch_idx: wp.array(dtype=wp.int32), cells_per_dimension: wp.array(dtype=wp.vec3i), pbc: wp.array2d(dtype=wp.bool), rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Detect per-system if atoms moved between cells requiring cell list rebuild. Computes current cell assignments for each atom and compares with stored cell assignments. When any atom in a system has crossed a cell boundary, that system's rebuild flag is set to True. Uses early termination per system. Parameters ---------- current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current Cartesian coordinates. cell : wp.array, shape (num_systems,), dtype=wp.mat33* Per-system unit cell matrices for coordinate transformations. atom_to_cell_mapping : wp.array, shape (total_atoms,), dtype=wp.vec3i Previously computed cell coordinates for each atom from existing cell lists. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. cells_per_dimension : wp.array, shape (num_systems,), dtype=wp.vec3i Number of cells in x, y, z directions for each system. pbc : wp.array2d, shape (num_systems, 3), dtype=bool Per-system periodic boundary condition flags. rebuild_flags : wp.array, shape (num_systems,), dtype=bool OUTPUT: Per-system flags set to True if any atom changed cells. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: rebuild_flags - Early termination: Threads exit if their system's rebuild flag is already set - Handles periodic boundaries with proper wrapping per system - No CPU-GPU synchronization required; flags are set entirely on GPU """ atom_idx = wp.tid() if atom_idx >= current_positions.shape[0]: return isys = batch_idx[atom_idx] # Skip computation if rebuild already flagged for this system if rebuild_flags[isys]: return _cell = cell[isys] _cpd = cells_per_dimension[isys] # Transform current position to fractional coordinates (row-vector convention) _inv_cell = wp.inverse(_cell) fractional_position = current_positions[atom_idx] * _inv_cell current_cell_coords = wp.vec3i(0, 0, 0) # Compute current cell coordinates for each dimension for dim in range(3): current_cell_coords[dim] = wp.int32( wp.floor( fractional_position[dim] * type(fractional_position[dim])(_cpd[dim]) ) ) # Handle periodic boundary conditions if pbc[isys, dim]: current_cell_coords[dim] = current_cell_coords[dim] % _cpd[dim] if current_cell_coords[dim] < 0: current_cell_coords[dim] += _cpd[dim] else: current_cell_coords[dim] = wp.clamp( current_cell_coords[dim], 0, _cpd[dim] - 1 ) # Compare with stored cell coordinates from existing cell list stored_cell_coords = atom_to_cell_mapping[atom_idx] if ( current_cell_coords[0] != stored_cell_coords[0] or current_cell_coords[1] != stored_cell_coords[1] or current_cell_coords[2] != stored_cell_coords[2] ): rebuild_flags[isys] = True # Generate overload dictionary for batch cell list rebuild kernel _check_batch_atoms_changed_cells_overload = {} for t, v, m in zip(_T, _V, _M): _check_batch_atoms_changed_cells_overload[t] = wp.overload( _check_batch_atoms_changed_cells, [ wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=wp.vec3i), wp.array(dtype=wp.int32), wp.array(dtype=wp.vec3i), wp.array2d(dtype=wp.bool), wp.array(dtype=wp.bool), ], ) ########################################################################################### ########################### Batch Warp Launchers ######################################### ########################################################################################### def check_batch_neighbor_list_rebuild( reference_positions: wp.array, current_positions: wp.array, batch_idx: wp.array, skin_distance_threshold: float, rebuild_flags: wp.array, wp_dtype: type, device: str, update_reference_positions: bool = False, cell: wp.array | None = None, cell_inv: wp.array | None = None, pbc: wp.array | None = None, ) -> None: """Core warp launcher for detecting per-system neighbor list rebuild needs. Checks if any atoms in each system have moved beyond the skin distance since the neighbor list was built. Sets per-system rebuild flags on GPU without requiring CPU synchronization. When ``cell``, ``cell_inv`` and ``pbc`` are all provided the check uses minimum-image convention (MIC) so that atoms crossing periodic boundaries are not spuriously flagged. Parameters ---------- reference_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Atomic positions when each system's neighbor list was last built. current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic positions to compare against reference. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. skin_distance_threshold : float Maximum allowed displacement before neighbor list becomes invalid. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool OUTPUT: Per-system flags set to True if rebuild is needed. Must be pre-allocated and initialized to False by caller. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). update_reference_positions : bool, optional If True, overwrite ``reference_positions`` with ``current_positions`` for all atoms in rebuilt systems when a rebuild is detected. The update runs in a second kernel launch after the detection kernel, so every atom in each rebuilt system is guaranteed to be updated with no race conditions. Default False. cell : wp.array or None, optional Per-system cell matrices, shape (num_systems,), dtype=wp.mat33*. Required together with ``cell_inv`` and ``pbc`` to enable MIC. cell_inv : wp.array or None, optional Precomputed per-system inverse cell matrices, same shape/dtype as ``cell``. pbc : wp.array or None, optional Per-system PBC flags, shape (num_systems, 3), dtype=wp.bool (2D). Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - rebuild_flags must be pre-allocated and initialized to False by caller. - No CPU-GPU synchronization required; flags are written entirely on GPU. Raises ------ ValueError If only a subset of ``cell``, ``cell_inv``, and ``pbc`` are provided. All three must be supplied together to enable MIC displacement. See Also -------- _check_batch_atoms_moved_beyond_skin : Euclidean kernel _check_batch_atoms_moved_beyond_skin_pbc : PBC kernel for periodic systems update_ref_positions_batch : Standalone reference-position update launcher """ pbc_params = (cell, cell_inv, pbc) if any(p is not None for p in pbc_params) and not all( p is not None for p in pbc_params ): raise ValueError( "cell, cell_inv, and pbc must all be provided together to enable MIC " "displacement checking. Received a partial set." ) total_atoms = reference_positions.shape[0] use_pbc = cell is not None if use_pbc: wp.launch( kernel=_check_batch_atoms_moved_beyond_skin_pbc_overload[wp_dtype], dim=total_atoms, inputs=[ reference_positions, current_positions, batch_idx, cell, cell_inv, pbc, wp_dtype(skin_distance_threshold), rebuild_flags, ], device=device, ) else: wp.launch( kernel=_check_batch_atoms_moved_beyond_skin_overload[wp_dtype], dim=total_atoms, inputs=[ reference_positions, current_positions, batch_idx, wp_dtype(skin_distance_threshold), rebuild_flags, ], device=device, ) if update_reference_positions: update_ref_positions_batch( current_positions, rebuild_flags, batch_idx, reference_positions, wp_dtype, device, ) def check_batch_cell_list_rebuild( current_positions: wp.array, atom_to_cell_mapping: wp.array, batch_idx: wp.array, cells_per_dimension: wp.array, cell: wp.array, pbc: wp.array, rebuild_flags: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for detecting per-system cell list rebuild needs. Checks if any atoms in each system have moved between spatial cells since the cell list was built. Sets per-system rebuild flags on GPU without requiring CPU synchronization. Parameters ---------- current_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current Cartesian coordinates. atom_to_cell_mapping : wp.array, shape (total_atoms,), dtype=wp.vec3i Previously computed cell coordinates for each atom. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. cells_per_dimension : wp.array, shape (num_systems,), dtype=wp.vec3i Number of cells in x, y, z directions for each system. cell : wp.array, shape (num_systems,), dtype=wp.mat33* Per-system unit cell matrices for coordinate transformations. pbc : wp.array, shape (num_systems, 3), dtype=wp.bool Per-system periodic boundary condition flags (2D array). rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool OUTPUT: Per-system flags set to True if rebuild is needed. Must be pre-allocated and initialized to False by caller. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - rebuild_flags must be pre-allocated and initialized to False by caller. - No CPU-GPU synchronization required; flags are written entirely on GPU. See Also -------- _check_batch_atoms_changed_cells : Kernel that performs the check """ total_atoms = current_positions.shape[0] wp.launch( kernel=_check_batch_atoms_changed_cells_overload[wp_dtype], dim=total_atoms, inputs=[ current_positions, cell, atom_to_cell_mapping, batch_idx, cells_per_dimension, pbc, rebuild_flags, ], device=device, )