Source code for nvalchemiops.torch.neighbors.batch_cell_list

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch bindings for batched cell list neighbor construction."""

from __future__ import annotations

import warnings

import torch
import warp as wp

from nvalchemiops.neighbors.batch_cell_list import (
    _batch_estimate_cell_list_sizes_overload,
)
from nvalchemiops.neighbors.batch_cell_list import (
    batch_build_cell_list as wp_batch_build_cell_list,
)
from nvalchemiops.neighbors.batch_cell_list import (
    batch_query_cell_list as wp_batch_query_cell_list,
)
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
from nvalchemiops.torch.neighbors.neighbor_utils import (
    allocate_cell_list,
    get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

__all__ = [
    "estimate_batch_cell_list_sizes",
    "batch_build_cell_list",
    "batch_query_cell_list",
    "batch_cell_list",
]


[docs] def estimate_batch_cell_list_sizes( cell: torch.Tensor, pbc: torch.Tensor, cutoff: float, max_nbins: int = 8192, ) -> tuple[int, torch.Tensor]: """Estimate memory allocation sizes for batch cell list construction. Analyzes a batch of systems to determine conservative memory allocation requirements for torch.compile-friendly batch cell list building. Uses system sizes, cutoff distance, and safety factors to prevent overflow. Parameters ---------- cell : torch.Tensor, shape (num_systems, 3, 3) Unit cell matrices for each system in the batch. pbc : torch.Tensor, shape (num_systems, 3), dtype=bool Periodic boundary condition flags for each system and dimension. cutoff : float Neighbor search cutoff distance. max_nbins : int, default=8192 Maximum number of cells to allocate per system. Returns ------- max_total_cells_across_batch : int Estimated maximum total cells needed across all systems combined. neighbor_search_radius : torch.Tensor, shape (num_systems, 3), dtype=int32 Radius of neighboring cells to search for each system. Notes ----- - Currently, only unit cells with a positive determinant (i.e. with positive volume) are supported. For non-periodic systems, pass an identity cell. - Estimates assume roughly uniform atomic distribution within each system - Cell sizes are determined by the smallest cutoff to ensure neighbor completeness - For degenerate cells or empty systems, returns conservative fallback values See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_build_cell_list : Core warp launcher allocate_cell_list : Allocates tensors based on these estimates batch_build_cell_list : High-level wrapper that uses these estimates """ if cell.numel() > 0 and torch.any(cell.det() <= 0.0): raise RuntimeError( "Cells with volume <= 0 detected and are not supported." " Please pass unit cells with `det(cell) > 0.0`." ) num_systems = cell.shape[0] if num_systems == 0 or cutoff <= 0: return 1, torch.zeros((num_systems, 3), device=cell.device, dtype=torch.int32) dtype = cell.dtype device = cell.device wp_device = str(device) wp_dtype = get_wp_dtype(dtype) wp_mat_dtype = get_wp_mat_dtype(dtype) wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True) wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True) max_total_cells = torch.zeros(num_systems, device=device, dtype=torch.int32) wp_max_total_cells = wp.from_torch( max_total_cells, dtype=wp.int32, return_ctype=True ) neighbor_search_radius = torch.zeros( (num_systems, 3), dtype=torch.int32, device=device ) wp_neighbor_search_radius = wp.from_torch( neighbor_search_radius, dtype=wp.vec3i, return_ctype=True ) wp.launch( _batch_estimate_cell_list_sizes_overload[wp_dtype], dim=num_systems, inputs=[ wp_cell, wp_pbc, wp_dtype(cutoff), max_nbins, wp_max_total_cells, wp_neighbor_search_radius, ], device=wp_device, ) return ( max_total_cells.sum().item(), neighbor_search_radius, )
@torch.library.custom_op( "nvalchemiops::batch_build_cell_list", mutates_args=( "cells_per_dimension", "atom_periodic_shifts", "atom_to_cell_mapping", "atoms_per_cell_count", "cell_atom_start_indices", "cell_atom_list", ), ) def _batch_build_cell_list_op( positions: torch.Tensor, cutoff: float, cell: torch.Tensor, pbc: torch.Tensor, batch_idx: torch.Tensor, cells_per_dimension: torch.Tensor, atom_periodic_shifts: torch.Tensor, atom_to_cell_mapping: torch.Tensor, atoms_per_cell_count: torch.Tensor, cell_atom_start_indices: torch.Tensor, cell_atom_list: torch.Tensor, ) -> None: """Internal custom op for building batch spatial cell lists. This function is torch compilable. Notes ----- The neighbor_search_radius is not an input parameter because it's not used during cell list building - it's only needed for querying the cell list. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_build_cell_list : Core warp launcher batch_build_cell_list : High-level wrapper function """ device = positions.device num_systems = cell.shape[0] # Handle empty case if positions.shape[0] == 0 or cutoff <= 0: return # Get warp dtype of input tensors wp_dtype = get_wp_dtype(positions.dtype) wp_vec_dtype = get_wp_vec_dtype(positions.dtype) wp_mat_dtype = get_wp_mat_dtype(positions.dtype) wp_device = str(device) # Convert to warp arrays wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True) wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True) wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True) wp_batch_idx = wp.from_torch( batch_idx.to(dtype=torch.int32), dtype=wp.int32, return_ctype=True ) wp_cells_per_dimension = wp.from_torch( cells_per_dimension, dtype=wp.vec3i, return_ctype=True ) # Allocate cell_offsets internally (shape num_systems, not num_systems+1) cell_offsets = torch.zeros(num_systems, dtype=torch.int32, device=device) wp_cell_offsets = wp.from_torch(cell_offsets, dtype=wp.int32) # Allocate cells_per_system scratch buffer cells_per_system = torch.zeros(num_systems, dtype=torch.int32, device=device) wp_cells_per_system = wp.from_torch(cells_per_system, dtype=wp.int32) wp_atom_periodic_shifts = wp.from_torch( atom_periodic_shifts, dtype=wp.vec3i, return_ctype=True ) wp_atom_to_cell_mapping = wp.from_torch( atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True ) # underlying warp launcher relies on Python API for array_scan # so `return_ctype` is omitted wp_atoms_per_cell_count = wp.from_torch(atoms_per_cell_count, dtype=wp.int32) wp_cell_atom_start_indices = wp.from_torch(cell_atom_start_indices, dtype=wp.int32) wp_cell_atom_list = wp.from_torch(cell_atom_list, dtype=wp.int32, return_ctype=True) # Zero atoms_per_cell_count before building atoms_per_cell_count.zero_() # Call core warp launcher wp_batch_build_cell_list( positions=wp_positions, cell=wp_cell, pbc=wp_pbc, cutoff=cutoff, batch_idx=wp_batch_idx, cells_per_dimension=wp_cells_per_dimension, cell_offsets=wp_cell_offsets, cells_per_system=wp_cells_per_system, atom_periodic_shifts=wp_atom_periodic_shifts, atom_to_cell_mapping=wp_atom_to_cell_mapping, atoms_per_cell_count=wp_atoms_per_cell_count, cell_atom_start_indices=wp_cell_atom_start_indices, cell_atom_list=wp_cell_atom_list, wp_dtype=wp_dtype, device=wp_device, )
[docs] def batch_build_cell_list( positions: torch.Tensor, cutoff: float, cell: torch.Tensor, pbc: torch.Tensor, batch_idx: torch.Tensor, cells_per_dimension: torch.Tensor, neighbor_search_radius: torch.Tensor, atom_periodic_shifts: torch.Tensor, atom_to_cell_mapping: torch.Tensor, atoms_per_cell_count: torch.Tensor, cell_atom_start_indices: torch.Tensor, cell_atom_list: torch.Tensor, ) -> None: """Build batch spatial cell lists with fixed allocation sizes for torch.compile compatibility. This function is torch compilable. Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3) Concatenated atomic coordinates for all systems in the batch. cutoff : float Neighbor search cutoff distance. cell : torch.Tensor, shape (num_systems, 3, 3) Unit cell matrices for each system in the batch. pbc : torch.Tensor, shape (num_systems, 3), dtype=bool Periodic boundary condition flags for each system and dimension. batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32 System index for each atom. cells_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=int32 OUTPUT: Number of cells in x, y, z directions for each system. neighbor_search_radius : torch.Tensor, shape (num_systems, 3), dtype=int32 Radius of neighboring cells to search in each dimension. Passed through from allocate_cell_list for API continuity but not used in this function. atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32 OUTPUT: Periodic boundary crossings for each atom across all systems. atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32 OUTPUT: 3D cell coordinates assigned to each atom across all systems. atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32 OUTPUT: Number of atoms in each cell across all systems. cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32 OUTPUT: Starting index in global cell arrays for each system (CSR format). cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32 OUTPUT: Flattened list of atom indices organized by cell across all systems. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_build_cell_list : Core warp launcher estimate_batch_cell_list_sizes : Estimate memory requirements batch_query_cell_list : Query the built cell list for neighbors batch_cell_list : High-level function that builds and queries in one call """ return _batch_build_cell_list_op( positions, cutoff, cell, pbc, batch_idx, cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, )
@torch.library.custom_op( "nvalchemiops::batch_query_cell_list", mutates_args=("neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"), ) def _batch_query_cell_list_op( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, cutoff: float, batch_idx: torch.Tensor, cells_per_dimension: torch.Tensor, neighbor_search_radius: torch.Tensor, atom_periodic_shifts: torch.Tensor, atom_to_cell_mapping: torch.Tensor, atoms_per_cell_count: torch.Tensor, cell_atom_start_indices: torch.Tensor, cell_atom_list: torch.Tensor, neighbor_matrix: torch.Tensor, neighbor_matrix_shifts: torch.Tensor, num_neighbors: torch.Tensor, half_fill: bool = False, ) -> None: """Internal custom op for querying batch spatial cell lists to build neighbor matrices. This function is torch compilable. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_query_cell_list : Core warp launcher batch_query_cell_list : High-level wrapper function """ device = positions.device num_systems = cell.shape[0] # Handle empty case if positions.shape[0] == 0 or cutoff <= 0: return # Get warp dtypes and arrays wp_dtype = get_wp_dtype(positions.dtype) wp_vec_dtype = get_wp_vec_dtype(positions.dtype) wp_mat_dtype = get_wp_mat_dtype(positions.dtype) wp_device = str(device) wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True) wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True) wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True) wp_batch_idx = wp.from_torch( batch_idx.to(dtype=torch.int32), dtype=wp.int32, return_ctype=True ) wp_cells_per_dimension = wp.from_torch( cells_per_dimension, dtype=wp.vec3i, return_ctype=True ) wp_neighbor_search_radius = wp.from_torch( neighbor_search_radius, dtype=wp.vec3i, return_ctype=True ) # cell_offsets[i] = sum of cells for systems 0..i-1 cells_per_system = cells_per_dimension.prod(dim=1) cell_offsets = torch.zeros(num_systems, dtype=torch.int32, device=device) if num_systems > 1: torch.cumsum(cells_per_system[:-1], dim=0, out=cell_offsets[1:]) # cell_offsets[0] is already 0 from zeros initialization wp_cell_offsets = wp.from_torch(cell_offsets, dtype=wp.int32, return_ctype=True) wp_atom_periodic_shifts = wp.from_torch( atom_periodic_shifts, dtype=wp.vec3i, return_ctype=True ) wp_atom_to_cell_mapping = wp.from_torch( atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True ) wp_atoms_per_cell_count = wp.from_torch( atoms_per_cell_count, dtype=wp.int32, return_ctype=True ) wp_cell_atom_start_indices = wp.from_torch( cell_atom_start_indices, dtype=wp.int32, return_ctype=True ) wp_cell_atom_list = wp.from_torch(cell_atom_list, dtype=wp.int32, return_ctype=True) wp_neighbor_matrix = wp.from_torch( neighbor_matrix, dtype=wp.int32, return_ctype=True ) wp_neighbor_matrix_shifts = wp.from_torch( neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True ) wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True) # Call core warp launcher wp_batch_query_cell_list( positions=wp_positions, cell=wp_cell, pbc=wp_pbc, cutoff=cutoff, batch_idx=wp_batch_idx, cells_per_dimension=wp_cells_per_dimension, neighbor_search_radius=wp_neighbor_search_radius, cell_offsets=wp_cell_offsets, atom_periodic_shifts=wp_atom_periodic_shifts, atom_to_cell_mapping=wp_atom_to_cell_mapping, atoms_per_cell_count=wp_atoms_per_cell_count, cell_atom_start_indices=wp_cell_atom_start_indices, cell_atom_list=wp_cell_atom_list, neighbor_matrix=wp_neighbor_matrix, neighbor_matrix_shifts=wp_neighbor_matrix_shifts, num_neighbors=wp_num_neighbors, wp_dtype=wp_dtype, device=wp_device, half_fill=half_fill, ) @torch.library.custom_op( "nvalchemiops::batch_query_cell_list_selective", mutates_args=("neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"), ) def _batch_query_cell_list_selective_op( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, cutoff: float, batch_idx: torch.Tensor, cells_per_dimension: torch.Tensor, neighbor_search_radius: torch.Tensor, atom_periodic_shifts: torch.Tensor, atom_to_cell_mapping: torch.Tensor, atoms_per_cell_count: torch.Tensor, cell_atom_start_indices: torch.Tensor, cell_atom_list: torch.Tensor, neighbor_matrix: torch.Tensor, neighbor_matrix_shifts: torch.Tensor, num_neighbors: torch.Tensor, rebuild_flags: torch.Tensor, half_fill: bool = False, ) -> None: """Internal custom op for querying batch cell lists with per-system selective skip. Only systems with rebuild_flags[i] == True are recomputed on the GPU. Existing neighbor data for non-rebuilt systems is preserved without CPU-GPU sync. This function is torch compilable. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_query_cell_list : Core warp launcher batch_query_cell_list : High-level wrapper function """ device = positions.device num_systems = cell.shape[0] if positions.shape[0] == 0 or cutoff <= 0: return wp_dtype = get_wp_dtype(positions.dtype) wp_vec_dtype = get_wp_vec_dtype(positions.dtype) wp_mat_dtype = get_wp_mat_dtype(positions.dtype) wp_device = str(device) wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True) wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True) wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True) wp_batch_idx = wp.from_torch( batch_idx.to(dtype=torch.int32), dtype=wp.int32, return_ctype=True ) wp_cells_per_dimension = wp.from_torch( cells_per_dimension, dtype=wp.vec3i, return_ctype=True ) wp_neighbor_search_radius = wp.from_torch( neighbor_search_radius, dtype=wp.vec3i, return_ctype=True ) cells_per_system = cells_per_dimension.prod(dim=1) cell_offsets = torch.zeros(num_systems, dtype=torch.int32, device=device) if num_systems > 1: torch.cumsum(cells_per_system[:-1], dim=0, out=cell_offsets[1:]) wp_cell_offsets = wp.from_torch(cell_offsets, dtype=wp.int32, return_ctype=True) wp_atom_periodic_shifts = wp.from_torch( atom_periodic_shifts, dtype=wp.vec3i, return_ctype=True ) wp_atom_to_cell_mapping = wp.from_torch( atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True ) wp_atoms_per_cell_count = wp.from_torch( atoms_per_cell_count, dtype=wp.int32, return_ctype=True ) wp_cell_atom_start_indices = wp.from_torch( cell_atom_start_indices, dtype=wp.int32, return_ctype=True ) wp_cell_atom_list = wp.from_torch(cell_atom_list, dtype=wp.int32, return_ctype=True) wp_neighbor_matrix = wp.from_torch( neighbor_matrix, dtype=wp.int32, return_ctype=True ) wp_neighbor_matrix_shifts = wp.from_torch( neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True ) wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True) wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True) wp_batch_query_cell_list( positions=wp_positions, cell=wp_cell, pbc=wp_pbc, cutoff=cutoff, batch_idx=wp_batch_idx, cells_per_dimension=wp_cells_per_dimension, neighbor_search_radius=wp_neighbor_search_radius, cell_offsets=wp_cell_offsets, atom_periodic_shifts=wp_atom_periodic_shifts, atom_to_cell_mapping=wp_atom_to_cell_mapping, atoms_per_cell_count=wp_atoms_per_cell_count, cell_atom_start_indices=wp_cell_atom_start_indices, cell_atom_list=wp_cell_atom_list, neighbor_matrix=wp_neighbor_matrix, neighbor_matrix_shifts=wp_neighbor_matrix_shifts, num_neighbors=wp_num_neighbors, wp_dtype=wp_dtype, device=wp_device, half_fill=half_fill, rebuild_flags=wp_rebuild_flags, )
[docs] def batch_query_cell_list( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, cutoff: float, batch_idx: torch.Tensor, cells_per_dimension: torch.Tensor, neighbor_search_radius: torch.Tensor, atom_periodic_shifts: torch.Tensor, atom_to_cell_mapping: torch.Tensor, atoms_per_cell_count: torch.Tensor, cell_atom_start_indices: torch.Tensor, cell_atom_list: torch.Tensor, neighbor_matrix: torch.Tensor, neighbor_matrix_shifts: torch.Tensor, num_neighbors: torch.Tensor, half_fill: bool = False, rebuild_flags: torch.Tensor | None = None, ) -> None: """Query batch spatial cell lists to build neighbor matrices for multiple systems. Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3) Concatenated Cartesian coordinates for all systems in the batch. cell : torch.Tensor, shape (num_systems, 3, 3) Unit cell matrices for each system in the batch. pbc : torch.Tensor, shape (num_systems, 3), dtype=bool Periodic boundary condition flags. cutoff : float Neighbor search cutoff distance. batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32 System index for each atom. cells_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=int32 Number of cells in x, y, z directions for each system. neighbor_search_radius : torch.Tensor, shape (num_systems, 3), dtype=int32 Radius of neighboring cells to search. atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32 Periodic boundary crossings per atom from batch_build_cell_list. atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32 3D cell coordinates per atom from batch_build_cell_list. atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32 Number of atoms per cell from batch_build_cell_list. cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32 Starting index per cell from batch_build_cell_list. cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32 Atom list organized by cell from batch_build_cell_list. neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=int32 OUTPUT: Neighbor matrix to be filled. neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=int32 OUTPUT: Shift vectors for each neighbor relationship. num_neighbors : torch.Tensor, shape (total_atoms,), dtype=int32 OUTPUT: Number of neighbors per atom. half_fill : bool, default=False If True, only store half of the neighbor relationships. rebuild_flags : torch.Tensor, shape (num_systems,), dtype=torch.bool, optional Per-system rebuild flags. If provided, only systems with True are processed on the GPU; existing neighbor data for other systems is preserved. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_query_cell_list : Core warp launcher batch_build_cell_list : Builds the cell list data structures batch_cell_list : High-level function that builds and queries in one call """ if rebuild_flags is None: return _batch_query_cell_list_op( positions, cell, pbc, cutoff, batch_idx, cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, ) return _batch_query_cell_list_selective_op( positions, cell, pbc, cutoff, batch_idx, cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, rebuild_flags, half_fill, )
[docs] def batch_cell_list( positions: torch.Tensor, cutoff: float, cell: torch.Tensor, pbc: torch.Tensor, batch_idx: torch.Tensor, max_neighbors: int | None = None, half_fill: bool = False, fill_value: int | None = None, return_neighbor_list: bool = False, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, num_neighbors: torch.Tensor | None = None, cells_per_dimension: torch.Tensor | None = None, neighbor_search_radius: torch.Tensor | None = None, cell_offsets: torch.Tensor | None = None, atom_periodic_shifts: torch.Tensor | None = None, atom_to_cell_mapping: torch.Tensor | None = None, atoms_per_cell_count: torch.Tensor | None = None, cell_atom_start_indices: torch.Tensor | None = None, cell_atom_list: torch.Tensor | None = None, rebuild_flags: torch.Tensor | None = None, ) -> ( tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor] ): """Build complete batch neighbor matrices using spatial cell list acceleration. High-level convenience function that processes multiple systems simultaneously. Automatically estimates memory requirements, builds batch spatial cell list data structures, and queries them to produce complete neighbor matrices for all systems. Parameters ---------- positions : torch.Tensor, shape (total_atoms, 3) Concatenated atomic coordinates for all systems in the batch. cutoff : float Neighbor search cutoff distance. cell : torch.Tensor, shape (num_systems, 3, 3) Unit cell matrices for each system in the batch. pbc : torch.Tensor, shape (num_systems, 3), dtype=bool Periodic boundary condition flags for each system and dimension. batch_idx : torch.Tensor, shape (total_atoms,), dtype=int32 System index for each atom. max_neighbors : int or None, optional Maximum number of neighbors per atom. If None, automatically estimated. half_fill : bool, default=False If True, only fill half of the neighbor matrix. fill_value : int | None, optional Value to use for padding empty neighbor slots in the matrix. Default is total_atoms. return_neighbor_list : bool, optional - default=False If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format. cells_per_dimension : torch.Tensor, shape (num_systems, 3), dtype=int32, optional Pre-allocated tensor for cell dimensions. neighbor_search_radius : torch.Tensor, shape (num_systems, 3), dtype=int32, optional Pre-allocated tensor for search radius. atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32, optional Pre-allocated tensor for periodic shifts. atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32, optional Pre-allocated tensor for cell mapping. atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32, optional Pre-allocated tensor for atom counts. cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32, optional Pre-allocated tensor for start indices. cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32, optional Pre-allocated tensor for atom list. rebuild_flags : torch.Tensor, shape (num_systems,), dtype=torch.bool, optional Per-system rebuild flags produced by ``batch_cell_list_needs_rebuild``. If provided, only systems where rebuild_flags[i] is True are recomputed; existing data in ``neighbor_matrix`` and ``num_neighbors`` is preserved for non-rebuilt systems entirely on the GPU (no CPU-GPU sync). When this is used, pre-allocated ``neighbor_matrix`` and ``num_neighbors`` tensors must be provided and will not be globally zeroed — only rebuilt-system entries are reset. Returns ------- results : tuple of torch.Tensor Variable-length tuple with neighbor data in matrix or list format. See Also -------- nvalchemiops.neighbors.batch_cell_list.batch_build_cell_list : Core warp launcher for building nvalchemiops.neighbors.batch_cell_list.batch_query_cell_list : Core warp launcher for querying batch_naive_neighbor_list : O(N²) method for small systems """ total_atoms = positions.shape[0] device = positions.device if device == "cpu": warnings.warn( "The CPU version of `batch_cell_list` is known to experience" " issues with memory allocation and under investigation. Please" " ensure tensor provided as `positions` is on GPU." ) # Handle empty case if total_atoms <= 0 or cutoff <= 0: if return_neighbor_list: return ( torch.zeros((2, 0), dtype=torch.int32, device=device), torch.zeros((total_atoms + 1,), dtype=torch.int32, device=device), torch.zeros((0, 3), dtype=torch.int32, device=device), ) else: return ( torch.full((total_atoms, 0), -1, dtype=torch.int32, device=device), torch.zeros((total_atoms,), dtype=torch.int32, device=device), torch.zeros((total_atoms, 0, 3), dtype=torch.int32, device=device), ) if max_neighbors is None and neighbor_matrix is None: max_neighbors = estimate_max_neighbors(cutoff) if fill_value is None: fill_value = total_atoms if neighbor_matrix is None: neighbor_matrix = torch.full( (total_atoms, max_neighbors), fill_value, dtype=torch.int32, device=device ) elif rebuild_flags is None: neighbor_matrix.fill_(fill_value) if neighbor_matrix_shifts is None: neighbor_matrix_shifts = torch.zeros( (total_atoms, max_neighbors, 3), dtype=torch.int32, device=device ) elif rebuild_flags is None: neighbor_matrix_shifts.zero_() if num_neighbors is None: num_neighbors = torch.zeros((total_atoms,), dtype=torch.int32, device=device) elif rebuild_flags is None: num_neighbors.zero_() # Allocate cell list if needed if ( cells_per_dimension is None or neighbor_search_radius is None or atom_periodic_shifts is None or atom_to_cell_mapping is None or atoms_per_cell_count is None or cell_atom_start_indices is None or cell_atom_list is None ): max_total_cells, neighbor_search_radius = estimate_batch_cell_list_sizes( cell, pbc, cutoff ) ( cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, ) = allocate_cell_list( total_atoms, max_total_cells, neighbor_search_radius, device, ) cell_list_cache = ( cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, ) else: cells_per_dimension.zero_() atom_periodic_shifts.zero_() atom_to_cell_mapping.zero_() atoms_per_cell_count.zero_() cell_atom_start_indices.zero_() cell_atom_list.zero_() cell_list_cache = ( cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, ) # Build batch cell list with fixed allocations batch_build_cell_list( positions, cutoff, cell, pbc, batch_idx, *cell_list_cache, ) # Query neighbor lists batch_query_cell_list( positions, cell, pbc, cutoff, batch_idx, *cell_list_cache, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, half_fill, rebuild_flags, ) if return_neighbor_list: neighbor_list, neighbor_ptr, neighbor_list_shifts = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, neighbor_shift_matrix=neighbor_matrix_shifts, fill_value=fill_value, ) ) return neighbor_list, neighbor_ptr, neighbor_list_shifts else: return neighbor_matrix, num_neighbors, neighbor_matrix_shifts