# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch bindings for unbatched naive dual cutoff neighbor list construction."""
from __future__ import annotations
import torch
import warp as wp
from nvalchemiops.neighbors.naive_dual_cutoff import (
naive_neighbor_matrix_dual_cutoff,
naive_neighbor_matrix_pbc_dual_cutoff,
)
from nvalchemiops.neighbors.neighbor_utils import (
estimate_max_neighbors,
selective_zero_num_neighbors_single,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
compute_naive_num_shifts,
get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
__all__ = ["naive_neighbor_list_dual_cutoff"]
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_no_pbc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"num_neighbors1",
"neighbor_matrix2",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_no_pbc_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
neighbor_matrix1: torch.Tensor,
num_neighbors1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
num_neighbors2: torch.Tensor,
half_fill: bool = False,
) -> None:
"""Fill two neighbor matrices for atoms using dual cutoffs with naive O(N^2) algorithm.
This function is torch compilable.
See Also
--------
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_dual_cutoff : Core warp launcher
naive_neighbor_list_dual_cutoff : High-level wrapper function
"""
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
naive_neighbor_matrix_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
neighbor_matrix1=wp_neighbor_matrix1,
num_neighbors1=wp_num_neighbors1,
neighbor_matrix2=wp_neighbor_matrix2,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(device),
half_fill=half_fill,
)
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_pbc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"neighbor_matrix2",
"neighbor_matrix_shifts1",
"neighbor_matrix_shifts2",
"num_neighbors1",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_pbc_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
cell: torch.Tensor,
neighbor_matrix1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
neighbor_matrix_shifts1: torch.Tensor,
neighbor_matrix_shifts2: torch.Tensor,
num_neighbors1: torch.Tensor,
num_neighbors2: torch.Tensor,
shift_range_per_dimension: torch.Tensor,
num_shifts_per_system: torch.Tensor,
max_shifts_per_system: int,
half_fill: bool = False,
wrap_positions: bool = True,
) -> None:
"""Compute two neighbor matrices with periodic boundary conditions using dual cutoffs.
This function is torch compilable.
See Also
--------
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
naive_neighbor_list_dual_cutoff : High-level wrapper function
"""
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_shift_range = wp.from_torch(
shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix_shifts1 = wp.from_torch(
neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix_shifts2 = wp.from_torch(
neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
naive_neighbor_matrix_pbc_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
cell=wp_cell,
shift_range=wp_shift_range,
num_shifts=max_shifts_per_system,
neighbor_matrix1=wp_neighbor_matrix1,
neighbor_matrix2=wp_neighbor_matrix2,
neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
num_neighbors1=wp_num_neighbors1,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(device),
half_fill=half_fill,
wrap_positions=wrap_positions,
)
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_no_pbc_dual_cutoff_selective",
mutates_args=(
"neighbor_matrix1",
"num_neighbors1",
"neighbor_matrix2",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_no_pbc_dual_cutoff_selective(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
neighbor_matrix1: torch.Tensor,
num_neighbors1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
num_neighbors2: torch.Tensor,
rebuild_flags: torch.Tensor,
half_fill: bool = False,
) -> None:
"""Selective naive dual cutoff neighbor matrix custom op (no PBC).
Wraps the GPU-side selective kernel: ``rebuild_flags[0]`` is checked on the
device — no CPU-GPU synchronisation occurs.
See Also
--------
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_dual_cutoff : Core warp launcher
naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
"""
device = positions.device
wp_device = wp.device_from_torch(device)
wp_dtype = get_wp_dtype(positions.dtype)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp_rebuild_flags = wp.from_torch(
rebuild_flags.view(-1)[:1].contiguous(), dtype=wp.bool, return_ctype=True
)
selective_zero_num_neighbors_single(
wp_num_neighbors1, wp_rebuild_flags, str(wp_device)
)
selective_zero_num_neighbors_single(
wp_num_neighbors2, wp_rebuild_flags, str(wp_device)
)
naive_neighbor_matrix_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
neighbor_matrix1=wp_neighbor_matrix1,
num_neighbors1=wp_num_neighbors1,
neighbor_matrix2=wp_neighbor_matrix2,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(wp_device),
half_fill=half_fill,
rebuild_flags=wp_rebuild_flags,
)
@torch.library.custom_op(
"nvalchemiops::_naive_neighbor_matrix_pbc_dual_cutoff_selective",
mutates_args=(
"neighbor_matrix1",
"neighbor_matrix2",
"neighbor_matrix_shifts1",
"neighbor_matrix_shifts2",
"num_neighbors1",
"num_neighbors2",
),
)
def _naive_neighbor_matrix_pbc_dual_cutoff_selective(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
cell: torch.Tensor,
neighbor_matrix1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
neighbor_matrix_shifts1: torch.Tensor,
neighbor_matrix_shifts2: torch.Tensor,
num_neighbors1: torch.Tensor,
num_neighbors2: torch.Tensor,
shift_range_per_dimension: torch.Tensor,
num_shifts_per_system: torch.Tensor,
max_shifts_per_system: int,
rebuild_flags: torch.Tensor,
half_fill: bool = False,
wrap_positions: bool = True,
) -> None:
"""Selective naive dual cutoff PBC neighbor matrix custom op.
``rebuild_flags[0]`` is checked on the device — no CPU-GPU synchronisation occurs.
See Also
--------
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
"""
device = positions.device
wp_device = wp.device_from_torch(device)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_shift_range = wp.from_torch(
shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix_shifts1 = wp.from_torch(
neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix_shifts2 = wp.from_torch(
neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp_rebuild_flags = wp.from_torch(
rebuild_flags.view(-1)[:1].contiguous(), dtype=wp.bool, return_ctype=True
)
selective_zero_num_neighbors_single(
wp_num_neighbors1, wp_rebuild_flags, str(wp_device)
)
selective_zero_num_neighbors_single(
wp_num_neighbors2, wp_rebuild_flags, str(wp_device)
)
naive_neighbor_matrix_pbc_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
cell=wp_cell,
shift_range=wp_shift_range,
num_shifts=max_shifts_per_system,
neighbor_matrix1=wp_neighbor_matrix1,
neighbor_matrix2=wp_neighbor_matrix2,
neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
num_neighbors1=wp_num_neighbors1,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(wp_device),
half_fill=half_fill,
rebuild_flags=wp_rebuild_flags,
wrap_positions=wrap_positions,
)
[docs]
def naive_neighbor_list_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
pbc: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
max_neighbors1: int | None = None,
max_neighbors2: int | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
neighbor_matrix1: torch.Tensor | None = None,
neighbor_matrix2: torch.Tensor | None = None,
neighbor_matrix_shifts1: torch.Tensor | None = None,
neighbor_matrix_shifts2: torch.Tensor | None = None,
num_neighbors1: torch.Tensor | None = None,
num_neighbors2: torch.Tensor | None = None,
shift_range_per_dimension: torch.Tensor | None = None,
num_shifts_per_system: torch.Tensor | None = None,
max_shifts_per_system: int | None = None,
rebuild_flags: torch.Tensor | None = None,
wrap_positions: bool = True,
) -> (
tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
"""Compute neighbor list using naive O(N^2) algorithm with dual cutoffs.
Identifies all atom pairs within two different cutoff distances using a
single brute-force pairwise distance calculation. This is more efficient
than running two separate neighbor calculations when both neighbor lists are needed.
See Also
--------
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_dual_cutoff : Core warp launcher (no PBC)
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher (with PBC)
naive_neighbor_list : Single cutoff version
"""
if pbc is None and cell is not None:
raise ValueError("If cell is provided, pbc must also be provided")
if pbc is not None and cell is None:
raise ValueError("If pbc is provided, cell must also be provided")
if cell is not None:
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
if pbc is not None:
pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0)
if fill_value is None:
fill_value = positions.shape[0]
if max_neighbors1 is None and (
neighbor_matrix1 is None
or neighbor_matrix2 is None
or (neighbor_matrix_shifts1 is None and pbc is not None)
or (neighbor_matrix_shifts2 is None and pbc is not None)
or num_neighbors1 is None
or num_neighbors2 is None
):
max_neighbors2 = estimate_max_neighbors(cutoff2)
max_neighbors1 = max_neighbors2
if max_neighbors2 is None:
max_neighbors2 = max_neighbors1
if neighbor_matrix1 is None:
neighbor_matrix1 = torch.full(
(positions.shape[0], max_neighbors1),
fill_value,
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix1.fill_(fill_value)
if num_neighbors1 is None:
num_neighbors1 = torch.zeros(
positions.shape[0], dtype=torch.int32, device=positions.device
)
elif rebuild_flags is None:
num_neighbors1.zero_()
if neighbor_matrix2 is None:
neighbor_matrix2 = torch.full(
(positions.shape[0], max_neighbors2),
fill_value,
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix2.fill_(fill_value)
if num_neighbors2 is None:
num_neighbors2 = torch.zeros(
positions.shape[0], dtype=torch.int32, device=positions.device
)
elif rebuild_flags is None:
num_neighbors2.zero_()
if pbc is not None:
if neighbor_matrix_shifts1 is None:
neighbor_matrix_shifts1 = torch.zeros(
(positions.shape[0], max_neighbors1, 3),
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix_shifts1.zero_()
if neighbor_matrix_shifts2 is None:
neighbor_matrix_shifts2 = torch.zeros(
(positions.shape[0], max_neighbors2, 3),
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix_shifts2.zero_()
if (
max_shifts_per_system is None
or num_shifts_per_system is None
or shift_range_per_dimension is None
):
shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system = (
compute_naive_num_shifts(cell, cutoff2, pbc)
)
if pbc is None:
if rebuild_flags is not None:
_naive_neighbor_matrix_no_pbc_dual_cutoff_selective(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
neighbor_matrix1=neighbor_matrix1,
num_neighbors1=num_neighbors1,
neighbor_matrix2=neighbor_matrix2,
num_neighbors2=num_neighbors2,
rebuild_flags=rebuild_flags,
half_fill=half_fill,
)
else:
_naive_neighbor_matrix_no_pbc_dual_cutoff(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
neighbor_matrix1=neighbor_matrix1,
num_neighbors1=num_neighbors1,
neighbor_matrix2=neighbor_matrix2,
num_neighbors2=num_neighbors2,
half_fill=half_fill,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1, num_neighbors=num_neighbors1, fill_value=fill_value
)
neighbor_list2, neighbor_ptr2 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2, num_neighbors=num_neighbors2, fill_value=fill_value
)
return (
neighbor_list1,
neighbor_ptr1,
neighbor_list2,
neighbor_ptr2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix2,
num_neighbors2,
)
else:
if rebuild_flags is not None:
_naive_neighbor_matrix_pbc_dual_cutoff_selective(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
cell=cell,
neighbor_matrix1=neighbor_matrix1,
neighbor_matrix2=neighbor_matrix2,
neighbor_matrix_shifts1=neighbor_matrix_shifts1,
neighbor_matrix_shifts2=neighbor_matrix_shifts2,
num_neighbors1=num_neighbors1,
num_neighbors2=num_neighbors2,
shift_range_per_dimension=shift_range_per_dimension,
num_shifts_per_system=num_shifts_per_system,
max_shifts_per_system=max_shifts_per_system,
rebuild_flags=rebuild_flags,
half_fill=half_fill,
wrap_positions=wrap_positions,
)
else:
_naive_neighbor_matrix_pbc_dual_cutoff(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
cell=cell,
neighbor_matrix1=neighbor_matrix1,
neighbor_matrix2=neighbor_matrix2,
neighbor_matrix_shifts1=neighbor_matrix_shifts1,
neighbor_matrix_shifts2=neighbor_matrix_shifts2,
num_neighbors1=num_neighbors1,
num_neighbors2=num_neighbors2,
shift_range_per_dimension=shift_range_per_dimension,
num_shifts_per_system=num_shifts_per_system,
max_shifts_per_system=max_shifts_per_system,
half_fill=half_fill,
wrap_positions=wrap_positions,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1, unit_shifts1 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1,
num_neighbors=num_neighbors1,
neighbor_shift_matrix=neighbor_matrix_shifts1,
fill_value=fill_value,
)
)
neighbor_list2, neighbor_ptr2, unit_shifts2 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2,
num_neighbors=num_neighbors2,
neighbor_shift_matrix=neighbor_matrix_shifts2,
fill_value=fill_value,
)
)
return (
neighbor_list1,
neighbor_ptr1,
unit_shifts1,
neighbor_list2,
neighbor_ptr2,
unit_shifts2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix_shifts1,
neighbor_matrix2,
num_neighbors2,
neighbor_matrix_shifts2,
)