# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
import warp as wp
from nvalchemiops.neighborlist.neighbor_utils import (
_expand_naive_shifts,
_update_neighbor_matrix,
_update_neighbor_matrix_pbc,
compute_naive_num_shifts,
estimate_max_neighbors,
get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
###########################################################################################
########################### Naive Neighbor List Kernels ###################################
###########################################################################################
@wp.kernel(enable_backward=False)
def _fill_naive_neighbor_matrix_dual_cutoff(
positions: wp.array(dtype=Any),
cutoff1_sq: Any,
cutoff2_sq: Any,
neighbor_matrix1: wp.array2d(dtype=wp.int32, ndim=2),
num_neighbors1: wp.array(dtype=wp.int32),
neighbor_matrix2: wp.array2d(dtype=wp.int32, ndim=2),
num_neighbors2: wp.array(dtype=wp.int32),
half_fill: wp.bool,
) -> None:
"""Calculate two neighbor matrices using dual cutoffs with naive O(N^2) algorithm.
Computes pairwise distances between all atoms and identifies neighbors
within two different cutoff distances simultaneously. This is more efficient
than running two separate neighbor calculations when both neighbor lists are needed.
Parameters
----------
positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
Atomic coordinates in Cartesian space. Each row represents one atom's
(x, y, z) position.
cutoff1_sq : float
Squared first cutoff distance in Cartesian units (typically the smaller cutoff).
Must be positive. Atoms within this distance are considered neighbors.
cutoff2_sq : float
Squared second cutoff distance in Cartesian units (typically the larger cutoff).
Must be positive and should be >= cutoff1_sq for optimal performance.
neighbor_matrix1 : wp.array, shape (total_atoms, max_neighbors1), dtype=wp.int32
OUTPUT: First neighbor matrix for cutoff1 to be filled with atom indices.
Entries are filled with atom indices, remaining entries stay as initialized.
num_neighbors1 : wp.array, shape (total_atoms,), dtype=wp.int32
OUTPUT: Number of neighbors found for each atom within cutoff1.
Updated in-place with actual neighbor counts.
neighbor_matrix2 : wp.array, shape (total_atoms, max_neighbors2), dtype=wp.int32
OUTPUT: Second neighbor matrix for cutoff2 to be filled with atom indices.
Entries are filled with atom indices, remaining entries stay as initialized.
num_neighbors2 : wp.array, shape (total_atoms,), dtype=wp.int32
OUTPUT: Number of neighbors found for each atom within cutoff2.
Updated in-place with actual neighbor counts.
half_fill : wp.bool
If True, only store relationships where i < j to avoid double counting.
If False, store all neighbor relationships symmetrically.
Returns
-------
None
This function modifies the input arrays in-place:
- neighbor_matrix1 : Filled with neighbor atom indices within cutoff1
- num_neighbors1 : Updated with neighbor counts per atom for cutoff1
- neighbor_matrix2 : Filled with neighbor atom indices within cutoff2
- num_neighbors2 : Updated with neighbor counts per atom for cutoff2
See Also
--------
_fill_naive_neighbor_matrix : Single cutoff version
_fill_naive_neighbor_matrix_pbc_dual_cutoff : Version with periodic boundaries
"""
tid = wp.tid()
i = tid
j_end = positions.shape[0]
positions_i = positions[i]
maxnb1 = neighbor_matrix1.shape[1]
maxnb2 = neighbor_matrix2.shape[1]
for j in range(i + 1, j_end):
diff = positions_i - positions[j]
dist_sq = wp.length_sq(diff)
if dist_sq < cutoff2_sq:
_update_neighbor_matrix(
i, j, neighbor_matrix2, num_neighbors2, maxnb2, half_fill
)
if dist_sq < cutoff1_sq:
_update_neighbor_matrix(
i, j, neighbor_matrix1, num_neighbors1, maxnb1, half_fill
)
@wp.kernel(enable_backward=False)
def _fill_naive_neighbor_matrix_pbc_dual_cutoff(
positions: wp.array(dtype=Any),
cutoff1_sq: Any,
cutoff2_sq: Any,
cell: wp.array(dtype=Any),
shifts: wp.array(dtype=wp.vec3i),
neighbor_matrix1: wp.array(dtype=wp.int32, ndim=2),
neighbor_matrix2: wp.array(dtype=wp.int32, ndim=2),
neighbor_matrix_shifts1: wp.array(dtype=wp.vec3i, ndim=2),
neighbor_matrix_shifts2: wp.array(dtype=wp.vec3i, ndim=2),
num_neighbors1: wp.array(dtype=wp.int32),
num_neighbors2: wp.array(dtype=wp.int32),
half_fill: wp.bool,
) -> None:
"""Calculate two neighbor matrices with periodic boundary conditions using dual cutoffs and naive O(N^2) algorithm.
Computes neighbor relationships between atoms across periodic boundaries by
considering all periodic images within two different cutoff distances simultaneously.
Uses a 2D launch pattern to parallelize over both atoms and periodic shifts.
This is more efficient than running two separate PBC neighbor calculations.
Parameters
----------
positions : wp.array, shape (total_atoms, 3), dtype=wp.vec3*
Atomic coordinates in Cartesian space. Each row represents one atom's
(x, y, z) position.
cutoff1_sq : float
Squared first cutoff distance in Cartesian units (typically the smaller cutoff).
Must be positive. Atoms within this distance are considered neighbors.
cutoff2_sq : float
Squared second cutoff distance in Cartesian units (typically the larger cutoff).
Must be positive and should be >= cutoff1_sq for optimal performance.
cell : wp.array, shape (1, 3, 3), dtype=wp.mat33*
Cell matrix defining lattice vectors in Cartesian coordinates.
Single 3x3 matrix representing the periodic cell.
shifts : wp.array, shape (total_shifts, 3), dtype=wp.vec3i
Integer shift vectors for periodic images.
Each row represents (nx, ny, nz) multiples of the cell vectors.
neighbor_matrix1 : wp.array, shape (total_atoms, max_neighbors1), dtype=wp.int32
OUTPUT: First neighbor matrix for cutoff1 to be filled with atom indices.
Entries are filled with atom indices, remaining entries stay as initialized.
neighbor_matrix2 : wp.array, shape (total_atoms, max_neighbors2), dtype=wp.int32
OUTPUT: Second neighbor matrix for cutoff2 to be filled with atom indices.
Entries are filled with atom indices, remaining entries stay as initialized.
neighbor_matrix_shifts1 : wp.array, shape (total_atoms, max_neighbors1), dtype=wp.vec3i
OUTPUT: Matrix storing shift vectors for each neighbor relationship in matrix1.
Each entry corresponds to the shift used for the neighbor in neighbor_matrix1.
neighbor_matrix_shifts2 : wp.array, shape (total_atoms, max_neighbors2), dtype=wp.vec3i
OUTPUT: Matrix storing shift vectors for each neighbor relationship in matrix2.
Each entry corresponds to the shift used for the neighbor in neighbor_matrix2.
num_neighbors1 : wp.array, shape (total_atoms,), dtype=wp.int32
OUTPUT: Number of neighbors found for each atom within cutoff1.
Updated in-place with actual neighbor counts.
num_neighbors2 : wp.array, shape (total_atoms,), dtype=wp.int32
OUTPUT: Number of neighbors found for each atom within cutoff2.
Updated in-place with actual neighbor counts.
half_fill : wp.bool
If True, only store relationships where i < j to avoid double counting.
If False, store all neighbor relationships symmetrically.
Returns
-------
None
This function modifies the input arrays in-place:
- neighbor_matrix1 : Filled with neighbor atom indices within cutoff1
- neighbor_matrix_shifts1 : Filled with corresponding shift vectors for cutoff1
- num_neighbors1 : Updated with neighbor counts per atom for cutoff1
- neighbor_matrix2 : Filled with neighbor atom indices within cutoff2
- neighbor_matrix_shifts2 : Filled with corresponding shift vectors for cutoff2
- num_neighbors2 : Updated with neighbor counts per atom for cutoff2
See Also
--------
_fill_naive_neighbor_matrix_dual_cutoff : Version without periodic boundary conditions
_fill_naive_neighbor_matrix_pbc : Single cutoff PBC version
"""
ishift, iatom = wp.tid()
jatom_start = 0
jatom_end = positions.shape[0]
maxnb1 = neighbor_matrix1.shape[1]
maxnb2 = neighbor_matrix2.shape[1]
# Get the atom coordinates and shift vector
_positions = positions[iatom]
_cell = cell[0]
_shift = shifts[ishift]
positions_shifted = wp.transpose(_cell) * type(_cell[0])(_shift) + _positions
_zero_shift = _shift[0] == 0 and _shift[1] == 0 and _shift[2] == 0
if _zero_shift:
jatom_end = iatom
for jatom in range(jatom_start, jatom_end):
diff = positions_shifted - positions[jatom]
dist_sq = wp.length_sq(diff)
if dist_sq < cutoff2_sq:
_update_neighbor_matrix_pbc(
jatom,
iatom,
neighbor_matrix2,
neighbor_matrix_shifts2,
num_neighbors2,
_shift,
maxnb2,
half_fill,
)
if dist_sq < cutoff1_sq:
_update_neighbor_matrix_pbc(
jatom,
iatom,
neighbor_matrix1,
neighbor_matrix_shifts1,
num_neighbors1,
_shift,
maxnb1,
half_fill,
)
## Generate overloads for all kernels
T = [wp.float32, wp.float64, wp.float16]
V = [wp.vec3f, wp.vec3d, wp.vec3h]
M = [wp.mat33f, wp.mat33d, wp.mat33h]
_fill_naive_neighbor_matrix_dual_cutoff_overload = {}
_fill_naive_neighbor_matrix_pbc_dual_cutoff_overload = {}
for t, v, m in zip(T, V, M):
_fill_naive_neighbor_matrix_dual_cutoff_overload[t] = wp.overload(
_fill_naive_neighbor_matrix_dual_cutoff,
[
wp.array(dtype=v),
t,
t,
wp.array2d(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.bool,
],
)
_fill_naive_neighbor_matrix_pbc_dual_cutoff_overload[t] = wp.overload(
_fill_naive_neighbor_matrix_pbc_dual_cutoff,
[
wp.array(dtype=v),
t,
t,
wp.array(dtype=m),
wp.array(dtype=wp.vec3i),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=wp.vec3i),
wp.array2d(dtype=wp.vec3i),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.bool,
],
)
################################################################################
########################### Dual cutoff ########################################
################################################################################
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_no_pbc_no_alloc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"num_neighbors1",
"neighbor_matrix2",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_no_pbc_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
neighbor_matrix1: torch.Tensor,
num_neighbors1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
num_neighbors2: torch.Tensor,
half_fill: bool = False,
) -> None:
"""Fill two neighbor matrices for atoms using dual cutoffs with naive O(N^2) algorithm.
Custom PyTorch operator that computes pairwise distances and fills
two neighbor matrices with atom indices within different cutoff distances
simultaneously. This is more efficient than running two separate neighbor
calculations when both neighbor lists are needed. No periodic boundary
conditions are applied.
This function is torch compilable.
This function does not allocate any tensors.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
Atomic coordinates in Cartesian space. Each row represents one atom's
(x, y, z) position.
cutoff1 : float
First cutoff distance in Cartesian units (typically the smaller cutoff).
Must be positive. Atoms within this distance are considered neighbors.
cutoff2 : float
Second cutoff distance in Cartesian units (typically the larger cutoff).
Must be positive and should be >= cutoff1 for optimal performance.
max_neighbors1 : int
Maximum number of neighbors per atom for the first neighbor matrix.
Must be positive. If exceeded, excess neighbors are ignored.
max_neighbors2 : int
Maximum number of neighbors per atom for the second neighbor matrix.
Must be positive. If exceeded, excess neighbors are ignored.
neighbor_matrix1 : torch.Tensor, shape (total_atoms, max_neighbors1), dtype=torch.int32
OUTPUT: First neighbor matrix for cutoff1 to be filled with atom indices.
Must be pre-allocated. Entries are filled with atom indices.
num_neighbors1 : torch.Tensor, shape (total_atoms,), dtype=torch.int32
OUTPUT: Number of neighbors found for each atom within cutoff1.
Must be pre-allocated. Updated in-place with actual neighbor counts.
neighbor_matrix2 : torch.Tensor, shape (total_atoms, max_neighbors2), dtype=torch.int32
OUTPUT: Second neighbor matrix for cutoff2 to be filled with atom indices.
Must be pre-allocated. Entries are filled with atom indices.
num_neighbors2 : torch.Tensor, shape (total_atoms,), dtype=torch.int32
OUTPUT: Number of neighbors found for each atom within cutoff2.
Must be pre-allocated. Updated in-place with actual neighbor counts.
half_fill : bool
If True, only store relationships where i < j to avoid double counting.
If False, store all neighbor relationships symmetrically.
Default is False.
Returns
-------
None
This function modifies the input tensors in-place:
- neighbor_matrix1 : Filled with neighbor atom indices within cutoff1
- num_neighbors1 : Updated with neighbor counts per atom for cutoff1
- neighbor_matrix2 : Filled with neighbor atom indices within cutoff2
- num_neighbors2 : Updated with neighbor counts per atom for cutoff2
See Also
--------
_naive_neighbor_matrix_dual_cutoff_no_pbc : Higher-level wrapper function
_naive_neighbor_matrix_no_pbc_no_alloc : Single cutoff version
"""
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp.launch(
kernel=_fill_naive_neighbor_matrix_dual_cutoff_overload[wp_dtype],
dim=positions.shape[0],
inputs=[
wp_positions,
wp_dtype(cutoff1 * cutoff1),
wp_dtype(cutoff2 * cutoff2),
wp_neighbor_matrix1,
wp_num_neighbors1,
wp_neighbor_matrix2,
wp_num_neighbors2,
half_fill,
],
device=wp.device_from_torch(device),
)
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_pbc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"neighbor_matrix2",
"neighbor_matrix_shifts1",
"neighbor_matrix_shifts2",
"num_neighbors1",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_pbc_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
cell: torch.Tensor,
neighbor_matrix1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
neighbor_matrix_shifts1: torch.Tensor,
neighbor_matrix_shifts2: torch.Tensor,
num_neighbors1: torch.Tensor,
num_neighbors2: torch.Tensor,
shift_range_per_dimension: torch.Tensor,
shift_offset: torch.Tensor,
total_shifts: int,
half_fill: bool = False,
) -> None:
"""Compute two neighbor matrices with periodic boundary conditions using dual cutoffs and naive O(N^2) algorithm.
Custom PyTorch operator that computes neighbor relationships between atoms
across periodic boundaries for two different cutoff distances simultaneously.
Uses pre-computed shift vectors for torch compilation compatibility. This is
more efficient than running two separate PBC neighbor calculations.
This function is torch compilable.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
Atomic coordinates in Cartesian space. Each row represents one atom's
(x, y, z) position.
cutoff1 : float
First cutoff distance in Cartesian units (typically the smaller cutoff).
Must be positive. Atoms within this distance are considered neighbors.
cutoff2 : float
Second cutoff distance in Cartesian units (typically the larger cutoff).
Must be positive and should be >= cutoff1 for optimal performance.
cell : torch.Tensor, shape (1, 3, 3), dtype=torch.float32 or torch.float64
Cell matrix defining lattice vectors in Cartesian coordinates.
Single 3x3 matrix representing the periodic cell.
pbc : torch.Tensor, shape (1, 3), dtype=torch.bool
Periodic boundary condition flags for each dimension.
True enables periodicity in that direction.
neighbor_matrix1 : torch.Tensor, shape (total_atoms, max_neighbors1), dtype=torch.int32
OUTPUT: First neighbor matrix for cutoff1 to be filled with atom indices.
Must be pre-allocated. Entries are filled with atom indices.
neighbor_matrix2 : torch.Tensor, shape (total_atoms, max_neighbors2), dtype=torch.int32
OUTPUT: Second neighbor matrix for cutoff2 to be filled with atom indices.
Must be pre-allocated. Entries are filled with atom indices.
neighbor_matrix_shifts1 : torch.Tensor, shape (total_atoms, max_neighbors1, 3), dtype=torch.int32
OUTPUT: Shift vectors for each neighbor relationship in matrix1.
Must be pre-allocated. Entries are filled with shift vectors.
neighbor_matrix_shifts2 : torch.Tensor, shape (total_atoms, max_neighbors2, 3), dtype=torch.int32
OUTPUT: Shift vectors for each neighbor relationship in matrix2.
Must be pre-allocated. Entries are filled with shift vectors.
num_neighbors1 : torch.Tensor, shape (total_atoms,), dtype=torch.int32
OUTPUT: Number of neighbors found for each atom within cutoff1.
Must be pre-allocated. Updated in-place with actual neighbor counts.
num_neighbors2 : torch.Tensor, shape (total_atoms,), dtype=torch.int32
OUTPUT: Number of neighbors found for each atom within cutoff2.
Must be pre-allocated. Updated in-place with actual neighbor counts.
shift_range_per_dimension : torch.Tensor, shape (1, 3), dtype=torch.int32
Shift range in each dimension for each system.
shift_offset : torch.Tensor, shape (2,), dtype=torch.int32
Cumulative sum of number of shifts for each system.
total_shifts : int
Total number of shifts.
half_fill : bool
If True, only store relationships where i < j to avoid double counting.
If False, store all neighbor relationships symmetrically.
Default is False.
See Also
--------
_naive_neighbor_matrix_dual_cutoff_pbc : Higher-level wrapper function
_compute_total_shifts : Computes the required shift vectors
"""
total_atoms = positions.shape[0]
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype)
shifts = torch.empty((total_shifts, 3), dtype=torch.int32, device=device)
wp_shifts = wp.from_torch(shifts, dtype=wp.vec3i, return_ctype=True)
shift_system_idx = torch.empty((total_shifts,), dtype=torch.int32, device=device)
wp_shift_system_idx = wp.from_torch(
shift_system_idx, dtype=wp.int32, return_ctype=True
)
wp.launch(
kernel=_expand_naive_shifts,
dim=1,
inputs=[
wp.from_torch(shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True),
wp.from_torch(shift_offset, dtype=wp.int32, return_ctype=True),
wp_shifts,
wp_shift_system_idx,
],
device=wp.device_from_torch(device),
)
# Initialize neighbor matrices
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix_shifts1 = wp.from_torch(
neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix_shifts2 = wp.from_torch(
neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
)
wp.launch(
kernel=_fill_naive_neighbor_matrix_pbc_dual_cutoff_overload[wp_dtype],
dim=(total_shifts, total_atoms),
inputs=[
wp_positions,
wp_dtype(cutoff1 * cutoff1),
wp_dtype(cutoff2 * cutoff2),
wp_cell,
wp_shifts,
wp_neighbor_matrix1,
wp_neighbor_matrix2,
wp_neighbor_matrix_shifts1,
wp_neighbor_matrix_shifts2,
wp_num_neighbors1,
wp_num_neighbors2,
half_fill,
],
device=wp.device_from_torch(device),
)
[docs]
def naive_neighbor_list_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
pbc: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
max_neighbors1: int | None = None,
max_neighbors2: int | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
neighbor_matrix1: torch.Tensor | None = None,
neighbor_matrix2: torch.Tensor | None = None,
neighbor_matrix_shifts1: torch.Tensor | None = None,
neighbor_matrix_shifts2: torch.Tensor | None = None,
num_neighbors1: torch.Tensor | None = None,
num_neighbors2: torch.Tensor | None = None,
shift_range_per_dimension: torch.Tensor | None = None,
shift_offset: torch.Tensor | None = None,
total_shifts: int | None = None,
) -> (
tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
"""Compute neighbor list using naive O(N^2) algorithm with dual cutoffs.
Identifies all atom pairs within two different cutoff distances using a
single brute-force pairwise distance calculation. This is more efficient
than running two separate neighbor calculations when both neighbor lists are needed.
Supports both non-periodic and periodic boundary conditions.
For non-pbc systems, this function is torch compilable. For pbc systems,
precompute the shift vectors using _compute_total_shifts to maintain torch compilability.
Parameters
----------
positions : torch.Tensor, shape (total_atoms, 3), dtype=torch.float32 or torch.float64
Atomic coordinates in Cartesian space. Each row represents one atom's
(x, y, z) position.
cutoff1 : float
First cutoff distance in Cartesian units (typically the smaller cutoff).
Must be positive. Atoms within this distance are considered neighbors.
cutoff2 : float
Second cutoff distance in Cartesian units (typically the larger cutoff).
Must be positive and should be >= cutoff1 for optimal performance.
pbc : torch.Tensor, shape (num_systems, 3), dtype=torch.bool, optional
Periodic boundary condition flags for each dimension.
True enables periodicity in that direction. Default is None (no PBC).
cell : torch.Tensor, shape (num_systems, 3, 3), dtype=torch.float32 or torch.float64, optional
Cell matrices defining lattice vectors in Cartesian coordinates.
Required if pbc is provided. Default is None.
max_neighbors1 : int
Maximum number of neighbors per atom for the first neighbor matrix.
Must be positive. If exceeded, excess neighbors are ignored.
max_neighbors2 : int, optional
Maximum number of neighbors per atom for the second neighbor matrix.
If None, defaults to max_neighbors1. Must be positive if provided.
half_fill : bool, optional
If True, only store relationships where i < j to avoid double counting.
If False, store all neighbor relationships symmetrically. Default is False.
fill_value : int | None, optional
Value to fill the neighbor matrices with. Default is total_atoms.
return_neighbor_list : bool, optional - default = False
If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by
creating a mask over the fill_value, which can incur a performance penalty.
We recommend using the neighbor matrix format,
and only convert to a neighbor list format if absolutely necessary.
neighbor_matrix1 : torch.Tensor, shape (total_atoms, max_neighbors1), dtype=torch.int32, optional
First neighbor matrix for cutoff1 to be filled with fill_value.
Must be pre-allocated. Entries are filled with fill_value.
neighbor_matrix2 : torch.Tensor, shape (total_atoms, max_neighbors2), dtype=torch.int32, optional
Second neighbor matrix for cutoff2 to be filled with fill_value.
Must be pre-allocated. Entries are filled with fill_value.
neighbor_matrix_shifts1 : torch.Tensor, shape (total_atoms, max_neighbors1, 3), dtype=torch.int32, optional
Shift vectors for each neighbor relationship in the first matrix.
Must be pre-allocated. Entries are filled with shift vectors.
neighbor_matrix_shifts2 : torch.Tensor, shape (total_atoms, max_neighbors2, 3), dtype=torch.int32, optional
Shift vectors for each neighbor relationship in the second matrix.
Must be pre-allocated. Entries are filled with shift vectors.
num_neighbors1 : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional
Number of neighbors found for each atom within cutoff1.
Must be pre-allocated. Updated in-place with actual neighbor counts.
num_neighbors2 : torch.Tensor, shape (total_atoms,), dtype=torch.int32, optional
Number of neighbors found for each atom within cutoff2.
Must be pre-allocated. Updated in-place with actual neighbor counts.
shift_range_per_dimension : torch.Tensor, shape (1, 3), dtype=torch.int32, optional
Shift range in each dimension for each system.
shift_offset : torch.Tensor, shape (2,), dtype=torch.int32, optional
Cumulative sum of number of shifts for each system.
total_shifts : int, optional
Total number of shifts.
Returns
-------
results : tuple of torch.Tensor
Variable-length tuple with interleaved results for cutoff1 and cutoff2. The return pattern follows:
- No PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)``
- No PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)``
- With PBC, matrix format: ``(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)``
- With PBC, list format: ``(neighbor_list1, neighbor_ptr1, neighbor_list_shifts1, neighbor_list2, neighbor_ptr2, neighbor_list_shifts2)``
**Components returned (interleaved for each cutoff):**
- **neighbor_data1, neighbor_data2** (tensors): Neighbor indices, format depends on ``return_neighbor_list``:
* If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix1`` and ``neighbor_matrix2``
with shapes (total_atoms, max_neighbors1) and (total_atoms, max_neighbors2), dtype int32.
Each row i contains indices of atom i's neighbors within the respective cutoff.
* If ``return_neighbor_list=True``: Returns ``neighbor_list1`` and ``neighbor_list2`` with shapes
(2, num_pairs1) and (2, num_pairs2), dtype int32, in COO format [source_atoms, target_atoms].
- **num_neighbor_data1, num_neighbor_data2** (tensors): Information about the number of neighbors for each atom,
format depends on ``return_neighbor_list``:
* If ``return_neighbor_list=False`` (default): Returns ``num_neighbors1`` and ``num_neighbors2`` with shape (total_atoms,), dtype int32.
Count of neighbors found for each atom within cutoff1 and cutoff2 respectively.
* If ``return_neighbor_list=True``: Returns ``neighbor_ptr1`` and ``neighbor_ptr2`` with shape (total_atoms + 1,), dtype int32.
CSR-style pointer arrays where ``neighbor_ptr_data[i]`` to ``neighbor_ptr_data[i+1]`` gives the range of
neighbors for atom i in the flattened neighbor list.
- **neighbor_shift_data1, neighbor_shift_data2** (tensors, optional): Periodic shift vectors, only when ``pbc`` is provided:
format depends on ``return_neighbor_list``:
* If ``return_neighbor_list=False`` (default): Returns ``neighbor_matrix_shifts1`` and ``neighbor_matrix_shifts2`` with
shape (total_atoms, max_neighbors1, 3) and (total_atoms, max_neighbors2, 3), dtype int32.
* If ``return_neighbor_list=True``: Returns ``unit_shifts1`` and ``unit_shifts2`` with shapes
(num_pairs1, 3) and (num_pairs2, 3), dtype int32.
Examples
--------
Basic usage with dual cutoffs:
>>> import torch
>>> positions = torch.rand(100, 3) * 10.0 # 100 atoms in 10x10x10 box
>>> cutoff1 = 2.0 # Short-range interactions
>>> cutoff2 = 4.0 # Long-range interactions
>>> max_neighbors1, max_neighbors2 = 20, 50
>>>
>>> results = naive_neighbor_list_dual_cutoff(
... positions, cutoff1, cutoff2, max_neighbors1=max_neighbors1, max_neighbors2=max_neighbors2
... )
>>> neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2 = results
>>> print(f"Short-range pairs: {num_neighbors1.sum()}")
>>> print(f"Long-range pairs: {num_neighbors2.sum()}")
With periodic boundary conditions:
>>> cell = torch.eye(3).unsqueeze(0) * 10.0 # 10x10x10 cubic cell
>>> pbc = torch.tensor([[True, True, True]]) # Periodic in all directions
>>> results = naive_neighbor_list_dual_cutoff(
... positions, cutoff1, cutoff2, max_neighbors1=max_neighbors1, max_neighbors2=max_neighbors2,
... pbc=pbc, cell=cell
... )
>>> (neighbor_matrix1, num_neighbors1, shifts1,
... neighbor_matrix2, num_neighbors2, shifts2) = results
Preallocate tensors for non-pbc systems:
>>> max_neighbors1 = 100
>>> neighbor_matrix1 = torch.zeros((positions.shape[0], max_neighbors1), dtype=torch.int32, device=positions.device)
>>> neighbor_matrix2 = torch.zeros((positions.shape[0], max_neighbors2), dtype=torch.int32, device=positions.device)
>>> num_neighbors1 = torch.zeros(positions.shape[0], dtype=torch.int32, device=positions.device)
>>> num_neighbors2 = torch.zeros(positions.shape[0], dtype=torch.int32, device=positions.device)
>>> naive_neighbor_list_dual_cutoff(
... positions, cutoff1, cutoff2,
... max_neighbors1=max_neighbors1, max_neighbors2=max_neighbors2,
... neighbor_matrix1=neighbor_matrix1, neighbor_matrix2=neighbor_matrix2,
... num_neighbors1=num_neighbors1, num_neighbors2=num_neighbors2
... )
>>> print(f"Short-range pairs: {num_neighbors1.sum()}")
>>> print(f"Long-range pairs: {num_neighbors2.sum()}")
Preallocate tensors for pbc systems:
>>> shift_range_per_dimension, shift_offset, total_shifts = compute_naive_num_shifts(
... cell, cutoff1, pbc
... )
>>> naive_neighbor_list_dual_cutoff(
... positions, cutoff1, cutoff2,
... max_neighbors1=max_neighbors1, max_neighbors2=max_neighbors2,
... shift_range_per_dimension=shift_range_per_dimension, shift_offset=shift_offset, total_shifts=total_shifts
... neighbor_matrix1=neighbor_matrix1, neighbor_matrix2=neighbor_matrix2,
... num_neighbors1=num_neighbors1, num_neighbors2=num_neighbors2
... )
>>> print(f"Short-range pairs: {num_neighbors1.sum()}")
>>> print(f"Long-range pairs: {num_neighbors2.sum()}")
See Also
--------
naive_neighbor_list : Single cutoff version
batch_neighbor_list_dual_cutoff : Batch version for multiple systems
"""
if pbc is None and cell is not None:
raise ValueError("If cell is provided, pbc must also be provided")
if pbc is not None and cell is None:
raise ValueError("If pbc is provided, cell must also be provided")
if cell is not None:
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
if pbc is not None:
pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0)
if fill_value is None:
fill_value = positions.shape[0]
if max_neighbors1 is None and (
neighbor_matrix1 is None
or neighbor_matrix2 is None
or (neighbor_matrix_shifts1 is None and pbc is not None)
or (neighbor_matrix_shifts2 is None and pbc is not None)
or num_neighbors1 is None
or num_neighbors2 is None
):
max_neighbors2 = estimate_max_neighbors(cutoff2)
max_neighbors1 = max_neighbors2
if max_neighbors2 is None:
max_neighbors2 = max_neighbors1
if neighbor_matrix1 is None:
neighbor_matrix1 = torch.full(
(positions.shape[0], max_neighbors1),
fill_value,
dtype=torch.int32,
device=positions.device,
)
else:
neighbor_matrix1.fill_(fill_value)
if num_neighbors1 is None:
num_neighbors1 = torch.zeros(
positions.shape[0], dtype=torch.int32, device=positions.device
)
else:
num_neighbors1.zero_()
if neighbor_matrix2 is None:
neighbor_matrix2 = torch.full(
(positions.shape[0], max_neighbors2),
fill_value,
dtype=torch.int32,
device=positions.device,
)
else:
neighbor_matrix2.fill_(fill_value)
if num_neighbors2 is None:
num_neighbors2 = torch.zeros(
positions.shape[0], dtype=torch.int32, device=positions.device
)
else:
num_neighbors2.zero_()
if pbc is not None:
if neighbor_matrix_shifts1 is None:
neighbor_matrix_shifts1 = torch.zeros(
(positions.shape[0], max_neighbors1, 3),
dtype=torch.int32,
device=positions.device,
)
else:
neighbor_matrix_shifts1.zero_()
if neighbor_matrix_shifts2 is None:
neighbor_matrix_shifts2 = torch.zeros(
(positions.shape[0], max_neighbors2, 3),
dtype=torch.int32,
device=positions.device,
)
else:
neighbor_matrix_shifts2.zero_()
if (
total_shifts is None
or shift_offset is None
or shift_range_per_dimension is None
):
shift_range_per_dimension, shift_offset, total_shifts = (
compute_naive_num_shifts(cell, cutoff2, pbc)
)
if pbc is None:
_naive_neighbor_matrix_no_pbc_dual_cutoff(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
neighbor_matrix1=neighbor_matrix1,
num_neighbors1=num_neighbors1,
neighbor_matrix2=neighbor_matrix2,
num_neighbors2=num_neighbors2,
half_fill=half_fill,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1, num_neighbors=num_neighbors1, fill_value=fill_value
)
neighbor_list2, neighbor_ptr2 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2, num_neighbors=num_neighbors2, fill_value=fill_value
)
return (
neighbor_list1,
neighbor_ptr1,
neighbor_list2,
neighbor_ptr2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix2,
num_neighbors2,
)
else:
_naive_neighbor_matrix_pbc_dual_cutoff(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
cell=cell,
neighbor_matrix1=neighbor_matrix1,
neighbor_matrix2=neighbor_matrix2,
neighbor_matrix_shifts1=neighbor_matrix_shifts1,
neighbor_matrix_shifts2=neighbor_matrix_shifts2,
num_neighbors1=num_neighbors1,
num_neighbors2=num_neighbors2,
shift_range_per_dimension=shift_range_per_dimension,
shift_offset=shift_offset,
total_shifts=total_shifts,
half_fill=half_fill,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1, unit_shifts1 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1,
num_neighbors=num_neighbors1,
neighbor_shift_matrix=neighbor_matrix_shifts1,
fill_value=fill_value,
)
)
neighbor_list2, neighbor_ptr2, unit_shifts2 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2,
num_neighbors=num_neighbors2,
neighbor_shift_matrix=neighbor_matrix_shifts2,
fill_value=fill_value,
)
)
return (
neighbor_list1,
neighbor_ptr1,
unit_shifts1,
neighbor_list2,
neighbor_ptr2,
unit_shifts2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix_shifts1,
neighbor_matrix2,
num_neighbors2,
neighbor_matrix_shifts2,
)