nvalchemiops.jax.neighbors: Neighbor Lists#

The neighbors module provides JAX bindings for the GPU-accelerated implementations of neighbor list algorithms.

Tip

For the underlying framework-agnostic Warp kernels, see nvalchemiops.neighbors: Neighbor Lists.

JAX neighbor list API.

This module provides JAX bindings for neighbor list computation and related utilities for both single and batched systems.

High-Level Interface#

nvalchemiops.jax.neighbors.neighbor_list(positions, cutoff, cell=None, pbc=None, batch_idx=None, batch_ptr=None, cutoff2=None, half_fill=False, fill_value=None, return_neighbor_list=False, method=None, wrap_positions=True, **kwargs)[source]#

Compute neighbor list using the appropriate method based on the provided parameters.

This is the main entry point for JAX users of the neighbor list API. It automatically selects the most appropriate algorithm (naive O(N²) or cell list O(N)) based on system size and parameters.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3)) – Concatenated atomic coordinates for all systems in Cartesian space. Each row represents one atom’s (x, y, z) position. Unwrapped (box-crossing) coordinates are supported when PBC is used; the kernel wraps positions internally.

  • cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.

  • cell (jax.Array, shape (3, 3) or (num_systems, 3, 3), optional) – Cell matrix defining the simulation box.

  • pbc (jax.Array, shape (3,) or (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=jnp.int32, optional) – System index for each atom.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=jnp.int32, optional) – Cumulative atom counts defining system boundaries.

  • cutoff2 (float, optional) – Second cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.

  • half_fill (bool, optional) – If True, only store half of the neighbor relationships to avoid double counting. Another half could be reconstructed by swapping source and target indices and inverting unit shifts.

  • fill_value (int | None, optional) – Value to fill the neighbor matrix with. Default is total_atoms.

  • return_neighbor_list (bool, optional - default = False) – If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty. We recommend using the neighbor matrix format, and only convert to a neighbor list format if absolutely necessary.

  • method (str | None, optional) – Method to use for neighbor list computation. Choices: “naive”, “cell_list”, “batch_naive”, “batch_cell_list”, “naive_dual_cutoff”, “batch_naive_dual_cutoff”. If None, a default method will be chosen based on average atoms per system (cell_list when >= 2000, naive otherwise). When only batch_idx is provided (no batch_ptr or 3-D cell), auto-selection reads batch_idx[-1] which triggers a device-to-host synchronization. To avoid this, pass batch_ptr, a 3-D cell array, or specify method explicitly.

  • wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call. Only applies to naive methods; cell list methods handle wrapping internally.

  • **kwargs (dict, optional) –

    Additional keyword arguments to pass to the method.

    max_neighborsint, optional

    Maximum number of neighbors per atom. Can be provided to aid in allocation for both naive and cell list methods.

    max_neighbors2int, optional

    Maximum number of neighbors per atom within cutoff2. Can be provided to aid in allocation for naive dual cutoff method.

    neighbor_matrixjax.Array, optional

    Pre-shaped array of shape (total_atoms, max_neighbors) for neighbor indices. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.

    neighbor_matrix_shiftsjax.Array, optional

    Pre-shaped array of shape (total_atoms, max_neighbors, 3) for shift vectors. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.

    num_neighborsjax.Array, optional

    Pre-shaped array of shape (total_atoms,) for neighbor counts. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.

    shift_range_per_dimensionjax.Array, optional

    Pre-computed array of shape (1, 3) for shift range in each dimension. Can be provided to avoid recomputation for naive methods.

    num_shifts_per_systemjax.Array, optional

    Pre-computed array of shape (num_systems,) for the number of periodic shifts per system. Can be provided to avoid recomputation for naive methods.

    max_shifts_per_systemint, optional

    Maximum per-system shift count. Can be provided to avoid recomputation for naive methods.

    cells_per_dimensionjax.Array, optional

    Pre-computed array of shape (3,) for number of cells in x, y, z directions. Can be provided to hint buffer reuse to XLA for cell list construction.

    neighbor_search_radiusjax.Array, optional

    Pre-computed array of shape (3,) for radius of neighboring cells to search in each dimension. Can be provided to hint buffer reuse to XLA for cell list construction.

    atom_periodic_shiftsjax.Array, optional

    Pre-shaped array of shape (total_atoms, 3) for periodic boundary crossings for each atom. Can be provided to hint buffer reuse to XLA for cell list construction.

    atom_to_cell_mappingjax.Array, optional

    Pre-shaped array of shape (total_atoms, 3) for cell coordinates for each atom. Can be provided to hint buffer reuse to XLA for cell list construction.

    atoms_per_cell_countjax.Array, optional

    Pre-shaped array of shape (max_total_cells,) for number of atoms in each cell. Can be provided to hint buffer reuse to XLA for cell list construction.

    cell_atom_start_indicesjax.Array, optional

    Pre-shaped array of shape (max_total_cells,) for starting index in cell_atom_list for each cell. Can be provided to hint buffer reuse to XLA for cell list construction.

    cell_atom_listjax.Array, optional

    Pre-shaped array of shape (total_atoms,) for flattened list of atom indices organized by cell. Can be provided to hint buffer reuse to XLA for cell list construction.

    max_atoms_per_systemint, optional

    Maximum number of atoms per system. Used in batch naive implementation with PBC. If not provided, it will be computed automatically. Can be provided to avoid CUDA synchronization.

Returns:

results – Variable-length tuple depending on input parameters. The return pattern follows:

Single cutoff:
  • No PBC, matrix format: (neighbor_matrix, num_neighbors)

  • No PBC, list format: (neighbor_list, neighbor_ptr)

  • With PBC, matrix format: (neighbor_matrix, num_neighbors, neighbor_matrix_shifts)

  • With PBC, list format: (neighbor_list, neighbor_ptr, neighbor_list_shifts)

Dual cutoff:
  • No PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)

  • No PBC, list format: (neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)

  • With PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)

  • With PBC, list format: (neighbor_list1, neighbor_ptr1, neighbor_list_shifts1, neighbor_list2, neighbor_ptr2, neighbor_list_shifts2)

Components returned:

  • neighbor_data (array): Neighbor indices, format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns neighbor_matrix with shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i’s neighbors.

    • If return_neighbor_list=True: Returns neighbor_list with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms].

  • num_neighbor_data (array): Information about the number of neighbors for each atom, format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns num_neighbors with shape (total_atoms,), dtype int32. Count of neighbors found for each atom.

    • If return_neighbor_list=True: Returns neighbor_ptr with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where neighbor_ptr_data[i] to neighbor_ptr_data[i+1] gives the range of neighbors for atom i in the flattened neighbor list.

  • neighbor_shift_data (array, optional): Periodic shift vectors, only when pbc is provided: format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns neighbor_matrix_shifts with shape (total_atoms, max_neighbors, 3), dtype int32.

    • If return_neighbor_list=True: Returns unit_shifts with shape (num_pairs, 3), dtype int32.

When cutoff2 is provided, the pattern repeats for the second cutoff with interleaved components (neighbor_data2, num_neighbor_data2, neighbor_shift_data2) appended to the tuple.

Return type:

tuple of jax.Array

Examples

Single cutoff, matrix format, with PBC:

>>> nm, num, shifts = neighbor_list(pos, 5.0, cell=cell, pbc=pbc)

Single cutoff, list format, no PBC:

>>> nlist, ptr = neighbor_list(pos, 5.0, return_neighbor_list=True)

Dual cutoff, matrix format, with PBC:

>>> nm1, num1, sh1, nm2, num2, sh2 = neighbor_list(
...     pos, 2.5, cutoff2=5.0, cell=cell, pbc=pbc
... )

See also

naive_neighbor_list

Direct access to naive O(N²) algorithm

cell_list

Direct access to cell list O(N) algorithm

batch_naive_neighbor_list

Batched naive algorithm

batch_cell_list

Batched cell list algorithm

Unbatched Algorithms#

Naive Algorithm#

nvalchemiops.jax.neighbors.naive_neighbor_list(positions, cutoff, cell=None, pbc=None, max_neighbors=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#

Compute neighbor list using naive O(N^2) algorithm.

Identifies all atom pairs within a specified cutoff distance using a brute-force pairwise distance calculation. Supports both non-periodic and periodic boundary conditions.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space. Each row represents one atom’s (x, y, z) position.

  • cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.

  • pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension. True enables periodicity in that direction. Default is None (no PBC).

  • cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors in Cartesian coordinates. Required if pbc is provided. Default is None.

  • max_neighbors (int, optional) – Maximum number of neighbors per atom. Must be positive. If exceeded, excess neighbors are ignored. Must be provided if neighbor_matrix is not provided.

  • half_fill (bool, optional) – If True, only store relationships where i < j to avoid double counting. If False, store all neighbor relationships symmetrically. Default is False.

  • fill_value (int, optional) – Value to fill the neighbor matrix with. Default is total_atoms.

  • neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32, optional) – Neighbor matrix to be filled. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.

  • neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Shift vectors for each neighbor relationship. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.

  • num_neighbors (jax.Array, shape (total_atoms,), dtype=int32, optional) – Number of neighbors found for each atom. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.

  • shift_range_per_dimension (jax.Array, shape (1, 3), dtype=int32, optional) – Shift range in each dimension for each system. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • num_shifts_per_system (jax.Array, shape (1,), dtype=int32, optional) – Number of periodic shifts for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • max_shifts_per_system (int, optional) – Maximum per-system shift count. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • return_neighbor_list (bool, optional - default = False) – If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty.

  • wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.

  • rebuild_flags (Array | None)

Returns:

results – Variable-length tuple depending on input parameters. The return pattern follows:

  • No PBC, matrix format: (neighbor_matrix, num_neighbors)

  • No PBC, list format: (neighbor_list, neighbor_ptr)

  • With PBC, matrix format: (neighbor_matrix, num_neighbors, neighbor_matrix_shifts)

  • With PBC, list format: (neighbor_list, neighbor_ptr, neighbor_list_shifts)

Components returned:

  • neighbor_data (array): Neighbor indices, format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns neighbor_matrix with shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i’s neighbors.

    • If return_neighbor_list=True: Returns neighbor_list with shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms].

  • num_neighbor_data (array): Information about the number of neighbors for each atom, format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns num_neighbors with shape (total_atoms,), dtype int32. Count of neighbors found for each atom. Always returned.

    • If return_neighbor_list=True: Returns neighbor_ptr with shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays where neighbor_ptr_data[i] to neighbor_ptr_data[i+1] gives the range of neighbors for atom i in the flattened neighbor list.

  • neighbor_shift_data (array, optional): Periodic shift vectors, only when pbc is provided: format depends on return_neighbor_list:

    • If return_neighbor_list=False (default): Returns neighbor_matrix_shifts with shape (total_atoms, max_neighbors, 3), dtype int32.

    • If return_neighbor_list=True: Returns unit_shifts with shape (num_pairs, 3), dtype int32.

Return type:

tuple of jax.Array

Examples

Basic usage without periodic boundary conditions:

>>> import jax.numpy as jnp
>>> from nvalchemiops.jax.neighbors import naive_neighbor_list
>>> positions = jnp.zeros((100, 3), dtype=jnp.float32)
>>> cutoff = 2.5
>>> max_neighbors = 50
>>> neighbor_matrix, num_neighbors = naive_neighbor_list(
...     positions, cutoff, max_neighbors=max_neighbors
... )

With periodic boundary conditions:

>>> cell = jnp.eye(3, dtype=jnp.float32).reshape(1, 3, 3) * 10.0
>>> pbc = jnp.array([[True, True, True]])
>>> neighbor_matrix, num_neighbors, shifts = naive_neighbor_list(
...     positions, cutoff, max_neighbors=max_neighbors, pbc=pbc, cell=cell
... )

Return as neighbor list instead of matrix:

>>> neighbor_list, neighbor_ptr = naive_neighbor_list(
...     positions, cutoff, max_neighbors=max_neighbors, return_neighbor_list=True
... )
>>> source_atoms, target_atoms = neighbor_list[0], neighbor_list[1]

See also

nvalchemiops.neighbors.naive.naive_neighbor_matrix

Core warp launcher (no PBC)

nvalchemiops.neighbors.naive.naive_neighbor_matrix_pbc

Core warp launcher (with PBC)

cell_list

O(N) cell list method for larger systems

Cell List Algorithm#

nvalchemiops.jax.neighbors.cell_list(positions, cutoff, cell=None, pbc=None, max_neighbors=None, max_total_cells=None, return_neighbor_list=False)[source]#

Build and query spatial cell list for efficient neighbor finding.

This is a convenience function that combines build_cell_list and query_cell_list in a single call.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.

  • cutoff (float) – Cutoff distance for neighbor detection.

  • cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64, optional) – Cell matrix defining lattice vectors. Default is identity matrix.

  • pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.

  • max_neighbors (int, optional) – Maximum number of neighbors per atom. If None, will be estimated.

  • max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.

  • return_neighbor_list (bool, optional) – If True, convert result to COO neighbor list format. Default is False.

Returns:

  • neighbor_data (jax.Array) – If return_neighbor_list=False (default): neighbor_matrix with shape (total_atoms, max_neighbors), dtype int32. If return_neighbor_list=True: neighbor_list with shape (2, num_pairs), dtype int32, in COO format.

  • neighbor_count (jax.Array) – If return_neighbor_list=False: num_neighbors with shape (total_atoms,), dtype int32. If return_neighbor_list=True: neighbor_ptr with shape (total_atoms + 1,), dtype int32.

  • shift_data (jax.Array) – If return_neighbor_list=False: neighbor_matrix_shifts with shape (total_atoms, max_neighbors, 3), dtype int32. If return_neighbor_list=True: neighbor_list_shifts with shape (num_pairs, 3), dtype int32.

Return type:

tuple[Array, Array, Array]

See also

build_cell_list

Build cell list separately

query_cell_list

Query cell list separately

naive_neighbor_list

Naive O(N^2) method

nvalchemiops.jax.neighbors.cell_list.build_cell_list(positions, cutoff, cell, pbc, cells_per_dimension=None, neighbor_search_radius=None, atom_periodic_shifts=None, atom_to_cell_mapping=None, atoms_per_cell_count=None, cell_atom_start_indices=None, cell_atom_list=None, max_total_cells=None)[source]#

Build spatial cell list for efficient neighbor searching.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.

  • cutoff (float) – Cutoff distance for neighbor searching. Must be positive.

  • cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.

  • pbc (jax.Array, shape (3,) or (1, 3), dtype=bool) – Periodic boundary condition flags.

  • cells_per_dimension (jax.Array, shape (3,), dtype=int32, optional) – OUTPUT: Number of cells in x, y, z directions. If None, allocated.

  • neighbor_search_radius (jax.Array, shape (3,), dtype=int32, optional) – Search radius in neighboring cells. If None, allocated.

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – OUTPUT: Periodic boundary crossings for each atom. If None, allocated.

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – OUTPUT: 3D cell coordinates for each atom. If None, allocated.

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32, optional) – OUTPUT: Number of atoms in each cell. If None, allocated.

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32, optional) – OUTPUT: Starting index in cell_atom_list for each cell. If None, allocated.

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32, optional) – OUTPUT: Flattened list of atom indices organized by cell. If None, allocated.

  • max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.

Returns:

  • cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of cells in x, y, z directions.

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom.

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell.

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell.

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.

  • neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Search radius in neighboring cells.

Return type:

tuple[Array, Array, Array, Array, Array, Array, Array]

Notes

When calling inside jax.jit, max_total_cells must be provided to avoid calling estimate_cell_list_sizes, which is not JIT-compatible.

See also

query_cell_list

Query the built cell list for neighbors

nvalchemiops.jax.neighbors.cell_list.query_cell_list(positions, cutoff, cell, pbc, cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, max_neighbors=None, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, rebuild_flags=None)[source]#

Query cell list to find neighbors within cutoff.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.

  • cutoff (float) – Cutoff distance for neighbor detection.

  • cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.

  • pbc (jax.Array, shape (3,) or (1, 3), dtype=bool) – Periodic boundary condition flags.

  • cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of cells in each dimension.

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom (output from build_cell_list).

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell (output from build_cell_list).

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell.

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.

  • neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Search radius in neighboring cells.

  • max_neighbors (int, optional) – Maximum number of neighbors per atom.

  • neighbor_matrix (jax.Array, optional) – Pre-allocated neighbor matrix.

  • num_neighbors (jax.Array, optional) – Pre-allocated neighbors count array.

  • neighbor_matrix_shifts (Array | None)

  • rebuild_flags (Array | None)

Returns:

  • neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – Neighbor matrix with neighbor atom indices.

  • num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – Number of neighbors found for each atom.

  • neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32) – Periodic shift vectors for each neighbor relationship.

Return type:

tuple[Array, Array, Array]

See also

build_cell_list

Build cell list before querying

cell_list

Combined build and query operation

Dual Cutoff Algorithm#

nvalchemiops.jax.neighbors.naive_neighbor_list_dual_cutoff(positions, cutoff1, cutoff2, pbc=None, cell=None, max_neighbors1=None, max_neighbors2=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix1=None, neighbor_matrix2=None, neighbor_matrix_shifts1=None, neighbor_matrix_shifts2=None, num_neighbors1=None, num_neighbors2=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#

Compute neighbor lists for two cutoff distances using naive O(N^2) algorithm.

This function builds two neighbor matrices simultaneously for different cutoff distances, which is more efficient than calling the single-cutoff function twice.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.

  • cutoff1 (float) – First cutoff distance (typically smaller).

  • cutoff2 (float) – Second cutoff distance (typically larger).

  • pbc (jax.Array, shape (1, 3) or (3,), dtype=bool, optional) – Periodic boundary condition flags for each dimension.

  • cell (jax.Array, shape (1, 3, 3) or (3, 3), dtype=float32 or float64, optional) – Cell matrix defining lattice vectors in Cartesian coordinates.

  • max_neighbors1 (int, optional) – Maximum number of neighbors per atom for cutoff1.

  • max_neighbors2 (int, optional) – Maximum number of neighbors per atom for cutoff2.

  • half_fill (bool, optional - default = False) – If True, only store relationships where i < j to avoid double counting.

  • fill_value (int, optional) – Value to use for padding in neighbor matrices. Default is total_atoms.

  • return_neighbor_list (bool, optional - default = False) – If True, convert neighbor matrices to neighbor list (idx_i, idx_j) format.

  • neighbor_matrix1 (jax.Array, shape (total_atoms, max_neighbors1), dtype=int32, optional) – Pre-allocated first neighbor matrix.

  • neighbor_matrix2 (jax.Array, shape (total_atoms, max_neighbors2), dtype=int32, optional) – Pre-allocated second neighbor matrix.

  • neighbor_matrix_shifts1 (jax.Array, shape (total_atoms, max_neighbors1, 3), dtype=int32, optional) – Pre-allocated first shift matrix for PBC.

  • neighbor_matrix_shifts2 (jax.Array, shape (total_atoms, max_neighbors2, 3), dtype=int32, optional) – Pre-allocated second shift matrix for PBC.

  • num_neighbors1 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated first neighbor count array.

  • num_neighbors2 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated second neighbor count array.

  • shift_range_per_dimension (jax.Array, shape (1, 3), dtype=int32, optional) – Shift range in each dimension for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • num_shifts_per_system (jax.Array, shape (1,), dtype=int32, optional) – Number of periodic shifts for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • max_shifts_per_system (int, optional) – Maximum per-system shift count. Pass in a pre-computed value to avoid recomputation for PBC systems.

  • wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.

  • rebuild_flags (Array | None)

Returns:

results – Variable-length tuple depending on input parameters:

  • No PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)

  • No PBC, list format: (neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)

  • With PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)

  • With PBC, list format: (neighbor_list1, neighbor_ptr1, unit_shifts1, neighbor_list2, neighbor_ptr2, unit_shifts2)

Return type:

tuple of jax.Array

Batched Algorithms#

Batched Naive Algorithm#

nvalchemiops.jax.neighbors.batch_naive_neighbor_list(positions, cutoff, batch_idx=None, batch_ptr=None, pbc=None, cell=None, max_neighbors=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, max_atoms_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#

Compute neighbor list for batch of systems using naive O(N^2) algorithm.

Identifies all atom pairs within a specified cutoff distance for each system independently using a brute-force pairwise distance calculation. Supports both non-periodic and periodic boundary conditions.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Concatenated Cartesian coordinates for all systems.

  • cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – System index for each atom. If None, batch_ptr must be provided.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries. If None, batch_idx must be provided.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags for each system and dimension. True enables periodicity in that direction. Default is None (no PBC).

  • cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors. Required if pbc is provided.

  • max_neighbors (int, optional) – Maximum number of neighbors per atom.

  • half_fill (bool, optional) – If True, only store relationships where i < j. Default is False.

  • fill_value (int, optional) – Value to fill the neighbor matrix with. Default is total_atoms.

  • neighbor_matrix (jax.Array, optional) – Pre-allocated neighbor matrix.

  • neighbor_matrix_shifts (jax.Array, optional) – Pre-allocated shift matrix for PBC.

  • num_neighbors (jax.Array, optional) – Pre-allocated neighbors count array.

  • shift_range_per_dimension (jax.Array, optional) – Pre-computed shift range for PBC systems.

  • num_shifts_per_system (jax.Array, optional) – Number of periodic shifts per system.

  • max_shifts_per_system (int, optional) – Maximum per-system shift count (launch dimension).

  • max_atoms_per_system (int, optional) – Maximum atoms in any system.

  • wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.

  • return_neighbor_list (bool)

  • rebuild_flags (Array | None)

Returns:

results – Variable-length tuple depending on input parameters.

Return type:

tuple of jax.Array

Examples

Basic usage with batch_ptr:

>>> import jax.numpy as jnp
>>> from nvalchemiops.jax.neighbors import batch_naive_neighbor_list
>>> positions = jnp.zeros((200, 3), dtype=jnp.float32)
>>> batch_ptr = jnp.array([0, 100, 200], dtype=jnp.int32)  # 2 systems
>>> cutoff = 2.5
>>> max_neighbors = 50
>>> neighbor_matrix, num_neighbors = batch_naive_neighbor_list(
...     positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors
... )

With PBC:

>>> cell = jnp.eye(3, dtype=jnp.float32)[jnp.newaxis, :, :] * 10.0
>>> cell = jnp.repeat(cell, 2, axis=0)
>>> pbc = jnp.ones((2, 3), dtype=jnp.bool_)
>>> neighbor_matrix, num_neighbors, shifts = batch_naive_neighbor_list(
...     positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors,
...     pbc=pbc, cell=cell
... )

See also

nvalchemiops.neighbors.batch_naive.batch_naive_neighbor_matrix

Core warp launcher

nvalchemiops.jax.neighbors.naive.naive_neighbor_list

Non-batched version

batch_cell_list

Cell list method for large systems

Batched Cell List Algorithm#

nvalchemiops.jax.neighbors.batch_cell_list(positions, cutoff, cell=None, pbc=None, batch_idx=None, batch_ptr=None, max_neighbors=None, max_total_cells=None, neighbor_matrix_shifts=None, return_neighbor_list=False)[source]#

Build and query spatial cell lists for batch of systems.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.

  • cutoff (float) – Cutoff distance for neighbor detection.

  • cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors. Default is identity matrix.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices for each atom.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries.

  • max_neighbors (int, optional) – Maximum number of neighbors per atom. If None, will be estimated.

  • max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.

  • neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Pre-allocated shift vectors array. If None, will be allocated internally. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input.

  • return_neighbor_list (bool, optional) – If True, convert result to COO neighbor list format. Default is False.

Returns:

  • neighbor_data (jax.Array) – If return_neighbor_list=False (default): neighbor_matrix with shape (total_atoms, max_neighbors), dtype int32. If return_neighbor_list=True: neighbor_list with shape (2, num_pairs), dtype int32, in COO format.

  • neighbor_count (jax.Array) – If return_neighbor_list=False: num_neighbors with shape (total_atoms,), dtype int32. If return_neighbor_list=True: neighbor_ptr with shape (total_atoms + 1,), dtype int32.

  • shift_data (jax.Array) – If return_neighbor_list=False (default): neighbor_matrix_shifts with shape (total_atoms, max_neighbors, 3), dtype int32. If return_neighbor_list=True: neighbor_list_shifts with shape (num_pairs, 3), dtype int32. Periodic shift vectors for each neighbor relationship.

Return type:

tuple[Array, Array] | tuple[Array, Array, tuple]

See also

batch_build_cell_list

Build cell list separately

batch_query_cell_list

Query cell list separately

batch_naive_neighbor_list

Naive O(N^2) method

nvalchemiops.jax.neighbors.batch_cell_list.batch_build_cell_list(positions, batch_idx=None, batch_ptr=None, cell=None, pbc=None, cutoff=5.0, max_total_cells=None)[source]#

Build spatial cell lists for batch of systems.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.

  • cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.

  • cutoff (float, optional) – Cutoff distance. Default is 5.0.

  • max_total_cells (int, optional) – Maximum cells. If None, will be estimated.

Returns:

  • cells_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32) – Number of cells in x, y, z directions for each system.

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom.

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell.

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell.

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.

  • neighbor_search_radius (jax.Array, shape (num_systems, 3), dtype=int32) – Search radius in neighboring cells for each system.

  • cell_origin (jax.Array, shape (3,), dtype same as positions) – Cell origin point (currently zeros).

Return type:

tuple[Array, Array, Array, Array, Array, Array, Array, Array]

Notes

When calling inside jax.jit, max_total_cells must be provided to avoid calling estimate_batch_cell_list_sizes, which is not JIT-compatible.

nvalchemiops.jax.neighbors.batch_cell_list.batch_query_cell_list(positions, batch_idx=None, batch_ptr=None, cutoff=5.0, cell=None, pbc=None, cells_per_dimension=None, atom_periodic_shifts=None, atom_to_cell_mapping=None, cell_atom_start_indices=None, cell_atom_list=None, atoms_per_cell_count=None, neighbor_search_radius=None, max_neighbors=None, neighbor_matrix=None, num_neighbors=None, neighbor_matrix_shifts=None, rebuild_flags=None)[source]#

Query batch cell lists to find neighbors.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.

  • cutoff (float, optional) – Cutoff distance.

  • cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.

  • cells_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Cells per dimension.

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – Periodic shifts for each atom (output from batch_build_cell_list).

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – Cell mappings.

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32, optional) – Start indices.

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32, optional) – Cell atom list.

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32, optional) – Number of atoms assigned to each cell. Output from batch_build_cell_list.

  • neighbor_search_radius (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Search radius.

  • max_neighbors (int, optional) – Maximum neighbors per atom.

  • neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32, optional) – Pre-allocated neighbor matrix.

  • num_neighbors (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated neighbors count array.

  • neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Pre-allocated shift vectors array. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input.

  • rebuild_flags (Array | None)

Returns:

  • neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – Neighbor matrix.

  • num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – Neighbors count.

  • neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32) – Periodic shifts for each neighbor relationship.

Return type:

tuple[Array, Array, Array]

Batched Dual Cutoff Algorithm#

nvalchemiops.jax.neighbors.batch_naive_neighbor_list_dual_cutoff(positions, cutoff1, cutoff2, batch_idx=None, batch_ptr=None, pbc=None, cell=None, max_neighbors1=None, max_neighbors2=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix1=None, neighbor_matrix2=None, neighbor_matrix_shifts1=None, neighbor_matrix_shifts2=None, num_neighbors1=None, num_neighbors2=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, max_atoms_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#

Compute batched neighbor lists for two cutoff distances using naive O(N^2) algorithm.

This function builds two neighbor matrices simultaneously for different cutoff distances in a batched manner, which is more efficient than calling the single-cutoff function twice.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Concatenated Cartesian coordinates for all systems.

  • cutoff1 (float) – First cutoff distance (typically smaller).

  • cutoff2 (float) – Second cutoff distance (typically larger).

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – System index for each atom.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries.

  • pbc (jax.Array, shape (num_systems, 3) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension.

  • cell (jax.Array, shape (num_systems, 3, 3) or (1, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors in Cartesian coordinates.

  • max_neighbors1 (int, optional) – Maximum number of neighbors per atom for cutoff1.

  • max_neighbors2 (int, optional) – Maximum number of neighbors per atom for cutoff2.

  • half_fill (bool, optional - default = False) – If True, only store relationships where i < j to avoid double counting.

  • fill_value (int, optional) – Value to use for padding in neighbor matrices. Default is total_atoms.

  • return_neighbor_list (bool, optional - default = False) – If True, convert neighbor matrices to neighbor list (idx_i, idx_j) format.

  • neighbor_matrix1 (jax.Array, shape (total_atoms, max_neighbors1), dtype=int32, optional) – Pre-allocated first neighbor matrix.

  • neighbor_matrix2 (jax.Array, shape (total_atoms, max_neighbors2), dtype=int32, optional) – Pre-allocated second neighbor matrix.

  • neighbor_matrix_shifts1 (jax.Array, shape (total_atoms, max_neighbors1, 3), dtype=int32, optional) – Pre-allocated first shift matrix for PBC.

  • neighbor_matrix_shifts2 (jax.Array, shape (total_atoms, max_neighbors2, 3), dtype=int32, optional) – Pre-allocated second shift matrix for PBC.

  • num_neighbors1 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated first neighbor count array.

  • num_neighbors2 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated second neighbor count array.

  • shift_range_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Pre-computed shift ranges for PBC.

  • num_shifts_per_system (jax.Array, shape (num_systems,), dtype=int32, optional) – Number of periodic shifts per system.

  • max_shifts_per_system (int, optional) – Maximum per-system shift count (launch dimension).

  • max_atoms_per_system (int, optional) – Maximum number of atoms in any system (for PBC batched dispatch).

  • wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.

  • rebuild_flags (Array | None)

Returns:

results – Variable-length tuple depending on input parameters:

  • No PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)

  • No PBC, list format: (neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)

  • With PBC, matrix format: (neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)

  • With PBC, list format: (neighbor_list1, neighbor_ptr1, unit_shifts1, neighbor_list2, neighbor_ptr2, unit_shifts2)

Return type:

tuple of jax.Array

Rebuild Detection#

nvalchemiops.jax.neighbors.rebuild_detection.cell_list_needs_rebuild(current_positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc)[source]#

Detect if spatial cell list requires rebuilding due to atomic motion.

Parameters:
  • current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates in Cartesian space.

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom from the existing cell list.

  • cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of spatial cells in x, y, z directions.

  • cell (jax.Array, shape (1, 3, 3)) – Unit cell matrix for coordinate transformations.

  • pbc (jax.Array, shape (3,), dtype=bool) – Periodic boundary condition flags for x, y, z directions.

Returns:

rebuild_needed – True if any atom has moved to a different cell requiring rebuild.

Return type:

jax.Array, shape (1,), dtype=bool

Notes

This function is not differentiable and should not be used in JAX transformations that require gradients.

See also

nvalchemiops.neighbors.rebuild_detection.check_cell_list_rebuild

Core warp launcher

check_cell_list_rebuild_needed

Convenience wrapper that returns Python bool

nvalchemiops.jax.neighbors.rebuild_detection.neighbor_list_needs_rebuild(reference_positions, current_positions, skin_distance_threshold, cell=None, cell_inv=None, pbc=None)[source]#

Detect if neighbor list requires rebuilding due to excessive atomic motion.

When cell, cell_inv and pbc are all provided, uses minimum-image convention (MIC) so atoms crossing periodic boundaries are not spuriously flagged.

Parameters:
  • reference_positions (jax.Array, shape (total_atoms, 3)) – Atomic positions when the neighbor list was last built.

  • current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic positions to compare against reference.

  • skin_distance_threshold (float) – Maximum allowed displacement before neighbor list becomes invalid.

  • cell (jax.Array or None, optional) – Unit cell matrix, shape (1, 3, 3).

  • cell_inv (jax.Array or None, optional) – Inverse cell matrix, same shape as cell.

  • pbc (jax.Array or None, optional) – PBC flags, shape (3,), dtype=bool.

Returns:

rebuild_needed – True if any atom has moved beyond skin distance.

Return type:

jax.Array, shape (1,), dtype=bool

Notes

This function is not differentiable and should not be used in JAX transformations that require gradients.

See also

nvalchemiops.neighbors.rebuild_detection.check_neighbor_list_rebuild

Core warp launcher

check_neighbor_list_rebuild_needed

Convenience wrapper that returns Python bool

nvalchemiops.jax.neighbors.rebuild_detection.check_cell_list_rebuild_needed(current_positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc)[source]#

Determine if spatial cell list requires rebuilding based on atomic motion.

This high-level convenience function determines if a spatial cell list needs to be reconstructed due to atomic movement. It uses GPU acceleration to efficiently detect when atoms have moved between spatial cells.

Parameters:
  • current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates to check against existing cell assignments.

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates assigned to each atom from existing cell list.

  • cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of spatial cells in x, y, z directions from existing cell list.

  • cell (jax.Array, shape (1, 3, 3)) – Current unit cell matrix for coordinate transformations.

  • pbc (jax.Array, shape (3,), dtype=bool) – Current periodic boundary condition flags for x, y, z directions.

Returns:

needs_rebuild – True if any atom has moved to a different cell requiring cell list rebuild.

Return type:

bool

Notes

This function is not differentiable and should not be used in JAX transformations that require gradients.

See also

cell_list_needs_rebuild

Returns jax.Array instead of bool

nvalchemiops.jax.neighbors.rebuild_detection.check_neighbor_list_rebuild_needed(reference_positions, current_positions, skin_distance_threshold, cell=None, cell_inv=None, pbc=None)[source]#

Determine if neighbor list requires rebuilding based on atomic motion.

When cell, cell_inv and pbc are all provided, uses MIC displacement so periodic boundary crossings are handled correctly.

Parameters:
  • reference_positions (jax.Array, shape (total_atoms, 3)) – Atomic coordinates when the neighbor list was last constructed.

  • current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates to compare against reference positions.

  • skin_distance_threshold (float) – Maximum allowed atomic displacement before neighbor list becomes invalid.

  • cell (jax.Array or None, optional) – Unit cell matrix, shape (1, 3, 3).

  • cell_inv (jax.Array or None, optional) – Inverse cell matrix, same shape as cell.

  • pbc (jax.Array or None, optional) – PBC flags, shape (3,), dtype=bool.

Returns:

needs_rebuild – True if any atom has moved beyond skin distance requiring rebuild.

Return type:

bool

See also

neighbor_list_needs_rebuild

Returns jax.Array instead of bool

Exceptions#

exception nvalchemiops.jax.neighbors.NeighborOverflowError(max_neighbors, num_neighbors)[source]#

Bases: Exception

Exception raised when the number of neighbors exceeds the maximum allowed.

This error indicates that the pre-allocated neighbor matrix is too small to hold all discovered neighbors. Users should increase max_neighbors parameter or use a larger pre-allocated tensor.

Parameters:
  • max_neighbors (int) – The maximum number of neighbors the matrix can hold.

  • num_neighbors (int) – The actual number of neighbors found.

Utility Functions#

Warning

The estimation and cell list building utilities are functional, however due to the dynamic nature of the two it is not possible to jax.jit compile workflows that combine the two. Users expecting to jax.jit end-to-end workflows should explicitly set max_total_cells to cell construction methods.

nvalchemiops.jax.neighbors.estimate_cell_list_sizes(positions, cell, cutoff, pbc=None, buffer_factor=1.5)[source]#

Estimate required cell list sizes based on atomic density.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.

  • cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.

  • cutoff (float) – Cutoff distance for neighbor searching.

  • pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.

  • buffer_factor (float, optional) – Buffer multiplier for cell count estimation. Default is 1.5.

Returns:

  • max_total_cells (int) – Maximum total number of cells to allocate.

  • cells_per_dimension (jax.Array, shape (3,) or (1, 3), dtype=int32) – Estimated number of cells in each dimension.

  • neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Estimated search radius in neighboring cells.

Return type:

tuple[int, Array, Array]

Notes

This function estimates cell list parameters based on atomic positions and density. The actual number of cells used will be determined during cell list construction.

Warning

This function is not compatible with jax.jit. The returned max_total_cells is used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When using cell_list or build_cell_list inside jax.jit, provide max_total_cells explicitly to bypass this function.

nvalchemiops.jax.neighbors.estimate_batch_cell_list_sizes(positions, batch_ptr=None, batch_idx=None, cell=None, cutoff=5.0, pbc=None, buffer_factor=1.5)[source]#

Estimate required batch cell list sizes.

Parameters:
  • positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices for each atom.

  • cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices for each system.

  • cutoff (float, optional) – Cutoff distance. Default is 5.0.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.

  • buffer_factor (float, optional) – Buffer multiplier. Default is 1.5.

Returns:

  • max_total_cells (int) – Maximum total cells to allocate.

  • cells_per_dimension (jax.Array, shape (num_systems, 3)) – Cells per dimension for each system.

  • neighbor_search_radius (jax.Array, shape (num_systems, 3)) – Search radius for each system.

  • .. warning:: – This function is not compatible with jax.jit. The returned max_total_cells is used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When using batch_cell_list or batch_build_cell_list inside jax.jit, provide max_total_cells explicitly to bypass this function.

Return type:

tuple[int, Array, Array]

nvalchemiops.jax.neighbors.neighbor_utils.allocate_cell_list(total_atoms, max_total_cells, neighbor_search_radius)[source]#

Allocate memory tensors for cell list data structures.

Parameters:
  • total_atoms (int) – Total number of atoms across all systems.

  • max_total_cells (int) – Maximum number of cells to allocate.

  • neighbor_search_radius (jax.Array, shape (3,) or (num_systems, 3), dtype=int32) – Radius of neighboring cells to search in each dimension.

Returns:

  • cells_per_dimension (jax.Array, shape (3,) or (num_systems, 3), dtype=int32) – Number of cells in x, y, z directions (to be filled by build_cell_list).

  • neighbor_search_radius (jax.Array, shape (3,) or (num_systems, 3), dtype=int32) – Radius of neighboring cells to search (passed through for convenience).

  • atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom (to be filled by build_cell_list).

  • atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom (to be filled by build_cell_list).

  • atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell (to be filled by build_cell_list).

  • cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell (to be filled by build_cell_list).

  • cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell (to be filled by build_cell_list).

Return type:

tuple[Array, Array, Array, Array, Array, Array, Array]

Notes

This is a pure JAX utility function with no warp dependencies. It pre-allocates all tensors needed for cell list construction, supporting both single-system and batched operations based on the shape of neighbor_search_radius.

nvalchemiops.jax.neighbors.neighbor_utils.prepare_batch_idx_ptr(batch_idx, batch_ptr, num_atoms)[source]#

Prepare batch index and pointer tensors from either representation.

Utility function to ensure both batch_idx and batch_ptr are available, computing one from the other if needed.

Parameters:
  • batch_idx (jax.Array | None, shape (total_atoms,), dtype=int32) – Array indicating the batch index for each atom.

  • batch_ptr (jax.Array | None, shape (num_systems + 1,), dtype=int32) – Array indicating the start index of each batch in the atom list.

  • num_atoms (int) – Total number of atoms across all systems.

Returns:

  • batch_idx (jax.Array, shape (total_atoms,), dtype=int32) – Prepared batch index tensor.

  • batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32) – Prepared batch pointer tensor.

Raises:

ValueError – If both batch_idx and batch_ptr are None.

Return type:

tuple[Array, Array]

Notes

This is a pure JAX utility function with no warp dependencies. It provides convenience for batch operations by converting between dense (batch_idx) and sparse (batch_ptr) batch representations.

See also

nvalchemiops.jax.neighbors.batch_naive.batch_naive_neighbor_list

Uses this for batch setup

nvalchemiops.jax.neighbors.batch_cell_list.batch_cell_list

Uses this for batch setup

nvalchemiops.jax.neighbors.neighbor_utils.estimate_max_neighbors(cutoff, atomic_density=0.2, safety_factor=1.0)[source]#

Estimate maximum neighbors per atom based on volume calculations.

Uses atomic density and cutoff volume to estimate a conservative upper bound on the number of neighbors any atom could have. This is a pure Python function with no framework dependencies.

Parameters:
  • cutoff (float) – Maximum distance for considering atoms as neighbors.

  • atomic_density (float, optional) – Atomic density in atoms per unit volume. Default is 0.2.

  • safety_factor (float) – Safety factor to multiply the estimated number of neighbors. Default is 1.0.

Returns:

max_neighbors_estimate – Conservative estimate of maximum neighbors per atom. Returns 0 for empty systems.

Return type:

int

Notes

The estimation uses the formula:

\[\text{neighbors} = \text{safety\_factor} \times \text{density} \times V_{\text{sphere}}\]

where the cutoff sphere volume is:

\[V_{\text{sphere}} = \frac{4}{3}\pi r^3\]

The result is rounded up to the multiple of 16 for memory alignment.

nvalchemiops.jax.neighbors.neighbor_utils.get_neighbor_list_from_neighbor_matrix(neighbor_matrix, num_neighbors, neighbor_shift_matrix=None, fill_value=-1)[source]#

Convert neighbor matrix format to neighbor list format.

Parameters:
  • neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – The neighbor matrix with neighbor atom indices.

  • num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – The number of neighbors for each atom.

  • neighbor_shift_matrix (jax.Array | None, shape (total_atoms, max_neighbors, 3), dtype=int32) – Optional neighbor shift matrix with periodic shift vectors.

  • fill_value (int, default=-1) – The fill value used in the neighbor matrix to indicate empty slots. This is used to create a mask from the neighbor matrix.

Returns:

  • neighbor_list (jax.Array, shape (2, num_pairs), dtype=int32) – The neighbor list in COO format [source_atoms, target_atoms].

  • neighbor_ptr (jax.Array, shape (total_atoms + 1,), dtype=int32) – CSR-style pointer array where neighbor_ptr[i]:neighbor_ptr[i+1] gives the range of neighbors for atom i in the flattened neighbor list.

  • neighbor_list_shifts (jax.Array, shape (num_pairs, 3), dtype=int32) – The neighbor shift vectors (only returned if neighbor_shift_matrix is not None).

Raises:

ValueError – If the max number of neighbors is larger than the neighbor matrix width.

Return type:

tuple[Array, Array] | tuple[Array, Array, Array]

Notes

This is a pure JAX utility function with no warp dependencies. It converts from the fixed-width matrix format to the variable-width list format by masking out fill values and flattening the result.

See also

nvalchemiops.jax.neighbors.naive.naive_neighbor_list

Uses this for format conversion

nvalchemiops.jax.neighbors.cell_list.cell_list

Uses this for format conversion

nvalchemiops.jax.neighbors.neighbor_utils.compute_naive_num_shifts(cell, cutoff, pbc)[source]#

Compute periodic image shifts needed for neighbor searching.

Parameters:
  • cell (jax.Array, shape (num_systems, 3, 3)) – Cell matrices defining lattice vectors in Cartesian coordinates. Each 3x3 matrix represents one system’s periodic cell.

  • cutoff (float) – Cutoff distance for neighbor searching in Cartesian units. Must be positive and typically less than half the minimum cell dimension.

  • pbc (jax.Array, shape (num_systems, 3), dtype=bool) – Periodic boundary condition flags for each dimension. True enables periodicity in that direction.

Returns:

  • shift_range (jax.Array, shape (num_systems, 3), dtype=int32) – Maximum shift indices in each dimension for each system.

  • num_shifts (jax.Array, shape (num_systems,), dtype=int32) – Number of periodic shifts for each system.

  • max_shifts (int) – Maximum per-system shift count across all systems.

Raises:

ValueError – If any per-system shift count exceeds int32 range.

Return type:

tuple[Array, Array, int]

See also

nvalchemiops.neighbors.neighbor_utils._compute_naive_num_shifts

Warp kernel

Notes

This function must be called outside jax.jit scope. The returned max_shifts is a Python int needed for determining launch dimensions, which cannot be traced. This is an inherent limitation: array shapes must be known at trace time in JAX.