Source code for nvalchemiops.jax.neighbors.cell_list

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX bindings for unbatched cell list O(N) neighbor list construction."""

from __future__ import annotations

import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel

from nvalchemiops.jax.neighbors.neighbor_utils import (
    get_neighbor_list_from_neighbor_matrix,
)
from nvalchemiops.neighbors.cell_list import (
    _cell_list_bin_atoms_overload,
    _cell_list_build_neighbor_matrix_overload,
    _cell_list_build_neighbor_matrix_selective_overload,
    _cell_list_construct_bin_size_overload,
    _cell_list_count_atoms_per_bin_overload,
)
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors

# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================

# Build step 1: Construct bin sizes
_jax_construct_bin_size_f32 = jax_kernel(
    _cell_list_construct_bin_size_overload[wp.float32],
    num_outputs=1,
    in_out_argnames=["cells_per_dimension"],
    enable_backward=False,
)
_jax_construct_bin_size_f64 = jax_kernel(
    _cell_list_construct_bin_size_overload[wp.float64],
    num_outputs=1,
    in_out_argnames=["cells_per_dimension"],
    enable_backward=False,
)

# Build step 2: Count atoms per bin
_jax_count_atoms_per_bin_f32 = jax_kernel(
    _cell_list_count_atoms_per_bin_overload[wp.float32],
    num_outputs=2,
    in_out_argnames=["atoms_per_cell_count", "atom_periodic_shifts"],
    enable_backward=False,
)
_jax_count_atoms_per_bin_f64 = jax_kernel(
    _cell_list_count_atoms_per_bin_overload[wp.float64],
    num_outputs=2,
    in_out_argnames=["atoms_per_cell_count", "atom_periodic_shifts"],
    enable_backward=False,
)

# Build step 3: Bin atoms into cells
_jax_bin_atoms_f32 = jax_kernel(
    _cell_list_bin_atoms_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["atom_to_cell_mapping", "atoms_per_cell_count", "cell_atom_list"],
    enable_backward=False,
)
_jax_bin_atoms_f64 = jax_kernel(
    _cell_list_bin_atoms_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["atom_to_cell_mapping", "atoms_per_cell_count", "cell_atom_list"],
    enable_backward=False,
)

# Query: Build neighbor matrix from cell list
_jax_build_neighbor_matrix_f32 = jax_kernel(
    _cell_list_build_neighbor_matrix_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)
_jax_build_neighbor_matrix_f64 = jax_kernel(
    _cell_list_build_neighbor_matrix_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)

# Selective query: Build neighbor matrix from cell list (skips non-rebuilt systems)
_jax_build_neighbor_matrix_selective_f32 = jax_kernel(
    _cell_list_build_neighbor_matrix_selective_overload[wp.float32],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)
_jax_build_neighbor_matrix_selective_f64 = jax_kernel(
    _cell_list_build_neighbor_matrix_selective_overload[wp.float64],
    num_outputs=3,
    in_out_argnames=["neighbor_matrix", "neighbor_matrix_shifts", "num_neighbors"],
    enable_backward=False,
)

__all__ = [
    "cell_list",
    "build_cell_list",
    "query_cell_list",
    "estimate_cell_list_sizes",
]


[docs] def estimate_cell_list_sizes( positions: jax.Array, cell: jax.Array, cutoff: float, pbc: jax.Array | None = None, buffer_factor: float = 1.5, ) -> tuple[int, jax.Array, jax.Array]: """Estimate required cell list sizes based on atomic density. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates in Cartesian space. cell : jax.Array, shape (1, 3, 3), dtype=float32 or float64 Cell matrix defining lattice vectors. cutoff : float Cutoff distance for neighbor searching. pbc : jax.Array, shape (3,) or (1, 3), dtype=bool, optional Periodic boundary condition flags. Default is all True. buffer_factor : float, optional Buffer multiplier for cell count estimation. Default is 1.5. Returns ------- max_total_cells : int Maximum total number of cells to allocate. cells_per_dimension : jax.Array, shape (3,) or (1, 3), dtype=int32 Estimated number of cells in each dimension. neighbor_search_radius : jax.Array, shape (3,), dtype=int32 Estimated search radius in neighboring cells. Notes ----- This function estimates cell list parameters based on atomic positions and density. The actual number of cells used will be determined during cell list construction. .. warning:: This function is **not compatible with** ``jax.jit``. The returned ``max_total_cells`` is used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When using ``cell_list`` or ``build_cell_list`` inside ``jax.jit``, provide ``max_total_cells`` explicitly to bypass this function. """ if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] if pbc is None: pbc = jnp.ones((1, 3), dtype=jnp.bool_) if pbc.ndim == 1: pbc = pbc[jnp.newaxis, :] # Simple estimation: compute total volume and estimate cell volume # Cell volume = det(cell_matrix) det = jnp.linalg.det(cell[0]) volume = jnp.abs(det) cell_volume = cutoff**3 # TODO: This estimation derives array sizes from traced input data (cell # geometry), which is fundamentally incompatible with jax.jit compilation. # The JAX bindings need a refactored usage pattern where sizing is always # performed outside the JIT boundary, or a fixed upper-bound allocation # strategy is adopted. num_cells_est = jnp.int32(volume / cell_volume * buffer_factor) max_total_cells = jnp.max(jnp.array([num_cells_est, 8])) # Minimum 8 cells # Compute cells_per_dimension and neighbor_search_radius from cell geometry, # mirroring the Warp _estimate_cell_list_sizes kernel used by the Torch path. inverse_cell_transpose = jnp.linalg.inv(cell[0]).T face_distances = 1.0 / jnp.linalg.norm(inverse_cell_transpose, axis=1) cells_per_dimension = jnp.maximum(jnp.int32(face_distances / cutoff), 1) pbc_squeezed = pbc.squeeze()[:3] if pbc.ndim > 1 else pbc[:3] neighbor_search_radius = jnp.where( (cells_per_dimension == 1) & ~pbc_squeezed, jnp.zeros(3, dtype=jnp.int32), jnp.int32(jnp.ceil(cutoff * cells_per_dimension / face_distances)), ) return max_total_cells, cells_per_dimension, neighbor_search_radius
[docs] def build_cell_list( positions: jax.Array, cutoff: float, cell: jax.Array, pbc: jax.Array, cells_per_dimension: jax.Array | None = None, neighbor_search_radius: jax.Array | None = None, atom_periodic_shifts: jax.Array | None = None, atom_to_cell_mapping: jax.Array | None = None, atoms_per_cell_count: jax.Array | None = None, cell_atom_start_indices: jax.Array | None = None, cell_atom_list: jax.Array | None = None, max_total_cells: int | None = None, ) -> tuple[ jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, ]: """Build spatial cell list for efficient neighbor searching. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates in Cartesian space. cutoff : float Cutoff distance for neighbor searching. Must be positive. cell : jax.Array, shape (1, 3, 3), dtype=float32 or float64 Cell matrix defining lattice vectors. pbc : jax.Array, shape (3,) or (1, 3), dtype=bool Periodic boundary condition flags. cells_per_dimension : jax.Array, shape (3,), dtype=int32, optional OUTPUT: Number of cells in x, y, z directions. If None, allocated. neighbor_search_radius : jax.Array, shape (3,), dtype=int32, optional Search radius in neighboring cells. If None, allocated. atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32, optional OUTPUT: Periodic boundary crossings for each atom. If None, allocated. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32, optional OUTPUT: 3D cell coordinates for each atom. If None, allocated. atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32, optional OUTPUT: Number of atoms in each cell. If None, allocated. cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32, optional OUTPUT: Starting index in cell_atom_list for each cell. If None, allocated. cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32, optional OUTPUT: Flattened list of atom indices organized by cell. If None, allocated. max_total_cells : int, optional Maximum number of cells to allocate. If None, will be estimated. Returns ------- cells_per_dimension : jax.Array, shape (3,), dtype=int32 Number of cells in x, y, z directions. atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32 Periodic boundary crossings for each atom. atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom. atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32 Number of atoms in each cell. cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32 Starting index in cell_atom_list for each cell. cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32 Flattened list of atom indices organized by cell. neighbor_search_radius : jax.Array, shape (3,), dtype=int32 Search radius in neighboring cells. Notes ----- When calling inside ``jax.jit``, ``max_total_cells`` **must** be provided to avoid calling ``estimate_cell_list_sizes``, which is not JIT-compatible. See Also -------- query_cell_list : Query the built cell list for neighbors """ if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] if pbc.ndim == 1: pbc = pbc[jnp.newaxis, :] if max_total_cells is None: max_total_cells, _, neighbor_search_radius_est = estimate_cell_list_sizes( positions, cell, cutoff, pbc ) if neighbor_search_radius is None: neighbor_search_radius = neighbor_search_radius_est else: if neighbor_search_radius is None: neighbor_search_radius = jnp.ones(3, dtype=jnp.int32) # Allocate cell list tensors if not provided if cells_per_dimension is None: cells_per_dimension = jnp.ones(3, dtype=jnp.int32) if atom_periodic_shifts is None: atom_periodic_shifts = jnp.zeros((positions.shape[0], 3), dtype=jnp.int32) if atom_to_cell_mapping is None: atom_to_cell_mapping = jnp.zeros((positions.shape[0], 3), dtype=jnp.int32) if atoms_per_cell_count is None: atoms_per_cell_count = jnp.zeros(max_total_cells, dtype=jnp.int32) if cell_atom_start_indices is None: cell_atom_start_indices = jnp.zeros(max_total_cells, dtype=jnp.int32) if cell_atom_list is None: cell_atom_list = jnp.zeros(positions.shape[0], dtype=jnp.int32) # Select kernels based on dtype if positions.dtype == jnp.float64: _construct_bin_size = _jax_construct_bin_size_f64 _count_atoms = _jax_count_atoms_per_bin_f64 _bin_atoms = _jax_bin_atoms_f64 else: _construct_bin_size = _jax_construct_bin_size_f32 _count_atoms = _jax_count_atoms_per_bin_f32 _bin_atoms = _jax_bin_atoms_f32 positions = positions.astype(jnp.float32) # Ensure cell dtype matches positions dtype so warp overload dispatch is consistent if cell.dtype != positions.dtype: cell = cell.astype(positions.dtype) total_atoms = positions.shape[0] # Squeeze pbc to 1D for kernel (kernels expect shape (3,)) pbc_1d = pbc.squeeze() if pbc.ndim == 2 else pbc pbc_bool = pbc_1d.astype(jnp.bool_) # Step 1: Construct bin sizes (cells_per_dimension,) = _construct_bin_size( cell, pbc_bool, cells_per_dimension, float(cutoff), int(max_total_cells), launch_dims=(1,), ) # Step 2: Count atoms per bin atoms_per_cell_count, atom_periodic_shifts = _count_atoms( positions, cell, pbc_bool, cells_per_dimension, atoms_per_cell_count, atom_periodic_shifts, launch_dims=(total_atoms,), ) # Step 3: Compute exclusive prefix sum (replaces wp.utils.array_scan) cell_atom_start_indices = jnp.concatenate( [ jnp.array([0], dtype=jnp.int32), jnp.cumsum(atoms_per_cell_count[:-1], dtype=jnp.int32), ] ) # Step 4: Zero counts before second pass atoms_per_cell_count = jnp.zeros_like(atoms_per_cell_count) # Step 5: Bin atoms atom_to_cell_mapping, atoms_per_cell_count, cell_atom_list = _bin_atoms( positions, cell, pbc_bool, cells_per_dimension, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, launch_dims=(total_atoms,), ) return ( cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, )
[docs] def query_cell_list( positions: jax.Array, cutoff: float, cell: jax.Array, pbc: jax.Array, cells_per_dimension: jax.Array, atom_periodic_shifts: jax.Array, atom_to_cell_mapping: jax.Array, atoms_per_cell_count: jax.Array, cell_atom_start_indices: jax.Array, cell_atom_list: jax.Array, neighbor_search_radius: jax.Array, max_neighbors: int | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, num_neighbors: jax.Array | None = None, rebuild_flags: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array]: """Query cell list to find neighbors within cutoff. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates in Cartesian space. cutoff : float Cutoff distance for neighbor detection. cell : jax.Array, shape (1, 3, 3), dtype=float32 or float64 Cell matrix defining lattice vectors. pbc : jax.Array, shape (3,) or (1, 3), dtype=bool Periodic boundary condition flags. cells_per_dimension : jax.Array, shape (3,), dtype=int32 Number of cells in each dimension. atom_periodic_shifts : jax.Array, shape (total_atoms, 3), dtype=int32 Periodic boundary crossings for each atom (output from ``build_cell_list``). atom_to_cell_mapping : jax.Array, shape (total_atoms, 3), dtype=int32 3D cell coordinates for each atom. atoms_per_cell_count : jax.Array, shape (max_total_cells,), dtype=int32 Number of atoms in each cell (output from ``build_cell_list``). cell_atom_start_indices : jax.Array, shape (max_total_cells,), dtype=int32 Starting index in cell_atom_list for each cell. cell_atom_list : jax.Array, shape (total_atoms,), dtype=int32 Flattened list of atom indices organized by cell. neighbor_search_radius : jax.Array, shape (3,), dtype=int32 Search radius in neighboring cells. max_neighbors : int, optional Maximum number of neighbors per atom. neighbor_matrix : jax.Array, optional Pre-allocated neighbor matrix. num_neighbors : jax.Array, optional Pre-allocated neighbors count array. Returns ------- neighbor_matrix : jax.Array, shape (total_atoms, max_neighbors), dtype=int32 Neighbor matrix with neighbor atom indices. num_neighbors : jax.Array, shape (total_atoms,), dtype=int32 Number of neighbors found for each atom. neighbor_matrix_shifts : jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32 Periodic shift vectors for each neighbor relationship. See Also -------- build_cell_list : Build cell list before querying cell_list : Combined build and query operation """ if max_neighbors is None: max_neighbors = estimate_max_neighbors(cutoff) if neighbor_matrix is None: neighbor_matrix = jnp.full( (positions.shape[0], max_neighbors), positions.shape[0], dtype=jnp.int32, ) elif rebuild_flags is None: neighbor_matrix = neighbor_matrix.at[:].set(jnp.int32(positions.shape[0])) if num_neighbors is None: num_neighbors = jnp.zeros(positions.shape[0], dtype=jnp.int32) elif rebuild_flags is None: num_neighbors = num_neighbors.at[:].set(jnp.int32(0)) if neighbor_matrix_shifts is None: neighbor_matrix_shifts = jnp.zeros( (positions.shape[0], max_neighbors, 3), dtype=jnp.int32, ) elif rebuild_flags is None: neighbor_matrix_shifts = neighbor_matrix_shifts.at[:].set(jnp.int32(0)) # Select kernel based on dtype if positions.dtype == jnp.float64: _query_kernel = _jax_build_neighbor_matrix_f64 _query_kernel_selective = _jax_build_neighbor_matrix_selective_f64 else: _query_kernel = _jax_build_neighbor_matrix_f32 _query_kernel_selective = _jax_build_neighbor_matrix_selective_f32 positions = positions.astype(jnp.float32) # Ensure cell dtype matches positions dtype so warp overload dispatch is consistent if cell.dtype != positions.dtype: cell = cell.astype(positions.dtype) total_atoms = positions.shape[0] # Squeeze pbc to 1D for kernel (kernels expect shape (3,)) pbc_1d = pbc.squeeze() if pbc.ndim == 2 else pbc pbc_bool = pbc_1d.astype(jnp.bool_) if rebuild_flags is not None: rf = rebuild_flags.flatten()[:1].astype(jnp.bool_) num_neighbors = jnp.where(rf[0], jnp.zeros_like(num_neighbors), num_neighbors) neighbor_matrix, neighbor_matrix_shifts, num_neighbors = ( _query_kernel_selective( positions, cell, pbc_bool, float(cutoff), cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, False, # half_fill rf, launch_dims=(total_atoms,), ) ) else: neighbor_matrix, neighbor_matrix_shifts, num_neighbors = _query_kernel( positions, cell, pbc_bool, float(cutoff), cells_per_dimension, neighbor_search_radius, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_matrix, neighbor_matrix_shifts, num_neighbors, False, # half_fill launch_dims=(total_atoms,), ) return neighbor_matrix, num_neighbors, neighbor_matrix_shifts
[docs] def cell_list( positions: jax.Array, cutoff: float, cell: jax.Array | None = None, pbc: jax.Array | None = None, max_neighbors: int | None = None, max_total_cells: int | None = None, return_neighbor_list: bool = False, ) -> tuple[jax.Array, jax.Array, jax.Array]: """Build and query spatial cell list for efficient neighbor finding. This is a convenience function that combines build_cell_list and query_cell_list in a single call. Parameters ---------- positions : jax.Array, shape (total_atoms, 3), dtype=float32 or float64 Atomic coordinates in Cartesian space. cutoff : float Cutoff distance for neighbor detection. cell : jax.Array, shape (1, 3, 3), dtype=float32 or float64, optional Cell matrix defining lattice vectors. Default is identity matrix. pbc : jax.Array, shape (3,) or (1, 3), dtype=bool, optional Periodic boundary condition flags. Default is all True. max_neighbors : int, optional Maximum number of neighbors per atom. If None, will be estimated. max_total_cells : int, optional Maximum number of cells to allocate. If None, will be estimated. return_neighbor_list : bool, optional If True, convert result to COO neighbor list format. Default is False. Returns ------- neighbor_data : jax.Array If ``return_neighbor_list=False`` (default): ``neighbor_matrix`` with shape (total_atoms, max_neighbors), dtype int32. If ``return_neighbor_list=True``: ``neighbor_list`` with shape (2, num_pairs), dtype int32, in COO format. neighbor_count : jax.Array If ``return_neighbor_list=False``: ``num_neighbors`` with shape (total_atoms,), dtype int32. If ``return_neighbor_list=True``: ``neighbor_ptr`` with shape (total_atoms + 1,), dtype int32. shift_data : jax.Array If ``return_neighbor_list=False``: ``neighbor_matrix_shifts`` with shape (total_atoms, max_neighbors, 3), dtype int32. If ``return_neighbor_list=True``: ``neighbor_list_shifts`` with shape (num_pairs, 3), dtype int32. See Also -------- build_cell_list : Build cell list separately query_cell_list : Query cell list separately naive_neighbor_list : Naive O(N^2) method """ if cell is None: cell = jnp.eye(3, dtype=jnp.float32)[jnp.newaxis, :, :] if pbc is None: pbc = jnp.ones((1, 3), dtype=jnp.bool_) # Build cell list ( cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, ) = build_cell_list( positions, cutoff, cell, pbc, max_total_cells=max_total_cells, ) # Query cell list neighbor_matrix, num_neighbors, neighbor_matrix_shifts = query_cell_list( positions=positions, cutoff=cutoff, cell=cell, pbc=pbc, cells_per_dimension=cells_per_dimension, atom_periodic_shifts=atom_periodic_shifts, atom_to_cell_mapping=atom_to_cell_mapping, atoms_per_cell_count=atoms_per_cell_count, cell_atom_start_indices=cell_atom_start_indices, cell_atom_list=cell_atom_list, neighbor_search_radius=neighbor_search_radius, max_neighbors=max_neighbors, ) if return_neighbor_list: neighbor_list, neighbor_ptr, neighbor_list_shifts = ( get_neighbor_list_from_neighbor_matrix( neighbor_matrix, num_neighbors=num_neighbors, neighbor_shift_matrix=neighbor_matrix_shifts, fill_value=positions.shape[0], ) ) return ( neighbor_list, neighbor_ptr, neighbor_list_shifts, ) else: return ( neighbor_matrix, num_neighbors, neighbor_matrix_shifts, )