# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch bindings for batched naive dual cutoff neighbor list construction."""
from __future__ import annotations
import torch
import warp as wp
from nvalchemiops.neighbors.batch_naive_dual_cutoff import (
batch_naive_neighbor_matrix_dual_cutoff,
batch_naive_neighbor_matrix_pbc_dual_cutoff,
)
from nvalchemiops.neighbors.neighbor_utils import (
estimate_max_neighbors,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
compute_naive_num_shifts,
get_neighbor_list_from_neighbor_matrix,
prepare_batch_idx_ptr,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
__all__ = ["batch_naive_neighbor_list_dual_cutoff"]
@torch.library.custom_op(
"nvalchemiops::_batch_naive_neighbor_matrix_no_pbc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"num_neighbors1",
"neighbor_matrix2",
"num_neighbors2",
),
)
def _batch_naive_neighbor_matrix_no_pbc_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
batch_idx: torch.Tensor,
batch_ptr: torch.Tensor,
neighbor_matrix1: torch.Tensor,
num_neighbors1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
num_neighbors2: torch.Tensor,
half_fill: bool,
) -> None:
"""Fill two neighbor matrices for batch using dual cutoffs with naive O(N^2) algorithm.
This function is torch compilable.
See Also
--------
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher
batch_naive_neighbor_list_dual_cutoff : High-level wrapper function
"""
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
batch_naive_neighbor_matrix_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=wp_batch_idx,
batch_ptr=wp_batch_ptr,
neighbor_matrix1=wp_neighbor_matrix1,
num_neighbors1=wp_num_neighbors1,
neighbor_matrix2=wp_neighbor_matrix2,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(device),
half_fill=half_fill,
)
@torch.library.custom_op(
"nvalchemiops::_batch_naive_neighbor_matrix_pbc_dual_cutoff",
mutates_args=(
"neighbor_matrix1",
"neighbor_matrix2",
"neighbor_matrix_shifts1",
"neighbor_matrix_shifts2",
"num_neighbors1",
"num_neighbors2",
),
)
def _batch_naive_neighbor_matrix_pbc_dual_cutoff(
positions: torch.Tensor,
cell: torch.Tensor,
cutoff1: float,
cutoff2: float,
batch_idx: torch.Tensor,
batch_ptr: torch.Tensor,
neighbor_matrix1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
neighbor_matrix_shifts1: torch.Tensor,
neighbor_matrix_shifts2: torch.Tensor,
num_neighbors1: torch.Tensor,
num_neighbors2: torch.Tensor,
shift_range_per_dimension: torch.Tensor,
num_shifts_per_system: torch.Tensor,
max_shifts_per_system: int,
half_fill: bool = False,
max_atoms_per_system: int | None = None,
wrap_positions: bool = True,
) -> None:
"""Compute batch neighbor matrices with PBC using dual cutoffs.
This function is torch compilable.
See Also
--------
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
batch_naive_neighbor_list_dual_cutoff : High-level wrapper function
"""
device = positions.device
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_shift_range = wp.from_torch(
shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
)
wp_num_shifts_arr = wp.from_torch(
num_shifts_per_system, dtype=wp.int32, return_ctype=True
)
wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix_shifts1 = wp.from_torch(
neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix_shifts2 = wp.from_torch(
neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
if max_atoms_per_system is None:
max_atoms_per_system = (batch_ptr[1:] - batch_ptr[:-1]).max().item()
batch_naive_neighbor_matrix_pbc_dual_cutoff(
positions=wp_positions,
cell=wp_cell,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_ptr=wp_batch_ptr,
batch_idx=wp_batch_idx,
shift_range=wp_shift_range,
num_shifts_arr=wp_num_shifts_arr,
max_shifts_per_system=max_shifts_per_system,
neighbor_matrix1=wp_neighbor_matrix1,
neighbor_matrix2=wp_neighbor_matrix2,
neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
num_neighbors1=wp_num_neighbors1,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(device),
max_atoms_per_system=max_atoms_per_system,
half_fill=half_fill,
wrap_positions=wrap_positions,
)
@torch.library.custom_op(
"nvalchemiops::_batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective",
mutates_args=(
"neighbor_matrix1",
"num_neighbors1",
"neighbor_matrix2",
"num_neighbors2",
),
)
def _batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
batch_idx: torch.Tensor,
batch_ptr: torch.Tensor,
neighbor_matrix1: torch.Tensor,
num_neighbors1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
num_neighbors2: torch.Tensor,
rebuild_flags: torch.Tensor,
half_fill: bool = False,
) -> None:
"""Selective batched naive dual cutoff neighbor matrix custom op (no PBC).
Wraps the GPU-side selective kernel: per-system rebuild_flags checked on the
device — no CPU-GPU synchronisation occurs.
See Also
--------
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher
batch_naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
"""
device = positions.device
wp_device = wp.device_from_torch(device)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)
batch_naive_neighbor_matrix_dual_cutoff(
positions=wp_positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=wp_batch_idx,
batch_ptr=wp_batch_ptr,
neighbor_matrix1=wp_neighbor_matrix1,
num_neighbors1=wp_num_neighbors1,
neighbor_matrix2=wp_neighbor_matrix2,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(wp_device),
half_fill=half_fill,
rebuild_flags=wp_rebuild_flags,
)
@torch.library.custom_op(
"nvalchemiops::_batch_naive_neighbor_matrix_pbc_dual_cutoff_selective",
mutates_args=(
"neighbor_matrix1",
"neighbor_matrix2",
"neighbor_matrix_shifts1",
"neighbor_matrix_shifts2",
"num_neighbors1",
"num_neighbors2",
),
)
def _batch_naive_neighbor_matrix_pbc_dual_cutoff_selective(
positions: torch.Tensor,
cell: torch.Tensor,
cutoff1: float,
cutoff2: float,
batch_idx: torch.Tensor,
batch_ptr: torch.Tensor,
neighbor_matrix1: torch.Tensor,
neighbor_matrix2: torch.Tensor,
neighbor_matrix_shifts1: torch.Tensor,
neighbor_matrix_shifts2: torch.Tensor,
num_neighbors1: torch.Tensor,
num_neighbors2: torch.Tensor,
shift_range_per_dimension: torch.Tensor,
num_shifts_per_system: torch.Tensor,
max_shifts_per_system: int,
rebuild_flags: torch.Tensor,
half_fill: bool = False,
max_atoms_per_system: int | None = None,
wrap_positions: bool = True,
) -> None:
"""Selective batched naive dual cutoff PBC neighbor matrix custom op.
Per-system rebuild_flags are checked on the device — no CPU-GPU
synchronisation occurs.
See Also
--------
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
batch_naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
"""
device = positions.device
wp_device = wp.device_from_torch(device)
wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
wp_dtype = get_wp_dtype(positions.dtype)
wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
wp_shift_range = wp.from_torch(
shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
)
wp_num_shifts_arr = wp.from_torch(
num_shifts_per_system, dtype=wp.int32, return_ctype=True
)
wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
wp_neighbor_matrix1 = wp.from_torch(
neighbor_matrix1, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix2 = wp.from_torch(
neighbor_matrix2, dtype=wp.int32, return_ctype=True
)
wp_neighbor_matrix_shifts1 = wp.from_torch(
neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
)
wp_neighbor_matrix_shifts2 = wp.from_torch(
neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
)
wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)
if max_atoms_per_system is None:
max_atoms_per_system = (batch_ptr[1:] - batch_ptr[:-1]).max().item()
batch_naive_neighbor_matrix_pbc_dual_cutoff(
positions=wp_positions,
cell=wp_cell,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_ptr=wp_batch_ptr,
batch_idx=wp_batch_idx,
shift_range=wp_shift_range,
num_shifts_arr=wp_num_shifts_arr,
max_shifts_per_system=max_shifts_per_system,
neighbor_matrix1=wp_neighbor_matrix1,
neighbor_matrix2=wp_neighbor_matrix2,
neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
num_neighbors1=wp_num_neighbors1,
num_neighbors2=wp_num_neighbors2,
wp_dtype=wp_dtype,
device=str(wp_device),
max_atoms_per_system=max_atoms_per_system,
half_fill=half_fill,
rebuild_flags=wp_rebuild_flags,
wrap_positions=wrap_positions,
)
[docs]
def batch_naive_neighbor_list_dual_cutoff(
positions: torch.Tensor,
cutoff1: float,
cutoff2: float,
batch_idx: torch.Tensor | None = None,
batch_ptr: torch.Tensor | None = None,
pbc: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
max_neighbors1: int | None = None,
max_neighbors2: int | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
neighbor_matrix1: torch.Tensor | None = None,
neighbor_matrix2: torch.Tensor | None = None,
neighbor_matrix_shifts1: torch.Tensor | None = None,
neighbor_matrix_shifts2: torch.Tensor | None = None,
num_neighbors1: torch.Tensor | None = None,
num_neighbors2: torch.Tensor | None = None,
shift_range_per_dimension: torch.Tensor | None = None,
num_shifts_per_system: torch.Tensor | None = None,
max_shifts_per_system: int | None = None,
max_atoms_per_system: int | None = None,
rebuild_flags: torch.Tensor | None = None,
wrap_positions: bool = True,
) -> (
tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
"""Compute batch neighbor matrices using naive O(N^2) algorithm with dual cutoffs.
See Also
--------
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher (no PBC)
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher (with PBC)
batch_naive_neighbor_list : Single cutoff version
"""
if pbc is None and cell is not None:
raise ValueError("If cell is provided, pbc must also be provided")
if pbc is not None and cell is None:
raise ValueError("If pbc is provided, cell must also be provided")
if cell is not None:
cell = cell if cell.ndim == 3 else cell.unsqueeze(0)
if pbc is not None:
pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0)
if fill_value is None:
fill_value = positions.shape[0]
if max_neighbors1 is None and (
neighbor_matrix1 is None
or neighbor_matrix2 is None
or (neighbor_matrix_shifts1 is None and pbc is not None)
or (neighbor_matrix_shifts2 is None and pbc is not None)
or num_neighbors1 is None
or num_neighbors2 is None
):
max_neighbors2 = estimate_max_neighbors(cutoff2)
max_neighbors1 = max_neighbors2
if max_neighbors2 is None:
max_neighbors2 = max_neighbors1
total_atoms = positions.shape[0]
if neighbor_matrix1 is None:
neighbor_matrix1 = torch.full(
(total_atoms, max_neighbors1),
fill_value,
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix1.fill_(fill_value)
if num_neighbors1 is None:
num_neighbors1 = torch.zeros(
total_atoms, dtype=torch.int32, device=positions.device
)
elif rebuild_flags is None:
num_neighbors1.zero_()
if neighbor_matrix2 is None:
neighbor_matrix2 = torch.full(
(total_atoms, max_neighbors2),
fill_value,
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix2.fill_(fill_value)
if num_neighbors2 is None:
num_neighbors2 = torch.zeros(
total_atoms, dtype=torch.int32, device=positions.device
)
elif rebuild_flags is None:
num_neighbors2.zero_()
if pbc is not None:
if neighbor_matrix_shifts1 is None:
neighbor_matrix_shifts1 = torch.zeros(
(total_atoms, max_neighbors1, 3),
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix_shifts1.zero_()
if neighbor_matrix_shifts2 is None:
neighbor_matrix_shifts2 = torch.zeros(
(total_atoms, max_neighbors2, 3),
dtype=torch.int32,
device=positions.device,
)
elif rebuild_flags is None:
neighbor_matrix_shifts2.zero_()
if (
max_shifts_per_system is None
or num_shifts_per_system is None
or shift_range_per_dimension is None
):
shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system = (
compute_naive_num_shifts(cell, cutoff2, pbc)
)
batch_idx, batch_ptr = prepare_batch_idx_ptr(
batch_idx=batch_idx,
batch_ptr=batch_ptr,
num_atoms=total_atoms,
device=positions.device,
)
# Validate batch_idx size matches total_atoms (check here since prepare_batch_idx_ptr
# is @torch.compile decorated and the check would be skipped during tracing)
if batch_idx.shape[0] != total_atoms:
raise RuntimeError(
f"batch_idx length ({batch_idx.shape[0]}) does not match "
f"num_atoms ({total_atoms}). batch_idx must have one entry per atom."
)
if pbc is None:
if rebuild_flags is not None:
_batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
neighbor_matrix1=neighbor_matrix1,
num_neighbors1=num_neighbors1,
neighbor_matrix2=neighbor_matrix2,
num_neighbors2=num_neighbors2,
rebuild_flags=rebuild_flags,
half_fill=half_fill,
)
else:
_batch_naive_neighbor_matrix_no_pbc_dual_cutoff(
positions=positions,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
neighbor_matrix1=neighbor_matrix1,
num_neighbors1=num_neighbors1,
neighbor_matrix2=neighbor_matrix2,
num_neighbors2=num_neighbors2,
half_fill=half_fill,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1, num_neighbors=num_neighbors1, fill_value=fill_value
)
neighbor_list2, neighbor_ptr2 = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2, num_neighbors=num_neighbors2, fill_value=fill_value
)
return (
neighbor_list1,
neighbor_ptr1,
neighbor_list2,
neighbor_ptr2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix2,
num_neighbors2,
)
else:
if rebuild_flags is not None:
_batch_naive_neighbor_matrix_pbc_dual_cutoff_selective(
positions=positions,
cell=cell,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
neighbor_matrix1=neighbor_matrix1,
neighbor_matrix2=neighbor_matrix2,
neighbor_matrix_shifts1=neighbor_matrix_shifts1,
neighbor_matrix_shifts2=neighbor_matrix_shifts2,
num_neighbors1=num_neighbors1,
num_neighbors2=num_neighbors2,
shift_range_per_dimension=shift_range_per_dimension,
num_shifts_per_system=num_shifts_per_system,
max_shifts_per_system=max_shifts_per_system,
rebuild_flags=rebuild_flags,
half_fill=half_fill,
max_atoms_per_system=max_atoms_per_system,
wrap_positions=wrap_positions,
)
else:
_batch_naive_neighbor_matrix_pbc_dual_cutoff(
positions=positions,
cell=cell,
cutoff1=cutoff1,
cutoff2=cutoff2,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
neighbor_matrix1=neighbor_matrix1,
neighbor_matrix2=neighbor_matrix2,
neighbor_matrix_shifts1=neighbor_matrix_shifts1,
neighbor_matrix_shifts2=neighbor_matrix_shifts2,
num_neighbors1=num_neighbors1,
num_neighbors2=num_neighbors2,
shift_range_per_dimension=shift_range_per_dimension,
num_shifts_per_system=num_shifts_per_system,
max_shifts_per_system=max_shifts_per_system,
half_fill=half_fill,
max_atoms_per_system=max_atoms_per_system,
wrap_positions=wrap_positions,
)
if return_neighbor_list:
neighbor_list1, neighbor_ptr1, unit_shifts1 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix1,
num_neighbors=num_neighbors1,
neighbor_shift_matrix=neighbor_matrix_shifts1,
fill_value=fill_value,
)
)
neighbor_list2, neighbor_ptr2, unit_shifts2 = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix2,
num_neighbors=num_neighbors2,
neighbor_shift_matrix=neighbor_matrix_shifts2,
fill_value=fill_value,
)
)
return (
neighbor_list1,
neighbor_ptr1,
unit_shifts1,
neighbor_list2,
neighbor_ptr2,
unit_shifts2,
)
else:
return (
neighbor_matrix1,
num_neighbors1,
neighbor_matrix_shifts1,
neighbor_matrix2,
num_neighbors2,
neighbor_matrix_shifts2,
)