nvalchemiops.jax.neighbors: Neighbor Lists#
The neighbors module provides JAX bindings for the GPU-accelerated implementations of neighbor list algorithms.
Tip
For the underlying framework-agnostic Warp kernels, see nvalchemiops.neighbors: Neighbor Lists.
JAX neighbor list API.
This module provides JAX bindings for neighbor list computation and related utilities for both single and batched systems.
High-Level Interface#
- nvalchemiops.jax.neighbors.neighbor_list(positions, cutoff, cell=None, pbc=None, batch_idx=None, batch_ptr=None, cutoff2=None, half_fill=False, fill_value=None, return_neighbor_list=False, method=None, wrap_positions=True, **kwargs)[source]#
Compute neighbor list using the appropriate method based on the provided parameters.
This is the main entry point for JAX users of the neighbor list API. It automatically selects the most appropriate algorithm (naive O(N²) or cell list O(N)) based on system size and parameters.
- Parameters:
positions (jax.Array, shape (total_atoms, 3)) – Concatenated atomic coordinates for all systems in Cartesian space. Each row represents one atom’s (x, y, z) position. Unwrapped (box-crossing) coordinates are supported when PBC is used; the kernel wraps positions internally.
cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.
cell (jax.Array, shape (3, 3) or (num_systems, 3, 3), optional) – Cell matrix defining the simulation box.
pbc (jax.Array, shape (3,) or (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension.
batch_idx (jax.Array, shape (total_atoms,), dtype=jnp.int32, optional) – System index for each atom.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=jnp.int32, optional) – Cumulative atom counts defining system boundaries.
cutoff2 (float, optional) – Second cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.
half_fill (bool, optional) – If True, only store half of the neighbor relationships to avoid double counting. Another half could be reconstructed by swapping source and target indices and inverting unit shifts.
fill_value (int | None, optional) – Value to fill the neighbor matrix with. Default is total_atoms.
return_neighbor_list (bool, optional - default = False) – If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty. We recommend using the neighbor matrix format, and only convert to a neighbor list format if absolutely necessary.
method (str | None, optional) – Method to use for neighbor list computation. Choices: “naive”, “cell_list”, “batch_naive”, “batch_cell_list”, “naive_dual_cutoff”, “batch_naive_dual_cutoff”. If None, a default method will be chosen based on average atoms per system (cell_list when >= 2000, naive otherwise). When only
batch_idxis provided (nobatch_ptror 3-Dcell), auto-selection readsbatch_idx[-1]which triggers a device-to-host synchronization. To avoid this, passbatch_ptr, a 3-Dcellarray, or specifymethodexplicitly.wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call. Only applies to naive methods; cell list methods handle wrapping internally.
**kwargs (dict, optional) –
Additional keyword arguments to pass to the method.
- max_neighborsint, optional
Maximum number of neighbors per atom. Can be provided to aid in allocation for both naive and cell list methods.
- max_neighbors2int, optional
Maximum number of neighbors per atom within cutoff2. Can be provided to aid in allocation for naive dual cutoff method.
- neighbor_matrixjax.Array, optional
Pre-shaped array of shape (total_atoms, max_neighbors) for neighbor indices. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.
- neighbor_matrix_shiftsjax.Array, optional
Pre-shaped array of shape (total_atoms, max_neighbors, 3) for shift vectors. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.
- num_neighborsjax.Array, optional
Pre-shaped array of shape (total_atoms,) for neighbor counts. Can be provided to hint buffer reuse to XLA for both naive and cell list methods.
- shift_range_per_dimensionjax.Array, optional
Pre-computed array of shape (1, 3) for shift range in each dimension. Can be provided to avoid recomputation for naive methods.
- num_shifts_per_systemjax.Array, optional
Pre-computed array of shape (num_systems,) for the number of periodic shifts per system. Can be provided to avoid recomputation for naive methods.
- max_shifts_per_systemint, optional
Maximum per-system shift count. Can be provided to avoid recomputation for naive methods.
- cells_per_dimensionjax.Array, optional
Pre-computed array of shape (3,) for number of cells in x, y, z directions. Can be provided to hint buffer reuse to XLA for cell list construction.
- neighbor_search_radiusjax.Array, optional
Pre-computed array of shape (3,) for radius of neighboring cells to search in each dimension. Can be provided to hint buffer reuse to XLA for cell list construction.
- atom_periodic_shiftsjax.Array, optional
Pre-shaped array of shape (total_atoms, 3) for periodic boundary crossings for each atom. Can be provided to hint buffer reuse to XLA for cell list construction.
- atom_to_cell_mappingjax.Array, optional
Pre-shaped array of shape (total_atoms, 3) for cell coordinates for each atom. Can be provided to hint buffer reuse to XLA for cell list construction.
- atoms_per_cell_countjax.Array, optional
Pre-shaped array of shape (max_total_cells,) for number of atoms in each cell. Can be provided to hint buffer reuse to XLA for cell list construction.
- cell_atom_start_indicesjax.Array, optional
Pre-shaped array of shape (max_total_cells,) for starting index in cell_atom_list for each cell. Can be provided to hint buffer reuse to XLA for cell list construction.
- cell_atom_listjax.Array, optional
Pre-shaped array of shape (total_atoms,) for flattened list of atom indices organized by cell. Can be provided to hint buffer reuse to XLA for cell list construction.
- max_atoms_per_systemint, optional
Maximum number of atoms per system. Used in batch naive implementation with PBC. If not provided, it will be computed automatically. Can be provided to avoid CUDA synchronization.
- Returns:
results – Variable-length tuple depending on input parameters. The return pattern follows:
- Single cutoff:
No PBC, matrix format:
(neighbor_matrix, num_neighbors)No PBC, list format:
(neighbor_list, neighbor_ptr)With PBC, matrix format:
(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)With PBC, list format:
(neighbor_list, neighbor_ptr, neighbor_list_shifts)
- Dual cutoff:
No PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)No PBC, list format:
(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)With PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)With PBC, list format:
(neighbor_list1, neighbor_ptr1, neighbor_list_shifts1, neighbor_list2, neighbor_ptr2, neighbor_list_shifts2)
Components returned:
neighbor_data (array): Neighbor indices, format depends on
return_neighbor_list:If
return_neighbor_list=False(default): Returnsneighbor_matrixwith shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i’s neighbors.If
return_neighbor_list=True: Returnsneighbor_listwith shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms].
num_neighbor_data (array): Information about the number of neighbors for each atom, format depends on
return_neighbor_list:If
return_neighbor_list=False(default): Returnsnum_neighborswith shape (total_atoms,), dtype int32. Count of neighbors found for each atom.If
return_neighbor_list=True: Returnsneighbor_ptrwith shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays whereneighbor_ptr_data[i]toneighbor_ptr_data[i+1]gives the range of neighbors for atom i in the flattened neighbor list.
neighbor_shift_data (array, optional): Periodic shift vectors, only when
pbcis provided: format depends onreturn_neighbor_list:If
return_neighbor_list=False(default): Returnsneighbor_matrix_shiftswith shape (total_atoms, max_neighbors, 3), dtype int32.If
return_neighbor_list=True: Returnsunit_shiftswith shape (num_pairs, 3), dtype int32.
When
cutoff2is provided, the pattern repeats for the second cutoff with interleaved components (neighbor_data2, num_neighbor_data2, neighbor_shift_data2) appended to the tuple.- Return type:
Examples
Single cutoff, matrix format, with PBC:
>>> nm, num, shifts = neighbor_list(pos, 5.0, cell=cell, pbc=pbc)
Single cutoff, list format, no PBC:
>>> nlist, ptr = neighbor_list(pos, 5.0, return_neighbor_list=True)
Dual cutoff, matrix format, with PBC:
>>> nm1, num1, sh1, nm2, num2, sh2 = neighbor_list( ... pos, 2.5, cutoff2=5.0, cell=cell, pbc=pbc ... )
See also
naive_neighbor_listDirect access to naive O(N²) algorithm
cell_listDirect access to cell list O(N) algorithm
batch_naive_neighbor_listBatched naive algorithm
batch_cell_listBatched cell list algorithm
Unbatched Algorithms#
Naive Algorithm#
- nvalchemiops.jax.neighbors.naive_neighbor_list(positions, cutoff, cell=None, pbc=None, max_neighbors=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#
Compute neighbor list using naive O(N^2) algorithm.
Identifies all atom pairs within a specified cutoff distance using a brute-force pairwise distance calculation. Supports both non-periodic and periodic boundary conditions.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space. Each row represents one atom’s (x, y, z) position.
cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.
pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension. True enables periodicity in that direction. Default is None (no PBC).
cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors in Cartesian coordinates. Required if pbc is provided. Default is None.
max_neighbors (int, optional) – Maximum number of neighbors per atom. Must be positive. If exceeded, excess neighbors are ignored. Must be provided if neighbor_matrix is not provided.
half_fill (bool, optional) – If True, only store relationships where i < j to avoid double counting. If False, store all neighbor relationships symmetrically. Default is False.
fill_value (int, optional) – Value to fill the neighbor matrix with. Default is total_atoms.
neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32, optional) – Neighbor matrix to be filled. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.
neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Shift vectors for each neighbor relationship. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.
num_neighbors (jax.Array, shape (total_atoms,), dtype=int32, optional) – Number of neighbors found for each atom. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input. Must be provided if max_neighbors is not provided.
shift_range_per_dimension (jax.Array, shape (1, 3), dtype=int32, optional) – Shift range in each dimension for each system. Pass in a pre-computed value to avoid recomputation for PBC systems.
num_shifts_per_system (jax.Array, shape (1,), dtype=int32, optional) – Number of periodic shifts for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.
max_shifts_per_system (int, optional) – Maximum per-system shift count. Pass in a pre-computed value to avoid recomputation for PBC systems.
return_neighbor_list (bool, optional - default = False) – If True, convert the neighbor matrix to a neighbor list (idx_i, idx_j) format by creating a mask over the fill_value, which can incur a performance penalty.
wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.
rebuild_flags (Array | None)
- Returns:
results – Variable-length tuple depending on input parameters. The return pattern follows:
No PBC, matrix format:
(neighbor_matrix, num_neighbors)No PBC, list format:
(neighbor_list, neighbor_ptr)With PBC, matrix format:
(neighbor_matrix, num_neighbors, neighbor_matrix_shifts)With PBC, list format:
(neighbor_list, neighbor_ptr, neighbor_list_shifts)
Components returned:
neighbor_data (array): Neighbor indices, format depends on
return_neighbor_list:If
return_neighbor_list=False(default): Returnsneighbor_matrixwith shape (total_atoms, max_neighbors), dtype int32. Each row i contains indices of atom i’s neighbors.If
return_neighbor_list=True: Returnsneighbor_listwith shape (2, num_pairs), dtype int32, in COO format [source_atoms, target_atoms].
num_neighbor_data (array): Information about the number of neighbors for each atom, format depends on
return_neighbor_list:If
return_neighbor_list=False(default): Returnsnum_neighborswith shape (total_atoms,), dtype int32. Count of neighbors found for each atom. Always returned.If
return_neighbor_list=True: Returnsneighbor_ptrwith shape (total_atoms + 1,), dtype int32. CSR-style pointer arrays whereneighbor_ptr_data[i]toneighbor_ptr_data[i+1]gives the range of neighbors for atom i in the flattened neighbor list.
neighbor_shift_data (array, optional): Periodic shift vectors, only when
pbcis provided: format depends onreturn_neighbor_list:If
return_neighbor_list=False(default): Returnsneighbor_matrix_shiftswith shape (total_atoms, max_neighbors, 3), dtype int32.If
return_neighbor_list=True: Returnsunit_shiftswith shape (num_pairs, 3), dtype int32.
- Return type:
Examples
Basic usage without periodic boundary conditions:
>>> import jax.numpy as jnp >>> from nvalchemiops.jax.neighbors import naive_neighbor_list >>> positions = jnp.zeros((100, 3), dtype=jnp.float32) >>> cutoff = 2.5 >>> max_neighbors = 50 >>> neighbor_matrix, num_neighbors = naive_neighbor_list( ... positions, cutoff, max_neighbors=max_neighbors ... )
With periodic boundary conditions:
>>> cell = jnp.eye(3, dtype=jnp.float32).reshape(1, 3, 3) * 10.0 >>> pbc = jnp.array([[True, True, True]]) >>> neighbor_matrix, num_neighbors, shifts = naive_neighbor_list( ... positions, cutoff, max_neighbors=max_neighbors, pbc=pbc, cell=cell ... )
Return as neighbor list instead of matrix:
>>> neighbor_list, neighbor_ptr = naive_neighbor_list( ... positions, cutoff, max_neighbors=max_neighbors, return_neighbor_list=True ... ) >>> source_atoms, target_atoms = neighbor_list[0], neighbor_list[1]
See also
nvalchemiops.neighbors.naive.naive_neighbor_matrixCore warp launcher (no PBC)
nvalchemiops.neighbors.naive.naive_neighbor_matrix_pbcCore warp launcher (with PBC)
cell_listO(N) cell list method for larger systems
Cell List Algorithm#
- nvalchemiops.jax.neighbors.cell_list(positions, cutoff, cell=None, pbc=None, max_neighbors=None, max_total_cells=None, return_neighbor_list=False)[source]#
Build and query spatial cell list for efficient neighbor finding.
This is a convenience function that combines build_cell_list and query_cell_list in a single call.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.
cutoff (float) – Cutoff distance for neighbor detection.
cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64, optional) – Cell matrix defining lattice vectors. Default is identity matrix.
pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.
max_neighbors (int, optional) – Maximum number of neighbors per atom. If None, will be estimated.
max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.
return_neighbor_list (bool, optional) – If True, convert result to COO neighbor list format. Default is False.
- Returns:
neighbor_data (jax.Array) – If
return_neighbor_list=False(default):neighbor_matrixwith shape (total_atoms, max_neighbors), dtype int32. Ifreturn_neighbor_list=True:neighbor_listwith shape (2, num_pairs), dtype int32, in COO format.neighbor_count (jax.Array) – If
return_neighbor_list=False:num_neighborswith shape (total_atoms,), dtype int32. Ifreturn_neighbor_list=True:neighbor_ptrwith shape (total_atoms + 1,), dtype int32.shift_data (jax.Array) – If
return_neighbor_list=False:neighbor_matrix_shiftswith shape (total_atoms, max_neighbors, 3), dtype int32. Ifreturn_neighbor_list=True:neighbor_list_shiftswith shape (num_pairs, 3), dtype int32.
- Return type:
See also
build_cell_listBuild cell list separately
query_cell_listQuery cell list separately
naive_neighbor_listNaive O(N^2) method
- nvalchemiops.jax.neighbors.cell_list.build_cell_list(positions, cutoff, cell, pbc, cells_per_dimension=None, neighbor_search_radius=None, atom_periodic_shifts=None, atom_to_cell_mapping=None, atoms_per_cell_count=None, cell_atom_start_indices=None, cell_atom_list=None, max_total_cells=None)[source]#
Build spatial cell list for efficient neighbor searching.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.
cutoff (float) – Cutoff distance for neighbor searching. Must be positive.
cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.
pbc (jax.Array, shape (3,) or (1, 3), dtype=bool) – Periodic boundary condition flags.
cells_per_dimension (jax.Array, shape (3,), dtype=int32, optional) – OUTPUT: Number of cells in x, y, z directions. If None, allocated.
neighbor_search_radius (jax.Array, shape (3,), dtype=int32, optional) – Search radius in neighboring cells. If None, allocated.
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – OUTPUT: Periodic boundary crossings for each atom. If None, allocated.
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – OUTPUT: 3D cell coordinates for each atom. If None, allocated.
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32, optional) – OUTPUT: Number of atoms in each cell. If None, allocated.
cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32, optional) – OUTPUT: Starting index in cell_atom_list for each cell. If None, allocated.
cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32, optional) – OUTPUT: Flattened list of atom indices organized by cell. If None, allocated.
max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.
- Returns:
cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of cells in x, y, z directions.
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom.
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell.
cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell.
cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.
neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Search radius in neighboring cells.
- Return type:
Notes
When calling inside
jax.jit,max_total_cellsmust be provided to avoid callingestimate_cell_list_sizes, which is not JIT-compatible.See also
query_cell_listQuery the built cell list for neighbors
- nvalchemiops.jax.neighbors.cell_list.query_cell_list(positions, cutoff, cell, pbc, cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping, atoms_per_cell_count, cell_atom_start_indices, cell_atom_list, neighbor_search_radius, max_neighbors=None, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, rebuild_flags=None)[source]#
Query cell list to find neighbors within cutoff.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.
cutoff (float) – Cutoff distance for neighbor detection.
cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.
pbc (jax.Array, shape (3,) or (1, 3), dtype=bool) – Periodic boundary condition flags.
cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of cells in each dimension.
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom (output from
build_cell_list).atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell (output from
build_cell_list).cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell.
cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.
neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Search radius in neighboring cells.
max_neighbors (int, optional) – Maximum number of neighbors per atom.
neighbor_matrix (jax.Array, optional) – Pre-allocated neighbor matrix.
num_neighbors (jax.Array, optional) – Pre-allocated neighbors count array.
neighbor_matrix_shifts (Array | None)
rebuild_flags (Array | None)
- Returns:
neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – Neighbor matrix with neighbor atom indices.
num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – Number of neighbors found for each atom.
neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32) – Periodic shift vectors for each neighbor relationship.
- Return type:
See also
build_cell_listBuild cell list before querying
cell_listCombined build and query operation
Dual Cutoff Algorithm#
- nvalchemiops.jax.neighbors.naive_neighbor_list_dual_cutoff(positions, cutoff1, cutoff2, pbc=None, cell=None, max_neighbors1=None, max_neighbors2=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix1=None, neighbor_matrix2=None, neighbor_matrix_shifts1=None, neighbor_matrix_shifts2=None, num_neighbors1=None, num_neighbors2=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#
Compute neighbor lists for two cutoff distances using naive O(N^2) algorithm.
This function builds two neighbor matrices simultaneously for different cutoff distances, which is more efficient than calling the single-cutoff function twice.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.
cutoff1 (float) – First cutoff distance (typically smaller).
cutoff2 (float) – Second cutoff distance (typically larger).
pbc (jax.Array, shape (1, 3) or (3,), dtype=bool, optional) – Periodic boundary condition flags for each dimension.
cell (jax.Array, shape (1, 3, 3) or (3, 3), dtype=float32 or float64, optional) – Cell matrix defining lattice vectors in Cartesian coordinates.
max_neighbors1 (int, optional) – Maximum number of neighbors per atom for cutoff1.
max_neighbors2 (int, optional) – Maximum number of neighbors per atom for cutoff2.
half_fill (bool, optional - default = False) – If True, only store relationships where i < j to avoid double counting.
fill_value (int, optional) – Value to use for padding in neighbor matrices. Default is total_atoms.
return_neighbor_list (bool, optional - default = False) – If True, convert neighbor matrices to neighbor list (idx_i, idx_j) format.
neighbor_matrix1 (jax.Array, shape (total_atoms, max_neighbors1), dtype=int32, optional) – Pre-allocated first neighbor matrix.
neighbor_matrix2 (jax.Array, shape (total_atoms, max_neighbors2), dtype=int32, optional) – Pre-allocated second neighbor matrix.
neighbor_matrix_shifts1 (jax.Array, shape (total_atoms, max_neighbors1, 3), dtype=int32, optional) – Pre-allocated first shift matrix for PBC.
neighbor_matrix_shifts2 (jax.Array, shape (total_atoms, max_neighbors2, 3), dtype=int32, optional) – Pre-allocated second shift matrix for PBC.
num_neighbors1 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated first neighbor count array.
num_neighbors2 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated second neighbor count array.
shift_range_per_dimension (jax.Array, shape (1, 3), dtype=int32, optional) – Shift range in each dimension for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.
num_shifts_per_system (jax.Array, shape (1,), dtype=int32, optional) – Number of periodic shifts for the system. Pass in a pre-computed value to avoid recomputation for PBC systems.
max_shifts_per_system (int, optional) – Maximum per-system shift count. Pass in a pre-computed value to avoid recomputation for PBC systems.
wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.
rebuild_flags (Array | None)
- Returns:
results – Variable-length tuple depending on input parameters:
No PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)No PBC, list format:
(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)With PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)With PBC, list format:
(neighbor_list1, neighbor_ptr1, unit_shifts1, neighbor_list2, neighbor_ptr2, unit_shifts2)
- Return type:
See also
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_dual_cutoffCore warp launcher (no PBC)
nvalchemiops.neighbors.naive_dual_cutoff.naive_neighbor_matrix_pbc_dual_cutoffCore warp launcher (with PBC)
naive_neighbor_listSingle cutoff version
Batched Algorithms#
Batched Naive Algorithm#
- nvalchemiops.jax.neighbors.batch_naive_neighbor_list(positions, cutoff, batch_idx=None, batch_ptr=None, pbc=None, cell=None, max_neighbors=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix=None, neighbor_matrix_shifts=None, num_neighbors=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, max_atoms_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#
Compute neighbor list for batch of systems using naive O(N^2) algorithm.
Identifies all atom pairs within a specified cutoff distance for each system independently using a brute-force pairwise distance calculation. Supports both non-periodic and periodic boundary conditions.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Concatenated Cartesian coordinates for all systems.
cutoff (float) – Cutoff distance for neighbor detection in Cartesian units. Must be positive. Atoms within this distance are considered neighbors.
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – System index for each atom. If None, batch_ptr must be provided.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries. If None, batch_idx must be provided.
pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags for each system and dimension. True enables periodicity in that direction. Default is None (no PBC).
cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors. Required if pbc is provided.
max_neighbors (int, optional) – Maximum number of neighbors per atom.
half_fill (bool, optional) – If True, only store relationships where i < j. Default is False.
fill_value (int, optional) – Value to fill the neighbor matrix with. Default is total_atoms.
neighbor_matrix (jax.Array, optional) – Pre-allocated neighbor matrix.
neighbor_matrix_shifts (jax.Array, optional) – Pre-allocated shift matrix for PBC.
num_neighbors (jax.Array, optional) – Pre-allocated neighbors count array.
shift_range_per_dimension (jax.Array, optional) – Pre-computed shift range for PBC systems.
num_shifts_per_system (jax.Array, optional) – Number of periodic shifts per system.
max_shifts_per_system (int, optional) – Maximum per-system shift count (launch dimension).
max_atoms_per_system (int, optional) – Maximum atoms in any system.
wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.
return_neighbor_list (bool)
rebuild_flags (Array | None)
- Returns:
results – Variable-length tuple depending on input parameters.
- Return type:
Examples
Basic usage with batch_ptr:
>>> import jax.numpy as jnp >>> from nvalchemiops.jax.neighbors import batch_naive_neighbor_list >>> positions = jnp.zeros((200, 3), dtype=jnp.float32) >>> batch_ptr = jnp.array([0, 100, 200], dtype=jnp.int32) # 2 systems >>> cutoff = 2.5 >>> max_neighbors = 50 >>> neighbor_matrix, num_neighbors = batch_naive_neighbor_list( ... positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors ... )
With PBC:
>>> cell = jnp.eye(3, dtype=jnp.float32)[jnp.newaxis, :, :] * 10.0 >>> cell = jnp.repeat(cell, 2, axis=0) >>> pbc = jnp.ones((2, 3), dtype=jnp.bool_) >>> neighbor_matrix, num_neighbors, shifts = batch_naive_neighbor_list( ... positions, cutoff, batch_ptr=batch_ptr, max_neighbors=max_neighbors, ... pbc=pbc, cell=cell ... )
See also
nvalchemiops.neighbors.batch_naive.batch_naive_neighbor_matrixCore warp launcher
nvalchemiops.jax.neighbors.naive.naive_neighbor_listNon-batched version
batch_cell_listCell list method for large systems
Batched Cell List Algorithm#
- nvalchemiops.jax.neighbors.batch_cell_list(positions, cutoff, cell=None, pbc=None, batch_idx=None, batch_ptr=None, max_neighbors=None, max_total_cells=None, neighbor_matrix_shifts=None, return_neighbor_list=False)[source]#
Build and query spatial cell lists for batch of systems.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.
cutoff (float) – Cutoff distance for neighbor detection.
cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors. Default is identity matrix.
pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices for each atom.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries.
max_neighbors (int, optional) – Maximum number of neighbors per atom. If None, will be estimated.
max_total_cells (int, optional) – Maximum number of cells to allocate. If None, will be estimated.
neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Pre-allocated shift vectors array. If None, will be allocated internally. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input.
return_neighbor_list (bool, optional) – If True, convert result to COO neighbor list format. Default is False.
- Returns:
neighbor_data (jax.Array) – If
return_neighbor_list=False(default):neighbor_matrixwith shape (total_atoms, max_neighbors), dtype int32. Ifreturn_neighbor_list=True:neighbor_listwith shape (2, num_pairs), dtype int32, in COO format.neighbor_count (jax.Array) – If
return_neighbor_list=False:num_neighborswith shape (total_atoms,), dtype int32. Ifreturn_neighbor_list=True:neighbor_ptrwith shape (total_atoms + 1,), dtype int32.shift_data (jax.Array) – If
return_neighbor_list=False(default):neighbor_matrix_shiftswith shape (total_atoms, max_neighbors, 3), dtype int32. Ifreturn_neighbor_list=True:neighbor_list_shiftswith shape (num_pairs, 3), dtype int32. Periodic shift vectors for each neighbor relationship.
- Return type:
See also
batch_build_cell_listBuild cell list separately
batch_query_cell_listQuery cell list separately
batch_naive_neighbor_listNaive O(N^2) method
- nvalchemiops.jax.neighbors.batch_cell_list.batch_build_cell_list(positions, batch_idx=None, batch_ptr=None, cell=None, pbc=None, cutoff=5.0, max_total_cells=None)[source]#
Build spatial cell lists for batch of systems.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.
cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices.
pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.
cutoff (float, optional) – Cutoff distance. Default is 5.0.
max_total_cells (int, optional) – Maximum cells. If None, will be estimated.
- Returns:
cells_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32) – Number of cells in x, y, z directions for each system.
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom.
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom.
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell.
cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in
cell_atom_listfor each cell.cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell.
neighbor_search_radius (jax.Array, shape (num_systems, 3), dtype=int32) – Search radius in neighboring cells for each system.
cell_origin (jax.Array, shape (3,), dtype same as positions) – Cell origin point (currently zeros).
- Return type:
tuple[Array, Array, Array, Array, Array, Array, Array, Array]
Notes
When calling inside
jax.jit,max_total_cellsmust be provided to avoid callingestimate_batch_cell_list_sizes, which is not JIT-compatible.
- nvalchemiops.jax.neighbors.batch_cell_list.batch_query_cell_list(positions, batch_idx=None, batch_ptr=None, cutoff=5.0, cell=None, pbc=None, cells_per_dimension=None, atom_periodic_shifts=None, atom_to_cell_mapping=None, cell_atom_start_indices=None, cell_atom_list=None, atoms_per_cell_count=None, neighbor_search_radius=None, max_neighbors=None, neighbor_matrix=None, num_neighbors=None, neighbor_matrix_shifts=None, rebuild_flags=None)[source]#
Query batch cell lists to find neighbors.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.
cutoff (float, optional) – Cutoff distance.
cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices.
pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.
cells_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Cells per dimension.
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – Periodic shifts for each atom (output from
batch_build_cell_list).atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32, optional) – Cell mappings.
cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32, optional) – Start indices.
cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32, optional) – Cell atom list.
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32, optional) – Number of atoms assigned to each cell. Output from
batch_build_cell_list.neighbor_search_radius (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Search radius.
max_neighbors (int, optional) – Maximum neighbors per atom.
neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32, optional) – Pre-allocated neighbor matrix.
num_neighbors (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated neighbors count array.
neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32, optional) – Pre-allocated shift vectors array. Pass in a pre-shaped array to hint buffer reuse to XLA; note that JAX returns a new array rather than mutating the input.
rebuild_flags (Array | None)
- Returns:
neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – Neighbor matrix.
num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – Neighbors count.
neighbor_matrix_shifts (jax.Array, shape (total_atoms, max_neighbors, 3), dtype=int32) – Periodic shifts for each neighbor relationship.
- Return type:
Batched Dual Cutoff Algorithm#
- nvalchemiops.jax.neighbors.batch_naive_neighbor_list_dual_cutoff(positions, cutoff1, cutoff2, batch_idx=None, batch_ptr=None, pbc=None, cell=None, max_neighbors1=None, max_neighbors2=None, half_fill=False, fill_value=None, return_neighbor_list=False, neighbor_matrix1=None, neighbor_matrix2=None, neighbor_matrix_shifts1=None, neighbor_matrix_shifts2=None, num_neighbors1=None, num_neighbors2=None, shift_range_per_dimension=None, num_shifts_per_system=None, max_shifts_per_system=None, max_atoms_per_system=None, rebuild_flags=None, wrap_positions=True)[source]#
Compute batched neighbor lists for two cutoff distances using naive O(N^2) algorithm.
This function builds two neighbor matrices simultaneously for different cutoff distances in a batched manner, which is more efficient than calling the single-cutoff function twice.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Concatenated Cartesian coordinates for all systems.
cutoff1 (float) – First cutoff distance (typically smaller).
cutoff2 (float) – Second cutoff distance (typically larger).
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – System index for each atom.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts defining system boundaries.
pbc (jax.Array, shape (num_systems, 3) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags for each dimension.
cell (jax.Array, shape (num_systems, 3, 3) or (1, 3, 3), dtype=float32 or float64, optional) – Cell matrices defining lattice vectors in Cartesian coordinates.
max_neighbors1 (int, optional) – Maximum number of neighbors per atom for cutoff1.
max_neighbors2 (int, optional) – Maximum number of neighbors per atom for cutoff2.
half_fill (bool, optional - default = False) – If True, only store relationships where i < j to avoid double counting.
fill_value (int, optional) – Value to use for padding in neighbor matrices. Default is total_atoms.
return_neighbor_list (bool, optional - default = False) – If True, convert neighbor matrices to neighbor list (idx_i, idx_j) format.
neighbor_matrix1 (jax.Array, shape (total_atoms, max_neighbors1), dtype=int32, optional) – Pre-allocated first neighbor matrix.
neighbor_matrix2 (jax.Array, shape (total_atoms, max_neighbors2), dtype=int32, optional) – Pre-allocated second neighbor matrix.
neighbor_matrix_shifts1 (jax.Array, shape (total_atoms, max_neighbors1, 3), dtype=int32, optional) – Pre-allocated first shift matrix for PBC.
neighbor_matrix_shifts2 (jax.Array, shape (total_atoms, max_neighbors2, 3), dtype=int32, optional) – Pre-allocated second shift matrix for PBC.
num_neighbors1 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated first neighbor count array.
num_neighbors2 (jax.Array, shape (total_atoms,), dtype=int32, optional) – Pre-allocated second neighbor count array.
shift_range_per_dimension (jax.Array, shape (num_systems, 3), dtype=int32, optional) – Pre-computed shift ranges for PBC.
num_shifts_per_system (jax.Array, shape (num_systems,), dtype=int32, optional) – Number of periodic shifts per system.
max_shifts_per_system (int, optional) – Maximum per-system shift count (launch dimension).
max_atoms_per_system (int, optional) – Maximum number of atoms in any system (for PBC batched dispatch).
wrap_positions (bool, default=True) – If True, wrap input positions into the primary cell before neighbor search. Set to False when positions are already wrapped (e.g. by a preceding integration step) to save two GPU kernel launches per call.
rebuild_flags (Array | None)
- Returns:
results – Variable-length tuple depending on input parameters:
No PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix2, num_neighbors2)No PBC, list format:
(neighbor_list1, neighbor_ptr1, neighbor_list2, neighbor_ptr2)With PBC, matrix format:
(neighbor_matrix1, num_neighbors1, neighbor_matrix_shifts1, neighbor_matrix2, num_neighbors2, neighbor_matrix_shifts2)With PBC, list format:
(neighbor_list1, neighbor_ptr1, unit_shifts1, neighbor_list2, neighbor_ptr2, unit_shifts2)
- Return type:
See also
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_dual_cutoffCore warp launcher (no PBC)
nvalchemiops.neighbors.batch_naive_dual_cutoff.batch_naive_neighbor_matrix_pbc_dual_cutoffCore warp launcher (with PBC)
batch_naive_neighbor_listSingle cutoff version
Rebuild Detection#
- nvalchemiops.jax.neighbors.rebuild_detection.cell_list_needs_rebuild(current_positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc)[source]#
Detect if spatial cell list requires rebuilding due to atomic motion.
- Parameters:
current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates in Cartesian space.
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom from the existing cell list.
cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of spatial cells in x, y, z directions.
cell (jax.Array, shape (1, 3, 3)) – Unit cell matrix for coordinate transformations.
pbc (jax.Array, shape (3,), dtype=bool) – Periodic boundary condition flags for x, y, z directions.
- Returns:
rebuild_needed – True if any atom has moved to a different cell requiring rebuild.
- Return type:
jax.Array, shape (1,), dtype=bool
Notes
This function is not differentiable and should not be used in JAX transformations that require gradients.
See also
nvalchemiops.neighbors.rebuild_detection.check_cell_list_rebuildCore warp launcher
check_cell_list_rebuild_neededConvenience wrapper that returns Python bool
- nvalchemiops.jax.neighbors.rebuild_detection.neighbor_list_needs_rebuild(reference_positions, current_positions, skin_distance_threshold, cell=None, cell_inv=None, pbc=None)[source]#
Detect if neighbor list requires rebuilding due to excessive atomic motion.
When
cell,cell_invandpbcare all provided, uses minimum-image convention (MIC) so atoms crossing periodic boundaries are not spuriously flagged.- Parameters:
reference_positions (jax.Array, shape (total_atoms, 3)) – Atomic positions when the neighbor list was last built.
current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic positions to compare against reference.
skin_distance_threshold (float) – Maximum allowed displacement before neighbor list becomes invalid.
cell (jax.Array or None, optional) – Unit cell matrix, shape (1, 3, 3).
cell_inv (jax.Array or None, optional) – Inverse cell matrix, same shape as
cell.pbc (jax.Array or None, optional) – PBC flags, shape (3,), dtype=bool.
- Returns:
rebuild_needed – True if any atom has moved beyond skin distance.
- Return type:
jax.Array, shape (1,), dtype=bool
Notes
This function is not differentiable and should not be used in JAX transformations that require gradients.
See also
nvalchemiops.neighbors.rebuild_detection.check_neighbor_list_rebuildCore warp launcher
check_neighbor_list_rebuild_neededConvenience wrapper that returns Python bool
- nvalchemiops.jax.neighbors.rebuild_detection.check_cell_list_rebuild_needed(current_positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc)[source]#
Determine if spatial cell list requires rebuilding based on atomic motion.
This high-level convenience function determines if a spatial cell list needs to be reconstructed due to atomic movement. It uses GPU acceleration to efficiently detect when atoms have moved between spatial cells.
- Parameters:
current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates to check against existing cell assignments.
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates assigned to each atom from existing cell list.
cells_per_dimension (jax.Array, shape (3,), dtype=int32) – Number of spatial cells in x, y, z directions from existing cell list.
cell (jax.Array, shape (1, 3, 3)) – Current unit cell matrix for coordinate transformations.
pbc (jax.Array, shape (3,), dtype=bool) – Current periodic boundary condition flags for x, y, z directions.
- Returns:
needs_rebuild – True if any atom has moved to a different cell requiring cell list rebuild.
- Return type:
Notes
This function is not differentiable and should not be used in JAX transformations that require gradients.
See also
cell_list_needs_rebuildReturns jax.Array instead of bool
- nvalchemiops.jax.neighbors.rebuild_detection.check_neighbor_list_rebuild_needed(reference_positions, current_positions, skin_distance_threshold, cell=None, cell_inv=None, pbc=None)[source]#
Determine if neighbor list requires rebuilding based on atomic motion.
When
cell,cell_invandpbcare all provided, uses MIC displacement so periodic boundary crossings are handled correctly.- Parameters:
reference_positions (jax.Array, shape (total_atoms, 3)) – Atomic coordinates when the neighbor list was last constructed.
current_positions (jax.Array, shape (total_atoms, 3)) – Current atomic coordinates to compare against reference positions.
skin_distance_threshold (float) – Maximum allowed atomic displacement before neighbor list becomes invalid.
cell (jax.Array or None, optional) – Unit cell matrix, shape (1, 3, 3).
cell_inv (jax.Array or None, optional) – Inverse cell matrix, same shape as
cell.pbc (jax.Array or None, optional) – PBC flags, shape (3,), dtype=bool.
- Returns:
needs_rebuild – True if any atom has moved beyond skin distance requiring rebuild.
- Return type:
See also
neighbor_list_needs_rebuildReturns jax.Array instead of bool
Exceptions#
- exception nvalchemiops.jax.neighbors.NeighborOverflowError(max_neighbors, num_neighbors)[source]#
Bases:
ExceptionException raised when the number of neighbors exceeds the maximum allowed.
This error indicates that the pre-allocated neighbor matrix is too small to hold all discovered neighbors. Users should increase max_neighbors parameter or use a larger pre-allocated tensor.
Utility Functions#
Warning
The estimation and cell list building utilities are functional, however
due to the dynamic nature of the two it is not possible to jax.jit
compile workflows that combine the two. Users expecting to jax.jit
end-to-end workflows should explicitly set max_total_cells to cell
construction methods.
- nvalchemiops.jax.neighbors.estimate_cell_list_sizes(positions, cell, cutoff, pbc=None, buffer_factor=1.5)[source]#
Estimate required cell list sizes based on atomic density.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates in Cartesian space.
cell (jax.Array, shape (1, 3, 3), dtype=float32 or float64) – Cell matrix defining lattice vectors.
cutoff (float) – Cutoff distance for neighbor searching.
pbc (jax.Array, shape (3,) or (1, 3), dtype=bool, optional) – Periodic boundary condition flags. Default is all True.
buffer_factor (float, optional) – Buffer multiplier for cell count estimation. Default is 1.5.
- Returns:
max_total_cells (int) – Maximum total number of cells to allocate.
cells_per_dimension (jax.Array, shape (3,) or (1, 3), dtype=int32) – Estimated number of cells in each dimension.
neighbor_search_radius (jax.Array, shape (3,), dtype=int32) – Estimated search radius in neighboring cells.
- Return type:
Notes
This function estimates cell list parameters based on atomic positions and density. The actual number of cells used will be determined during cell list construction.
Warning
This function is not compatible with
jax.jit. The returnedmax_total_cellsis used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When usingcell_listorbuild_cell_listinsidejax.jit, providemax_total_cellsexplicitly to bypass this function.
- nvalchemiops.jax.neighbors.estimate_batch_cell_list_sizes(positions, batch_ptr=None, batch_idx=None, cell=None, cutoff=5.0, pbc=None, buffer_factor=1.5)[source]#
Estimate required batch cell list sizes.
- Parameters:
positions (jax.Array, shape (total_atoms, 3), dtype=float32 or float64) – Atomic coordinates.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32, optional) – Cumulative atom counts.
batch_idx (jax.Array, shape (total_atoms,), dtype=int32, optional) – Batch indices for each atom.
cell (jax.Array, shape (num_systems, 3, 3), dtype=float32 or float64, optional) – Cell matrices for each system.
cutoff (float, optional) – Cutoff distance. Default is 5.0.
pbc (jax.Array, shape (num_systems, 3), dtype=bool, optional) – PBC flags.
buffer_factor (float, optional) – Buffer multiplier. Default is 1.5.
- Returns:
max_total_cells (int) – Maximum total cells to allocate.
cells_per_dimension (jax.Array, shape (num_systems, 3)) – Cells per dimension for each system.
neighbor_search_radius (jax.Array, shape (num_systems, 3)) – Search radius for each system.
.. warning:: – This function is not compatible with
jax.jit. The returnedmax_total_cellsis used to determine array allocation sizes, which must be concrete (statically known) at JAX trace time. When usingbatch_cell_listorbatch_build_cell_listinsidejax.jit, providemax_total_cellsexplicitly to bypass this function.
- Return type:
- nvalchemiops.jax.neighbors.neighbor_utils.allocate_cell_list(total_atoms, max_total_cells, neighbor_search_radius)[source]#
Allocate memory tensors for cell list data structures.
- Parameters:
- Returns:
cells_per_dimension (jax.Array, shape (3,) or (num_systems, 3), dtype=int32) – Number of cells in x, y, z directions (to be filled by build_cell_list).
neighbor_search_radius (jax.Array, shape (3,) or (num_systems, 3), dtype=int32) – Radius of neighboring cells to search (passed through for convenience).
atom_periodic_shifts (jax.Array, shape (total_atoms, 3), dtype=int32) – Periodic boundary crossings for each atom (to be filled by build_cell_list).
atom_to_cell_mapping (jax.Array, shape (total_atoms, 3), dtype=int32) – 3D cell coordinates for each atom (to be filled by build_cell_list).
atoms_per_cell_count (jax.Array, shape (max_total_cells,), dtype=int32) – Number of atoms in each cell (to be filled by build_cell_list).
cell_atom_start_indices (jax.Array, shape (max_total_cells,), dtype=int32) – Starting index in cell_atom_list for each cell (to be filled by build_cell_list).
cell_atom_list (jax.Array, shape (total_atoms,), dtype=int32) – Flattened list of atom indices organized by cell (to be filled by build_cell_list).
- Return type:
Notes
This is a pure JAX utility function with no warp dependencies. It pre-allocates all tensors needed for cell list construction, supporting both single-system and batched operations based on the shape of neighbor_search_radius.
See also
nvalchemiops.neighbors.cell_list.build_cell_listWarp launcher that uses these tensors
nvalchemiops.jax.neighbors.cell_list.build_cell_listHigh-level JAX wrapper
nvalchemiops.jax.neighbors.batch_cell_list.batch_build_cell_listBatched version
- nvalchemiops.jax.neighbors.neighbor_utils.prepare_batch_idx_ptr(batch_idx, batch_ptr, num_atoms)[source]#
Prepare batch index and pointer tensors from either representation.
Utility function to ensure both batch_idx and batch_ptr are available, computing one from the other if needed.
- Parameters:
batch_idx (jax.Array | None, shape (total_atoms,), dtype=int32) – Array indicating the batch index for each atom.
batch_ptr (jax.Array | None, shape (num_systems + 1,), dtype=int32) – Array indicating the start index of each batch in the atom list.
num_atoms (int) – Total number of atoms across all systems.
- Returns:
batch_idx (jax.Array, shape (total_atoms,), dtype=int32) – Prepared batch index tensor.
batch_ptr (jax.Array, shape (num_systems + 1,), dtype=int32) – Prepared batch pointer tensor.
- Raises:
ValueError – If both batch_idx and batch_ptr are None.
- Return type:
Notes
This is a pure JAX utility function with no warp dependencies. It provides convenience for batch operations by converting between dense (batch_idx) and sparse (batch_ptr) batch representations.
See also
nvalchemiops.jax.neighbors.batch_naive.batch_naive_neighbor_listUses this for batch setup
nvalchemiops.jax.neighbors.batch_cell_list.batch_cell_listUses this for batch setup
- nvalchemiops.jax.neighbors.neighbor_utils.estimate_max_neighbors(cutoff, atomic_density=0.2, safety_factor=1.0)[source]#
Estimate maximum neighbors per atom based on volume calculations.
Uses atomic density and cutoff volume to estimate a conservative upper bound on the number of neighbors any atom could have. This is a pure Python function with no framework dependencies.
- Parameters:
- Returns:
max_neighbors_estimate – Conservative estimate of maximum neighbors per atom. Returns 0 for empty systems.
- Return type:
Notes
The estimation uses the formula:
\[\text{neighbors} = \text{safety\_factor} \times \text{density} \times V_{\text{sphere}}\]where the cutoff sphere volume is:
\[V_{\text{sphere}} = \frac{4}{3}\pi r^3\]The result is rounded up to the multiple of 16 for memory alignment.
- nvalchemiops.jax.neighbors.neighbor_utils.get_neighbor_list_from_neighbor_matrix(neighbor_matrix, num_neighbors, neighbor_shift_matrix=None, fill_value=-1)[source]#
Convert neighbor matrix format to neighbor list format.
- Parameters:
neighbor_matrix (jax.Array, shape (total_atoms, max_neighbors), dtype=int32) – The neighbor matrix with neighbor atom indices.
num_neighbors (jax.Array, shape (total_atoms,), dtype=int32) – The number of neighbors for each atom.
neighbor_shift_matrix (jax.Array | None, shape (total_atoms, max_neighbors, 3), dtype=int32) – Optional neighbor shift matrix with periodic shift vectors.
fill_value (int, default=-1) – The fill value used in the neighbor matrix to indicate empty slots. This is used to create a mask from the neighbor matrix.
- Returns:
neighbor_list (jax.Array, shape (2, num_pairs), dtype=int32) – The neighbor list in COO format [source_atoms, target_atoms].
neighbor_ptr (jax.Array, shape (total_atoms + 1,), dtype=int32) – CSR-style pointer array where neighbor_ptr[i]:neighbor_ptr[i+1] gives the range of neighbors for atom i in the flattened neighbor list.
neighbor_list_shifts (jax.Array, shape (num_pairs, 3), dtype=int32) – The neighbor shift vectors (only returned if neighbor_shift_matrix is not None).
- Raises:
ValueError – If the max number of neighbors is larger than the neighbor matrix width.
- Return type:
Notes
This is a pure JAX utility function with no warp dependencies. It converts from the fixed-width matrix format to the variable-width list format by masking out fill values and flattening the result.
See also
nvalchemiops.jax.neighbors.naive.naive_neighbor_listUses this for format conversion
nvalchemiops.jax.neighbors.cell_list.cell_listUses this for format conversion
- nvalchemiops.jax.neighbors.neighbor_utils.compute_naive_num_shifts(cell, cutoff, pbc)[source]#
Compute periodic image shifts needed for neighbor searching.
- Parameters:
cell (jax.Array, shape (num_systems, 3, 3)) – Cell matrices defining lattice vectors in Cartesian coordinates. Each 3x3 matrix represents one system’s periodic cell.
cutoff (float) – Cutoff distance for neighbor searching in Cartesian units. Must be positive and typically less than half the minimum cell dimension.
pbc (jax.Array, shape (num_systems, 3), dtype=bool) – Periodic boundary condition flags for each dimension. True enables periodicity in that direction.
- Returns:
shift_range (jax.Array, shape (num_systems, 3), dtype=int32) – Maximum shift indices in each dimension for each system.
num_shifts (jax.Array, shape (num_systems,), dtype=int32) – Number of periodic shifts for each system.
max_shifts (int) – Maximum per-system shift count across all systems.
- Raises:
ValueError – If any per-system shift count exceeds int32 range.
- Return type:
See also
nvalchemiops.neighbors.neighbor_utils._compute_naive_num_shiftsWarp kernel
Notes
This function must be called outside
jax.jitscope. The returnedmax_shiftsis a Python int needed for determining launch dimensions, which cannot be traced. This is an inherent limitation: array shapes must be known at trace time in JAX.