Note
Go to the end to download the full example code.
JAX Neighbor List Example#
This example demonstrates how to use the JAX neighbor list API in nvalchemiops for computing neighbor lists in periodic systems.
In this example you will learn:
How to use the unified
neighbor_list()API with JAX arraysMatrix format vs COO (list) format outputs
Comparing
naive_neighbor_listandcell_listalgorithmsUsing
half_fillmode for symmetric neighbor listsValidating neighbor distances are within cutoff
jax.jitcompilation of the neighbor matrix
Important
This example is for educational purposes. Do not use it for performance benchmarking, as the code includes print statements and small system sizes that are not representative of production workloads.
import sys
try:
import jax
import jax.numpy as jnp
except ImportError:
print(
"This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'"
)
sys.exit(0)
try:
from nvalchemiops.jax.neighbors import neighbor_list
except Exception as exc:
print(
f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime."
)
sys.exit(0)
from nvalchemiops.jax.neighbors.cell_list import cell_list
from nvalchemiops.jax.neighbors.naive import naive_neighbor_list
Setup#
JAX handles device placement automatically. We’ll create a random periodic system to demonstrate the neighbor list API.
print("=" * 70)
print("JAX NEIGHBOR LIST EXAMPLE")
print("=" * 70)
# System parameters
num_atoms = 200
box_size = 15.0
cutoff = 5.0
# Create random atomic positions using JAX random
key = jax.random.PRNGKey(42)
positions = jax.random.uniform(key, (num_atoms, 3), dtype=jnp.float32) * box_size
# Create a cubic periodic cell: (1, 3, 3) shape
cell = jnp.eye(3, dtype=jnp.float32)[None, ...] * box_size
# Enable periodic boundary conditions in all directions: (1, 3) shape
pbc = jnp.array([[True, True, True]])
print("\nSystem configuration:")
print(f" Number of atoms: {num_atoms}")
print(f" Box size: {box_size} Å")
print(f" Cutoff distance: {cutoff} Å")
print(f" Positions shape: {positions.shape}")
print(f" Cell shape: {cell.shape}")
print(f" PBC shape: {pbc.shape}")
======================================================================
JAX NEIGHBOR LIST EXAMPLE
======================================================================
System configuration:
Number of atoms: 200
Box size: 15.0 Å
Cutoff distance: 5.0 Å
Positions shape: (200, 3)
Cell shape: (1, 3, 3)
PBC shape: (1, 3)
Unified API - Matrix Format (default)#
The neighbor_list() function automatically selects the best algorithm
based on system size. For small systems (< 5000 atoms), it uses the naive
O(N²) algorithm. For larger systems, it uses the cell list O(N) algorithm.
print("\n" + "=" * 70)
print("UNIFIED API - MATRIX FORMAT")
print("=" * 70)
# Call the unified API (returns matrix format by default)
neighbor_matrix, num_neighbors, shifts = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc
)
print("\nReturned neighbor matrix format:")
print(f" neighbor_matrix shape: {neighbor_matrix.shape}")
print(f" num_neighbors shape: {num_neighbors.shape}")
print(f" shifts shape: {shifts.shape}")
print("\nStatistics:")
print(f" Total neighbor pairs: {int(num_neighbors.sum())}")
print(f" Average neighbors per atom: {float(num_neighbors.mean()):.2f}")
print(f" Max neighbors for any atom: {int(num_neighbors.max())}")
print(f" Min neighbors for any atom: {int(num_neighbors.min())}")
# Show first few neighbors of atom 0
print("\nFirst 5 neighbors of atom 0:")
for i in range(min(5, int(num_neighbors[0]))):
neighbor_idx = int(neighbor_matrix[0, i])
shift = shifts[0, i].tolist()
print(f" Neighbor {i}: atom {neighbor_idx}, shift {shift}")
======================================================================
UNIFIED API - MATRIX FORMAT
======================================================================
Returned neighbor matrix format:
neighbor_matrix shape: (200, 112)
num_neighbors shape: (200,)
shifts shape: (200, 112, 3)
Statistics:
Total neighbor pairs: 6166
Average neighbors per atom: 30.83
Max neighbors for any atom: 43
Min neighbors for any atom: 13
First 5 neighbors of atom 0:
Neighbor 0: atom 145, shift [0, 0, 0]
Neighbor 1: atom 148, shift [0, 0, 0]
Neighbor 2: atom 151, shift [0, 0, 0]
Neighbor 3: atom 155, shift [0, 0, 0]
Neighbor 4: atom 156, shift [0, 0, 0]
Unified API - COO Format#
The COO (coordinate) format is often preferred for graph neural networks.
Set return_neighbor_list=True to get this format.
print("\n" + "=" * 70)
print("UNIFIED API - COO FORMAT")
print("=" * 70)
# Get neighbor list in COO format
neighbor_list_coo, neighbor_ptr, shifts_coo = neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True
)
print("\nReturned COO format:")
print(f" neighbor_list shape: {neighbor_list_coo.shape} (2 x num_pairs)")
print(f" neighbor_ptr shape: {neighbor_ptr.shape} (CSR pointers)")
print(f" shifts shape: {shifts_coo.shape}")
source_atoms = neighbor_list_coo[0]
target_atoms = neighbor_list_coo[1]
print("\nStatistics:")
print(f" Total pairs: {neighbor_list_coo.shape[1]}")
print(f" Source atoms range: [{int(source_atoms.min())}, {int(source_atoms.max())}]")
print(f" Target atoms range: [{int(target_atoms.min())}, {int(target_atoms.max())}]")
# Show first few pairs
print("\nFirst 5 neighbor pairs:")
for i in range(min(5, neighbor_list_coo.shape[1])):
src = int(source_atoms[i])
tgt = int(target_atoms[i])
shift = shifts_coo[i].tolist()
print(f" Pair {i}: atom {src} -> atom {tgt}, shift {shift}")
======================================================================
UNIFIED API - COO FORMAT
======================================================================
Returned COO format:
neighbor_list shape: (2, 6166) (2 x num_pairs)
neighbor_ptr shape: (201,) (CSR pointers)
shifts shape: (6166, 3)
Statistics:
Total pairs: 6166
Source atoms range: [0, 199]
Target atoms range: [0, 199]
First 5 neighbor pairs:
Pair 0: atom 0 -> atom 145, shift [0, 0, 0]
Pair 1: atom 0 -> atom 148, shift [0, 0, 0]
Pair 2: atom 0 -> atom 151, shift [0, 0, 0]
Pair 3: atom 0 -> atom 155, shift [0, 0, 0]
Pair 4: atom 0 -> atom 156, shift [0, 0, 0]
Algorithm Comparison#
The nvalchemiops library provides two main algorithms:
- naive_neighbor_list: O(N²) - best for small systems
- cell_list: O(N) - best for large systems
Both should produce identical results.
print("\n" + "=" * 70)
print("ALGORITHM COMPARISON")
print("=" * 70)
# Direct call to naive algorithm
nm_naive, num_naive, shifts_naive = naive_neighbor_list(
positions, cutoff, cell=cell, pbc=pbc
)
# Direct call to cell list algorithm
nm_cell, num_cell, shifts_cell = cell_list(positions, cutoff, cell=cell, pbc=pbc)
print("\nNaive algorithm (O(N²)):")
print(f" Total pairs: {int(num_naive.sum())}")
print(f" Average neighbors: {float(num_naive.mean()):.2f}")
print("\nCell list algorithm (O(N)):")
print(f" Total pairs: {int(num_cell.sum())}")
print(f" Average neighbors: {float(num_cell.mean()):.2f}")
# Verify they find the same number of pairs per atom
pairs_match = jnp.allclose(num_naive, num_cell)
print(f"\nResults match: {pairs_match}")
#
======================================================================
ALGORITHM COMPARISON
======================================================================
Naive algorithm (O(N²)):
Total pairs: 6166
Average neighbors: 30.83
Cell list algorithm (O(N)):
Total pairs: 6166
Average neighbors: 30.83
Results match: True
Distance Validation#
Let’s verify that all neighbor pairs are actually within the cutoff distance.
print("\n" + "=" * 70)
print("DISTANCE VALIDATION")
print("=" * 70)
# Get neighbor list in COO format for easy distance computation
nlist, nptr, nshifts = naive_neighbor_list(
positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True
)
if nlist.shape[1] > 0:
# Extract source and target positions
src_idx = nlist[0]
tgt_idx = nlist[1]
pos_src = positions[src_idx]
pos_tgt = positions[tgt_idx]
# Compute Cartesian shift from lattice shift
# shifts are in lattice coordinates, multiply by cell vectors
cell_squeezed = cell.squeeze(0) # (3, 3)
cartesian_shifts = jnp.einsum(
"ij,jk->ik", nshifts.astype(jnp.float32), cell_squeezed
)
# Compute distances: r_j - r_i + shift
diff = pos_tgt - pos_src + cartesian_shifts
distances = jnp.linalg.norm(diff, axis=1)
print(f"\nComputed distances for {len(distances)} neighbor pairs:")
print(f" Min distance: {float(distances.min()):.4f} Å")
print(f" Max distance: {float(distances.max()):.4f} Å")
print(f" Mean distance: {float(distances.mean()):.4f} Å")
print(f" Cutoff: {cutoff} Å")
# Check if all distances are within cutoff (with small tolerance)
within_cutoff = jnp.all(distances <= cutoff + 1e-5)
print(f"\n All distances within cutoff: {within_cutoff}")
# Show distribution of first 10 distances
print("\nFirst 10 neighbor distances:")
for i in range(min(10, len(distances))):
src = int(src_idx[i])
tgt = int(tgt_idx[i])
dist = float(distances[i])
print(f" Atom {src} -> {tgt}: {dist:.4f} Å")
else:
print("\nNo neighbor pairs found (empty system or cutoff too small)")
======================================================================
DISTANCE VALIDATION
======================================================================
Computed distances for 6166 neighbor pairs:
Min distance: 0.4676 Å
Max distance: 4.9992 Å
Mean distance: 3.7425 Å
Cutoff: 5.0 Å
All distances within cutoff: True
First 10 neighbor distances:
Atom 0 -> 145: 2.7335 Å
Atom 0 -> 148: 3.8481 Å
Atom 0 -> 151: 3.7375 Å
Atom 0 -> 155: 4.6850 Å
Atom 0 -> 156: 4.9675 Å
Atom 0 -> 105: 3.8244 Å
Atom 0 -> 127: 4.6600 Å
Atom 0 -> 163: 4.1604 Å
Atom 0 -> 167: 3.6606 Å
Atom 0 -> 173: 2.9997 Å
JIT compilation#
Demonstrate usage of jax.jit to include neighborhood computation
print("\n" + "=" * 70)
print("JIT compilation example")
print("=" * 70)
@jax.jit
def run_compute_loop(
positions,
cell,
pbc,
max_neighbors: int = 128,
max_total_cells: int = 16,
cutoff: float = 6.0,
max_num_atoms: int = 200,
) -> jax.Array:
"""Example of encapsulating a compute loop"""
num_loops = 100
all_neighbors = jnp.zeros(
(num_loops, max_num_atoms, max_neighbors), dtype=positions.dtype
)
# generate some random positions
key = jax.random.PRNGKey(64)
for i in range(num_loops):
new_positions = (
jax.random.normal(key, (max_num_atoms, 3), dtype=positions.dtype)
+ positions
)
# for JIT compilation, max_neighbors and total cells **must** be specified to
# accommoate for static array shapes
neighbor_matrix, neighbor_ptr, neighbor_matrix_shifts = cell_list(
new_positions,
cutoff,
cell * 1.5,
pbc,
max_neighbors=max_neighbors,
max_total_cells=max_total_cells,
)
# in this example we don't do any additional computation
# other than neighborhoods; include your computation logic
# within this scope
all_neighbors = all_neighbors.at[i].set(neighbor_matrix)
return all_neighbors
# run the compute loop N times
num_loops = 100
print(f"\nRun neighbor computation loop {num_loops} times.")
all_neighbors = run_compute_loop(positions, cell, pbc)
print(f"Returned neighbor matrix shape: {all_neighbors.shape}")
======================================================================
JIT compilation example
======================================================================
Run neighbor computation loop 100 times.
Returned neighbor matrix shape: (100, 200, 128)
Summary#
This example demonstrated the JAX neighbor list API in nvalchemiops:
Unified API:
neighbor_list()automatically selects the best algorithmMatrix format: Dense (N, max_neighbors) format for neighbor indices
COO format: Sparse (2, num_pairs) format for graph neural networks
Algorithm choice: O(N²) naive vs O(N) cell list for different system sizes
Half-fill mode: Store only unique pairs to save memory
Distance validation: Verify all pairs are within cutoff
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print("\nKey takeaways:")
print(" - Use neighbor_list() for automatic algorithm selection")
print(" - Use return_neighbor_list=True for COO format (GNNs)")
print(" - Use half_fill=True to store only unique pairs")
print(" - naive_neighbor_list: O(N²), best for < 5000 atoms")
print(" - cell_list: O(N), best for >= 5000 atoms")
print("\nExample completed successfully!")
======================================================================
SUMMARY
======================================================================
Key takeaways:
- Use neighbor_list() for automatic algorithm selection
- Use return_neighbor_list=True for COO format (GNNs)
- Use half_fill=True to store only unique pairs
- naive_neighbor_list: O(N²), best for < 5000 atoms
- cell_list: O(N), best for >= 5000 atoms
Example completed successfully!
Total running time of the script: (0 minutes 8.285 seconds)