# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch bindings for unbatched cell list neighbor construction."""
from __future__ import annotations
import torch
import warp as wp
from nvalchemiops.neighbors.cell_list import (
build_cell_list as wp_build_cell_list,
)
from nvalchemiops.neighbors.cell_list import (
query_cell_list as wp_query_cell_list,
)
from nvalchemiops.neighbors.neighbor_utils import (
estimate_max_neighbors,
selective_zero_num_neighbors_single,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
allocate_cell_list,
get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
__all__ = [
"estimate_cell_list_sizes",
"build_cell_list",
"query_cell_list",
"cell_list",
]
[docs]
def estimate_cell_list_sizes(
cell: torch.Tensor,
pbc: torch.Tensor,
cutoff: float,
max_nbins: int = 8192,
) -> tuple[int, torch.Tensor]:
"""Estimate allocation sizes for torch.compile-friendly cell list construction.
Provides conservative estimates for maximum memory allocations needed when
building cell lists with fixed-size tensors to avoid dynamic allocation
and graph breaks in torch.compile.
This function is not torch.compile compatible because it returns an integer
received from using torch.Tensor.item()
Parameters
----------
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix defining the simulation box.
pbc : torch.Tensor, shape (1, 3), dtype=bool
Flags indicating periodic boundary conditions in x, y, z directions.
cutoff : float
Maximum distance for neighbor search, determines minimum cell size.
max_nbins : int, default=8192
Maximum number of cells to allocate.
Returns
-------
max_total_cells : int
Estimated maximum number of cells needed for spatial decomposition.
For degenerate cells, returns the total number of atoms.
neighbor_search_radius : torch.Tensor, shape (3,), dtype=int32
Radius of neighboring cells to search in each dimension.
Notes
-----
- Cell size is determined by the cutoff distance to ensure neighboring
cells contain all potential neighbors. The estimation assumes roughly
cubic cells and uniform atomic distribution.
- Currently, only unit cells with a positive determinant (i.e. with
positive volume) are supported. For non-periodic systems, pass an identity
cell.
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Core warp launcher
allocate_cell_list : Allocates tensors based on these estimates
build_cell_list : High-level wrapper that uses these estimates
"""
if cell.numel() > 0 and cell.det() <= 0.0:
raise RuntimeError(
"Cell with volume <= 0.0 detected and is not supported."
" Please pass unit cells with `det(cell) > 0.0`."
)
dtype = cell.dtype
device = cell.device
wp_device = str(device)
wp_dtype = get_wp_dtype(dtype)
wp_mat_dtype = get_wp_mat_dtype(dtype)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
if (cell.ndim == 3 and cell.shape[0] == 0) or cutoff <= 0:
return 1, torch.zeros((3,), dtype=torch.int32, device=device)
if cell.ndim == 2:
cell = cell.unsqueeze(0)
max_total_cells = torch.zeros(1, device=device, dtype=torch.int32)
wp_max_total_cells = wp.from_torch(
max_total_cells, dtype=wp.int32, return_ctype=True
)
neighbor_search_radius = torch.zeros((3,), dtype=torch.int32, device=device)
wp_neighbor_search_radius = wp.from_torch(
neighbor_search_radius, dtype=wp.int32, return_ctype=True
)
# Note: Using the _estimate_cell_list_sizes kernel from cell_list module
from nvalchemiops.neighbors.cell_list import _estimate_cell_list_sizes_overload
wp.launch(
_estimate_cell_list_sizes_overload[wp_dtype],
dim=1,
inputs=[
wp_cell,
wp_pbc,
wp_dtype(cutoff),
max_nbins,
wp_max_total_cells,
wp_neighbor_search_radius,
],
device=wp_device,
)
return (
max_total_cells.item(),
neighbor_search_radius,
)
@torch.library.custom_op(
"nvalchemiops::build_cell_list",
mutates_args=(
"cells_per_dimension",
"atom_periodic_shifts",
"atom_to_cell_mapping",
"atoms_per_cell_count",
"cell_atom_start_indices",
"cell_atom_list",
),
)
def _build_cell_list_op(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor,
pbc: torch.Tensor,
cells_per_dimension: torch.Tensor,
atom_periodic_shifts: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
atoms_per_cell_count: torch.Tensor,
cell_atom_start_indices: torch.Tensor,
cell_atom_list: torch.Tensor,
) -> None:
"""Internal custom op for building spatial cell list.
This function is torch compilable.
Notes
-----
The neighbor_search_radius is not an input parameter because it's computed
internally by the warp launcher and doesn't need to be passed in.
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Core warp launcher
build_cell_list : High-level wrapper function
"""
total_atoms = positions.shape[0]
device = positions.device
# Handle empty case
if total_atoms == 0:
return
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
pbc = pbc.squeeze(0)
# Get warp dtypes and arrays
wp_dtype = get_wp_dtype(positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_device = str(device)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
wp_cells_per_dimension = wp.from_torch(
cells_per_dimension, dtype=wp.int32, return_ctype=True
)
wp_atom_periodic_shifts = wp.from_torch(
atom_periodic_shifts, dtype=wp.vec3i, return_ctype=True
)
wp_atom_to_cell_mapping = wp.from_torch(
atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True
)
# underlying warp launcher relies on Python API for array_scan
# so `return_ctype` is omitted
wp_atoms_per_cell_count = wp.from_torch(atoms_per_cell_count, dtype=wp.int32)
wp_cell_atom_start_indices = wp.from_torch(cell_atom_start_indices, dtype=wp.int32)
wp_cell_atom_list = wp.from_torch(cell_atom_list, dtype=wp.int32, return_ctype=True)
# Zero atoms_per_cell_count before building
atoms_per_cell_count.zero_()
# Call core warp launcher
wp_build_cell_list(
positions=wp_positions,
cell=wp_cell,
pbc=wp_pbc,
cutoff=cutoff,
cells_per_dimension=wp_cells_per_dimension,
atom_periodic_shifts=wp_atom_periodic_shifts,
atom_to_cell_mapping=wp_atom_to_cell_mapping,
atoms_per_cell_count=wp_atoms_per_cell_count,
cell_atom_start_indices=wp_cell_atom_start_indices,
cell_atom_list=wp_cell_atom_list,
wp_dtype=wp_dtype,
device=wp_device,
)
# Compute cell atom start indices using cumsum
max_total_cells = atoms_per_cell_count.shape[0]
cell_atom_start_indices[0] = 0
if max_total_cells > 1:
torch.cumsum(atoms_per_cell_count[:-1], dim=0, out=cell_atom_start_indices[1:])
[docs]
def build_cell_list(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor,
pbc: torch.Tensor,
cells_per_dimension: torch.Tensor,
neighbor_search_radius: torch.Tensor,
atom_periodic_shifts: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
atoms_per_cell_count: torch.Tensor,
cell_atom_start_indices: torch.Tensor,
cell_atom_list: torch.Tensor,
) -> None:
"""Build spatial cell list with fixed allocation sizes for torch.compile compatibility.
Constructs a spatial decomposition data structure for efficient neighbor searching.
Uses fixed-size memory allocations to prevent dynamic tensor creation that would
cause graph breaks in torch.compile.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3)
Atomic coordinates in Cartesian space where total_atoms is the number of atoms.
Must be float32, float64, or float16 dtype.
cutoff : float
Maximum distance for neighbor search. Determines minimum cell size.
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix defining the simulation box. Each row represents a
lattice vector in Cartesian coordinates. Must match positions dtype.
pbc : torch.Tensor, shape (3,), dtype=bool
Flags indicating periodic boundary conditions in x, y, z directions.
True enables PBC, False disables it for that dimension.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32
OUTPUT: Number of cells created in x, y, z directions.
neighbor_search_radius : torch.Tensor, shape (3,), dtype=int32
Radius of neighboring cells to search in each dimension. Passed through
from allocate_cell_list for API continuity but not used in this function.
atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32
OUTPUT: Periodic boundary crossings for each atom.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
OUTPUT: 3D cell coordinates assigned to each atom.
atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32
OUTPUT: Number of atoms in each cell. Only first 'total_cells' entries are valid.
cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32
OUTPUT: Starting index in cell_atom_list for each cell's atoms.
cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32
OUTPUT: Flattened list of atom indices organized by cell. Use with start_indices
to extract atoms for each cell.
Notes
-----
- This function is torch.compile compatible and uses only static tensor shapes
- Memory usage is determined by max_total_cells
- For optimal performance, use estimates from estimate_cell_list_sizes()
- Cell list must be rebuilt when atoms move between cells or PBC/cell changes
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Core warp launcher
estimate_cell_list_sizes : Estimate memory requirements
query_cell_list : Query the built cell list for neighbors
cell_list : High-level function that builds and queries in one call
"""
return _build_cell_list_op(
positions,
cutoff,
cell,
pbc,
cells_per_dimension,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
)
@torch.library.custom_op(
"nvalchemiops::query_cell_list",
mutates_args=("neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"),
)
def _query_cell_list_op(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor,
pbc: torch.Tensor,
cells_per_dimension: torch.Tensor,
neighbor_search_radius: torch.Tensor,
atom_periodic_shifts: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
atoms_per_cell_count: torch.Tensor,
cell_atom_start_indices: torch.Tensor,
cell_atom_list: torch.Tensor,
neighbor_matrix: torch.Tensor,
neighbor_matrix_shifts: torch.Tensor,
num_neighbors: torch.Tensor,
half_fill: bool = False,
rebuild_flags: torch.Tensor | None = None,
) -> None:
"""Internal custom op for querying spatial cell list to build neighbor matrix.
This function is torch compilable.
See Also
--------
nvalchemiops.neighbors.cell_list.query_cell_list : Core warp launcher
query_cell_list : High-level wrapper function
"""
total_atoms = positions.shape[0]
device = positions.device
# Handle empty case
if total_atoms == 0:
return
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
pbc = pbc.squeeze(0)
# Get warp dtypes and arrays
wp_dtype = get_wp_dtype(positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_device = str(device)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_pbc = wp.from_torch(pbc, dtype=wp.bool, return_ctype=True)
wp_cells_per_dimension = wp.from_torch(
cells_per_dimension, dtype=wp.int32, return_ctype=True
)
wp_neighbor_search_radius = wp.from_torch(
neighbor_search_radius, dtype=wp.int32, return_ctype=True
)
wp_atom_periodic_shifts = wp.from_torch(
atom_periodic_shifts, dtype=wp.vec3i, return_ctype=True
)
wp_atom_to_cell_mapping = wp.from_torch(
atom_to_cell_mapping, dtype=wp.vec3i, return_ctype=True
)
wp_atoms_per_cell_count = wp.from_torch(
atoms_per_cell_count, dtype=wp.int32, return_ctype=True
)
wp_cell_atom_start_indices = wp.from_torch(
cell_atom_start_indices, dtype=wp.int32, return_ctype=True
)
wp_cell_atom_list = wp.from_torch(cell_atom_list, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix = wp.from_torch(
neighbor_matrix, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix_shifts = wp.from_torch(
neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True
)
wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, return_ctype=True)
if rebuild_flags is not None:
wp_rebuild_flags = wp.from_torch(
rebuild_flags, dtype=wp.bool, return_ctype=True
)
selective_zero_num_neighbors_single(
wp_num_neighbors, wp_rebuild_flags, wp_device
)
else:
wp_rebuild_flags = None
# Call core warp launcher
wp_query_cell_list(
positions=wp_positions,
cell=wp_cell,
pbc=wp_pbc,
cutoff=cutoff,
cells_per_dimension=wp_cells_per_dimension,
neighbor_search_radius=wp_neighbor_search_radius,
atom_periodic_shifts=wp_atom_periodic_shifts,
atom_to_cell_mapping=wp_atom_to_cell_mapping,
atoms_per_cell_count=wp_atoms_per_cell_count,
cell_atom_start_indices=wp_cell_atom_start_indices,
cell_atom_list=wp_cell_atom_list,
neighbor_matrix=wp_neighbor_matrix,
neighbor_matrix_shifts=wp_neighbor_matrix_shifts,
num_neighbors=wp_num_neighbors,
wp_dtype=wp_dtype,
device=wp_device,
half_fill=half_fill,
rebuild_flags=wp_rebuild_flags,
)
[docs]
def query_cell_list(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor,
pbc: torch.Tensor,
cells_per_dimension: torch.Tensor,
neighbor_search_radius: torch.Tensor,
atom_periodic_shifts: torch.Tensor,
atom_to_cell_mapping: torch.Tensor,
atoms_per_cell_count: torch.Tensor,
cell_atom_start_indices: torch.Tensor,
cell_atom_list: torch.Tensor,
neighbor_matrix: torch.Tensor,
neighbor_matrix_shifts: torch.Tensor,
num_neighbors: torch.Tensor,
half_fill: bool = False,
rebuild_flags: torch.Tensor | None = None,
) -> None:
"""Query spatial cell list to build neighbor matrix with distance constraints.
Uses pre-built cell list data structures to efficiently find all atom pairs
within the specified cutoff distance. Handles periodic boundary conditions
and returns neighbor matrix format.
This function is torch compilable.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3)
Atomic coordinates in Cartesian space.
cutoff : float
Maximum distance for considering atoms as neighbors.
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix for periodic boundary coordinate shifts.
pbc : torch.Tensor, shape (3,), dtype=bool
Periodic boundary condition flags.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32
Number of cells in x, y, z directions from build_cell_list.
neighbor_search_radius : torch.Tensor, shape (3,), dtype=int32
Shifts to search from build_cell_list.
atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32
Periodic boundary crossings for each atom from build_cell_list.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom from build_cell_list.
atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32
Number of atoms in each cell from build_cell_list.
cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32
Starting index in cell_atom_list for each cell from build_cell_list.
cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32
Flattened list of atom indices organized by cell from build_cell_list.
neighbor_matrix : torch.Tensor, shape (total_atoms, max_neighbors), dtype=int32
OUTPUT: Neighbor matrix to be filled with neighbor atom indices.
Must be pre-allocated.
neighbor_matrix_shifts : torch.Tensor, shape (total_atoms, max_neighbors, 3), dtype=int32
OUTPUT: Matrix storing shift vectors for each neighbor relationship.
Must be pre-allocated.
num_neighbors : torch.Tensor, shape (total_atoms,), dtype=int32
OUTPUT: Number of neighbors found for each atom.
Must be pre-allocated.
half_fill : bool, default=False
If True, only store half of the neighbor relationships.
rebuild_flags : torch.Tensor, shape () or (1,), dtype=torch.bool, optional
If provided, controls whether the neighbor list is recomputed.
When the flag is False the kernel is skipped and the pre-allocated output
tensors are returned unchanged. When the flag is True (or when this
argument is None) the query proceeds as normal.
Note: providing this argument disables torch.compile compatibility.
See Also
--------
nvalchemiops.neighbors.cell_list.query_cell_list : Core warp launcher
build_cell_list : Builds the cell list data structures
cell_list : High-level function that builds and queries in one call
"""
return _query_cell_list_op(
positions,
cutoff,
cell,
pbc,
cells_per_dimension,
neighbor_search_radius,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
rebuild_flags,
)
[docs]
def cell_list(
positions: torch.Tensor,
cutoff: float,
cell: torch.Tensor,
pbc: torch.Tensor,
max_neighbors: int | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
neighbor_matrix: torch.Tensor | None = None,
neighbor_matrix_shifts: torch.Tensor | None = None,
num_neighbors: torch.Tensor | None = None,
cells_per_dimension: torch.Tensor | None = None,
neighbor_search_radius: torch.Tensor | None = None,
atom_periodic_shifts: torch.Tensor | None = None,
atom_to_cell_mapping: torch.Tensor | None = None,
atoms_per_cell_count: torch.Tensor | None = None,
cell_atom_start_indices: torch.Tensor | None = None,
cell_atom_list: torch.Tensor | None = None,
rebuild_flags: torch.Tensor | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor]
):
"""Build complete neighbor matrix using spatial cell list acceleration.
High-level convenience function that automatically estimates memory requirements,
builds spatial cell list data structures, and queries them to produce a complete
neighbor matrix. Combines build_cell_list and query_cell_list operations.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3)
Atomic coordinates in Cartesian space where total_atoms is the number of atoms.
cutoff : float
Maximum distance for neighbor search.
cell : torch.Tensor, shape (1, 3, 3)
Unit cell matrix defining the simulation box. Each row represents a
lattice vector in Cartesian coordinates.
pbc : torch.Tensor, shape (1, 3), dtype=bool
Flags indicating periodic boundary conditions in x, y, z directions.
max_neighbors : int, optional
Maximum number of neighbors per atom. If not provided, will be estimated automatically.
half_fill : bool, optional
If True, only fill half of the neighbor matrix. Default is False.
fill_value : int | None, optional
Value to fill the neighbor matrix with. Default is total_atoms.
return_neighbor_list : bool, optional - default = False
If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by
creating a mask over the fill_value, which can incur a performance penalty.
We recommend using the neighbor matrix format,
and only convert to a neighbor list format if absolutely necessary.
neighbor_matrix : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, max_neighbors) for neighbor indices.
If None, allocated internally.
neighbor_matrix_shifts : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms, max_neighbors, 3) for shift vectors.
If None, allocated internally.
num_neighbors : torch.Tensor, optional
Pre-allocated tensor of shape (total_atoms,) for neighbor counts.
If None, allocated internally.
cells_per_dimension : torch.Tensor, shape (3,), dtype=int32, optional
Number of cells in x, y, z directions.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
neighbor_search_radius : torch.Tensor, shape (3,), dtype=int32, optional
Radius of neighboring cells to search in each dimension.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
atom_periodic_shifts : torch.Tensor, shape (total_atoms, 3), dtype=int32, optional
Periodic boundary crossings for each atom.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
atom_to_cell_mapping : torch.Tensor, shape (total_atoms, 3), dtype=int32, optional
Cell coordinates for each atom.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
atoms_per_cell_count : torch.Tensor, shape (max_total_cells,), dtype=int32, optional
Number of atoms in each cell.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
cell_atom_start_indices : torch.Tensor, shape (max_total_cells,), dtype=int32, optional
Starting index in cell_atom_list for each cell.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
cell_atom_list : torch.Tensor, shape (total_atoms,), dtype=int32, optional
Flattened list of atom indices organized by cell.
Pass a pre-allocated tensor to avoid reallocation for cell list construction.
If None, allocated internally to build the cell list.
rebuild_flags : torch.Tensor, shape () or (1,), dtype=torch.bool, optional
If provided, controls whether the neighbor list is recomputed.
When the flag is False the existing ``neighbor_matrix``, ``num_neighbors``,
and ``neighbor_matrix_shifts`` tensors are returned unchanged and all
kernel launches are skipped. When the flag is True (or when this argument
is None) the neighbor list is recomputed as normal.
Returns
-------
results : tuple of torch.Tensor
Variable-length tuple depending on input parameters. The return pattern follows:
- Matrix format (default): ``(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)``
- List format (return_neighbor_list=True): ``(neighbor_list, neighbor_ptr, neighbor_list_shifts)``
Notes
-----
- This is the main user-facing API for cell list neighbor construction
- Uses automatic memory allocation estimation for torch.compile compatibility
- For advanced users who want to cache cell lists, use build_cell_list and query_cell_list separately
- Returns appropriate empty tensors for systems with <= 1 atom or cutoff <= 0
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Core warp launcher for building
nvalchemiops.neighbors.cell_list.query_cell_list : Core warp launcher for querying
naive_neighbor_list : O(N²) method for small systems
"""
total_atoms = positions.shape[0]
device = positions.device
if pbc is None:
raise ValueError(
"cell_list requires `pbc` to be specified. "
"Pass a boolean tensor of shape (3,) or (1, 3), "
"e.g. pbc=torch.tensor([True, True, True])."
)
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
pbc = pbc.squeeze(0)
if fill_value is None:
fill_value = total_atoms
# Handle empty case
if total_atoms <= 0 or cutoff <= 0:
if return_neighbor_list:
return (
torch.zeros((2, 0), dtype=torch.int32, device=device),
torch.zeros((total_atoms + 1,), dtype=torch.int32, device=device),
torch.zeros((0, 3), dtype=torch.int32, device=device),
)
else:
return (
torch.full(
(total_atoms, 0), fill_value, dtype=torch.int32, device=device
),
torch.zeros((total_atoms,), dtype=torch.int32, device=device),
torch.zeros((total_atoms, 0, 3), dtype=torch.int32, device=device),
)
if max_neighbors is None and (
neighbor_matrix is None
or neighbor_matrix_shifts is None
or num_neighbors is None
):
max_neighbors = estimate_max_neighbors(cutoff)
if neighbor_matrix is None:
neighbor_matrix = torch.full(
(total_atoms, max_neighbors), fill_value, dtype=torch.int32, device=device
)
elif rebuild_flags is None:
neighbor_matrix.fill_(fill_value)
if neighbor_matrix_shifts is None:
neighbor_matrix_shifts = torch.zeros(
(total_atoms, max_neighbors, 3), dtype=torch.int32, device=device
)
elif rebuild_flags is None:
neighbor_matrix_shifts.zero_()
if num_neighbors is None:
num_neighbors = torch.zeros((total_atoms,), dtype=torch.int32, device=device)
elif rebuild_flags is None:
num_neighbors.zero_()
# Allocate cell list if needed
if (
cells_per_dimension is None
or neighbor_search_radius is None
or atom_periodic_shifts is None
or atom_to_cell_mapping is None
or atoms_per_cell_count is None
or cell_atom_start_indices is None
or cell_atom_list is None
):
max_total_cells, neighbor_search_radius = estimate_cell_list_sizes(
cell, pbc, cutoff
)
cell_list_cache = allocate_cell_list(
total_atoms,
max_total_cells,
neighbor_search_radius,
device,
)
else:
cells_per_dimension.zero_()
atom_periodic_shifts.zero_()
atom_to_cell_mapping.zero_()
atoms_per_cell_count.zero_()
cell_atom_start_indices.zero_()
cell_atom_list.zero_()
cell_list_cache = (
cells_per_dimension,
neighbor_search_radius,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
)
build_cell_list(
positions,
cutoff,
cell,
pbc,
*cell_list_cache,
)
query_cell_list(
positions,
cutoff,
cell,
pbc,
*cell_list_cache,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
rebuild_flags,
)
if return_neighbor_list:
neighbor_list, neighbor_ptr, neighbor_list_shifts = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix,
num_neighbors=num_neighbors,
neighbor_shift_matrix=neighbor_matrix_shifts,
fill_value=fill_value,
)
)
return neighbor_list, neighbor_ptr, neighbor_list_shifts
else:
return neighbor_matrix, num_neighbors, neighbor_matrix_shifts