# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX bindings for batched naive O(N^2) neighbor list construction."""
from __future__ import annotations
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel
from nvalchemiops.jax.neighbors.neighbor_utils import (
compute_naive_num_shifts,
get_neighbor_list_from_neighbor_matrix,
prepare_batch_idx_ptr,
)
from nvalchemiops.neighbors.batch_naive import (
_fill_batch_naive_neighbor_matrix_overload,
_fill_batch_naive_neighbor_matrix_pbc_overload,
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload,
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload,
_fill_batch_naive_neighbor_matrix_pbc_selective_overload,
_fill_batch_naive_neighbor_matrix_selective_overload,
)
from nvalchemiops.neighbors.neighbor_utils import (
_wrap_positions_batch_overload,
estimate_max_neighbors,
)
__all__ = ["batch_naive_neighbor_list"]
# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================
# No-PBC batch naive neighbor matrix kernel wrappers
_jax_fill_batch_naive_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_overload[wp.float32],
num_outputs=2,
in_out_argnames=["neighbor_matrix", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_overload[wp.float64],
num_outputs=2,
in_out_argnames=["neighbor_matrix", "num_neighbors"],
enable_backward=False,
)
# PBC batch naive neighbor matrix kernel wrappers
_jax_fill_batch_naive_pbc_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_overload[wp.float32],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_pbc_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_overload[wp.float64],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
# Selective no-PBC batch naive neighbor matrix kernel wrappers
_jax_fill_batch_naive_selective_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_selective_overload[wp.float32],
num_outputs=2,
in_out_argnames=["neighbor_matrix", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_selective_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_selective_overload[wp.float64],
num_outputs=2,
in_out_argnames=["neighbor_matrix", "num_neighbors"],
enable_backward=False,
)
# Selective PBC batch naive neighbor matrix kernel wrappers
_jax_fill_batch_naive_pbc_selective_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_selective_overload[wp.float32],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_pbc_selective_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_selective_overload[wp.float64],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
# Prewrapped PBC batch naive neighbor matrix kernel wrappers
_jax_fill_batch_naive_pbc_prewrapped_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload[wp.float32],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_pbc_prewrapped_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_overload[wp.float64],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_pbc_prewrapped_selective_f32 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload[wp.float32],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
_jax_fill_batch_naive_pbc_prewrapped_selective_f64 = jax_kernel(
_fill_batch_naive_neighbor_matrix_pbc_prewrapped_selective_overload[wp.float64],
num_outputs=3,
in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
enable_backward=False,
)
# Wrap positions batch kernel wrappers
_jax_wrap_positions_batch_f32 = jax_kernel(
_wrap_positions_batch_overload[wp.float32],
num_outputs=2,
in_out_argnames=["positions_wrapped", "per_atom_cell_offsets"],
enable_backward=False,
)
_jax_wrap_positions_batch_f64 = jax_kernel(
_wrap_positions_batch_overload[wp.float64],
num_outputs=2,
in_out_argnames=["positions_wrapped", "per_atom_cell_offsets"],
enable_backward=False,
)
[docs]
def batch_naive_neighbor_list(
positions: jax.Array,
cutoff: float,
batch_idx: jax.Array | None = None,
batch_ptr: jax.Array | None = None,
pbc: jax.Array | None = None,
cell: jax.Array | None = None,
max_neighbors: int | None = None,
half_fill: bool = False,
fill_value: int | None = None,
return_neighbor_list: bool = False,
neighbor_matrix: jax.Array | None = None,
neighbor_matrix_shifts: jax.Array | None = None,
num_neighbors: jax.Array | None = None,
shift_range_per_dimension: jax.Array | None = None,
num_shifts_per_system: jax.Array | None = None,
max_shifts_per_system: int | None = None,
max_atoms_per_system: int | None = None,
rebuild_flags: jax.Array | None = None,
wrap_positions: bool = True,
) -> (
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
| tuple[jax.Array, jax.Array, jax.Array]
| tuple[jax.Array, jax.Array]
):
"""Compute neighbor list for batch of systems using naive O(N^2) algorithm.
Identifies all atom pairs within a specified cutoff distance for each system
independently using a brute-force pairwise distance calculation. Supports both
non-periodic and periodic boundary conditions.
Parameters
----------
positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64
Concatenated Cartesian coordinates for all systems.
cutoff : float
Cutoff distance for neighbor detection in Cartesian units.
Must be positive. Atoms within this distance are considered neighbors.
batch_idx : jax.Array, shape (total_atoms,), dtype=int32, optional
System index for each atom. If None, batch_ptr must be provided.
batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32, optional
Cumulative atom counts defining system boundaries. If None, batch_idx must be provided.
pbc : jax.Array, shape (num_systems, 3), dtype=bool, optional
Periodic boundary condition flags for each system and dimension.
True enables periodicity in that direction. Default is None (no PBC).
cell : jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional
Cell matrices defining lattice vectors. Required if pbc is provided.
max_neighbors : int, optional
Maximum number of neighbors per atom.
half_fill : bool, optional
If True, only store relationships where i < j. Default is False.
fill_value : int, optional
Value to fill the neighbor matrix with. Default is total_atoms.
neighbor_matrix : jax.Array, optional
Pre-allocated neighbor matrix.
neighbor_matrix_shifts : jax.Array, optional
Pre-allocated shift matrix for PBC.
num_neighbors : jax.Array, optional
Pre-allocated neighbors count array.
shift_range_per_dimension : jax.Array, optional
Pre-computed shift range for PBC systems.
num_shifts_per_system : jax.Array, optional
Number of periodic shifts per system.
max_shifts_per_system : int, optional
Maximum per-system shift count (launch dimension).
max_atoms_per_system : int, optional
Maximum atoms in any system.
wrap_positions : bool, default=True
If True, wrap input positions into the primary cell before
neighbor search. Set to False when positions are already
wrapped (e.g. by a preceding integration step) to save two
GPU kernel launches per call.
Returns
-------
results : tuple of jax.Array
Variable-length tuple depending on input parameters.
Examples
--------
Basic usage with batch_ptr:
>>> import jax.numpy as jnp
>>> from nvalchemiops.jax.neighbors import batch_naive_neighbor_list
>>> positions = jnp.zeros((200, 3), dtype=jnp.float32)
>>> batch_ptr = jnp.array([0, 100, 200], dtype=jnp.int32) # 2 systems
>>> cutoff = 2.5
>>> max_neighbors = 50
>>> neighbor_matrix, num_neighbors = batch_naive_neighbor_list(
... positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors
... )
With PBC:
>>> cell = jnp.eye(3, dtype=jnp.float32)[jnp.newaxis, :, :] * 10.0
>>> cell = jnp.repeat(cell, 2, axis=0)
>>> pbc = jnp.ones((2, 3), dtype=jnp.bool_)
>>> neighbor_matrix, num_neighbors, shifts = batch_naive_neighbor_list(
... positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors,
... pbc=pbc, cell=cell
... )
See Also
--------
nvalchemiops.neighbors.batch_naive.batch_naive_neighbor_matrix : Core warp launcher
nvalchemiops.jax.neighbors.naive.naive_neighbor_list : Non-batched version
batch_cell_list : Cell list method for large systems
"""
if pbc is None and cell is not None:
raise ValueError("If cell is provided, pbc must also be provided")
if pbc is not None and cell is None:
raise ValueError("If pbc is provided, cell must also be provided")
# Prepare batch indices and pointers
batch_idx, batch_ptr = prepare_batch_idx_ptr(
batch_idx, batch_ptr, positions.shape[0]
)
num_systems = batch_ptr.shape[0] - 1
if cell is not None:
cell = cell if cell.ndim == 3 else cell[jnp.newaxis, :, :]
# Ensure cell dtype matches positions dtype so warp overload dispatch is consistent
if cell.dtype != positions.dtype:
cell = cell.astype(positions.dtype)
if pbc is not None:
pbc = pbc if pbc.ndim == 2 else pbc[jnp.newaxis, :]
if max_neighbors is None:
max_neighbors = estimate_max_neighbors(cutoff)
if fill_value is None:
fill_value = jnp.int32(positions.shape[0])
if neighbor_matrix is None:
neighbor_matrix = jnp.full(
(positions.shape[0], max_neighbors),
fill_value,
dtype=jnp.int32,
)
elif rebuild_flags is None:
neighbor_matrix = neighbor_matrix.at[:].set(fill_value)
if num_neighbors is None:
num_neighbors = jnp.zeros(positions.shape[0], dtype=jnp.int32)
elif rebuild_flags is None:
num_neighbors = num_neighbors.at[:].set(jnp.int32(0))
if pbc is not None:
if neighbor_matrix_shifts is None:
neighbor_matrix_shifts = jnp.zeros(
(positions.shape[0], max_neighbors, 3),
dtype=jnp.int32,
)
elif rebuild_flags is None:
neighbor_matrix_shifts = neighbor_matrix_shifts.at[:].set(jnp.int32(0))
if (
max_shifts_per_system is None
or num_shifts_per_system is None
or shift_range_per_dimension is None
):
shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system = (
compute_naive_num_shifts(cell, cutoff, pbc)
)
if cutoff <= 0:
if return_neighbor_list:
if pbc is not None:
return (
jnp.zeros((2, 0), dtype=jnp.int32),
jnp.zeros((positions.shape[0] + 1,), dtype=jnp.int32),
jnp.zeros((0, 3), dtype=jnp.int32),
)
else:
return (
jnp.zeros((2, 0), dtype=jnp.int32),
jnp.zeros((positions.shape[0] + 1,), dtype=jnp.int32),
)
else:
if pbc is not None:
return neighbor_matrix, num_neighbors, neighbor_matrix_shifts
else:
return neighbor_matrix, num_neighbors
# Select kernel based on dtype
if positions.dtype == jnp.float64:
_jax_fill = _jax_fill_batch_naive_f64
_jax_fill_pbc = _jax_fill_batch_naive_pbc_f64
_jax_fill_selective = _jax_fill_batch_naive_selective_f64
_jax_fill_pbc_selective = _jax_fill_batch_naive_pbc_selective_f64
_jax_fill_pbc_prewrapped = _jax_fill_batch_naive_pbc_prewrapped_f64
_jax_fill_pbc_prewrapped_selective = (
_jax_fill_batch_naive_pbc_prewrapped_selective_f64
)
_jax_wrap_batch = _jax_wrap_positions_batch_f64
else:
_jax_fill = _jax_fill_batch_naive_f32
_jax_fill_pbc = _jax_fill_batch_naive_pbc_f32
_jax_fill_selective = _jax_fill_batch_naive_selective_f32
_jax_fill_pbc_selective = _jax_fill_batch_naive_pbc_selective_f32
_jax_fill_pbc_prewrapped = _jax_fill_batch_naive_pbc_prewrapped_f32
_jax_fill_pbc_prewrapped_selective = (
_jax_fill_batch_naive_pbc_prewrapped_selective_f32
)
_jax_wrap_batch = _jax_wrap_positions_batch_f32
positions = positions.astype(jnp.float32)
total_atoms = positions.shape[0]
batch_idx_i32 = batch_idx.astype(jnp.int32)
batch_ptr_i32 = batch_ptr.astype(jnp.int32)
if pbc is None:
# No PBC case
if rebuild_flags is not None:
rf = rebuild_flags.astype(jnp.bool_)
atom_rebuild = rf[batch_idx_i32]
num_neighbors = jnp.where(
atom_rebuild, jnp.zeros_like(num_neighbors), num_neighbors
)
neighbor_matrix, num_neighbors = _jax_fill_selective(
positions,
float(cutoff * cutoff),
batch_idx_i32,
batch_ptr_i32,
neighbor_matrix,
num_neighbors,
half_fill,
rf,
launch_dims=(total_atoms,),
)
else:
neighbor_matrix, num_neighbors = _jax_fill(
positions,
float(cutoff * cutoff),
batch_idx_i32,
batch_ptr_i32,
neighbor_matrix,
num_neighbors,
half_fill,
launch_dims=(total_atoms,),
)
else:
if cell.dtype != positions.dtype:
cell = cell.astype(positions.dtype)
if max_atoms_per_system is None:
try:
max_atoms_per_system = int(jnp.max(batch_ptr[1:] - batch_ptr[:-1]))
except (
jax.errors.ConcretizationTypeError,
jax.errors.TracerIntegerConversionError,
):
raise ValueError(
"Cannot infer max_atoms_per_system inside jax.jit. "
"Please provide max_atoms_per_system explicitly when using jax.jit."
) from None
if wrap_positions:
inv_cell = jnp.linalg.inv(cell)
positions_wrapped = jnp.zeros_like(positions)
per_atom_cell_offsets = jnp.zeros((total_atoms, 3), dtype=jnp.int32)
positions_wrapped, per_atom_cell_offsets = _jax_wrap_batch(
positions,
cell,
inv_cell,
batch_idx_i32,
positions_wrapped,
per_atom_cell_offsets,
launch_dims=(total_atoms,),
)
if rebuild_flags is not None:
rf = rebuild_flags.astype(jnp.bool_)
atom_rebuild = rf[batch_idx_i32]
num_neighbors = jnp.where(
atom_rebuild, jnp.zeros_like(num_neighbors), num_neighbors
)
neighbor_matrix, neighbor_matrix_shifts, num_neighbors = (
_jax_fill_pbc_selective(
positions_wrapped,
per_atom_cell_offsets,
cell,
float(cutoff * cutoff),
batch_ptr_i32,
shift_range_per_dimension,
num_shifts_per_system,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
rf,
launch_dims=(
num_systems,
max_shifts_per_system,
max_atoms_per_system,
),
)
)
else:
neighbor_matrix, neighbor_matrix_shifts, num_neighbors = _jax_fill_pbc(
positions_wrapped,
per_atom_cell_offsets,
cell,
float(cutoff * cutoff),
batch_ptr_i32,
shift_range_per_dimension,
num_shifts_per_system,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
launch_dims=(
num_systems,
max_shifts_per_system,
max_atoms_per_system,
),
)
else:
if rebuild_flags is not None:
rf = rebuild_flags.astype(jnp.bool_)
atom_rebuild = rf[batch_idx_i32]
num_neighbors = jnp.where(
atom_rebuild, jnp.zeros_like(num_neighbors), num_neighbors
)
neighbor_matrix, neighbor_matrix_shifts, num_neighbors = (
_jax_fill_pbc_prewrapped_selective(
positions,
cell,
float(cutoff * cutoff),
batch_ptr_i32,
shift_range_per_dimension,
num_shifts_per_system,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
rf,
launch_dims=(
num_systems,
max_shifts_per_system,
max_atoms_per_system,
),
)
)
else:
neighbor_matrix, neighbor_matrix_shifts, num_neighbors = (
_jax_fill_pbc_prewrapped(
positions,
cell,
float(cutoff * cutoff),
batch_ptr_i32,
shift_range_per_dimension,
num_shifts_per_system,
neighbor_matrix,
neighbor_matrix_shifts,
num_neighbors,
half_fill,
launch_dims=(
num_systems,
max_shifts_per_system,
max_atoms_per_system,
),
)
)
if return_neighbor_list:
if pbc is not None:
neighbor_list, neighbor_ptr, neighbor_list_shifts = (
get_neighbor_list_from_neighbor_matrix(
neighbor_matrix,
num_neighbors=num_neighbors,
neighbor_shift_matrix=neighbor_matrix_shifts,
fill_value=fill_value,
)
)
return neighbor_list, neighbor_ptr, neighbor_list_shifts
else:
neighbor_list, neighbor_ptr = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix,
num_neighbors=num_neighbors,
fill_value=fill_value,
)
return neighbor_list, neighbor_ptr
else:
if pbc is not None:
return neighbor_matrix, num_neighbors, neighbor_matrix_shifts
else:
return neighbor_matrix, num_neighbors