# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX utilities for neighbor list construction.
This module contains JAX-specific helper functions for neighbor list operations.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel
from nvalchemiops.neighbors.neighbor_utils import (
NeighborOverflowError,
_compute_naive_num_shifts_overload,
estimate_max_neighbors,
)
__all__ = [
"compute_naive_num_shifts",
"get_neighbor_list_from_neighbor_matrix",
"prepare_batch_idx_ptr",
"allocate_cell_list",
"estimate_max_neighbors",
"NeighborOverflowError",
]
# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================
# Wrap the original kernel overloads with jax_kernel
# jax_kernel handles the bool-to-int conversion internally
_jax_compute_naive_num_shifts_f32 = jax_kernel(
_compute_naive_num_shifts_overload[wp.float32],
num_outputs=2,
in_out_argnames=["num_shifts", "shift_range"],
enable_backward=False,
)
_jax_compute_naive_num_shifts_f64 = jax_kernel(
_compute_naive_num_shifts_overload[wp.float64],
num_outputs=2,
in_out_argnames=["num_shifts", "shift_range"],
enable_backward=False,
)
# ==============================================================================
# Public API
# ==============================================================================
[docs]
def compute_naive_num_shifts(
cell: jax.Array,
cutoff: float,
pbc: jax.Array,
) -> tuple[jax.Array, jax.Array, int]:
"""Compute periodic image shifts needed for neighbor searching.
Parameters
----------
cell : jax.Array, shape (num_systems, 3, 3)
Cell matrices defining lattice vectors in Cartesian coordinates.
Each 3x3 matrix represents one system's periodic cell.
cutoff : float
Cutoff distance for neighbor searching in Cartesian units.
Must be positive and typically less than half the minimum cell dimension.
pbc : jax.Array, shape (num_systems, 3), dtype=bool
Periodic boundary condition flags for each dimension.
True enables periodicity in that direction.
Returns
-------
shift_range : jax.Array, shape (num_systems, 3), dtype=int32
Maximum shift indices in each dimension for each system.
num_shifts : jax.Array, shape (num_systems,), dtype=int32
Number of periodic shifts for each system.
max_shifts : int
Maximum per-system shift count across all systems.
Raises
------
ValueError
If any per-system shift count exceeds int32 range.
See Also
--------
nvalchemiops.neighbors.neighbor_utils._compute_naive_num_shifts : Warp kernel
Notes
-----
This function must be called outside ``jax.jit`` scope. The returned
``max_shifts`` is a Python int needed for determining launch dimensions,
which cannot be traced. This is an inherent limitation: array shapes must
be known at trace time in JAX.
"""
num_systems = cell.shape[0]
# Allocate outputs as JAX arrays
num_shifts_i32 = jnp.zeros(num_systems, dtype=jnp.int32)
shift_range = jnp.zeros((num_systems, 3), dtype=jnp.int32)
# Ensure pbc is bool dtype (jax_kernel handles bool arrays directly)
pbc_bool = pbc.astype(jnp.bool_)
# Select the appropriate kernel based on input dtype
if cell.dtype == jnp.float64:
cell_f64 = cell.astype(jnp.float64)
num_shifts_i32, shift_range = _jax_compute_naive_num_shifts_f64(
cell_f64,
float(cutoff),
pbc_bool,
num_shifts_i32,
shift_range,
launch_dims=(num_systems,),
)
else:
cell_f32 = cell.astype(jnp.float32)
num_shifts_i32, shift_range = _jax_compute_naive_num_shifts_f32(
cell_f32,
float(cutoff),
pbc_bool,
num_shifts_i32,
shift_range,
launch_dims=(num_systems,),
)
s = shift_range.astype(jnp.int64)
k1 = 2 * s[:, 1] + 1
k2 = 2 * s[:, 2] + 1
num_shifts_i64 = s[:, 0] * k1 * k2 + s[:, 1] * k2 + s[:, 2] + 1
max_shifts_i64 = int(num_shifts_i64.max()) if num_systems > 0 else 0
if max_shifts_i64 > 2**31 - 1:
raise ValueError(
f"Per-system shift count ({max_shifts_i64}) exceeds int32 max "
f"(2^31 - 1). Reduce the cutoff, increase cell size, or use a "
f"cell-list method for very small cells."
)
num_shifts = num_shifts_i64.astype(jnp.int32)
return shift_range, num_shifts, int(max_shifts_i64)
[docs]
def get_neighbor_list_from_neighbor_matrix(
neighbor_matrix: jax.Array,
num_neighbors: jax.Array,
neighbor_shift_matrix: jax.Array | None = None,
fill_value: int = -1,
) -> tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array]:
"""Convert neighbor matrix format to neighbor list format.
Parameters
----------
neighbor_matrix : jax.Array, shape (total_atoms, max_neighbors), dtype=int32
The neighbor matrix with neighbor atom indices.
num_neighbors : jax.Array, shape (total_atoms,), dtype=int32
The number of neighbors for each atom.
neighbor_shift_matrix : jax.Array | None, shape (total_atoms, max_neighbors, 3), dtype=int32
Optional neighbor shift matrix with periodic shift vectors.
fill_value : int, default=-1
The fill value used in the neighbor matrix to indicate empty slots.
This is used to create a mask from the neighbor matrix.
Returns
-------
neighbor_list : jax.Array, shape (2, num_pairs), dtype=int32
The neighbor list in COO format [source_atoms, target_atoms].
neighbor_ptr : jax.Array, shape (total_atoms + 1,), dtype=int32
CSR-style pointer array where neighbor_ptr[i]:neighbor_ptr[i+1] gives the range of
neighbors for atom i in the flattened neighbor list.
neighbor_list_shifts : jax.Array, shape (num_pairs, 3), dtype=int32
The neighbor shift vectors (only returned if neighbor_shift_matrix is not None).
Raises
------
ValueError
If the max number of neighbors is larger than the neighbor matrix width.
Notes
-----
This is a pure JAX utility function with no warp dependencies. It converts
from the fixed-width matrix format to the variable-width list format by masking
out fill values and flattening the result.
See Also
--------
nvalchemiops.jax.neighbors.naive.naive_neighbor_list : Uses this for format conversion
nvalchemiops.jax.neighbors.cell_list.cell_list : Uses this for format conversion
"""
# Handle empty case
if neighbor_matrix.shape[0] == 0:
neighbor_list = jnp.zeros((2, 0), dtype=neighbor_matrix.dtype)
neighbor_ptr = jnp.zeros(1, dtype=jnp.int32)
if neighbor_shift_matrix is not None:
neighbor_shift_list = jnp.empty((0, 3), dtype=neighbor_shift_matrix.dtype)
return neighbor_list, neighbor_ptr, neighbor_shift_list
else:
return neighbor_list, neighbor_ptr
# Validate that the neighbor matrix is large enough
# Note: This check only works outside jax.jit scope; inside jit it's skipped
# because max_found would be a tracer and int() conversion fails.
max_found = jnp.max(num_neighbors)
try:
if int(max_found) > neighbor_matrix.shape[1]:
raise NeighborOverflowError(
neighbor_matrix.shape[1],
int(max_found),
)
except (
jax.errors.ConcretizationTypeError,
jax.errors.TracerIntegerConversionError,
):
pass # Skip validation during jax.jit tracing
# Create mask and extract neighbor pairs
mask = neighbor_matrix != fill_value
dtype = neighbor_matrix.dtype
i_idx = jnp.where(mask)[0].astype(dtype)
j_idx = neighbor_matrix[mask].astype(dtype)
neighbor_list = jnp.stack([i_idx, j_idx], axis=0)
# Create CSR-style pointer array
neighbor_ptr = jnp.zeros(num_neighbors.shape[0] + 1, dtype=jnp.int32)
neighbor_ptr = neighbor_ptr.at[1:].set(jnp.cumsum(num_neighbors, dtype=jnp.int32))
if neighbor_shift_matrix is not None:
neighbor_list_shifts = neighbor_shift_matrix[mask]
return neighbor_list, neighbor_ptr, neighbor_list_shifts
else:
return neighbor_list, neighbor_ptr
[docs]
def prepare_batch_idx_ptr(
batch_idx: jax.Array | None,
batch_ptr: jax.Array | None,
num_atoms: int,
) -> tuple[jax.Array, jax.Array]:
"""Prepare batch index and pointer tensors from either representation.
Utility function to ensure both batch_idx and batch_ptr are available,
computing one from the other if needed.
Parameters
----------
batch_idx : jax.Array | None, shape (total_atoms,), dtype=int32
Array indicating the batch index for each atom.
batch_ptr : jax.Array | None, shape (num_systems + 1,), dtype=int32
Array indicating the start index of each batch in the atom list.
num_atoms : int
Total number of atoms across all systems.
Returns
-------
batch_idx : jax.Array, shape (total_atoms,), dtype=int32
Prepared batch index tensor.
batch_ptr : jax.Array, shape (num_systems + 1,), dtype=int32
Prepared batch pointer tensor.
Raises
------
ValueError
If both batch_idx and batch_ptr are None.
Notes
-----
This is a pure JAX utility function with no warp dependencies. It provides
convenience for batch operations by converting between dense (batch_idx) and
sparse (batch_ptr) batch representations.
See Also
--------
nvalchemiops.jax.neighbors.batch_naive.batch_naive_neighbor_list : Uses this for batch setup
nvalchemiops.jax.neighbors.batch_cell_list.batch_cell_list : Uses this for batch setup
"""
if batch_idx is None and batch_ptr is None:
raise ValueError("Either batch_idx or batch_ptr must be provided.")
if batch_idx is None:
num_systems = batch_ptr.shape[0] - 1
num_atoms_per_system = batch_ptr[1:] - batch_ptr[:-1]
batch_idx = jnp.repeat(
jnp.arange(num_systems, dtype=jnp.int32),
num_atoms_per_system,
)
elif batch_ptr is None:
try:
num_systems = int(jnp.max(batch_idx)) + 1
except (
jax.errors.ConcretizationTypeError,
jax.errors.TracerIntegerConversionError,
):
raise ValueError(
"Cannot infer num_systems from batch_idx inside jax.jit. "
"Please provide batch_ptr explicitly when using jax.jit."
) from None
# Use bincount to compute atoms per system
num_atoms_per_system = jnp.bincount(
batch_idx, minlength=num_systems, length=num_systems
)
batch_ptr = jnp.zeros(num_systems + 1, dtype=jnp.int32)
batch_ptr = batch_ptr.at[1:].set(
jnp.cumsum(num_atoms_per_system, dtype=jnp.int32)
)
return batch_idx, batch_ptr
[docs]
def allocate_cell_list(
total_atoms: int,
max_total_cells: int,
neighbor_search_radius: jax.Array,
) -> tuple[
jax.Array,
jax.Array,
jax.Array,
jax.Array,
jax.Array,
jax.Array,
jax.Array,
]:
"""Allocate memory tensors for cell list data structures.
Parameters
----------
total_atoms : int
Total number of atoms across all systems.
max_total_cells : int
Maximum number of cells to allocate.
neighbor_search_radius : jax.Array, shape (3,) or (num_systems, 3), dtype=int32
Radius of neighboring cells to search in each dimension.
Returns
-------
cells_per_dimension : jax.Array, shape (3,) or (num_systems, 3), dtype=int32
Number of cells in x, y, z directions (to be filled by build_cell_list).
neighbor_search_radius : jax.Array, shape (3,) or (num_systems, 3), dtype=int32
Radius of neighboring cells to search (passed through for convenience).
atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32
Periodic boundary crossings for each atom (to be filled by build_cell_list).
atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32
3D cell coordinates for each atom (to be filled by build_cell_list).
atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32
Number of atoms in each cell (to be filled by build_cell_list).
cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32
Starting index in cell_atom_list for each cell (to be filled by build_cell_list).
cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32
Flattened list of atom indices organized by cell (to be filled by build_cell_list).
Notes
-----
This is a pure JAX utility function with no warp dependencies. It pre-allocates
all tensors needed for cell list construction, supporting both single-system and
batched operations based on the shape of neighbor_search_radius.
See Also
--------
nvalchemiops.neighbors.cell_list.build_cell_list : Warp launcher that uses these tensors
nvalchemiops.jax.neighbors.cell_list.build_cell_list : High-level JAX wrapper
nvalchemiops.jax.neighbors.batch_cell_list.batch_build_cell_list : Batched version
"""
# Detect number of systems from neighbor_search_radius shape
is_batched = neighbor_search_radius.ndim == 2
num_systems = neighbor_search_radius.shape[0] if is_batched else 1
cells_per_dimension = jnp.zeros(
(3,) if not is_batched else (num_systems, 3),
dtype=jnp.int32,
)
atom_periodic_shifts = jnp.zeros((total_atoms, 3), dtype=jnp.int32)
atom_to_cell_mapping = jnp.zeros((total_atoms, 3), dtype=jnp.int32)
atoms_per_cell_count = jnp.zeros((max_total_cells,), dtype=jnp.int32)
cell_atom_start_indices = jnp.zeros((max_total_cells,), dtype=jnp.int32)
cell_atom_list = jnp.zeros((total_atoms,), dtype=jnp.int32)
return (
cells_per_dimension,
neighbor_search_radius,
atom_periodic_shifts,
atom_to_cell_mapping,
atoms_per_cell_count,
cell_atom_start_indices,
cell_atom_list,
)