Source code for nvalchemiops.jax.neighbors.rebuild_detection

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX bindings for rebuild detection.

This module provides JAX functions for detecting when cell lists and neighbor lists
need to be rebuilt.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel

from nvalchemiops.neighbors.rebuild_detection import (
    _check_atoms_changed_cells_overload,
    _check_atoms_moved_beyond_skin_overload,
    _check_atoms_moved_beyond_skin_pbc_overload,
    _check_batch_atoms_changed_cells_overload,
    _check_batch_atoms_moved_beyond_skin_overload,
    _check_batch_atoms_moved_beyond_skin_pbc_overload,
)

__all__ = [
    "cell_list_needs_rebuild",
    "neighbor_list_needs_rebuild",
    "check_cell_list_rebuild_needed",
    "check_neighbor_list_rebuild_needed",
    "batch_neighbor_list_needs_rebuild",
    "batch_cell_list_needs_rebuild",
    "check_batch_neighbor_list_rebuild_needed",
    "check_batch_cell_list_rebuild_needed",
]

# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================

# Cell list rebuild detection kernel wrappers
_jax_check_cells_f32 = jax_kernel(
    _check_atoms_changed_cells_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)
_jax_check_cells_f64 = jax_kernel(
    _check_atoms_changed_cells_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)

# Neighbor list rebuild detection kernel wrappers
_jax_check_skin_f32 = jax_kernel(
    _check_atoms_moved_beyond_skin_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)
_jax_check_skin_f64 = jax_kernel(
    _check_atoms_moved_beyond_skin_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)

# Batch neighbor list rebuild detection kernel wrappers
_jax_batch_check_skin_f32 = jax_kernel(
    _check_batch_atoms_moved_beyond_skin_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)
_jax_batch_check_skin_f64 = jax_kernel(
    _check_batch_atoms_moved_beyond_skin_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)

# MIC neighbor list rebuild detection kernel wrappers
_jax_check_skin_pbc_f32 = jax_kernel(
    _check_atoms_moved_beyond_skin_pbc_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)
_jax_check_skin_pbc_f64 = jax_kernel(
    _check_atoms_moved_beyond_skin_pbc_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flag"],
    enable_backward=False,
)

# MIC batch neighbor list rebuild detection kernel wrappers
_jax_batch_check_skin_pbc_f32 = jax_kernel(
    _check_batch_atoms_moved_beyond_skin_pbc_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)
_jax_batch_check_skin_pbc_f64 = jax_kernel(
    _check_batch_atoms_moved_beyond_skin_pbc_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)

# Batch cell list rebuild detection kernel wrappers
_jax_batch_check_cells_f32 = jax_kernel(
    _check_batch_atoms_changed_cells_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)
_jax_batch_check_cells_f64 = jax_kernel(
    _check_batch_atoms_changed_cells_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["rebuild_flags"],
    enable_backward=False,
)


# ==============================================================================
# Cell List Rebuild Detection
# ==============================================================================


[docs] def cell_list_needs_rebuild( current_positions: jax.Array, atom_to_cell_mapping: jax.Array, cells_per_dimension: jax.Array, cell: jax.Array, pbc: jax.Array, ) -> jax.Array: """Detect if spatial cell list requires rebuilding due to atomic motion. Parameters ---------- current_positions : jax.Array, shape (total_atoms, 3) Current atomic coordinates in Cartesian space. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom from the existing cell list. cells_per_dimension : jax.Array, shape (3,), dtype=int32 Number of spatial cells in x, y, z directions. cell : jax.Array, shape (1, 3, 3) Unit cell matrix for coordinate transformations. pbc : jax.Array, shape (3,), dtype=bool Periodic boundary condition flags for x, y, z directions. Returns ------- rebuild_needed : jax.Array, shape (1,), dtype=bool True if any atom has moved to a different cell requiring rebuild. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- nvalchemiops.neighbors.rebuild_detection.check_cell_list_rebuild : Core warp launcher check_cell_list_rebuild_needed : Convenience wrapper that returns Python bool """ total_atoms = current_positions.shape[0] if total_atoms == 0: return jnp.array([False], dtype=jnp.bool_) # Ensure cell dtype matches positions dtype so warp overload dispatch is consistent if cell.dtype != current_positions.dtype: cell = cell.astype(current_positions.dtype) # Ensure pbc is bool pbc = pbc.astype(jnp.bool_) # Squeeze cells_per_dimension to 1D if needed cells_1d = ( cells_per_dimension.squeeze() if cells_per_dimension.ndim == 2 else cells_per_dimension ) # Allocate output rebuild_flag = jnp.array([False], dtype=jnp.bool_) # Select kernel based on dtype if current_positions.dtype == jnp.float64: _jax_check = _jax_check_cells_f64 else: _jax_check = _jax_check_cells_f32 current_positions = current_positions.astype(jnp.float32) # Call kernel (rebuild_flag,) = _jax_check( current_positions, cell, atom_to_cell_mapping, cells_1d, pbc, rebuild_flag, launch_dims=(total_atoms,), ) return rebuild_flag
[docs] def neighbor_list_needs_rebuild( reference_positions: jax.Array, current_positions: jax.Array, skin_distance_threshold: float, cell: jax.Array | None = None, cell_inv: jax.Array | None = None, pbc: jax.Array | None = None, ) -> jax.Array: """Detect if neighbor list requires rebuilding due to excessive atomic motion. When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses minimum-image convention (MIC) so atoms crossing periodic boundaries are not spuriously flagged. Parameters ---------- reference_positions : jax.Array, shape (total_atoms, 3) Atomic positions when the neighbor list was last built. current_positions : jax.Array, shape (total_atoms, 3) Current atomic positions to compare against reference. skin_distance_threshold : float Maximum allowed displacement before neighbor list becomes invalid. cell : jax.Array or None, optional Unit cell matrix, shape (1, 3, 3). cell_inv : jax.Array or None, optional Inverse cell matrix, same shape as ``cell``. pbc : jax.Array or None, optional PBC flags, shape (3,), dtype=bool. Returns ------- rebuild_needed : jax.Array, shape (1,), dtype=bool True if any atom has moved beyond skin distance. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- nvalchemiops.neighbors.rebuild_detection.check_neighbor_list_rebuild : Core warp launcher check_neighbor_list_rebuild_needed : Convenience wrapper that returns Python bool """ if reference_positions.shape != current_positions.shape: return jnp.array([True], dtype=jnp.bool_) total_atoms = reference_positions.shape[0] if total_atoms == 0: return jnp.array([False], dtype=jnp.bool_) rebuild_flag = jnp.array([False], dtype=jnp.bool_) use_pbc = cell is not None and cell_inv is not None and pbc is not None if use_pbc: if cell.dtype != reference_positions.dtype: cell = cell.astype(reference_positions.dtype) if cell_inv.dtype != reference_positions.dtype: cell_inv = cell_inv.astype(reference_positions.dtype) pbc = pbc.astype(jnp.bool_) if reference_positions.dtype == jnp.float64: _jax_check = _jax_check_skin_pbc_f64 else: _jax_check = _jax_check_skin_pbc_f32 reference_positions = reference_positions.astype(jnp.float32) current_positions = current_positions.astype(jnp.float32) (rebuild_flag,) = _jax_check( reference_positions, current_positions, cell, cell_inv, pbc, float(skin_distance_threshold), rebuild_flag, launch_dims=(total_atoms,), ) else: if reference_positions.dtype == jnp.float64: _jax_check = _jax_check_skin_f64 else: _jax_check = _jax_check_skin_f32 reference_positions = reference_positions.astype(jnp.float32) current_positions = current_positions.astype(jnp.float32) (rebuild_flag,) = _jax_check( reference_positions, current_positions, float(skin_distance_threshold), rebuild_flag, launch_dims=(total_atoms,), ) return rebuild_flag
# ============================================================================== # High-level API Functions # ==============================================================================
[docs] def check_cell_list_rebuild_needed( current_positions: jax.Array, atom_to_cell_mapping: jax.Array, cells_per_dimension: jax.Array, cell: jax.Array, pbc: jax.Array, ) -> bool: """Determine if spatial cell list requires rebuilding based on atomic motion. This high-level convenience function determines if a spatial cell list needs to be reconstructed due to atomic movement. It uses GPU acceleration to efficiently detect when atoms have moved between spatial cells. Parameters ---------- current_positions : jax.Array, shape (total_atoms, 3) Current atomic coordinates to check against existing cell assignments. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates assigned to each atom from existing cell list. cells_per_dimension : jax.Array, shape (3,), dtype=int32 Number of spatial cells in x, y, z directions from existing cell list. cell : jax.Array, shape (1, 3, 3) Current unit cell matrix for coordinate transformations. pbc : jax.Array, shape (3,), dtype=bool Current periodic boundary condition flags for x, y, z directions. Returns ------- needs_rebuild : bool True if any atom has moved to a different cell requiring cell list rebuild. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- cell_list_needs_rebuild : Returns jax.Array instead of bool """ rebuild_tensor = cell_list_needs_rebuild( current_positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc, ) return bool(rebuild_tensor[0])
[docs] def check_neighbor_list_rebuild_needed( reference_positions: jax.Array, current_positions: jax.Array, skin_distance_threshold: float, cell: jax.Array | None = None, cell_inv: jax.Array | None = None, pbc: jax.Array | None = None, ) -> bool: """Determine if neighbor list requires rebuilding based on atomic motion. When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses MIC displacement so periodic boundary crossings are handled correctly. Parameters ---------- reference_positions : jax.Array, shape (total_atoms, 3) Atomic coordinates when the neighbor list was last constructed. current_positions : jax.Array, shape (total_atoms, 3) Current atomic coordinates to compare against reference positions. skin_distance_threshold : float Maximum allowed atomic displacement before neighbor list becomes invalid. cell : jax.Array or None, optional Unit cell matrix, shape (1, 3, 3). cell_inv : jax.Array or None, optional Inverse cell matrix, same shape as ``cell``. pbc : jax.Array or None, optional PBC flags, shape (3,), dtype=bool. Returns ------- needs_rebuild : bool True if any atom has moved beyond skin distance requiring rebuild. See Also -------- neighbor_list_needs_rebuild : Returns jax.Array instead of bool """ rebuild_tensor = neighbor_list_needs_rebuild( reference_positions, current_positions, skin_distance_threshold, cell, cell_inv, pbc, ) return bool(rebuild_tensor[0])
# ============================================================================== # Batch Rebuild Detection # ============================================================================== def batch_neighbor_list_needs_rebuild( reference_positions: jax.Array, current_positions: jax.Array, batch_idx: jax.Array, skin_distance_threshold: float, num_systems: int, cell: jax.Array | None = None, cell_inv: jax.Array | None = None, pbc: jax.Array | None = None, ) -> jax.Array: """Detect per-system if neighbor lists require rebuilding due to atomic motion. When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses MIC displacement so periodic boundary crossings are handled correctly. Parameters ---------- reference_positions : jax.Array, shape (total_atoms, 3) Atomic positions when each system's neighbor list was last built. current_positions : jax.Array, shape (total_atoms, 3) Current atomic positions to compare against reference. batch_idx : jax.Array, shape (total_atoms,), dtype=int32 System index for each atom. skin_distance_threshold : float Maximum allowed displacement before neighbor list becomes invalid. num_systems : int Number of systems in the batch. cell : jax.Array or None, optional Per-system cell matrices, shape (num_systems, 3, 3). cell_inv : jax.Array or None, optional Inverse cell matrices, same shape as ``cell``. pbc : jax.Array or None, optional PBC flags, shape (num_systems, 3), dtype=bool. Returns ------- rebuild_flags : jax.Array, shape (num_systems,), dtype=bool Per-system flags; True if any atom in that system moved beyond skin distance. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- neighbor_list_needs_rebuild : Single-system version check_batch_neighbor_list_rebuild_needed : Convenience wrapper """ total_atoms = reference_positions.shape[0] if total_atoms == 0: return jnp.zeros(num_systems, dtype=jnp.bool_) rebuild_flags = jnp.zeros(num_systems, dtype=jnp.bool_) use_pbc = cell is not None and cell_inv is not None and pbc is not None if use_pbc: if cell.dtype != reference_positions.dtype: cell = cell.astype(reference_positions.dtype) if cell_inv.dtype != reference_positions.dtype: cell_inv = cell_inv.astype(reference_positions.dtype) pbc = pbc.astype(jnp.bool_) if reference_positions.dtype == jnp.float64: _jax_check = _jax_batch_check_skin_pbc_f64 else: _jax_check = _jax_batch_check_skin_pbc_f32 reference_positions = reference_positions.astype(jnp.float32) current_positions = current_positions.astype(jnp.float32) (rebuild_flags,) = _jax_check( reference_positions, current_positions, batch_idx, cell, cell_inv, pbc, float(skin_distance_threshold), rebuild_flags, launch_dims=(total_atoms,), ) else: if reference_positions.dtype == jnp.float64: _jax_check = _jax_batch_check_skin_f64 else: _jax_check = _jax_batch_check_skin_f32 reference_positions = reference_positions.astype(jnp.float32) current_positions = current_positions.astype(jnp.float32) (rebuild_flags,) = _jax_check( reference_positions, current_positions, batch_idx, float(skin_distance_threshold), rebuild_flags, launch_dims=(total_atoms,), ) return rebuild_flags def batch_cell_list_needs_rebuild( current_positions: jax.Array, atom_to_cell_mapping: jax.Array, batch_idx: jax.Array, cells_per_dimension: jax.Array, cell: jax.Array, pbc: jax.Array, ) -> jax.Array: """Detect per-system if cell lists require rebuilding due to atomic motion. Parameters ---------- current_positions : jax.Array, shape (total_atoms, 3) Current atomic coordinates in Cartesian space. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom from the existing cell lists. batch_idx : jax.Array, shape (total_atoms,), dtype=int32 System index for each atom. cells_per_dimension : jax.Array, shape (num_systems, 3), dtype=int32 Number of spatial cells in x, y, z directions per system. cell : jax.Array, shape (num_systems, 3, 3) Per-system unit cell matrices for coordinate transformations. pbc : jax.Array, shape (num_systems, 3), dtype=bool Per-system periodic boundary condition flags. Returns ------- rebuild_flags : jax.Array, shape (num_systems,), dtype=bool Per-system flags; True if any atom in that system changed cells. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- cell_list_needs_rebuild : Single-system version check_batch_cell_list_rebuild_needed : Convenience wrapper returning list[bool] """ total_atoms = current_positions.shape[0] num_systems = cell.shape[0] if total_atoms == 0: return jnp.zeros(num_systems, dtype=jnp.bool_) if cell.dtype != current_positions.dtype: cell = cell.astype(current_positions.dtype) pbc = pbc.astype(jnp.bool_) rebuild_flags = jnp.zeros(num_systems, dtype=jnp.bool_) if current_positions.dtype == jnp.float64: _jax_check = _jax_batch_check_cells_f64 else: _jax_check = _jax_batch_check_cells_f32 current_positions = current_positions.astype(jnp.float32) (rebuild_flags,) = _jax_check( current_positions, cell, atom_to_cell_mapping, batch_idx, cells_per_dimension, pbc, rebuild_flags, launch_dims=(total_atoms,), ) return rebuild_flags # ============================================================================== # High-level Batch API Functions # ============================================================================== def check_batch_neighbor_list_rebuild_needed( reference_positions: jax.Array, current_positions: jax.Array, batch_idx: jax.Array, skin_distance_threshold: float, num_systems: int, cell: jax.Array | None = None, cell_inv: jax.Array | None = None, pbc: jax.Array | None = None, ) -> list[bool]: """Determine per-system if neighbor lists require rebuilding. When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses MIC displacement so periodic boundary crossings are handled correctly. Parameters ---------- reference_positions : jax.Array, shape (total_atoms, 3) Atomic positions when each system's neighbor list was last built. current_positions : jax.Array, shape (total_atoms, 3) Current atomic positions to compare against reference. batch_idx : jax.Array, shape (total_atoms,), dtype=int32 System index for each atom. skin_distance_threshold : float Maximum allowed displacement before neighbor list becomes invalid. num_systems : int Number of systems in the batch. cell : jax.Array or None, optional Per-system cell matrices, shape (num_systems, 3, 3). cell_inv : jax.Array or None, optional Inverse cell matrices, same shape as ``cell``. pbc : jax.Array or None, optional PBC flags, shape (num_systems, 3), dtype=bool. Returns ------- needs_rebuild : list[bool] Per-system flags; True if neighbor list for that system needs rebuilding. See Also -------- batch_neighbor_list_needs_rebuild : Returns jax.Array instead of list[bool] """ rebuild_flags = batch_neighbor_list_needs_rebuild( reference_positions, current_positions, batch_idx, skin_distance_threshold, num_systems, cell, cell_inv, pbc, ) return [bool(flag) for flag in rebuild_flags] def check_batch_cell_list_rebuild_needed( current_positions: jax.Array, atom_to_cell_mapping: jax.Array, batch_idx: jax.Array, cells_per_dimension: jax.Array, cell: jax.Array, pbc: jax.Array, ) -> list[bool]: """Determine per-system if cell lists require rebuilding based on atomic motion. Parameters ---------- current_positions : jax.Array, shape (total_atoms, 3) Current atomic coordinates in Cartesian space. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom from the existing cell lists. batch_idx : jax.Array, shape (total_atoms,), dtype=int32 System index for each atom. cells_per_dimension : jax.Array, shape (num_systems, 3), dtype=int32 Number of spatial cells in x, y, z directions per system. cell : jax.Array, shape (num_systems, 3, 3) Per-system unit cell matrices for coordinate transformations. pbc : jax.Array, shape (num_systems, 3), dtype=bool Per-system periodic boundary condition flags. Returns ------- needs_rebuild : list[bool] Per-system flags; True if cell list for that system needs rebuilding. Notes ----- This function is not differentiable and should not be used in JAX transformations that require gradients. See Also -------- batch_cell_list_needs_rebuild : Returns jax.Array instead of list[bool] """ rebuild_flags = batch_cell_list_needs_rebuild( current_positions, atom_to_cell_mapping, batch_idx, cells_per_dimension, cell, pbc, ) return [bool(flag) for flag in rebuild_flags]