Source code for nvalchemiops.jax.neighbors.batch_cell_list

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX bindings for batched cell list O(N) neighbor list construction."""

from __future__ import annotations

import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel

from nvalchemiops.jax.neighbors.neighbor_utils import (
    allocate_cell_list,
    get_neighbor_list_from_neighbor_matrix,
    prepare_batch_idx_ptr,
)
from nvalchemiops.neighbors.batch_cell_list import (
    _batch_cell_list_bin_atoms_overload,
    _batch_cell_list_build_neighbor_matrix_overload,
    _batch_cell_list_build_neighbor_matrix_selective_overload,
    _batch_cell_list_construct_bin_size_overload,
    _batch_cell_list_count_atoms_per_bin_overload,
    _compute_cells_per_system,
)
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors

# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================

# Build step 1: Construct bin sizes (per system)
_jax_batch_construct_bin_size_f32 = jax_kernel(
    _batch_cell_list_construct_bin_size_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["cells_per_dimension"],
    enable_backward=False,
)
_jax_batch_construct_bin_size_f64 = jax_kernel(
    _batch_cell_list_construct_bin_size_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["cells_per_dimension"],
    enable_backward=False,
)

# Helper: Compute cells per system
_jax_compute_cells_per_system = jax_kernel(
    _compute_cells_per_system,
    num_outputs=1,
    in_out_argnames=["cells_per_system"],
    enable_backward=False,
)

# Build step 2: Count atoms per bin
_jax_batch_count_atoms_per_bin_f32 = jax_kernel(
    _batch_cell_list_count_atoms_per_bin_overload[wp.float32],
    num_outputs=2,
    in_out_argnames=["atoms_per_cell_count", "atom_periodic_shifts"],
    enable_backward=False,
)
_jax_batch_count_atoms_per_bin_f64 = jax_kernel(
    _batch_cell_list_count_atoms_per_bin_overload[wp.float64],
    num_outputs=2,
    in_out_argnames=["atoms_per_cell_count", "atom_periodic_shifts"],
    enable_backward=False,
)

# Build step 3: Bin atoms into cells
_jax_batch_bin_atoms_f32 = jax_kernel(
    _batch_cell_list_bin_atoms_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["atom_to_cell_mapping", "atoms_per_cell_count", "cell_atom_list"],
    enable_backward=False,
)
_jax_batch_bin_atoms_f64 = jax_kernel(
    _batch_cell_list_bin_atoms_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["atom_to_cell_mapping", "atoms_per_cell_count", "cell_atom_list"],
    enable_backward=False,
)

# Query: Build neighbor matrix from batch cell list
_jax_batch_build_neighbor_matrix_f32 = jax_kernel(
    _batch_cell_list_build_neighbor_matrix_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)
_jax_batch_build_neighbor_matrix_f64 = jax_kernel(
    _batch_cell_list_build_neighbor_matrix_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)

# Selective query: Build neighbor matrix from batch cell list (skips non-rebuilt systems)
_jax_batch_build_neighbor_matrix_selective_f32 = jax_kernel(
    _batch_cell_list_build_neighbor_matrix_selective_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)
_jax_batch_build_neighbor_matrix_selective_f64 = jax_kernel(
    _batch_cell_list_build_neighbor_matrix_selective_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)

__all__ = [
    "batch_cell_list",
    "batch_build_cell_list",
    "batch_query_cell_list",
    "estimate_batch_cell_list_sizes",
]


[docs] def estimate_batch_cell_list_sizes( positions: jax.Array, batch_ptr: jax.Array | None = None, batch_idx: jax.Array | None = None, cell: jax.Array | None = None, cutoff: float = 5.0, pbc: jax.Array | None = None, buffer_factor: float = 1.5, ) -> tuple[int, jax.Array, jax.Array]: """Estimate required batch cell list sizes. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates. batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32, optional Cumulative atom counts. batch_idx : jax.Array, shape (total_atoms,), dtype=int32, optional Batch indices for each atom. cell : jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional Cell matrices for each system. cutoff : float, optional Cutoff distance. Default is 5.0. pbc : jax.Array, shape (num_systems, 3), dtype=bool, optional PBC flags. buffer_factor : float, optional Buffer multiplier. Default is 1.5. Returns ------- max_total_cells : int Maximum total cells to allocate. cells_per_dimension : jax.Array, shape (num_systems, 3) Cells per dimension for each system. neighbor_search_radius : jax.Array, shape (num_systems, 3) Search radius for each system. .. warning:: This function is **not compatible with** ``jax.jit``. The returned ``max_total_cells`` is used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When using ``batch_cell_list`` or ``batch_build_cell_list`` inside ``jax.jit``, provide ``max_total_cells`` explicitly to bypass this function. """ # Prepare batch info batch_idx, batch_ptr = prepare_batch_idx_ptr( batch_idx, batch_ptr, positions.shape[0] ) num_systems = batch_ptr.shape[0] - 1 # Simple estimation per system max_total_cells = 0 cells_per_dim_list = [] search_radius_list = [] for sys_idx in range(num_systems): start_idx = batch_ptr[sys_idx] end_idx = batch_ptr[sys_idx + 1] num_atoms_in_sys = end_idx - start_idx if num_atoms_in_sys == 0: cells_per_dim_list.append(jnp.ones(3, dtype=jnp.int32)) search_radius_list.append(jnp.ones(3, dtype=jnp.int32)) continue # Volume estimation if cell is not None: det = jnp.linalg.det(cell[sys_idx]) volume = jnp.abs(det) else: volume = 1000.0 # Default assumption cell_volume = cutoff**3 # TODO: This estimation derives array sizes from traced input data (cell # geometry), which is fundamentally incompatible with jax.jit compilation. # The JAX bindings need a refactored usage pattern where sizing is always # performed outside the JIT boundary, or a fixed upper-bound allocation # strategy is adopted. num_cells_est = max(int(volume / cell_volume * buffer_factor), 8) max_total_cells += num_cells_est cells_per_dim = jnp.ceil(num_cells_est ** (1 / 3)).astype(jnp.int32) cells_per_dim_list.append(cells_per_dim * jnp.ones(3, dtype=jnp.int32)) search_radius_list.append(jnp.ones(3, dtype=jnp.int32)) cells_per_dimension = jnp.stack(cells_per_dim_list, axis=0) neighbor_search_radius = jnp.stack(search_radius_list, axis=0) return max_total_cells, cells_per_dimension, neighbor_search_radius
[docs] def batch_build_cell_list( positions: jax.Array, batch_idx: jax.Array | None = None, batch_ptr: jax.Array | None = None, cell: jax.Array | None = None, pbc: jax.Array | None = None, cutoff: float = 5.0, max_total_cells: int | None = None, ) -> tuple[ jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, ]: """Build spatial cell lists for batch of systems. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates. batch_idx : jax.Array, shape (total_atoms,), dtype=int32, optional Batch indices. batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32, optional Cumulative atom counts. cell : jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional Cell matrices. pbc : jax.Array, shape (num_systems, 3), dtype=bool, optional PBC flags. cutoff : float, optional Cutoff distance. Default is 5.0. max_total_cells : int, optional Maximum cells. If None, will be estimated. Returns ------- cells_per_dimension : jax.Array, shape (num_systems, 3), dtype=int32 Number of cells in x, y, z directions for each system. atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32 Periodic boundary crossings for each atom. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom. atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32 Number of atoms in each cell. cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32 Starting index in ``cell_atom_list`` for each cell. cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32 Flattened list of atom indices organized by cell. neighbor_search_radius : jax.Array, shape (num_systems, 3), dtype=int32 Search radius in neighboring cells for each system. cell_origin : jax.Array, shape (3,), dtype same as positions Cell origin point (currently zeros). Notes ----- When calling inside ``jax.jit``, ``max_total_cells`` **must** be provided to avoid calling ``estimate_batch_cell_list_sizes``, which is not JIT-compatible. """ # Prepare batch info batch_idx, batch_ptr = prepare_batch_idx_ptr( batch_idx, batch_ptr, positions.shape[0] ) num_systems = batch_ptr.shape[0] - 1 if max_total_cells is None: max_total_cells, cells_per_dim_est, neighbor_search_radius = ( estimate_batch_cell_list_sizes( positions, batch_ptr, batch_idx, cell, cutoff, pbc ) ) # Ensure neighbor_search_radius is on the correct device neighbor_search_radius = neighbor_search_radius else: neighbor_search_radius = jnp.ones((num_systems, 3), dtype=jnp.int32) # Allocate cell list tensors ( cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, ) = allocate_cell_list( positions.shape[0], max_total_cells, neighbor_search_radius, ) # Select kernels based on dtype if positions.dtype == jnp.float64: _construct = _jax_batch_construct_bin_size_f64 _count = _jax_batch_count_atoms_per_bin_f64 _bin = _jax_batch_bin_atoms_f64 else: _construct = _jax_batch_construct_bin_size_f32 _count = _jax_batch_count_atoms_per_bin_f32 _bin = _jax_batch_bin_atoms_f32 positions = positions.astype(jnp.float32) # Ensure cell dtype matches positions if cell is not None and cell.dtype != positions.dtype: cell = cell.astype(positions.dtype) # Ensure pbc is bool with shape (num_systems, 3) if pbc is not None: pbc_bool = pbc.astype(jnp.bool_) else: pbc_bool = jnp.ones((num_systems, 3), dtype=jnp.bool_) total_atoms = positions.shape[0] # Step 1: Construct bin sizes (one thread per system) (cells_per_dimension,) = _construct( cell, pbc_bool, cells_per_dimension, float(cutoff), int(max_total_cells), launch_dims=(num_systems,), ) # Step 2: Compute cells_per_system and cell_offsets cells_per_system = jnp.zeros(num_systems, dtype=jnp.int32) (cells_per_system,) = _jax_compute_cells_per_system( cells_per_dimension, cells_per_system, launch_dims=(num_systems,), ) cell_offsets = jnp.concatenate( [ jnp.array([0], dtype=jnp.int32), jnp.cumsum(cells_per_system[:-1], dtype=jnp.int32), ] ) # Step 3: Count atoms per bin atoms_per_cell_count, atom_periodic_shifts = _count( positions, cell, pbc_bool, batch_idx, cells_per_dimension, cell_offsets, atoms_per_cell_count, atom_periodic_shifts, launch_dims=(total_atoms,), ) # Step 4: Compute exclusive prefix sum (replaces wp.utils.array_scan) cell_atom_start_indices = jnp.concatenate( [ jnp.array([0], dtype=jnp.int32), jnp.cumsum(atoms_per_cell_count[:-1], dtype=jnp.int32), ] ) # Step 5: Zero counts before second pass atoms_per_cell_count = jnp.zeros_like(atoms_per_cell_count) # Step 6: Bin atoms atom_to_cell_mapping, atoms_per_cell_count, cell_atom_list = _bin( positions, cell, pbc_bool, batch_idx, cells_per_dimension, cell_offsets, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, launch_dims=(total_atoms,), ) cell_origin = jnp.zeros(3, dtype=positions.dtype) return ( cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, cell_origin, )
[docs] def batch_query_cell_list( positions: jax.Array, batch_idx: jax.Array | None = None, batch_ptr: jax.Array | None = None, cutoff: float = 5.0, cell: jax.Array | None = None, pbc: jax.Array | None = None, cells_per_dimension: jax.Array | None = None, atom_periodic_shifts: jax.Array | None = None, atom_to_cell_mapping: jax.Array | None = None, cell_atom_start_indices: jax.Array | None = None, cell_atom_list: jax.Array | None = None, atoms_per_cell_count: jax.Array | None = None, neighbor_search_radius: jax.Array | None = None, max_neighbors: int | None = None, neighbor_matrix: jax.Array | None = None, num_neighbors: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, rebuild_flags: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array]: """Query batch cell lists to find neighbors. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates. batch_idx : jax.Array, shape (total_atoms,), dtype=int32, optional Batch indices. batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32, optional Cumulative atom counts. cutoff : float, optional Cutoff distance. cell : jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional Cell matrices. pbc : jax.Array, shape (num_systems, 3), dtype=bool, optional PBC flags. cells_per_dimension : jax.Array, shape (num_systems, 3), dtype=int32, optional Cells per dimension. atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32, optional Periodic shifts for each atom (output from ``batch_build_cell_list``). atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32, optional Cell mappings. cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32, optional Start indices. cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32, optional Cell atom list. atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32, optional Number of atoms assigned to each cell. Output from ``batch_build_cell_list``. neighbor_search_radius : jax.Array, shape (num_systems, 3), dtype=int32, optional Search radius. max_neighbors : int, optional Maximum neighbors per atom. neighbor_matrix : jax.Array, shape (total_atoms, max_neighbors), dtype=int32, optional Pre-allocated neighbor matrix. num_neighbors : jax.Array, shape (total_atoms,), dtype=int32, optional Pre-allocated neighbors count array. neighbor_matrix_shifts : jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional Pre-allocated shift vectors array. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Returns ------- neighbor_matrix : jax.Array, shape (total_atoms, max_neighbors), dtype=int32 Neighbor matrix. num_neighbors : jax.Array, shape (total_atoms,), dtype=int32 Neighbors count. neighbor_matrix_shifts : jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32 Periodic shifts for each neighbor relationship. """ if max_neighbors is None: max_neighbors = estimate_max_neighbors(cutoff) # Prepare batch info batch_idx, batch_ptr = prepare_batch_idx_ptr( batch_idx, batch_ptr, positions.shape[0] ) num_systems = batch_ptr.shape[0] - 1 if neighbor_matrix is None: neighbor_matrix = jnp.full( (positions.shape[0], max_neighbors), positions.shape[0], dtype=jnp.int32, ) elif rebuild_flags is None: neighbor_matrix = neighbor_matrix.at[:].set(jnp.int32(positions.shape[0])) if num_neighbors is None: num_neighbors = jnp.zeros(positions.shape[0], dtype=jnp.int32) elif rebuild_flags is None: num_neighbors = num_neighbors.at[:].set(jnp.int32(0)) # Select kernel based on dtype if positions.dtype == jnp.float64: _query_kernel = _jax_batch_build_neighbor_matrix_f64 _query_kernel_selective = _jax_batch_build_neighbor_matrix_selective_f64 else: _query_kernel = _jax_batch_build_neighbor_matrix_f32 _query_kernel_selective = _jax_batch_build_neighbor_matrix_selective_f32 positions = positions.astype(jnp.float32) # Ensure cell dtype matches positions if cell is not None and cell.dtype != positions.dtype: cell = cell.astype(positions.dtype) # Ensure pbc is bool with shape (num_systems, 3) if pbc is not None: pbc_bool = pbc.astype(jnp.bool_) else: pbc_bool = jnp.ones((num_systems, 3), dtype=jnp.bool_) total_atoms = positions.shape[0] if neighbor_matrix_shifts is None: neighbor_matrix_shifts = jnp.zeros( (total_atoms, max_neighbors, 3), dtype=jnp.int32, ) elif rebuild_flags is None: neighbor_matrix_shifts = neighbor_matrix_shifts.at[:].set(jnp.int32(0)) if atoms_per_cell_count is None: max_total_cells = cell_atom_start_indices.shape[0] atoms_per_cell_count = jnp.zeros(max_total_cells, dtype=jnp.int32) # Compute cell_offsets from cells_per_dimension cells_per_system = jnp.prod(cells_per_dimension, axis=1) cell_offsets = jnp.concatenate( [ jnp.array([0], dtype=jnp.int32), jnp.cumsum(cells_per_system[:-1], dtype=jnp.int32), ] ) batch_idx_i32 = batch_idx.astype(jnp.int32) if rebuild_flags is not None: rf = rebuild_flags.astype(jnp.bool_) atom_rebuild = rf[batch_idx_i32] num_neighbors = jnp.where( atom_rebuild, jnp.zeros_like(num_neighbors), num_neighbors ) neighbor_matrix, neighbor_matrix_shifts, num_neighbors = ( _query_kernel_selective( positions, cell, pbc_bool, batch_idx_i32, float(cutoff), cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, cell_offsets, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, False, # half_fill rf, launch_dims=(total_atoms,), ) ) else: neighbor_matrix, neighbor_matrix_shifts, num_neighbors = _query_kernel( positions, cell, pbc_bool, batch_idx_i32, float(cutoff), cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, cell_offsets, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, False, # half_fill launch_dims=(total_atoms,), ) return neighbor_matrix, num_neighbors, neighbor_matrix_shifts
[docs] def batch_cell_list( positions: jax.Array, cutoff: float, cell: jax.Array | None = None, pbc: jax.Array | None = None, batch_idx: jax.Array | None = None, batch_ptr: jax.Array | None = None, max_neighbors: int | None = None, max_total_cells: int | None = None, neighbor_matrix_shifts: jax.Array | None = None, return_neighbor_list: bool = False, ) -> tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, tuple]: """Build and query spatial cell lists for batch of systems. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates. cutoff : float Cutoff distance for neighbor detection. cell : jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional Cell matrices defining lattice vectors. Default is identity matrix. pbc : jax.Array, shape (num_systems, 3), dtype=bool, optional Periodic boundary condition flags. Default is all True. batch_idx : jax.Array, shape (total_atoms,), dtype=int32, optional Batch indices for each atom. batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32, optional Cumulative atom counts defining system boundaries. max_neighbors : int, optional Maximum number of neighbors per atom. If None, will be estimated. max_total_cells : int, optional Maximum number of cells to allocate. If None, will be estimated. neighbor_matrix_shifts : jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional Pre-allocated shift vectors array. If None, will be allocated internally. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. return_neighbor_list : bool, optional If True, convert result to COO neighbor list format. Default is False. Returns ------- neighbor_data : jax.Array If ``return_neighbor_list=False`` (default): ``neighbor_matrix`` with shape (total_atoms, max_neighbors), dtype int32. If ``return_neighbor_list=True``: ``neighbor_list`` with shape (2, num_pairs), dtype int32, in COO format. neighbor_count : jax.Array If ``return_neighbor_list=False``: ``num_neighbors`` with shape (total_atoms,), dtype int32. If ``return_neighbor_list=True``: ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32. shift_data : jax.Array If ``return_neighbor_list=False`` (default): ``neighbor_matrix_shifts`` with shape (total_atoms, max_neighbors, 3), dtype int32. If ``return_neighbor_list=True``: ``neighbor_list_shifts`` with shape (num_pairs, 3), dtype int32. Periodic shift vectors for each neighbor relationship. See Also -------- batch_build_cell_list : Build cell list separately batch_query_cell_list : Query cell list separately batch_naive_neighbor_list : Naive O(N^2) method """ # Prepare batch info batch_idx, batch_ptr = prepare_batch_idx_ptr( batch_idx, batch_ptr, positions.shape[0] ) # Build cell list ( cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, cell_origin, ) = batch_build_cell_list( positions, batch_idx=batch_idx, batch_ptr=batch_ptr, cell=cell, pbc=pbc, cutoff=cutoff, max_total_cells=max_total_cells, ) # Query cell list neighbor_matrix, num_neighbors, neighbor_matrix_shifts = batch_query_cell_list( positions=positions, batch_idx=batch_idx, batch_ptr=batch_ptr, cutoff=cutoff, cell=cell, pbc=pbc, cells_per_dimension=cells_per_dimension, atom_periodic_shifts=atom_periodic_shifts, atom_to_cell_mapping=atom_to_cell_mapping, atoms_per_cell_count=atoms_per_cell_count, cell_atom_start_indices=cell_atom_start_indices, cell_atom_list=cell_atom_list, neighbor_search_radius=neighbor_search_radius, max_neighbors=max_neighbors, neighbor_matrix_shifts=neighbor_matrix_shifts, ) if return_neighbor_list: neighbor_list, neighbor_ptr, neighbor_list_shifts = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, neighbor_shift_matrix=neighbor_matrix_shifts, fill_value=positions.shape[0], ) ) return neighbor_list, neighbor_ptr, neighbor_list_shifts else: return neighbor_matrix, num_neighbors, neighbor_matrix_shifts