# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch bindings for rebuild detection.
This module provides PyTorch custom operators for detecting when cell lists and
neighbor lists need to be rebuilt.
"""
from __future__ import annotations
import torch
import warp as wp
from nvalchemiops.neighbors.rebuild_detection import (
check_batch_cell_list_rebuild,
check_batch_neighbor_list_rebuild,
check_cell_list_rebuild,
check_neighbor_list_rebuild,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
__all__ = [
"cell_list_needs_rebuild",
"neighbor_list_needs_rebuild",
"check_cell_list_rebuild_needed",
"check_neighbor_list_rebuild_needed",
"batch_neighbor_list_needs_rebuild",
"batch_cell_list_needs_rebuild",
]
###########################################################################################
########################### Cell List Rebuild Detection ###################################
###########################################################################################
@torch.library.custom_op("nvalchemiops::_cell_list_needs_rebuild", mutates_args=())
def _cell_list_needs_rebuild(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Detect if spatial cell list requires rebuilding due to atomic motion.
Parameters
----------
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic coordinates in Cartesian space.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom from the existing cell list.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32
Number of spatial cells in x, y, z directions.
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix for coordinate transformations.
pbc : torch.Tensor, shape (3,), dtype=bool
Periodic boundary condition flags for x, y, z directions.
Returns
-------
rebuild_needed : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved to a different cell requiring rebuild.
See Also
--------
nvalchemiops.neighborlist.rebuild_detection.wp_check_cell_list_rebuild : Core warp launcher
cell_list_needs_rebuild : High-level wrapper function
"""
total_atoms = current_positions.shape[0]
device = current_positions.device
pbc = pbc.squeeze(0)
if total_atoms == 0:
return torch.tensor([False], device=device, dtype=torch.bool)
# Get warp data types for the input tensor precision
wp_dtype = get_wp_dtype(current_positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(current_positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(current_positions.dtype)
# Convert PyTorch tensors to warp arrays
wp_current_positions = wp.from_torch(
current_positions, dtype=wp_vec_dtype, return_ctype=True
)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
wp_atom_to_cell_mapping = wp.from_torch(
atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True
)
wp_cells_per_dimension = wp.from_torch(
cells_per_dimension, dtype=wp.int32, return_ctype=True
)
# Initialize rebuild flag (False = no rebuild needed)
rebuild_needed = torch.tensor([False], device=device, dtype=torch.bool)
wp_rebuild_flag = wp.from_torch(rebuild_needed, dtype=wp.bool, return_ctype=True)
# Call core warp launcher
check_cell_list_rebuild(
current_positions=wp_current_positions,
atom_to_cell_mapping=wp_atom_to_cell_mapping,
cells_per_dimension=wp_cells_per_dimension,
cell=wp_cell,
pbc=wp_pbc,
rebuild_flag=wp_rebuild_flag,
wp_dtype=wp_dtype,
device=str(device),
)
return rebuild_needed
@_cell_list_needs_rebuild.register_fake
def _cell_list_needs_rebuild_fake(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Fake implementation for torch.compile compatibility.
Returns a conservative default (no rebuild needed) for compilation tracing.
The actual implementation will be called during runtime execution.
"""
return torch.tensor([False], device=current_positions.device, dtype=torch.bool)
[docs]
def cell_list_needs_rebuild(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Detect if spatial cell list requires rebuilding due to atomic motion.
This torch.compile-compatible custom operator efficiently determines if any atoms
have moved between spatial cells since the last cell list construction. Uses GPU
acceleration with early termination for optimal performance.
Parameters
----------
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic coordinates in Cartesian space.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom from the existing cell list.
Typically obtained from build_cell_list.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32
Number of spatial cells in x, y, z directions.
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix for coordinate transformations.
pbc : torch.Tensor, shape (3,), dtype=bool
Periodic boundary condition flags for x, y, z directions.
Returns
-------
rebuild_needed : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved to a different cell requiring rebuild.
Notes
-----
- Currently only supports single system.
- torch.compile compatible custom operation
- Uses GPU kernels for parallel cell assignment computation
- Early termination optimization stops computation once rebuild is detected
- Handles periodic boundary conditions correctly
- Returns tensor (not Python bool) for compilation compatibility
See Also
--------
nvalchemiops.neighborlist.rebuild_detection.wp_check_cell_list_rebuild : Core warp launcher
check_cell_list_rebuild_needed : Convenience wrapper that returns Python bool
"""
return _cell_list_needs_rebuild(
current_positions,
atom_to_cell_mapping,
cells_per_dimension,
cell,
pbc,
)
###########################################################################################
########################### Neighbor List Rebuild Detection ##############################
###########################################################################################
@torch.library.custom_op(
"nvalchemiops::_neighbor_list_needs_rebuild", mutates_args=("reference_positions",)
)
def _neighbor_list_needs_rebuild(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Detect if neighbor list requires rebuilding due to excessive atomic motion.
Parameters
----------
reference_positions : torch.Tensor, shape (total_atoms, 3)
Atomic positions when the neighbor list was last built.
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic positions to compare against reference.
skin_distance_threshold : float
Maximum allowed displacement before neighbor list becomes invalid.
update_reference_positions : bool, default=False
If True, overwrite ``reference_positions`` with ``current_positions``
after a rebuild is detected. Uses a separate deterministic kernel launch.
cell : torch.Tensor or None, optional
Unit cell matrix, shape (1, 3, 3). Required together with
``cell_inv`` and ``pbc`` to enable MIC displacement.
cell_inv : torch.Tensor or None, optional
Inverse cell matrix, same shape as ``cell``.
pbc : torch.Tensor or None, optional
PBC flags, shape (1, 3) or (3,), dtype=bool.
Returns
-------
rebuild_needed : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved beyond skin distance.
See Also
--------
neighbor_list_needs_rebuild : High-level wrapper function
"""
if reference_positions.shape != current_positions.shape:
return torch.tensor([True], device=current_positions.device, dtype=torch.bool)
total_atoms = reference_positions.shape[0]
device = reference_positions.device
if total_atoms == 0:
return torch.tensor([False], device=device, dtype=torch.bool)
wp_dtype = get_wp_dtype(reference_positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(reference_positions.dtype)
wp_reference_positions = wp.from_torch(
reference_positions, dtype=wp_vec_dtype, return_ctype=True
)
wp_current_positions = wp.from_torch(
current_positions, dtype=wp_vec_dtype, return_ctype=True
)
rebuild_needed = torch.tensor([False], device=device, dtype=torch.bool)
wp_rebuild_flag = wp.from_torch(rebuild_needed, dtype=wp.bool, return_ctype=True)
wp_cell = wp_cell_inv = wp_pbc = None
if cell is not None and cell_inv is not None and pbc is not None:
wp_mat_dtype = get_wp_mat_dtype(reference_positions.dtype)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_cell_inv = wp.from_torch(cell_inv, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc.squeeze(0), dtype=wp.bool, return_ctype=True)
check_neighbor_list_rebuild(
reference_positions=wp_reference_positions,
current_positions=wp_current_positions,
skin_distance_threshold=skin_distance_threshold,
rebuild_flag=wp_rebuild_flag,
wp_dtype=wp_dtype,
device=str(device),
update_reference_positions=update_reference_positions,
cell=wp_cell,
cell_inv=wp_cell_inv,
pbc=wp_pbc,
)
return rebuild_needed
@_neighbor_list_needs_rebuild.register_fake
def _neighbor_list_needs_rebuild_fake(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fake implementation for torch.compile compatibility."""
return torch.tensor([False], device=current_positions.device, dtype=torch.bool)
[docs]
def neighbor_list_needs_rebuild(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Detect if neighbor list requires rebuilding due to excessive atomic motion.
This torch.compile-compatible custom operator efficiently determines if any atoms
have moved beyond the skin distance since the neighbor list was last built. Uses
GPU acceleration with early termination for optimal performance in MD simulations.
When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses minimum-image
convention (MIC) so atoms crossing periodic boundaries are not spuriously
flagged.
Parameters
----------
reference_positions : torch.Tensor, shape (total_atoms, 3)
Atomic coordinates when the neighbor list was last constructed.
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic coordinates to compare against reference.
skin_distance_threshold : float
Maximum allowed atomic displacement before neighbor list becomes invalid.
Typically set to (cutoff_radius - cutoff) / 2 for safety.
update_reference_positions : bool, default=False
If True, overwrite ``reference_positions`` with ``current_positions``
after a rebuild is detected. Uses a separate deterministic kernel launch
so all atoms are guaranteed to be updated.
cell : torch.Tensor or None, optional
Unit cell matrix, shape (1, 3, 3). Required together with
``cell_inv`` and ``pbc`` to enable MIC displacement.
cell_inv : torch.Tensor or None, optional
Inverse cell matrix, same shape as ``cell``.
pbc : torch.Tensor or None, optional
PBC flags, shape (1, 3) or (3,), dtype=bool.
Returns
-------
rebuild_needed : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved beyond skin distance requiring rebuild.
Notes
-----
- Currently only supports single system.
- torch.compile compatible custom operation
- Uses GPU kernels for parallel displacement computation
- Early termination optimization stops computation once rebuild is detected
- When cell/cell_inv/pbc are supplied, uses MIC displacement; otherwise
Euclidean distance.
- Returns tensor (not Python bool) for compilation compatibility
See Also
--------
check_neighbor_list_rebuild_needed : Convenience wrapper that returns Python bool
"""
return _neighbor_list_needs_rebuild(
reference_positions,
current_positions,
skin_distance_threshold,
update_reference_positions,
cell,
cell_inv,
pbc,
)
###########################################################################################
########################### High-level API Functions ######################################
###########################################################################################
[docs]
def check_cell_list_rebuild_needed(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> bool:
"""Determine if spatial cell list requires rebuilding based on atomic motion.
This high-level convenience function determines if a spatial cell list needs to be
reconstructed due to atomic movement. It uses GPU acceleration to efficiently detect
when atoms have moved between spatial cells.
The function checks if any atoms have moved to different spatial cells since the
cell list was last built by comparing current positions against the stored cell
assignments from the existing cell list.
This function is not torch.compile compatible (use cell_list_needs_rebuild for that).
Parameters
----------
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic coordinates to check against existing cell assignments.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates assigned to each atom from existing cell list.
This is the key tensor used for comparison with current positions.
Typically obtained from build_cell_list.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32
Number of spatial cells in x, y, z directions from existing cell list.
cell : torch.Tensor, shape (1, 3, 3)
Current unit cell matrix for coordinate transformations.
pbc : torch.Tensor, shape (3,), dtype=bool
Current periodic boundary condition flags for x, y, z directions.
Returns
-------
needs_rebuild : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved to a different cell requiring cell list rebuild.
Notes
-----
- Currently only supports single system.
- Uses GPU kernels for efficient parallel computation
- Primary check: atomic motion between spatial cells
- Early termination optimization for performance
- Returns Python bool (calls .item() on tensor result)
See Also
--------
cell_list_needs_rebuild : Returns tensor instead of bool (torch.compile compatible)
nvalchemiops.neighborlist.rebuild_detection.wp_check_cell_list_rebuild : Core warp launcher
"""
rebuild_tensor = cell_list_needs_rebuild(
current_positions,
atom_to_cell_mapping,
cells_per_dimension,
cell,
pbc,
)
return rebuild_tensor
[docs]
def check_neighbor_list_rebuild_needed(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Determine if neighbor list requires rebuilding based on atomic motion.
When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses MIC
displacement so periodic boundary crossings are handled correctly.
This function is not torch.compile compatible.
Parameters
----------
reference_positions : torch.Tensor, shape (total_atoms, 3)
Atomic coordinates when the neighbor list was last constructed.
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic coordinates to compare against reference positions.
skin_distance_threshold : float
Maximum allowed atomic displacement before neighbor list becomes invalid.
update_reference_positions : bool, default=False
If True, overwrite ``reference_positions`` after a rebuild is detected.
cell : torch.Tensor or None, optional
Unit cell matrix, shape (1, 3, 3).
cell_inv : torch.Tensor or None, optional
Inverse cell matrix, same shape as ``cell``.
pbc : torch.Tensor or None, optional
PBC flags, shape (1, 3) or (3,), dtype=bool.
Returns
-------
needs_rebuild : torch.Tensor, shape (1,), dtype=bool
True if any atom has moved beyond skin distance requiring rebuild.
See Also
--------
neighbor_list_needs_rebuild : Returns tensor (torch.compile compatible)
"""
rebuild_tensor = neighbor_list_needs_rebuild(
reference_positions,
current_positions,
skin_distance_threshold,
update_reference_positions,
cell,
cell_inv,
pbc,
)
return rebuild_tensor
###########################################################################################
########################### Batch Neighbor List Rebuild Detection ########################
###########################################################################################
@torch.library.custom_op(
"nvalchemiops::_batch_neighbor_list_needs_rebuild",
mutates_args=("reference_positions",),
)
def _batch_neighbor_list_needs_rebuild(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
batch_idx: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Detect per-system if neighbor lists require rebuilding due to atomic motion.
Parameters
----------
reference_positions : torch.Tensor, shape (total_atoms, 3)
Atomic positions when each system's neighbor list was last built.
current_positions : torch.Tensor, shape (total_atoms, 3)
Current atomic positions to compare against reference.
batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32
System index for each atom.
skin_distance_threshold : float
Maximum allowed displacement before neighbor list becomes invalid.
update_reference_positions : bool, default=False
If True, overwrite ``reference_positions`` with ``current_positions``
after a rebuild is detected. Uses a separate deterministic kernel launch.
cell : torch.Tensor or None, optional
Per-system cell matrices, shape (num_systems, 3, 3).
cell_inv : torch.Tensor or None, optional
Inverse cell matrices, same shape as ``cell``.
pbc : torch.Tensor or None, optional
PBC flags, shape (num_systems, 3), dtype=bool.
Returns
-------
rebuild_flags : torch.Tensor, shape (num_systems,), dtype=bool
Per-system flags: True if any atom in that system moved beyond skin distance.
See Also
--------
batch_neighbor_list_needs_rebuild : High-level wrapper function
"""
if reference_positions.shape != current_positions.shape:
num_systems = int(batch_idx.max().item()) + 1 if batch_idx.numel() > 0 else 1
return torch.ones(
num_systems, device=current_positions.device, dtype=torch.bool
)
total_atoms = reference_positions.shape[0]
device = reference_positions.device
num_systems = int(batch_idx.max().item()) + 1 if total_atoms > 0 else 1
rebuild_flags = torch.zeros(num_systems, device=device, dtype=torch.bool)
if total_atoms == 0:
return rebuild_flags
wp_dtype = get_wp_dtype(reference_positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(reference_positions.dtype)
wp_reference = wp.from_torch(
reference_positions, dtype=wp_vec_dtype, return_ctype=True
)
wp_current = wp.from_torch(current_positions, dtype=wp_vec_dtype, return_ctype=True)
wp_batch_idx = wp.from_torch(
batch_idx.to(dtype=torch.int32), dtype=wp.int32, return_ctype=True
)
wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)
wp_cell = wp_cell_inv = wp_pbc = None
if cell is not None and cell_inv is not None and pbc is not None:
wp_mat_dtype = get_wp_mat_dtype(reference_positions.dtype)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_cell_inv = wp.from_torch(cell_inv, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
check_batch_neighbor_list_rebuild(
reference_positions=wp_reference,
current_positions=wp_current,
batch_idx=wp_batch_idx,
skin_distance_threshold=skin_distance_threshold,
rebuild_flags=wp_rebuild_flags,
wp_dtype=wp_dtype,
device=str(device),
update_reference_positions=update_reference_positions,
cell=wp_cell,
cell_inv=wp_cell_inv,
pbc=wp_pbc,
)
return rebuild_flags
@_batch_neighbor_list_needs_rebuild.register_fake
def _batch_neighbor_list_needs_rebuild_fake(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
batch_idx: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fake implementation for torch.compile compatibility."""
num_systems = batch_idx.max() + 1 if batch_idx.numel() > 0 else 1
return torch.zeros(num_systems, device=reference_positions.device, dtype=torch.bool)
def batch_neighbor_list_needs_rebuild(
reference_positions: torch.Tensor,
current_positions: torch.Tensor,
batch_idx: torch.Tensor,
skin_distance_threshold: float,
update_reference_positions: bool = False,
cell: torch.Tensor | None = None,
cell_inv: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Detect per-system if neighbor lists require rebuilding due to atomic motion.
This torch.compile-compatible custom operator efficiently determines which systems
in a batch need their neighbor list rebuilt based on atomic displacements. Uses
GPU-side flagging with no CPU-GPU synchronization.
When ``cell``, ``cell_inv`` and ``pbc`` are all provided, uses MIC
displacement so periodic boundary crossings are handled correctly.
Parameters
----------
reference_positions : torch.Tensor, shape (total_atoms, 3)
Atomic positions when each system's neighbor list was last built.
current_positions : torch.Tensor, shape (total_atoms, 3)
Current Cartesian coordinates to compare against reference.
batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32
System index for each atom.
skin_distance_threshold : float
Maximum allowed atomic displacement before neighbor list becomes invalid.
update_reference_positions : bool, default=False
If True, overwrite ``reference_positions`` with ``current_positions``
after a rebuild is detected. Uses a separate deterministic kernel launch
so all atoms in rebuilt systems are guaranteed to be updated.
cell : torch.Tensor or None, optional
Per-system cell matrices, shape (num_systems, 3, 3).
cell_inv : torch.Tensor or None, optional
Inverse cell matrices, same shape as ``cell``.
pbc : torch.Tensor or None, optional
PBC flags, shape (num_systems, 3), dtype=bool.
Returns
-------
rebuild_flags : torch.Tensor, shape (num_systems,), dtype=bool
Per-system flags: True if any atom in that system moved beyond the skin
distance.
Notes
-----
- torch.compile compatible custom operation
- No CPU-GPU synchronization required; all flag writes happen on GPU
- ``num_systems`` is inferred as ``batch_idx.max() + 1``
See Also
--------
neighbor_list_needs_rebuild : Single-system version
"""
return _batch_neighbor_list_needs_rebuild(
reference_positions,
current_positions,
batch_idx,
skin_distance_threshold,
update_reference_positions,
cell,
cell_inv,
pbc,
)
###########################################################################################
########################### Batch Cell List Rebuild Detection ############################
###########################################################################################
@torch.library.custom_op(
"nvalchemiops::_batch_cell_list_needs_rebuild", mutates_args=()
)
def _batch_cell_list_needs_rebuild(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
batch_idx: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Detect per-system if spatial cell lists require rebuilding due to atomic motion.
Parameters
----------
current_positions : torch.Tensor, shape (total_atoms, 3)
Current Cartesian coordinates.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom from the existing cell lists.
batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32
System index for each atom.
cells_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=int32
Number of spatial cells in x, y, z directions for each system.
cell : torch.Tensor, shape (num_systems, 3, 3)
Per-system unit cell matrices for coordinate transformations.
pbc : torch.Tensor, shape (num_systems, 3), dtype=bool
Per-system periodic boundary condition flags.
Returns
-------
rebuild_flags : torch.Tensor, shape (num_systems,), dtype=bool
Per-system flags: True if any atom in that system changed cells.
See Also
--------
batch_cell_list_needs_rebuild : High-level wrapper function
"""
total_atoms = current_positions.shape[0]
num_systems = cell.shape[0]
device = current_positions.device
rebuild_flags = torch.zeros(num_systems, device=device, dtype=torch.bool)
if total_atoms == 0:
return rebuild_flags
wp_dtype = get_wp_dtype(current_positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(current_positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(current_positions.dtype)
wp_current = wp.from_torch(current_positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_atom_to_cell_mapping = wp.from_torch(
atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True
)
wp_batch_idx = wp.from_torch(
batch_idx.to(dtype=torch.int32), dtype=wp.int32, return_ctype=True
)
wp_cells_per_dimension = wp.from_torch(
cells_per_dimension, dtype=wp.vec3i, return_ctype=True
)
# pbc shape (num_systems, 3) → 2D warp array of bool
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)
check_batch_cell_list_rebuild(
current_positions=wp_current,
atom_to_cell_mapping=wp_atom_to_cell_mapping,
batch_idx=wp_batch_idx,
cells_per_dimension=wp_cells_per_dimension,
cell=wp_cell,
pbc=wp_pbc,
rebuild_flags=wp_rebuild_flags,
wp_dtype=wp_dtype,
device=str(device),
)
return rebuild_flags
@_batch_cell_list_needs_rebuild.register_fake
def _batch_cell_list_needs_rebuild_fake(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
batch_idx: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Fake implementation for torch.compile compatibility."""
num_systems = cell.shape[0]
return torch.zeros(num_systems, device=current_positions.device, dtype=torch.bool)
def batch_cell_list_needs_rebuild(
current_positions: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
batch_idx: torch.Tensor,
cells_per_dimension: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
) -> torch.Tensor:
"""Detect per-system if spatial cell lists require rebuilding due to atomic motion.
This torch.compile-compatible custom operator efficiently determines which systems
in a batch need their cell list rebuilt by checking if any atoms have moved between
spatial cells. Uses GPU-side flagging with no CPU-GPU synchronization.
Parameters
----------
current_positions : torch.Tensor, shape (total_atoms, 3)
Current Cartesian coordinates.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom from the existing cell lists.
Typically obtained from batch_build_cell_list.
batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32
System index for each atom.
cells_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=int32
Number of spatial cells in x, y, z directions for each system.
cell : torch.Tensor, shape (num_systems, 3, 3)
Per-system unit cell matrices for coordinate transformations.
pbc : torch.Tensor, shape (num_systems, 3), dtype=bool
Per-system periodic boundary condition flags.
Returns
-------
rebuild_flags : torch.Tensor, shape (num_systems,), dtype=bool
Per-system flags: True if any atom in that system changed cells.
Notes
-----
- torch.compile compatible custom operation
- No CPU-GPU synchronization required; all flag writes happen on GPU
- Returns tensor (not Python bool) for compilation compatibility
See Also
--------
cell_list_needs_rebuild : Single-system version
batch_neighbor_list_needs_rebuild : Skin-distance based alternative
"""
return _batch_cell_list_needs_rebuild(
current_positions,
atom_to_cell_mapping,
batch_idx,
cells_per_dimension,
cell,
pbc,
)