Source code for nvalchemiops.neighbors.neighbor_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Core warp utilities for neighbor list construction.

This module contains warp kernels and launchers for neighbor list operations.
See `nvalchemiops.torch.neighbors` for PyTorch bindings.
"""

import math
from typing import Any

import warp as wp


[docs] class NeighborOverflowError(Exception): """Exception raised when the number of neighbors exceeds the maximum allowed. This error indicates that the pre-allocated neighbor matrix is too small to hold all discovered neighbors. Users should increase `max_neighbors` parameter or use a larger pre-allocated tensor. Parameters ---------- max_neighbors : int The maximum number of neighbors the matrix can hold. num_neighbors : int The actual number of neighbors found. """ def __init__(self, max_neighbors: int, num_neighbors: int): super().__init__( f"The number of neighbors is larger than the maximum allowed: " f"{num_neighbors} > {max_neighbors}." ) self.max_neighbors = max_neighbors self.num_neighbors = num_neighbors
__all__ = [ "NeighborOverflowError", "compute_naive_num_shifts", "compute_inv_cells", "zero_array", "selective_zero_num_neighbors", "selective_zero_num_neighbors_single", "estimate_max_neighbors", "wrap_positions_single", "wrap_positions_batch", "_expand_naive_shifts_selective", "update_ref_positions", "update_ref_positions_batch", ] @wp.kernel(enable_backward=False) def _expand_naive_shifts( shift_range: wp.array(dtype=wp.vec3i), shift_offset: wp.array(dtype=int), shifts: wp.array(dtype=wp.vec3i), shift_system_idx: wp.array(dtype=int), ) -> None: """Expand shift ranges into actual shift vectors for all systems in the batch. Converts the compact shift range representation into a flattened array of explicit shift vectors, maintaining proper indexing to avoid double counting of periodic images. Parameters ---------- shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i Array of shift ranges in each dimension for each system. shift_offset : wp.array, shape (num_systems+1,), dtype=wp.int32 Cumulative sum of number of shifts for each system. shifts : wp.array, shape (total_shifts, 3), dtype=wp.vec3i OUTPUT: Flattened array to store the shift vectors. shift_system_idx : wp.array, shape (total_shifts,), dtype=wp.int32 OUTPUT: System index mapping for each shift vector. Notes ----- - Thread launch: One thread per system in the batch (dim=num_systems) - Modifies: shifts, shift_system_idx - total_shifts = shift_offset[-1] - Shift vectors generated in order k0, k1, k2 (increasing) - All shift vectors are integer lattice coordinates """ tid = wp.tid() pos = shift_offset[tid] _shift_range = shift_range[tid] for k0 in range(0, _shift_range[0] + 1): for k1 in range(-_shift_range[1], _shift_range[1] + 1): for k2 in range(-_shift_range[2], _shift_range[2] + 1): if k0 > 0 or (k0 == 0 and k1 > 0) or (k0 == 0 and k1 == 0 and k2 >= 0): shifts[pos] = wp.vec3i(k0, k1, k2) shift_system_idx[pos] = tid pos += 1 @wp.kernel(enable_backward=False) def _expand_naive_shifts_selective( shift_range: wp.array(dtype=wp.vec3i), shift_offset: wp.array(dtype=int), shifts: wp.array(dtype=wp.vec3i), shift_system_idx: wp.array(dtype=int), rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Expand shift ranges into actual shift vectors, skipping non-rebuilt systems. Identical to ``_expand_naive_shifts`` but checks ``rebuild_flags[tid]`` on the GPU and exits immediately for systems that do not need rebuilding. No CPU-GPU synchronisation occurs. Parameters ---------- shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i Array of shift ranges in each dimension for each system. shift_offset : wp.array, shape (num_systems+1,), dtype=wp.int32 Cumulative sum of number of shifts for each system. shifts : wp.array, shape (total_shifts, 3), dtype=wp.vec3i OUTPUT: Flattened array to store the shift vectors. shift_system_idx : wp.array, shape (total_shifts,), dtype=wp.int32 OUTPUT: System index mapping for each shift vector. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool Per-system rebuild flags. False → kernel returns immediately for that system. Notes ----- - Thread launch: One thread per system in the batch (dim=num_systems) - Modifies: shifts, shift_system_idx (only for rebuilt systems) - total_shifts = shift_offset[-1] """ tid = wp.tid() if not rebuild_flags[tid]: return pos = shift_offset[tid] _shift_range = shift_range[tid] for k0 in range(0, _shift_range[0] + 1): for k1 in range(-_shift_range[1], _shift_range[1] + 1): for k2 in range(-_shift_range[2], _shift_range[2] + 1): if k0 > 0 or (k0 == 0 and k1 > 0) or (k0 == 0 and k1 == 0 and k2 >= 0): shifts[pos] = wp.vec3i(k0, k1, k2) shift_system_idx[pos] = tid pos += 1 @wp.func def _decode_shift_index(local_idx: int, shift_range: wp.vec3i) -> wp.vec3i: """Decode a flat shift index into (kx, ky, kz) lattice shift vector. Reverses the enumeration order used by ``_expand_naive_shifts`` so that shift vectors can be computed on-the-fly from a thread index without materialising the full shifts array. Parameters ---------- local_idx : int Zero-based index into the per-system shift enumeration. shift_range : wp.vec3i Shift range in each dimension (from ``_compute_naive_num_shifts``). Returns ------- wp.vec3i The integer lattice shift vector ``(kx, ky, kz)``. """ k2_size = 2 * shift_range[2] + 1 k1_size = 2 * shift_range[1] + 1 group0_size = shift_range[1] * k2_size + shift_range[2] + 1 k0 = wp.int32(0) k1 = wp.int32(0) k2 = wp.int32(0) if local_idx < group0_size: if local_idx <= shift_range[2]: k2 = local_idx else: rem = local_idx - (shift_range[2] + 1) k1 = rem / k2_size + 1 k2 = rem % k2_size - shift_range[2] else: rem = local_idx - group0_size k0 = rem / (k1_size * k2_size) + 1 rem2 = rem % (k1_size * k2_size) k1 = rem2 / k2_size - shift_range[1] k2 = rem2 % k2_size - shift_range[2] return wp.vec3i(k0, k1, k2) @wp.func def _update_neighbor_matrix( i: int, j: int, neighbor_matrix: wp.array(dtype=wp.int32, ndim=2), num_neighbors: wp.array(dtype=wp.int32), max_neighbors: int, half_fill: bool, ): """ Update the neighbor matrix with the given atom indices. Parameters ---------- i: int The index of the source atom. j: int The index of the target atom. neighbor_matrix: wp.array(dtype=wp.int32, ndim=2) OUTPUT: The neighbor matrix to be updated. num_neighbors: wp.array(dtype=wp.int32) OUTPUT: The number of neighbors for each atom. max_neighbors: int The maximum number of neighbors for each atom. half_fill: bool If True, only fill half of the neighbor matrix. """ pos = wp.atomic_add(num_neighbors, i, 1) if pos < max_neighbors: neighbor_matrix[i, pos] = j if not half_fill and i < j: pos = wp.atomic_add(num_neighbors, j, 1) if pos < max_neighbors: neighbor_matrix[j, pos] = i @wp.func def _update_neighbor_matrix_pbc( i: int, j: int, neighbor_matrix: wp.array(dtype=wp.int32, ndim=2), neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2), num_neighbors: wp.array(dtype=wp.int32), unit_shift: wp.vec3i, max_neighbors: int, half_fill: bool, ): """ Update the neighbor matrix with the given atom indices and periodic shift. Parameters ---------- i: int The index of the source atom. j: int The index of the target atom. neighbor_matrix: wp.array(dtype=wp.int32, ndim=2) OUTPUT: The neighbor matrix to be updated. neighbor_matrix_shifts: wp.array(dtype=wp.vec3i, ndim=2) OUTPUT: The neighbor matrix shifts to be updated. num_neighbors: wp.array(dtype=wp.int32) OUTPUT: The number of neighbors for each atom. unit_shift: wp.vec3i The unit shift vector for the periodic boundary. max_neighbors: int The maximum number of neighbors for each atom. half_fill: bool If True, only fill half of the neighbor matrix. """ pos = wp.atomic_add(num_neighbors, i, 1) if pos < max_neighbors: neighbor_matrix[i, pos] = j neighbor_matrix_shifts[i, pos] = unit_shift if not half_fill: pos = wp.atomic_add(num_neighbors, j, 1) if pos < max_neighbors: neighbor_matrix[j, pos] = i neighbor_matrix_shifts[j, pos] = -unit_shift @wp.kernel(enable_backward=False) def _compute_naive_num_shifts( cell: wp.array(dtype=Any), cutoff: Any, pbc: wp.array2d(dtype=wp.bool), num_shifts: wp.array(dtype=int), shift_range: wp.array(dtype=wp.vec3i), ) -> None: """Compute periodic image shifts needed for neighbor searching. Calculates the number and range of periodic boundary shifts required to ensure all atoms within the cutoff distance are found, taking into account the geometry of the simulation cell and minimum image convention. Parameters ---------- cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33* Cell matrices defining lattice vectors in Cartesian coordinates. Each 3x3 matrix represents one system's periodic cell. cutoff : float Cutoff distance for neighbor searching in Cartesian units. Must be positive and typically less than half the minimum cell dimension. pbc : wp.array, shape (num_systems, 3), dtype=wp.bool Periodic boundary condition flags for each dimension. True enables periodicity in that direction. num_shifts : wp.array, shape (num_systems,), dtype=int OUTPUT: Total number of periodic shifts needed for each system. Updated with calculated shift counts. shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i OUTPUT: Maximum shift indices in each dimension for each system. Updated with calculated shift ranges. Returns ------- None This function modifies the input arrays in-place: - num_shifts : Updated with total shift counts per system - shift_range : Updated with shift ranges per dimension See Also -------- _expand_naive_shifts : Expands shift ranges into explicit shift vectors """ tid = wp.tid() _cell = cell[tid] _pbc = pbc[tid] _cell_inv = wp.transpose(wp.inverse(_cell)) _d_inv_0 = wp.length(_cell_inv[0]) if _pbc[0] else type(_cell_inv[0, 0])(0.0) _d_inv_1 = wp.length(_cell_inv[1]) if _pbc[1] else type(_cell_inv[1, 0])(0.0) _d_inv_2 = wp.length(_cell_inv[2]) if _pbc[2] else type(_cell_inv[2, 0])(0.0) _s = wp.vec3i( wp.int32(wp.ceil(_d_inv_0 * type(_d_inv_0)(cutoff))), wp.int32(wp.ceil(_d_inv_1 * type(_d_inv_1)(cutoff))), wp.int32(wp.ceil(_d_inv_2 * type(_d_inv_2)(cutoff))), ) k1 = 2 * _s[1] + 1 k2 = 2 * _s[2] + 1 shift_range[tid] = _s num_shifts[tid] = _s[0] * k1 * k2 + _s[1] * k2 + _s[2] + 1 ## Generate overloads T = [wp.float32, wp.float64, wp.float16] V = [wp.vec3f, wp.vec3d, wp.vec3h] M = [wp.mat33f, wp.mat33d, wp.mat33h] _compute_naive_num_shifts_overload = {} for t, v, m in zip(T, V, M): _compute_naive_num_shifts_overload[t] = wp.overload( _compute_naive_num_shifts, [ wp.array(dtype=m), t, wp.array2d(dtype=wp.bool), wp.array(dtype=int), wp.array(dtype=wp.vec3i), ], ) @wp.kernel(enable_backward=False) def _zero_array_kernel( array: wp.array(dtype=Any), ) -> None: """Zero an array in parallel. Parameters ---------- array : wp.array, dtype=Any OUTPUT: Array to be zeroed. Notes ----- - Thread launch: One thread per element (dim=array.shape[0]) - Modifies: array (sets all elements to 0) """ tid = wp.tid() array[tid] = type(array[tid])(0)
[docs] def zero_array( array: wp.array, device: str, ) -> None: """Core warp launcher for zeroing an array. Zeros all elements of an array in parallel using pure warp operations. Parameters ---------- array : wp.array, dtype=Any OUTPUT: Array to be zeroed. device : str Warp device string (e.g., 'cuda:0', 'cpu'). Notes ----- - This is a low-level warp interface. - Operates on arrays of any dtype. See Also -------- _zero_array_kernel : Kernel that performs the zeroing """ n = array.shape[0] wp.launch( kernel=_zero_array_kernel, dim=n, inputs=[array], device=device, )
@wp.kernel(enable_backward=False) def _selective_zero_num_neighbors( num_neighbors: wp.array(dtype=wp.int32), batch_idx: wp.array(dtype=wp.int32), rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Zero num_neighbors entries for atoms in systems that need rebuilding. Parameters ---------- num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Number of neighbors; zeroed for atoms in rebuilt systems. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool Per-system rebuild flags. True means this system needs rebuilding. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: num_neighbors (selective zero for rebuilt systems) """ tid = wp.tid() isys = batch_idx[tid] if rebuild_flags[isys]: num_neighbors[tid] = 0 def selective_zero_num_neighbors( num_neighbors: wp.array, batch_idx: wp.array, rebuild_flags: wp.array, device: str, ) -> None: """Core warp launcher for selectively zeroing num_neighbors. Zeros the num_neighbors count for atoms belonging to systems where rebuild_flags is True, preserving counts for non-rebuilt systems. Parameters ---------- num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Per-atom neighbor counts; selectively zeroed. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool Per-system flags indicating which systems need rebuilding. device : str Warp device string (e.g., 'cuda:0', 'cpu'). See Also -------- _selective_zero_num_neighbors : Kernel that performs the selective zeroing """ total_atoms = num_neighbors.shape[0] wp.launch( kernel=_selective_zero_num_neighbors, dim=total_atoms, inputs=[num_neighbors, batch_idx, rebuild_flags], device=device, ) @wp.kernel(enable_backward=False) def _selective_zero_num_neighbors_single( num_neighbors: wp.array(dtype=wp.int32), rebuild_flags: wp.array(dtype=wp.bool), ) -> None: """Zero num_neighbors entries when the single-system rebuild flag is set. Parameters ---------- num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Number of neighbors; zeroed for all atoms when rebuild_flags[0] is True. rebuild_flags : wp.array, shape (1,) or shape (), dtype=wp.bool Single-system flag. When True, all entries of num_neighbors are zeroed. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: num_neighbors (only when rebuild_flags[0] is True) """ tid = wp.tid() if rebuild_flags[0]: num_neighbors[tid] = 0 def selective_zero_num_neighbors_single( num_neighbors: wp.array, rebuild_flags: wp.array, device: str, ) -> None: """Core warp launcher for selectively zeroing num_neighbors for a single system. Zeros all num_neighbors entries when rebuild_flags[0] is True. When False the kernel returns immediately — no CPU-GPU synchronization occurs. Parameters ---------- num_neighbors : wp.array, shape (total_atoms,), dtype=wp.int32 OUTPUT: Per-atom neighbor counts; zeroed when rebuild is needed. rebuild_flags : wp.array, shape (1,) or shape (), dtype=wp.bool Single-system rebuild flag. device : str Warp device string (e.g., 'cuda:0', 'cpu'). See Also -------- _selective_zero_num_neighbors_single : Kernel that performs the selective zeroing selective_zero_num_neighbors : Batch variant using per-atom batch_idx """ total_atoms = num_neighbors.shape[0] wp.launch( kernel=_selective_zero_num_neighbors_single, dim=total_atoms, inputs=[num_neighbors, rebuild_flags], device=device, ) @wp.kernel(enable_backward=False) def _compute_inv_cells_kernel( cell: wp.array(dtype=Any), inv_cell: wp.array(dtype=Any), ) -> None: """Compute the inverse of each cell matrix. Parameters ---------- cell : wp.array, shape (num_systems,), dtype=wp.mat33* Input cell matrices. inv_cell : wp.array, shape (num_systems,), dtype=wp.mat33* OUTPUT: Inverse of each cell matrix. Notes ----- - Thread launch: One thread per system (dim=num_systems) """ tid = wp.tid() inv_cell[tid] = wp.inverse(cell[tid]) _compute_inv_cells_overload = {} for _t, _m in zip( [wp.float32, wp.float64, wp.float16], [wp.mat33f, wp.mat33d, wp.mat33h], ): _compute_inv_cells_overload[_t] = wp.overload( _compute_inv_cells_kernel, [wp.array(dtype=_m), wp.array(dtype=_m)], ) def compute_inv_cells( cell: wp.array, inv_cell: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for computing inverse cell matrices. Inverts each cell matrix in the batch using pure warp operations. Call this once before launching naive PBC neighbor-list kernels to avoid redundant per-thread inversions inside those kernels. Parameters ---------- cell : wp.array, shape (num_systems,), dtype=wp.mat33* Input cell matrices. inv_cell : wp.array, shape (num_systems,), dtype=wp.mat33* OUTPUT: Inverse of each cell matrix. Must be pre-allocated with the same shape and dtype as *cell*. wp_dtype : type Warp scalar dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., ``'cuda:0'``, ``'cpu'``). See Also -------- _compute_inv_cells_kernel : Underlying warp kernel """ num_systems = cell.shape[0] wp.launch( kernel=_compute_inv_cells_overload[wp_dtype], dim=num_systems, inputs=[cell, inv_cell], device=device, ) def compute_naive_num_shifts( cell: wp.array, cutoff: float, pbc: wp.array, num_shifts: wp.array, shift_range: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for computing periodic image shifts. Calculates the number and range of periodic boundary shifts required to ensure all atoms within the cutoff distance are found, using pure warp operations. Parameters ---------- cell : wp.array, shape (num_systems, 3, 3), dtype=wp.mat33* Cell matrices defining lattice vectors in Cartesian coordinates. Each 3x3 matrix represents one system's periodic cell. cutoff : float Cutoff distance for neighbor searching in Cartesian units. Must be positive and typically less than half the minimum cell dimension. pbc : wp.array, shape (num_systems, 3), dtype=wp.bool Periodic boundary condition flags for each dimension. True enables periodicity in that direction. num_shifts : wp.array, shape (num_systems,), dtype=wp.int32 OUTPUT: Total number of periodic shifts needed for each system. Updated with calculated shift counts. shift_range : wp.array, shape (num_systems, 3), dtype=wp.vec3i OUTPUT: Maximum shift indices in each dimension for each system. Updated with calculated shift ranges. wp_dtype : type Warp dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., 'cuda:0', 'cpu'). Notes ----- - This is a low-level warp interface. For framework bindings, use torch/jax wrappers. - Output arrays (num_shifts, shift_range) must be pre-allocated by caller. See Also -------- _compute_naive_num_shifts : Kernel that performs the computation _expand_naive_shifts : Expands shift ranges into explicit shift vectors """ num_systems = cell.shape[0] wp.launch( kernel=_compute_naive_num_shifts, dim=num_systems, inputs=[ cell, wp_dtype(cutoff), pbc, num_shifts, shift_range, ], device=device, )
[docs] def estimate_max_neighbors( cutoff: float, atomic_density: float = 0.2, safety_factor: float = 1.0, ) -> int: r"""Estimate maximum neighbors per atom based on volume calculations. Uses atomic density and cutoff volume to estimate a conservative upper bound on the number of neighbors any atom could have. This is a pure Python function with no framework dependencies. Parameters ---------- cutoff : float Maximum distance for considering atoms as neighbors. atomic_density : float, optional Atomic density in atoms per unit volume. Default is 0.2. safety_factor : float Safety factor to multiply the estimated number of neighbors. Default is 1.0. Returns ------- max_neighbors_estimate : int Conservative estimate of maximum neighbors per atom. Returns 0 for empty systems. Notes ----- The estimation uses the formula: .. math:: \text{neighbors} = \text{safety\_factor} \times \text{density} \times V_{\text{sphere}} where the cutoff sphere volume is: .. math:: V_{\text{sphere}} = \frac{4}{3}\pi r^3 The result is rounded up to the multiple of 16 for memory alignment. """ if cutoff <= 0: return 0 cutoff_sphere_volume = atomic_density * (4.0 / 3.0) * math.pi * (cutoff**3) # Estimate neighbors based on density and cutoff volume expected_neighbors = max(1, safety_factor * cutoff_sphere_volume) # Round up to multiple of 16 for memory alignment and safety max_neighbors_estimate = int(math.ceil(expected_neighbors / 16)) * 16 return max_neighbors_estimate
########################################################################################### ########################### Position Wrapping Kernels #################################### ########################################################################################### @wp.kernel(enable_backward=False) def _wrap_positions_single_kernel( positions: wp.array(dtype=Any), cell: wp.array(dtype=Any), inv_cell: wp.array(dtype=Any), positions_wrapped: wp.array(dtype=Any), per_atom_cell_offsets: wp.array(dtype=wp.vec3i), ) -> None: """Wrap positions into the primary cell for a single system. Computes fractional coordinates to determine integer cell offsets, then shifts each atom back into the primary cell. The integer offsets are stored so that corrected shift vectors can be recovered for the original (unwrapped) positions. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Atomic coordinates in Cartesian space. May be unwrapped. cell : wp.array, shape (1,), dtype=wp.mat33* Cell matrix defining lattice vectors in Cartesian coordinates. inv_cell : wp.array, shape (1,), dtype=wp.mat33* Pre-computed inverse of the cell matrix. positions_wrapped : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Wrapped positions in Cartesian space. per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i OUTPUT: Integer cell offsets for each atom (floor of fractional coordinates). Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: positions_wrapped, per_atom_cell_offsets """ i = wp.tid() _cell = cell[0] _inv_cell = inv_cell[0] _pos = positions[i] _frac = _pos * _inv_cell _int = wp.vec3i( wp.int32(wp.floor(_frac[0])), wp.int32(wp.floor(_frac[1])), wp.int32(wp.floor(_frac[2])), ) positions_wrapped[i] = _pos - type(_pos)(_int) * _cell per_atom_cell_offsets[i] = _int _wrap_positions_single_overload = {} for _t, _v, _m in zip( [wp.float32, wp.float64, wp.float16], [wp.vec3f, wp.vec3d, wp.vec3h], [wp.mat33f, wp.mat33d, wp.mat33h], ): _wrap_positions_single_overload[_t] = wp.overload( _wrap_positions_single_kernel, [ wp.array(dtype=_v), wp.array(dtype=_m), wp.array(dtype=_m), wp.array(dtype=_v), wp.array(dtype=wp.vec3i), ], ) def wrap_positions_single( positions: wp.array, cell: wp.array, inv_cell: wp.array, positions_wrapped: wp.array, per_atom_cell_offsets: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for wrapping positions into the primary cell (single system). Computes per-atom integer cell offsets and wrapped positions in a single GPU pass. Call this before naive PBC neighbor-list kernels to move the wrapping out of the hot ishift × iatom loop. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Atomic coordinates in Cartesian space. May be unwrapped. cell : wp.array, shape (1,), dtype=wp.mat33* Cell matrix defining lattice vectors. inv_cell : wp.array, shape (1,), dtype=wp.mat33* Pre-computed inverse cell matrix. Must be pre-allocated with the same shape and dtype as *cell*. positions_wrapped : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Wrapped positions. Must be pre-allocated with the same shape and dtype as *positions*. per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i OUTPUT: Integer cell offsets per atom. Must be pre-allocated. wp_dtype : type Warp scalar dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., ``'cuda:0'``, ``'cpu'``). See Also -------- _wrap_positions_single_kernel : Underlying warp kernel wrap_positions_batch : Batch variant for multiple systems """ total_atoms = positions.shape[0] wp.launch( kernel=_wrap_positions_single_overload[wp_dtype], dim=total_atoms, inputs=[positions, cell, inv_cell, positions_wrapped, per_atom_cell_offsets], device=device, ) @wp.kernel(enable_backward=False) def _wrap_positions_batch_kernel( positions: wp.array(dtype=Any), cell: wp.array(dtype=Any), inv_cell: wp.array(dtype=Any), batch_idx: wp.array(dtype=wp.int32), positions_wrapped: wp.array(dtype=Any), per_atom_cell_offsets: wp.array(dtype=wp.vec3i), ) -> None: """Wrap positions into the primary cell for a batch of systems. Each atom uses the cell matrix of its system (indexed via batch_idx). Computes fractional coordinates to determine integer cell offsets, then shifts each atom back into the primary cell. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Concatenated atomic coordinates for all systems. May be unwrapped. cell : wp.array, shape (num_systems,), dtype=wp.mat33* Cell matrices for each system. inv_cell : wp.array, shape (num_systems,), dtype=wp.mat33* Pre-computed inverse cell matrices. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. positions_wrapped : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Wrapped positions in Cartesian space. per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i OUTPUT: Integer cell offsets for each atom (floor of fractional coordinates). Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: positions_wrapped, per_atom_cell_offsets """ i = wp.tid() isys = batch_idx[i] _cell = cell[isys] _inv_cell = inv_cell[isys] _pos = positions[i] _frac = _pos * _inv_cell _int = wp.vec3i( wp.int32(wp.floor(_frac[0])), wp.int32(wp.floor(_frac[1])), wp.int32(wp.floor(_frac[2])), ) positions_wrapped[i] = _pos - type(_pos)(_int) * _cell per_atom_cell_offsets[i] = _int _wrap_positions_batch_overload = {} for _t, _v, _m in zip( [wp.float32, wp.float64, wp.float16], [wp.vec3f, wp.vec3d, wp.vec3h], [wp.mat33f, wp.mat33d, wp.mat33h], ): _wrap_positions_batch_overload[_t] = wp.overload( _wrap_positions_batch_kernel, [ wp.array(dtype=_v), wp.array(dtype=_m), wp.array(dtype=_m), wp.array(dtype=wp.int32), wp.array(dtype=_v), wp.array(dtype=wp.vec3i), ], ) def wrap_positions_batch( positions: wp.array, cell: wp.array, inv_cell: wp.array, batch_idx: wp.array, positions_wrapped: wp.array, per_atom_cell_offsets: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for wrapping positions into the primary cell (batch of systems). Each atom uses the cell matrix of its system (indexed via batch_idx). Computes per-atom integer cell offsets and wrapped positions in a single GPU pass. Call this before batch naive PBC neighbor-list kernels to move the wrapping out of the hot ishift × iatom loop. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Concatenated atomic coordinates for all systems. May be unwrapped. cell : wp.array, shape (num_systems,), dtype=wp.mat33* Cell matrices for each system. inv_cell : wp.array, shape (num_systems,), dtype=wp.mat33* Pre-computed inverse cell matrices. Must be pre-allocated with the same shape and dtype as *cell*. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. positions_wrapped : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Wrapped positions. Must be pre-allocated with the same shape and dtype as *positions*. per_atom_cell_offsets : wp.array, shape (total_atoms,), dtype=wp.vec3i OUTPUT: Integer cell offsets per atom. Must be pre-allocated. wp_dtype : type Warp scalar dtype (wp.float32, wp.float64, or wp.float16). device : str Warp device string (e.g., ``'cuda:0'``, ``'cpu'``). See Also -------- _wrap_positions_batch_kernel : Underlying warp kernel wrap_positions_single : Single-system variant """ total_atoms = positions.shape[0] wp.launch( kernel=_wrap_positions_batch_overload[wp_dtype], dim=total_atoms, inputs=[ positions, cell, inv_cell, batch_idx, positions_wrapped, per_atom_cell_offsets, ], device=device, ) ########################################################################################### ########################### Reference Position Update Kernels ############################ ########################################################################################### @wp.kernel(enable_backward=False) def _update_ref_positions_kernel( positions: wp.array(dtype=Any), rebuild_flag: wp.array(dtype=wp.bool), ref_positions: wp.array(dtype=Any), ) -> None: """Conditionally copy positions to ref_positions when rebuild_flag[0] is True. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic coordinates. rebuild_flag : wp.array, shape (1,), dtype=wp.bool Single-system rebuild flag. When True, ref_positions is updated. ref_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Reference positions updated when rebuild_flag[0] is True. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: ref_positions (only when rebuild_flag[0] is True) """ i = wp.tid() if rebuild_flag[0]: ref_positions[i] = positions[i] _update_ref_positions_overload = {} for _t, _v in zip([wp.float32, wp.float64], [wp.vec3f, wp.vec3d]): _update_ref_positions_overload[_t] = wp.overload( _update_ref_positions_kernel, [wp.array(dtype=_v), wp.array(dtype=wp.bool), wp.array(dtype=_v)], ) def update_ref_positions( positions: wp.array, rebuild_flag: wp.array, ref_positions: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for conditionally updating reference positions (single system). Copies current positions into reference positions only when rebuild_flag[0] is True. No CPU-GPU synchronization required. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic coordinates. rebuild_flag : wp.array, shape (1,), dtype=wp.bool Single-system rebuild flag. ref_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Reference positions to update selectively. wp_dtype : type Warp scalar dtype (wp.float32 or wp.float64). device : str Warp device string (e.g., 'cuda:0', 'cpu'). See Also -------- _update_ref_positions_kernel : Underlying warp kernel update_ref_positions_batch : Batch variant """ total_atoms = positions.shape[0] wp.launch( kernel=_update_ref_positions_overload[wp_dtype], dim=total_atoms, inputs=[positions, rebuild_flag, ref_positions], device=device, ) @wp.kernel(enable_backward=False) def _update_ref_positions_batch_kernel( positions: wp.array(dtype=Any), rebuild_flags: wp.array(dtype=wp.bool), batch_idx: wp.array(dtype=wp.int32), ref_positions: wp.array(dtype=Any), ) -> None: """Conditionally copy positions to ref_positions per-system (batch, no CPU sync). Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic coordinates for all systems. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool Per-system rebuild flags. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. ref_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Reference positions; updated for atoms in rebuilt systems. Notes ----- - Thread launch: One thread per atom (dim=total_atoms) - Modifies: ref_positions (only for atoms in rebuilt systems) """ i = wp.tid() if rebuild_flags[batch_idx[i]]: ref_positions[i] = positions[i] _update_ref_positions_batch_overload = {} for _t, _v in zip([wp.float32, wp.float64], [wp.vec3f, wp.vec3d]): _update_ref_positions_batch_overload[_t] = wp.overload( _update_ref_positions_batch_kernel, [ wp.array(dtype=_v), wp.array(dtype=wp.bool), wp.array(dtype=wp.int32), wp.array(dtype=_v), ], ) def update_ref_positions_batch( positions: wp.array, rebuild_flags: wp.array, batch_idx: wp.array, ref_positions: wp.array, wp_dtype: type, device: str, ) -> None: """Core warp launcher for conditionally updating reference positions (batch). Updates reference positions only for atoms in systems where rebuild_flags is True. No CPU-GPU synchronization required. Parameters ---------- positions : wp.array, shape (total_atoms,), dtype=wp.vec3* Current atomic coordinates for all systems. rebuild_flags : wp.array, shape (num_systems,), dtype=wp.bool Per-system rebuild flags. batch_idx : wp.array, shape (total_atoms,), dtype=wp.int32 System index for each atom. ref_positions : wp.array, shape (total_atoms,), dtype=wp.vec3* OUTPUT: Reference positions to update selectively. wp_dtype : type Warp scalar dtype (wp.float32 or wp.float64). device : str Warp device string (e.g., 'cuda:0', 'cpu'). See Also -------- _update_ref_positions_batch_kernel : Underlying warp kernel update_ref_positions : Single-system variant """ total_atoms = positions.shape[0] wp.launch( kernel=_update_ref_positions_batch_overload[wp_dtype], dim=total_atoms, inputs=[positions, rebuild_flags, batch_idx, ref_positions], device=device, )