Source code for nvalchemiops.torch.neighbors.batch_naive_dual_cutoff

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch bindings for batched naive dual cutoff neighbor list construction."""

from __future__ import annotations

import torch
import warp as wp

from nvalchemiops.neighbors.batch_naive_dual_cutoff import (
    batch_naive_neighbor_matrix_dual_cutoff,
    batch_naive_neighbor_matrix_pbc_dual_cutoff,
)
from nvalchemiops.neighbors.neighbor_utils import (
    estimate_max_neighbors,
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
    compute_naive_num_shifts,
    get_neighbor_list_from_neighbor_matrix,
    prepare_batch_idx_ptr,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

__all__ = ["batch_naive_neighbor_list_dual_cutoff"]


@torch.library.custom_op(
    "nvalchemiops::_batch_naive_neighbor_matrix_no_pbc_dual_cutoff",
    mutates_args=(
        "neighbor_matrix1",
        "num_neighbors1",
        "neighbor_matrix2",
        "num_neighbors2",
    ),
)
def _batch_naive_neighbor_matrix_no_pbc_dual_cutoff(
    positions: torch.Tensor,
    cutoff1: float,
    cutoff2: float,
    batch_idx: torch.Tensor,
    batch_ptr: torch.Tensor,
    neighbor_matrix1: torch.Tensor,
    num_neighbors1: torch.Tensor,
    neighbor_matrix2: torch.Tensor,
    num_neighbors2: torch.Tensor,
    half_fill: bool,
) -> None:
    """Fill two neighbor matrices for batch using dual cutoffs with naive O(N^2) algorithm.

    This function is torch compilable.

    See Also
    --------
    nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher
    batch_naive_neighbor_list_dual_cutoff : High-level wrapper function
    """
    device = positions.device
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_dtype = get_wp_dtype(positions.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix1 = wp.from_torch(
        neighbor_matrix1, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix2 = wp.from_torch(
        neighbor_matrix2, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)

    batch_naive_neighbor_matrix_dual_cutoff(
        positions=wp_positions,
        cutoff1=cutoff1,
        cutoff2=cutoff2,
        batch_idx=wp_batch_idx,
        batch_ptr=wp_batch_ptr,
        neighbor_matrix1=wp_neighbor_matrix1,
        num_neighbors1=wp_num_neighbors1,
        neighbor_matrix2=wp_neighbor_matrix2,
        num_neighbors2=wp_num_neighbors2,
        wp_dtype=wp_dtype,
        device=str(device),
        half_fill=half_fill,
    )


@torch.library.custom_op(
    "nvalchemiops::_batch_naive_neighbor_matrix_pbc_dual_cutoff",
    mutates_args=(
        "neighbor_matrix1",
        "neighbor_matrix2",
        "neighbor_matrix_shifts1",
        "neighbor_matrix_shifts2",
        "num_neighbors1",
        "num_neighbors2",
    ),
)
def _batch_naive_neighbor_matrix_pbc_dual_cutoff(
    positions: torch.Tensor,
    cell: torch.Tensor,
    cutoff1: float,
    cutoff2: float,
    batch_idx: torch.Tensor,
    batch_ptr: torch.Tensor,
    neighbor_matrix1: torch.Tensor,
    neighbor_matrix2: torch.Tensor,
    neighbor_matrix_shifts1: torch.Tensor,
    neighbor_matrix_shifts2: torch.Tensor,
    num_neighbors1: torch.Tensor,
    num_neighbors2: torch.Tensor,
    shift_range_per_dimension: torch.Tensor,
    num_shifts_per_system: torch.Tensor,
    max_shifts_per_system: int,
    half_fill: bool = False,
    max_atoms_per_system: int | None = None,
    wrap_positions: bool = True,
) -> None:
    """Compute batch neighbor matrices with PBC using dual cutoffs.

    This function is torch compilable.

    See Also
    --------
    nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
    batch_naive_neighbor_list_dual_cutoff : High-level wrapper function
    """
    device = positions.device
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
    wp_dtype = get_wp_dtype(positions.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
    wp_shift_range = wp.from_torch(
        shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_shifts_arr = wp.from_torch(
        num_shifts_per_system, dtype=wp.int32, return_ctype=True
    )
    wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix1 = wp.from_torch(
        neighbor_matrix1, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix2 = wp.from_torch(
        neighbor_matrix2, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix_shifts1 = wp.from_torch(
        neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
    )
    wp_neighbor_matrix_shifts2 = wp.from_torch(
        neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
    wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)

    if max_atoms_per_system is None:
        max_atoms_per_system = (batch_ptr[1:] - batch_ptr[:-1]).max().item()

    batch_naive_neighbor_matrix_pbc_dual_cutoff(
        positions=wp_positions,
        cell=wp_cell,
        cutoff1=cutoff1,
        cutoff2=cutoff2,
        batch_ptr=wp_batch_ptr,
        batch_idx=wp_batch_idx,
        shift_range=wp_shift_range,
        num_shifts_arr=wp_num_shifts_arr,
        max_shifts_per_system=max_shifts_per_system,
        neighbor_matrix1=wp_neighbor_matrix1,
        neighbor_matrix2=wp_neighbor_matrix2,
        neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
        neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
        num_neighbors1=wp_num_neighbors1,
        num_neighbors2=wp_num_neighbors2,
        wp_dtype=wp_dtype,
        device=str(device),
        max_atoms_per_system=max_atoms_per_system,
        half_fill=half_fill,
        wrap_positions=wrap_positions,
    )


@torch.library.custom_op(
    "nvalchemiops::_batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective",
    mutates_args=(
        "neighbor_matrix1",
        "num_neighbors1",
        "neighbor_matrix2",
        "num_neighbors2",
    ),
)
def _batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective(
    positions: torch.Tensor,
    cutoff1: float,
    cutoff2: float,
    batch_idx: torch.Tensor,
    batch_ptr: torch.Tensor,
    neighbor_matrix1: torch.Tensor,
    num_neighbors1: torch.Tensor,
    neighbor_matrix2: torch.Tensor,
    num_neighbors2: torch.Tensor,
    rebuild_flags: torch.Tensor,
    half_fill: bool = False,
) -> None:
    """Selective batched naive dual cutoff neighbor matrix custom op (no PBC).

    Wraps the GPU-side selective kernel: per-system rebuild_flags checked on the
    device — no CPU-GPU synchronisation occurs.

    See Also
    --------
    nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher
    batch_naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
    """
    device = positions.device
    wp_device = wp.device_from_torch(device)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_dtype = get_wp_dtype(positions.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix1 = wp.from_torch(
        neighbor_matrix1, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix2 = wp.from_torch(
        neighbor_matrix2, dtype=wp.int32, return_ctype=True
    )
    wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
    wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)

    batch_naive_neighbor_matrix_dual_cutoff(
        positions=wp_positions,
        cutoff1=cutoff1,
        cutoff2=cutoff2,
        batch_idx=wp_batch_idx,
        batch_ptr=wp_batch_ptr,
        neighbor_matrix1=wp_neighbor_matrix1,
        num_neighbors1=wp_num_neighbors1,
        neighbor_matrix2=wp_neighbor_matrix2,
        num_neighbors2=wp_num_neighbors2,
        wp_dtype=wp_dtype,
        device=str(wp_device),
        half_fill=half_fill,
        rebuild_flags=wp_rebuild_flags,
    )


@torch.library.custom_op(
    "nvalchemiops::_batch_naive_neighbor_matrix_pbc_dual_cutoff_selective",
    mutates_args=(
        "neighbor_matrix1",
        "neighbor_matrix2",
        "neighbor_matrix_shifts1",
        "neighbor_matrix_shifts2",
        "num_neighbors1",
        "num_neighbors2",
    ),
)
def _batch_naive_neighbor_matrix_pbc_dual_cutoff_selective(
    positions: torch.Tensor,
    cell: torch.Tensor,
    cutoff1: float,
    cutoff2: float,
    batch_idx: torch.Tensor,
    batch_ptr: torch.Tensor,
    neighbor_matrix1: torch.Tensor,
    neighbor_matrix2: torch.Tensor,
    neighbor_matrix_shifts1: torch.Tensor,
    neighbor_matrix_shifts2: torch.Tensor,
    num_neighbors1: torch.Tensor,
    num_neighbors2: torch.Tensor,
    shift_range_per_dimension: torch.Tensor,
    num_shifts_per_system: torch.Tensor,
    max_shifts_per_system: int,
    rebuild_flags: torch.Tensor,
    half_fill: bool = False,
    max_atoms_per_system: int | None = None,
    wrap_positions: bool = True,
) -> None:
    """Selective batched naive dual cutoff PBC neighbor matrix custom op.

    Per-system rebuild_flags are checked on the device — no CPU-GPU
    synchronisation occurs.

    See Also
    --------
    nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher
    batch_naive_neighbor_list_dual_cutoff : High-level wrapper that dispatches here when rebuild_flags is set
    """
    device = positions.device
    wp_device = wp.device_from_torch(device)
    wp_vec_dtype = get_wp_vec_dtype(positions.dtype)
    wp_mat_dtype = get_wp_mat_dtype(positions.dtype)
    wp_dtype = get_wp_dtype(positions.dtype)

    wp_positions = wp.from_torch(positions, dtype=wp_vec_dtype, return_ctype=True)
    wp_cell = wp.from_torch(cell, dtype=wp_mat_dtype, return_ctype=True)
    wp_shift_range = wp.from_torch(
        shift_range_per_dimension, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_shifts_arr = wp.from_torch(
        num_shifts_per_system, dtype=wp.int32, return_ctype=True
    )
    wp_batch_idx = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
    wp_batch_ptr = wp.from_torch(batch_ptr, dtype=wp.int32, return_ctype=True)
    wp_neighbor_matrix1 = wp.from_torch(
        neighbor_matrix1, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix2 = wp.from_torch(
        neighbor_matrix2, dtype=wp.int32, return_ctype=True
    )
    wp_neighbor_matrix_shifts1 = wp.from_torch(
        neighbor_matrix_shifts1, dtype=wp.vec3i, return_ctype=True
    )
    wp_neighbor_matrix_shifts2 = wp.from_torch(
        neighbor_matrix_shifts2, dtype=wp.vec3i, return_ctype=True
    )
    wp_num_neighbors1 = wp.from_torch(num_neighbors1, dtype=wp.int32, return_ctype=True)
    wp_num_neighbors2 = wp.from_torch(num_neighbors2, dtype=wp.int32, return_ctype=True)
    wp_rebuild_flags = wp.from_torch(rebuild_flags, dtype=wp.bool, return_ctype=True)

    if max_atoms_per_system is None:
        max_atoms_per_system = (batch_ptr[1:] - batch_ptr[:-1]).max().item()

    batch_naive_neighbor_matrix_pbc_dual_cutoff(
        positions=wp_positions,
        cell=wp_cell,
        cutoff1=cutoff1,
        cutoff2=cutoff2,
        batch_ptr=wp_batch_ptr,
        batch_idx=wp_batch_idx,
        shift_range=wp_shift_range,
        num_shifts_arr=wp_num_shifts_arr,
        max_shifts_per_system=max_shifts_per_system,
        neighbor_matrix1=wp_neighbor_matrix1,
        neighbor_matrix2=wp_neighbor_matrix2,
        neighbor_matrix_shifts1=wp_neighbor_matrix_shifts1,
        neighbor_matrix_shifts2=wp_neighbor_matrix_shifts2,
        num_neighbors1=wp_num_neighbors1,
        num_neighbors2=wp_num_neighbors2,
        wp_dtype=wp_dtype,
        device=str(wp_device),
        max_atoms_per_system=max_atoms_per_system,
        half_fill=half_fill,
        rebuild_flags=wp_rebuild_flags,
        wrap_positions=wrap_positions,
    )


[docs] def batch_naive_neighbor_list_dual_cutoff( positions: torch.Tensor, cutoff1: float, cutoff2: float, batch_idx: torch.Tensor | None = None, batch_ptr: torch.Tensor | None = None, pbc: torch.Tensor | None = None, cell: torch.Tensor | None = None, max_neighbors1: int | None = None, max_neighbors2: int | None = None, half_fill: bool = False, fill_value: int | None = None, return_neighbor_list: bool = False, neighbor_matrix1: torch.Tensor | None = None, neighbor_matrix2: torch.Tensor | None = None, neighbor_matrix_shifts1: torch.Tensor | None = None, neighbor_matrix_shifts2: torch.Tensor | None = None, num_neighbors1: torch.Tensor | None = None, num_neighbors2: torch.Tensor | None = None, shift_range_per_dimension: torch.Tensor | None = None, num_shifts_per_system: torch.Tensor | None = None, max_shifts_per_system: int | None = None, max_atoms_per_system: int | None = None, rebuild_flags: torch.Tensor | None = None, wrap_positions: bool = True, ) -> ( tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ] | tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ): """Compute batch neighbor matrices using naive O(N^2) algorithm with dual cutoffs. See Also -------- nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoff : Core warp launcher (no PBC) nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoff : Core warp launcher (with PBC) batch_naive_neighbor_list : Single cutoff version """ if pbc is None and cell is not None: raise ValueError("If cell is provided, pbc must also be provided") if pbc is not None and cell is None: raise ValueError("If pbc is provided, cell must also be provided") if cell is not None: cell = cell if cell.ndim == 3 else cell.unsqueeze(0) if pbc is not None: pbc = pbc if pbc.ndim == 2 else pbc.unsqueeze(0) if fill_value is None: fill_value = positions.shape[0] if max_neighbors1 is None and ( neighbor_matrix1 is None or neighbor_matrix2 is None or (neighbor_matrix_shifts1 is None and pbc is not None) or (neighbor_matrix_shifts2 is None and pbc is not None) or num_neighbors1 is None or num_neighbors2 is None ): max_neighbors2 = estimate_max_neighbors(cutoff2) max_neighbors1 = max_neighbors2 if max_neighbors2 is None: max_neighbors2 = max_neighbors1 total_atoms = positions.shape[0] if neighbor_matrix1 is None: neighbor_matrix1 = torch.full( (total_atoms, max_neighbors1), fill_value, dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix1.fill_(fill_value) if num_neighbors1 is None: num_neighbors1 = torch.zeros( total_atoms, dtype=torch.int32, device=positions.device ) elif rebuild_flags is None: num_neighbors1.zero_() if neighbor_matrix2 is None: neighbor_matrix2 = torch.full( (total_atoms, max_neighbors2), fill_value, dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix2.fill_(fill_value) if num_neighbors2 is None: num_neighbors2 = torch.zeros( total_atoms, dtype=torch.int32, device=positions.device ) elif rebuild_flags is None: num_neighbors2.zero_() if pbc is not None: if neighbor_matrix_shifts1 is None: neighbor_matrix_shifts1 = torch.zeros( (total_atoms, max_neighbors1, 3), dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix_shifts1.zero_() if neighbor_matrix_shifts2 is None: neighbor_matrix_shifts2 = torch.zeros( (total_atoms, max_neighbors2, 3), dtype=torch.int32, device=positions.device, ) elif rebuild_flags is None: neighbor_matrix_shifts2.zero_() if ( max_shifts_per_system is None or num_shifts_per_system is None or shift_range_per_dimension is None ): shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system = ( compute_naive_num_shifts(cell, cutoff2, pbc) ) batch_idx, batch_ptr = prepare_batch_idx_ptr( batch_idx=batch_idx, batch_ptr=batch_ptr, num_atoms=total_atoms, device=positions.device, ) # Validate batch_idx size matches total_atoms (check here since prepare_batch_idx_ptr # is @torch.compile decorated and the check would be skipped during tracing) if batch_idx.shape[0] != total_atoms: raise RuntimeError( f"batch_idx length ({batch_idx.shape[0]}) does not match " f"num_atoms ({total_atoms}). batch_idx must have one entry per atom." ) if pbc is None: if rebuild_flags is not None: _batch_naive_neighbor_matrix_no_pbc_dual_cutoff_selective( positions=positions, cutoff1=cutoff1, cutoff2=cutoff2, batch_idx=batch_idx, batch_ptr=batch_ptr, neighbor_matrix1=neighbor_matrix1, num_neighbors1=num_neighbors1, neighbor_matrix2=neighbor_matrix2, num_neighbors2=num_neighbors2, rebuild_flags=rebuild_flags, half_fill=half_fill, ) else: _batch_naive_neighbor_matrix_no_pbc_dual_cutoff( positions=positions, cutoff1=cutoff1, cutoff2=cutoff2, batch_idx=batch_idx, batch_ptr=batch_ptr, neighbor_matrix1=neighbor_matrix1, num_neighbors1=num_neighbors1, neighbor_matrix2=neighbor_matrix2, num_neighbors2=num_neighbors2, half_fill=half_fill, ) if return_neighbor_list: neighbor_list1, neighbor_ptr1 = get_neighbor_list_from_neighbor_matrix( neighbor_matrix1, num_neighbors=num_neighbors1, fill_value=fill_value ) neighbor_list2, neighbor_ptr2 = get_neighbor_list_from_neighbor_matrix( neighbor_matrix2, num_neighbors=num_neighbors2, fill_value=fill_value ) return ( neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2, ) else: return ( neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2, ) else: if rebuild_flags is not None: _batch_naive_neighbor_matrix_pbc_dual_cutoff_selective( positions=positions, cell=cell, cutoff1=cutoff1, cutoff2=cutoff2, batch_idx=batch_idx, batch_ptr=batch_ptr, neighbor_matrix1=neighbor_matrix1, neighbor_matrix2=neighbor_matrix2, neighbor_matrix_shifts1=neighbor_matrix_shifts1, neighbor_matrix_shifts2=neighbor_matrix_shifts2, num_neighbors1=num_neighbors1, num_neighbors2=num_neighbors2, shift_range_per_dimension=shift_range_per_dimension, num_shifts_per_system=num_shifts_per_system, max_shifts_per_system=max_shifts_per_system, rebuild_flags=rebuild_flags, half_fill=half_fill, max_atoms_per_system=max_atoms_per_system, wrap_positions=wrap_positions, ) else: _batch_naive_neighbor_matrix_pbc_dual_cutoff( positions=positions, cell=cell, cutoff1=cutoff1, cutoff2=cutoff2, batch_idx=batch_idx, batch_ptr=batch_ptr, neighbor_matrix1=neighbor_matrix1, neighbor_matrix2=neighbor_matrix2, neighbor_matrix_shifts1=neighbor_matrix_shifts1, neighbor_matrix_shifts2=neighbor_matrix_shifts2, num_neighbors1=num_neighbors1, num_neighbors2=num_neighbors2, shift_range_per_dimension=shift_range_per_dimension, num_shifts_per_system=num_shifts_per_system, max_shifts_per_system=max_shifts_per_system, half_fill=half_fill, max_atoms_per_system=max_atoms_per_system, wrap_positions=wrap_positions, ) if return_neighbor_list: neighbor_list1, neighbor_ptr1, unit_shifts1 = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix1, num_neighbors=num_neighbors1, neighbor_shift_matrix=neighbor_matrix_shifts1, fill_value=fill_value, ) ) neighbor_list2, neighbor_ptr2, unit_shifts2 = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix2, num_neighbors=num_neighbors2, neighbor_shift_matrix=neighbor_matrix_shifts2, fill_value=fill_value, ) ) return ( neighbor_list1, neighbor_ptr1, unit_shifts1, neighbor_list2, neighbor_ptr2, unit_shifts2, ) else: return ( neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2, )