Neighbor Lists#
Neighbor lists identify atom pairs within a cutoff distance—the foundation for all forms of interatomic interactions including but not limited to: machine-learned interatomic potentials, dispersion corrections, and so on. ALCHEMI Toolkit-Ops provides GPU-accelerated neighbor list algorithms via NVIDIA Warp with bindings for both PyTorch and JAX.
Tip
Start with the unified neighbor_list function
(neighbor_list() for PyTorch,
neighbor_list() for JAX).
It automatically selects the best algorithm for your system size and handles
both single and batched inputs.
Why Neighbor Lists Matter for Performance#
Neighbor list construction can dominate runtime in atomistic foundation models:
Naive algorithms scale as (O(N^2)): Checking all atom pairs becomes prohibitive for systems with a large number of atoms (approx. 2000 atoms, but depends on structure and hardware)
Repeated construction: Training loops and MD simulations rebuild neighbor lists frequently—every step or every few steps
Memory bandwidth: Large neighbor matrices can bottleneck GPU throughput
ALCHEMI Toolkit-Ops addresses these challenges with O(N) cell list algorithms, efficient batch processing for heterogeneous datasets, and memory layouts optimized for GPU access patterns. See performance considerations for guidance.
Quick Start#
The neighbor_list function provides a unified interface that automatically
dispatches to the optimal algorithm based on system size and whether batch
indices are provided.
Single system with >2000 atoms
from nvalchemiops.torch.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="cell_list"
)
Dispatches to cell_list() — (O(N)) algorithm
using spatial decomposition.
Single system with <2000 atoms
from nvalchemiops.torch.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="naive"
)
Dispatches to naive_neighbor_list() — (O(N^2))
algorithm with lower overhead.
Multiple systems with >2000 atoms each
from nvalchemiops.torch.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cells, pbc=pbc,
batch_idx=batch_idx, method="batch_cell_list"
)
Dispatches to batch_cell_list() — (O(N))
algorithm for heterogeneous batches.
Multiple systems with <2000 atoms each
from nvalchemiops.torch.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cells, pbc=pbc,
batch_idx=batch_idx, method="batch_naive"
)
Dispatches to batch_naive_neighbor_list() —
(O(N^2)) algorithm for batched small systems.
Single system with >5000 atoms
from nvalchemiops.jax.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="cell_list"
)
Dispatches to cell_list() — (O(N)) algorithm
using spatial decomposition.
Single system with <5000 atoms
from nvalchemiops.jax.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="naive"
)
Dispatches to naive_neighbor_list() — (O(N^2))
algorithm with lower overhead.
Multiple systems with >5000 atoms each
from nvalchemiops.jax.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cells, pbc=pbc,
batch_idx=batch_idx, method="batch_cell_list"
)
Dispatches to batch_cell_list() — (O(N))
algorithm for heterogeneous batches.
Multiple systems with <5000 atoms each
from nvalchemiops.jax.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cells, pbc=pbc,
batch_idx=batch_idx, method="batch_naive"
)
Dispatches to batch_naive_neighbor_list() —
(O(N^2)) algorithm for batched small systems.
Note
When method is not specified, neighbor_list automatically selects based on
average system size and whether batch_idx is provided.
The crossover point depends on system density and cutoff radius—benchmark
your workload to find the optimal threshold. The default threshold is 2000
atoms for PyTorch and 5000 atoms for JAX.
Data Formats#
ALCHEMI Toolkit-Ops supports two output formats for neighbor data:
Neighbor Matrix (default)
: Fixed-size array of shape (num_atoms, max_neighbors) where each row
contains the neighbor indices for that atom, padded with a fill value.
Returns (neighbor_matrix, num_neighbors, neighbor_matrix_shifts).
Neighbor List (COO format)
: Sparse array of shape (2, num_pairs) containing [source_atoms, target_atoms].
Returns (neighbor_list, neighbor_ptr, neighbor_list_shifts) where
neighbor_ptr is a CSR-style pointer array. The first set of atoms (nominally
source_atoms) is guaranteed to be sorted.
When to Use Each Format#
Neighbor Matrix is preferred when:
Using
torch.compileorjax.jit(fixed memory layout avoids graph breaks)Systems have dense, uniform neighbor distributions
Cache-friendly access patterns are important
Neighbor List (COO) is preferred when:
Integrating with graph neural network libraries (PyG, DGL)
Systems are sparse with highly variable neighbors per atom
Memory efficiency is critical
Switching Formats#
# Get COO format directly
neighbor_list_coo, neighbor_ptr, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True
)
# Or convert from matrix format
from nvalchemiops.torch.neighbors.neighbor_utils import get_neighbor_list_from_neighbor_matrix
neighbor_list_coo, neighbor_ptr, shifts_coo = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix, num_neighbors, neighbor_matrix_shifts, fill_value=num_atoms
)
# Get COO format directly
neighbor_list_coo, neighbor_ptr, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True
)
# Or convert from matrix format
from nvalchemiops.jax.neighbors.neighbor_utils import get_neighbor_list_from_neighbor_matrix
neighbor_list_coo, neighbor_ptr, shifts_coo = get_neighbor_list_from_neighbor_matrix(
neighbor_matrix, num_neighbors, neighbor_matrix_shifts, fill_value=num_atoms
)
Warning
Setting return_neighbor_list=True incurs a conversion overhead. If you need
both formats, compute the matrix format first and convert as needed.
Method Dispatch#
When method=None, neighbor_list selects an algorithm using the following
logic:
If
cutoff2is provided, then dual cutoff methodIf average atoms per system exceeds the threshold, then
"cell_list"Otherwise,
"naive"(\(N^2\) scaling algorithm)If
batch_idxorbatch_ptris provided, then prepend"batch_"to the method
Available Methods#
Method |
Algorithm |
Use Case |
|---|---|---|
|
(O(N^2)) pairwise |
Small single systems |
|
(O(N)) spatial decomposition |
Large single systems |
|
(O(N^2)) per system |
Batched small systems |
|
(O(N)) per system |
Batched large systems |
|
(O(N^2)) with two cutoffs |
Multi-range potentials |
|
Batched dual cutoff |
Batched multi-range |
Override automatic selection by passing the method parameter:
# Force cell_list on a small system for testing
from nvalchemiops.torch.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="cell_list"
)
# Force cell_list on a small system for testing
from nvalchemiops.jax.neighbors import neighbor_list
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, method="cell_list"
)
Performance Tuning#
Key Parameters#
max_neighbors
: Maximum neighbors per atom; determines the width of neighbor_matrix.
Auto-estimated if not provided. Pass this value explicitly to neighbor_list
calls if you have an accurate value to reduce memory requirements as well
as improve kernel performance. The estimate_max_neighbors() method will
otherwise provide a very conservative estimate based on atomic
density.
atomic_density
: Atomic density in atoms per unit volume, used by estimate_max_neighbors().
Default is 0.2. Increase for dense systems to avoid truncated neighbor lists.
safety_factor
: Multiplier applied to the neighbor estimate. Default is 1.0. Provides
headroom for density fluctuations.
max_nbins
: Maximum number of spatial cells for cell list decomposition. Default is 1000.
Limits memory usage for very large simulation boxes.
wrap_positions
: Controls whether positions are wrapped into the primary cell before neighbor
search. Default is True. Set to False when positions are already wrapped
(e.g. after an integration step that keeps coordinates inside the box) to skip
two GPU kernel launches per call.
Only applies to naive methods; cell list methods handle wrapping internally.
shift_range_per_dimension, num_shifts_per_system, max_shifts_per_system
: Optional cached naive-PBC metadata for advanced workflows. Use
compute_naive_num_shifts() to compute these values outside repeated calls,
especially for JAX where max_shifts_per_system must be concrete outside
jax.jit. Older shift_offset and total_shifts inputs are no longer part
of the public Torch/JAX API.
Estimation Utilities#
The estimate_max_neighbors() function estimates
the maximum number of neighbors \(n\) any atom could have based on the cutoff sphere
volume (\(r\)) and atomic density \(\rho\), with an additional safety factor (\(S\)):
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
from nvalchemiops.torch.neighbors import estimate_cell_list_sizes
max_neighbors = estimate_max_neighbors(
cutoff,
atomic_density=0.15,
safety_factor=1.0
)
max_total_cells, neighbor_search_radius = estimate_cell_list_sizes(
cell, pbc, cutoff, max_nbins=1000
)
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
from nvalchemiops.jax.neighbors import estimate_cell_list_sizes
max_neighbors = estimate_max_neighbors(
cutoff,
atomic_density=0.15,
safety_factor=1.0
)
max_total_cells, neighbor_search_radius, _ = estimate_cell_list_sizes(
positions, cell, cutoff, pbc=pbc, buffer_factor=1.5
)
Note
The JAX estimate_cell_list_sizes takes positions as its first argument
(to infer array sizes) and uses a buffer_factor parameter instead of
max_nbins. It also returns a 3-tuple.
This function is not compatible with jax.jit because it derives
concrete array sizes from traced data.
Setting atomic_density: This should reflect the expected atomic density of
your system in atoms per unit volume (using the same length units as cutoff).
If set too low, the neighbor matrix may be too narrow and a
NeighborOverflowError will be raised at runtime. If set too high, memory is
wasted on unused columns.
Setting safety_factor: This multiplier provides headroom for local density
fluctuations (e.g., atoms clustering in one region). The default of 1.0 is
typically sufficient for systems with reasonably uniform density (e.g. standard
public datasets). Increase it for systems with significant density variation
where atoms may cluster in one region.
Tip
Users should check the “convergence” of the neighbor list computation by checking the respective array containing the number of neighbors per atom, against the maximum estimated number of neighbors. For optimal performance these two factors should be close: if the actual number of neighbors per atom is low relative to the estimated number, the allocated neighbor matrix will be very sparse and memory inefficient (i.e. most elements will be padding). If the actual number exceeds the estimate, neighborhoods will be truncated and there is no guarantee that the nearest neighbors are included.
Pre-allocation for Repeated Calculations#
Pre-allocating output arrays avoids repeated memory allocation overhead when computing neighbor lists in a loop (e.g., during MD simulation or training).
Pre-allocation also enables torch.compile compatibility by ensuring fixed
tensor shapes.
import torch
from nvalchemiops.torch.neighbors import neighbor_list
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
num_atoms = positions.shape[0]
max_neighbors = estimate_max_neighbors(cutoff, atomic_density=0.15)
# Pre-allocate tensors
neighbor_matrix = torch.full(
(num_atoms, max_neighbors), num_atoms, dtype=torch.int32, device="cuda"
)
neighbor_matrix_shifts = torch.zeros(
(num_atoms, max_neighbors, 3), dtype=torch.int32, device="cuda"
)
num_neighbors = torch.zeros(num_atoms, dtype=torch.int32, device="cuda")
# Pass pre-allocated tensors
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc,
neighbor_matrix=neighbor_matrix,
neighbor_matrix_shifts=neighbor_matrix_shifts,
num_neighbors=num_neighbors,
fill_value=num_atoms
)
For cell list methods, you can also pre-allocate the spatial data structures:
from nvalchemiops.torch.neighbors import neighbor_list
from nvalchemiops.torch.neighbors.cell_list import estimate_cell_list_sizes
from nvalchemiops.torch.neighbors.neighbor_utils import allocate_cell_list
max_total_cells, neighbor_search_radius = estimate_cell_list_sizes(cell, pbc, cutoff)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius, device)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc,
cells_per_dimension=cells_per_dimension,
neighbor_search_radius=neighbor_search_radius,
atom_periodic_shifts=atom_periodic_shifts,
atom_to_cell_mapping=atom_to_cell_mapping,
atoms_per_cell_count=atoms_per_cell_count,
cell_atom_start_indices=cell_atom_start_indices,
cell_atom_list=cell_atom_list
)
Pre-allocating arrays with known shapes enables jax.jit compilation by
ensuring static array dimensions. Note that JAX functions always return
new arrays rather than mutating inputs in-place.
import jax.numpy as jnp
from nvalchemiops.jax.neighbors import neighbor_list
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
num_atoms = positions.shape[0]
max_neighbors = estimate_max_neighbors(cutoff, atomic_density=0.15)
# Pre-allocate arrays (used as shape hints for XLA)
neighbor_matrix = jnp.full(
(num_atoms, max_neighbors), num_atoms, dtype=jnp.int32
)
neighbor_matrix_shifts = jnp.zeros(
(num_atoms, max_neighbors, 3), dtype=jnp.int32
)
num_neighbors = jnp.zeros(num_atoms, dtype=jnp.int32)
# Pass pre-allocated arrays
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc,
neighbor_matrix=neighbor_matrix,
neighbor_matrix_shifts=neighbor_matrix_shifts,
num_neighbors=num_neighbors,
fill_value=num_atoms
)
For cell list methods, you can also pre-allocate the spatial data structures:
from nvalchemiops.jax.neighbors import estimate_cell_list_sizes
from nvalchemiops.jax.neighbors.neighbor_utils import allocate_cell_list
max_total_cells, neighbor_search_radius, _ = estimate_cell_list_sizes(
positions, cell, cutoff, pbc=pbc
)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc,
cells_per_dimension=cells_per_dimension,
neighbor_search_radius=neighbor_search_radius,
atom_periodic_shifts=atom_periodic_shifts,
atom_to_cell_mapping=atom_to_cell_mapping,
atoms_per_cell_count=atoms_per_cell_count,
cell_atom_start_indices=cell_atom_start_indices,
cell_atom_list=cell_atom_list
)
Warning
If max_neighbors is too small, neighbors beyond that limit are silently
dropped. Monitor num_neighbors.max() (PyTorch) or jnp.max(num_neighbors)
(JAX) against your max_neighbors setting to detect truncation.
Usage Patterns#
Basic Single System#
import torch
from nvalchemiops.torch.neighbors import neighbor_list
# Create atomic system
positions = torch.rand(1000, 3, device="cuda") * 20.0
cell = torch.eye(3, device="cuda").unsqueeze(0) * 20.0
pbc = torch.tensor([True, True, True], device="cuda")
cutoff = 5.0
# Compute neighbors (automatic method selection)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc
)
print(f"Average neighbors: {num_neighbors.float().mean():.1f}")
import jax
import jax.numpy as jnp
from nvalchemiops.jax.neighbors import neighbor_list
# Create atomic system
key = jax.random.PRNGKey(0)
positions = jax.random.uniform(key, (1000, 3), dtype=jnp.float32) * 20.0
cell = jnp.eye(3, dtype=jnp.float32)[None, ...] * 20.0
pbc = jnp.array([[True, True, True]])
cutoff = 5.0
# Compute neighbors (automatic method selection)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc
)
print(f"Average neighbors: {jnp.mean(num_neighbors.astype(jnp.float32)):.1f}")
Batch Processing#
import torch
from nvalchemiops.torch.neighbors import neighbor_list
# Three systems of different sizes
positions = torch.cat([
torch.rand(100, 3, device="cuda"), # System 0
torch.rand(150, 3, device="cuda"), # System 1
torch.rand(80, 3, device="cuda"), # System 2
])
batch_idx = torch.cat([
torch.zeros(100, dtype=torch.int32, device="cuda"),
torch.ones(150, dtype=torch.int32, device="cuda"),
torch.full((80,), 2, dtype=torch.int32, device="cuda"),
])
cells = torch.stack([
torch.eye(3, device="cuda") * 10.0,
torch.eye(3, device="cuda") * 12.0,
torch.eye(3, device="cuda") * 8.0,
])
pbc = torch.tensor([
[True, True, True],
[True, True, False],
[False, False, False],
], device="cuda")
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff=5.0, cell=cells, pbc=pbc, batch_idx=batch_idx
)
import jax
import jax.numpy as jnp
from nvalchemiops.jax.neighbors import neighbor_list
# Three systems of different sizes
key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)
positions = jnp.concatenate([
jax.random.uniform(k1, (100, 3), dtype=jnp.float32), # System 0
jax.random.uniform(k2, (150, 3), dtype=jnp.float32), # System 1
jax.random.uniform(k3, (80, 3), dtype=jnp.float32), # System 2
])
batch_idx = jnp.concatenate([
jnp.zeros(100, dtype=jnp.int32),
jnp.ones(150, dtype=jnp.int32),
jnp.full((80,), 2, dtype=jnp.int32),
])
batch_ptr = jnp.array([0, 100, 250, 330], dtype=jnp.int32)
cells = jnp.stack([
jnp.eye(3, dtype=jnp.float32) * 10.0,
jnp.eye(3, dtype=jnp.float32) * 12.0,
jnp.eye(3, dtype=jnp.float32) * 8.0,
])
pbc = jnp.array([
[True, True, True],
[True, True, False],
[False, False, False],
])
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff=5.0, cell=cells, pbc=pbc,
batch_idx=batch_idx, batch_ptr=batch_ptr
)
Half-Fill Mode#
Store only half of neighbor pairs to avoid double-counting in symmetric calculations:
# Full: stores both (i,j) and (j,i)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, half_fill=False
)
# Half: stores only (i,j) where i < j (or with non-zero periodic shift)
neighbor_matrix_half, num_neighbors_half, shifts_half = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, half_fill=True
)
# half_fill=True produces ~50% of the pairs
# Full: stores both (i,j) and (j,i)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, half_fill=False
)
# Half: stores only (i,j) where i < j (or with non-zero periodic shift)
neighbor_matrix_half, num_neighbors_half, shifts_half = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, half_fill=True
)
# half_fill=True produces ~50% of the pairs
Note
The half_fill parameter is currently supported only by the naive and
batch_naive methods in JAX. The cell_list and batch_cell_list methods
silently ignore this parameter and always produce full neighbor lists.
Build/Query Separation for MD Workflows#
For molecular dynamics, separate building and querying allows caching the spatial data structure:
from nvalchemiops.torch.neighbors.cell_list import (
build_cell_list, query_cell_list, estimate_cell_list_sizes
)
from nvalchemiops.torch.neighbors.neighbor_utils import (
allocate_cell_list, estimate_max_neighbors
)
# Setup (once)
max_total_cells, neighbor_search_radius = estimate_cell_list_sizes(cell, pbc, cutoff)
cell_list_cache = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius, device)
max_neighbors = estimate_max_neighbors(cutoff)
neighbor_matrix = torch.full((num_atoms, max_neighbors), -1, dtype=torch.int32, device=device)
neighbor_shifts = torch.zeros((num_atoms, max_neighbors, 3), dtype=torch.int32, device=device)
num_neighbors = torch.zeros(num_atoms, dtype=torch.int32, device=device)
# MD loop
for step in range(num_steps):
# Build cell list (expensive, done when atoms change cells)
build_cell_list(positions, cutoff, cell, pbc, *cell_list_cache)
# Query neighbors (cheaper)
neighbor_matrix.fill_(-1)
neighbor_shifts.zero_()
num_neighbors.zero_()
query_cell_list(
positions, cutoff, cell, pbc, *cell_list_cache,
neighbor_matrix, neighbor_shifts, num_neighbors
)
forces = compute_forces(positions, neighbor_matrix, num_neighbors, ...)
positions = integrate(positions, forces, dt)
from nvalchemiops.jax.neighbors import (
build_cell_list, query_cell_list, estimate_cell_list_sizes
)
from nvalchemiops.jax.neighbors.neighbor_utils import allocate_cell_list
from nvalchemiops.neighbors.neighbor_utils import estimate_max_neighbors
# Setup (once, outside jit)
max_total_cells, neighbor_search_radius, _ = estimate_cell_list_sizes(
positions, cell, cutoff, pbc=pbc
)
cell_list_cache = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius)
max_neighbors = estimate_max_neighbors(cutoff)
# MD loop (JAX returns new arrays each step; no in-place mutation)
for step in range(num_steps):
# Build cell list (expensive, done when atoms change cells)
cell_list_cache = build_cell_list(
positions, cutoff, cell, pbc, *cell_list_cache
)
# Query neighbors (cheaper)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = cell_list_cache
neighbor_matrix, num_neighbors, neighbor_shifts = query_cell_list(
positions, cutoff, cell, pbc,
cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list,
neighbor_search_radius, max_neighbors=max_neighbors
)
forces = compute_forces(positions, neighbor_matrix, num_neighbors, ...)
positions = integrate(positions, forces, dt)
Note
JAX follows a functional paradigm: build_cell_list and query_cell_list
return new arrays rather than mutating buffers in-place. Reassign the
returned values each step.
Rebuild Detection with Skin Distance#
Avoid rebuilding neighbor lists every step by using a skin distance:
from nvalchemiops.torch.neighbors.cell_list import (
build_cell_list, query_cell_list, estimate_cell_list_sizes
)
from nvalchemiops.torch.neighbors.neighbor_utils import allocate_cell_list
from nvalchemiops.torch.neighbors.rebuild_detection import cell_list_needs_rebuild
cutoff = 5.0
skin_distance = 1.0
effective_cutoff = cutoff + skin_distance
# Build with effective cutoff (includes skin)
max_total_cells, neighbor_search_radius = estimate_cell_list_sizes(
cell, pbc, effective_cutoff
)
cell_list_cache = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius, device)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = cell_list_cache
build_cell_list(positions, effective_cutoff, cell, pbc, *cell_list_cache)
for step in range(num_steps):
positions = integrate(positions, forces, dt)
# Check if any atom moved to a different cell
needs_rebuild = cell_list_needs_rebuild(
positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc
)
if needs_rebuild.item():
build_cell_list(positions, effective_cutoff, cell, pbc, *cell_list_cache)
# Query with actual cutoff (not effective)
query_cell_list(positions, cutoff, cell, pbc, *cell_list_cache, ...)
from nvalchemiops.jax.neighbors import (
build_cell_list, query_cell_list, estimate_cell_list_sizes
)
from nvalchemiops.jax.neighbors.neighbor_utils import allocate_cell_list
from nvalchemiops.jax.neighbors.rebuild_detection import cell_list_needs_rebuild
cutoff = 5.0
skin_distance = 1.0
effective_cutoff = cutoff + skin_distance
# Build with effective cutoff (includes skin)
max_total_cells, neighbor_search_radius, _ = estimate_cell_list_sizes(
positions, cell, effective_cutoff, pbc=pbc
)
cell_list_cache = allocate_cell_list(num_atoms, max_total_cells, neighbor_search_radius)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = cell_list_cache
cell_list_cache = build_cell_list(
positions, effective_cutoff, cell, pbc, *cell_list_cache
)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = cell_list_cache
for step in range(num_steps):
positions = integrate(positions, forces, dt)
# Check if any atom moved to a different cell
needs_rebuild = cell_list_needs_rebuild(
positions, atom_to_cell_mapping, cells_per_dimension, cell, pbc
)
if needs_rebuild.item():
cell_list_cache = build_cell_list(
positions, effective_cutoff, cell, pbc, *cell_list_cache
)
(
cells_per_dimension, neighbor_search_radius,
atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list
) = cell_list_cache
# Query with actual cutoff (not effective)
neighbor_matrix, num_neighbors, neighbor_shifts = query_cell_list(
positions, cutoff, cell, pbc,
cells_per_dimension, atom_periodic_shifts, atom_to_cell_mapping,
atoms_per_cell_count, cell_atom_start_indices, cell_atom_list,
neighbor_search_radius
)
Dual Cutoff#
Compute two neighbor lists with different cutoffs simultaneously:
from nvalchemiops.torch.neighbors import neighbor_list
cutoff1, cutoff2 = 3.0, 6.0
(
neighbor_matrix1, num_neighbors1, shifts1,
neighbor_matrix2, num_neighbors2, shifts2
) = neighbor_list(
positions, cutoff1, cutoff2=cutoff2, cell=cell, pbc=pbc
)
# neighbor_matrix1: neighbors within cutoff1
# neighbor_matrix2: neighbors within cutoff2 (superset of cutoff1)
from nvalchemiops.jax.neighbors import neighbor_list
cutoff1, cutoff2 = 3.0, 6.0
(
neighbor_matrix1, num_neighbors1, shifts1,
neighbor_matrix2, num_neighbors2, shifts2
) = neighbor_list(
positions, cutoff1, cutoff2=cutoff2, cell=cell, pbc=pbc
)
# neighbor_matrix1: neighbors within cutoff1
# neighbor_matrix2: neighbors within cutoff2 (superset of cutoff1)
This concludes the high-level documentation for neighbor lists: you should now
be able to integrate nvalchemiops routines for your neighbor list requirements,
and consult the API reference for PyTorch
, JAX, and Warp for further details.