.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/neighbors/05_jax_neighbor_list.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_neighbors_05_jax_neighbor_list.py: JAX Neighbor List Example ========================= This example demonstrates how to use the JAX neighbor list API in nvalchemiops for computing neighbor lists in periodic systems. In this example you will learn: - How to use the unified ``neighbor_list()`` API with JAX arrays - Matrix format vs COO (list) format outputs - Comparing ``naive_neighbor_list`` and ``cell_list`` algorithms - Using ``half_fill`` mode for symmetric neighbor lists - Validating neighbor distances are within cutoff - ``jax.jit`` compilation of the neighbor matrix .. important:: This example is for educational purposes. Do not use it for performance benchmarking, as the code includes print statements and small system sizes that are not representative of production workloads. .. GENERATED FROM PYTHON SOURCE LINES 38-60 .. code-block:: Python import sys try: import jax import jax.numpy as jnp except ImportError: print( "This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'" ) sys.exit(0) try: from nvalchemiops.jax.neighbors import neighbor_list except Exception as exc: print( f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime." ) sys.exit(0) from nvalchemiops.jax.neighbors.cell_list import cell_list from nvalchemiops.jax.neighbors.naive import naive_neighbor_list .. GENERATED FROM PYTHON SOURCE LINES 61-65 Setup ===== JAX handles device placement automatically. We'll create a random periodic system to demonstrate the neighbor list API. .. GENERATED FROM PYTHON SOURCE LINES 65-93 .. code-block:: Python print("=" * 70) print("JAX NEIGHBOR LIST EXAMPLE") print("=" * 70) # System parameters num_atoms = 200 box_size = 15.0 cutoff = 5.0 # Create random atomic positions using JAX random key = jax.random.PRNGKey(42) positions = jax.random.uniform(key, (num_atoms, 3), dtype=jnp.float32) * box_size # Create a cubic periodic cell: (1, 3, 3) shape cell = jnp.eye(3, dtype=jnp.float32)[None, ...] * box_size # Enable periodic boundary conditions in all directions: (1, 3) shape pbc = jnp.array([[True, True, True]]) print("\nSystem configuration:") print(f" Number of atoms: {num_atoms}") print(f" Box size: {box_size} Å") print(f" Cutoff distance: {cutoff} Å") print(f" Positions shape: {positions.shape}") print(f" Cell shape: {cell.shape}") print(f" PBC shape: {pbc.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== JAX NEIGHBOR LIST EXAMPLE ====================================================================== System configuration: Number of atoms: 200 Box size: 15.0 Å Cutoff distance: 5.0 Å Positions shape: (200, 3) Cell shape: (1, 3, 3) PBC shape: (1, 3) .. GENERATED FROM PYTHON SOURCE LINES 94-99 Unified API - Matrix Format (default) ===================================== The ``neighbor_list()`` function automatically selects the best algorithm based on system size. For small systems (< 5000 atoms), it uses the naive O(N²) algorithm. For larger systems, it uses the cell list O(N) algorithm. .. GENERATED FROM PYTHON SOURCE LINES 99-126 .. code-block:: Python print("\n" + "=" * 70) print("UNIFIED API - MATRIX FORMAT") print("=" * 70) # Call the unified API (returns matrix format by default) neighbor_matrix, num_neighbors, shifts = neighbor_list( positions, cutoff, cell=cell, pbc=pbc ) print("\nReturned neighbor matrix format:") print(f" neighbor_matrix shape: {neighbor_matrix.shape}") print(f" num_neighbors shape: {num_neighbors.shape}") print(f" shifts shape: {shifts.shape}") print("\nStatistics:") print(f" Total neighbor pairs: {int(num_neighbors.sum())}") print(f" Average neighbors per atom: {float(num_neighbors.mean()):.2f}") print(f" Max neighbors for any atom: {int(num_neighbors.max())}") print(f" Min neighbors for any atom: {int(num_neighbors.min())}") # Show first few neighbors of atom 0 print("\nFirst 5 neighbors of atom 0:") for i in range(min(5, int(num_neighbors[0]))): neighbor_idx = int(neighbor_matrix[0, i]) shift = shifts[0, i].tolist() print(f" Neighbor {i}: atom {neighbor_idx}, shift {shift}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== UNIFIED API - MATRIX FORMAT ====================================================================== Returned neighbor matrix format: neighbor_matrix shape: (200, 112) num_neighbors shape: (200,) shifts shape: (200, 112, 3) Statistics: Total neighbor pairs: 6166 Average neighbors per atom: 30.83 Max neighbors for any atom: 43 Min neighbors for any atom: 13 First 5 neighbors of atom 0: Neighbor 0: atom 145, shift [0, 0, 0] Neighbor 1: atom 148, shift [0, 0, 0] Neighbor 2: atom 151, shift [0, 0, 0] Neighbor 3: atom 155, shift [0, 0, 0] Neighbor 4: atom 156, shift [0, 0, 0] .. GENERATED FROM PYTHON SOURCE LINES 127-131 Unified API - COO Format ======================== The COO (coordinate) format is often preferred for graph neural networks. Set ``return_neighbor_list=True`` to get this format. .. GENERATED FROM PYTHON SOURCE LINES 131-162 .. code-block:: Python print("\n" + "=" * 70) print("UNIFIED API - COO FORMAT") print("=" * 70) # Get neighbor list in COO format neighbor_list_coo, neighbor_ptr, shifts_coo = neighbor_list( positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True ) print("\nReturned COO format:") print(f" neighbor_list shape: {neighbor_list_coo.shape} (2 x num_pairs)") print(f" neighbor_ptr shape: {neighbor_ptr.shape} (CSR pointers)") print(f" shifts shape: {shifts_coo.shape}") source_atoms = neighbor_list_coo[0] target_atoms = neighbor_list_coo[1] print("\nStatistics:") print(f" Total pairs: {neighbor_list_coo.shape[1]}") print(f" Source atoms range: [{int(source_atoms.min())}, {int(source_atoms.max())}]") print(f" Target atoms range: [{int(target_atoms.min())}, {int(target_atoms.max())}]") # Show first few pairs print("\nFirst 5 neighbor pairs:") for i in range(min(5, neighbor_list_coo.shape[1])): src = int(source_atoms[i]) tgt = int(target_atoms[i]) shift = shifts_coo[i].tolist() print(f" Pair {i}: atom {src} -> atom {tgt}, shift {shift}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== UNIFIED API - COO FORMAT ====================================================================== Returned COO format: neighbor_list shape: (2, 6166) (2 x num_pairs) neighbor_ptr shape: (201,) (CSR pointers) shifts shape: (6166, 3) Statistics: Total pairs: 6166 Source atoms range: [0, 199] Target atoms range: [0, 199] First 5 neighbor pairs: Pair 0: atom 0 -> atom 145, shift [0, 0, 0] Pair 1: atom 0 -> atom 148, shift [0, 0, 0] Pair 2: atom 0 -> atom 151, shift [0, 0, 0] Pair 3: atom 0 -> atom 155, shift [0, 0, 0] Pair 4: atom 0 -> atom 156, shift [0, 0, 0] .. GENERATED FROM PYTHON SOURCE LINES 163-170 Algorithm Comparison ==================== The nvalchemiops library provides two main algorithms: - ``naive_neighbor_list``: O(N²) - best for small systems - ``cell_list``: O(N) - best for large systems Both should produce identical results. .. GENERATED FROM PYTHON SOURCE LINES 170-195 .. code-block:: Python print("\n" + "=" * 70) print("ALGORITHM COMPARISON") print("=" * 70) # Direct call to naive algorithm nm_naive, num_naive, shifts_naive = naive_neighbor_list( positions, cutoff, cell=cell, pbc=pbc ) # Direct call to cell list algorithm nm_cell, num_cell, shifts_cell = cell_list(positions, cutoff, cell=cell, pbc=pbc) print("\nNaive algorithm (O(N²)):") print(f" Total pairs: {int(num_naive.sum())}") print(f" Average neighbors: {float(num_naive.mean()):.2f}") print("\nCell list algorithm (O(N)):") print(f" Total pairs: {int(num_cell.sum())}") print(f" Average neighbors: {float(num_cell.mean()):.2f}") # Verify they find the same number of pairs per atom pairs_match = jnp.allclose(num_naive, num_cell) print(f"\nResults match: {pairs_match}") # .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== ALGORITHM COMPARISON ====================================================================== Naive algorithm (O(N²)): Total pairs: 6166 Average neighbors: 30.83 Cell list algorithm (O(N)): Total pairs: 6166 Average neighbors: 30.83 Results match: True .. GENERATED FROM PYTHON SOURCE LINES 196-199 Distance Validation =================== Let's verify that all neighbor pairs are actually within the cutoff distance. .. GENERATED FROM PYTHON SOURCE LINES 199-249 .. code-block:: Python print("\n" + "=" * 70) print("DISTANCE VALIDATION") print("=" * 70) # Get neighbor list in COO format for easy distance computation nlist, nptr, nshifts = naive_neighbor_list( positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True ) if nlist.shape[1] > 0: # Extract source and target positions src_idx = nlist[0] tgt_idx = nlist[1] pos_src = positions[src_idx] pos_tgt = positions[tgt_idx] # Compute Cartesian shift from lattice shift # shifts are in lattice coordinates, multiply by cell vectors cell_squeezed = cell.squeeze(0) # (3, 3) cartesian_shifts = jnp.einsum( "ij,jk->ik", nshifts.astype(jnp.float32), cell_squeezed ) # Compute distances: r_j - r_i + shift diff = pos_tgt - pos_src + cartesian_shifts distances = jnp.linalg.norm(diff, axis=1) print(f"\nComputed distances for {len(distances)} neighbor pairs:") print(f" Min distance: {float(distances.min()):.4f} Å") print(f" Max distance: {float(distances.max()):.4f} Å") print(f" Mean distance: {float(distances.mean()):.4f} Å") print(f" Cutoff: {cutoff} Å") # Check if all distances are within cutoff (with small tolerance) within_cutoff = jnp.all(distances <= cutoff + 1e-5) print(f"\n All distances within cutoff: {within_cutoff}") # Show distribution of first 10 distances print("\nFirst 10 neighbor distances:") for i in range(min(10, len(distances))): src = int(src_idx[i]) tgt = int(tgt_idx[i]) dist = float(distances[i]) print(f" Atom {src} -> {tgt}: {dist:.4f} Å") else: print("\nNo neighbor pairs found (empty system or cutoff too small)") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== DISTANCE VALIDATION ====================================================================== Computed distances for 6166 neighbor pairs: Min distance: 0.4676 Å Max distance: 4.9992 Å Mean distance: 3.7425 Å Cutoff: 5.0 Å All distances within cutoff: True First 10 neighbor distances: Atom 0 -> 145: 2.7335 Å Atom 0 -> 148: 3.8481 Å Atom 0 -> 151: 3.7375 Å Atom 0 -> 155: 4.6850 Å Atom 0 -> 156: 4.9675 Å Atom 0 -> 105: 3.8244 Å Atom 0 -> 127: 4.6600 Å Atom 0 -> 163: 4.1604 Å Atom 0 -> 167: 3.6606 Å Atom 0 -> 173: 2.9997 Å .. GENERATED FROM PYTHON SOURCE LINES 250-253 JIT compilation =============== Demonstrate usage of `jax.jit` to include neighborhood computation .. GENERATED FROM PYTHON SOURCE LINES 253-305 .. code-block:: Python print("\n" + "=" * 70) print("JIT compilation example") print("=" * 70) @jax.jit def run_compute_loop( positions, cell, pbc, max_neighbors: int = 128, max_total_cells: int = 16, cutoff: float = 6.0, max_num_atoms: int = 200, ) -> jax.Array: """Example of encapsulating a compute loop""" num_loops = 100 all_neighbors = jnp.zeros( (num_loops, max_num_atoms, max_neighbors), dtype=positions.dtype ) # generate some random positions key = jax.random.PRNGKey(64) for i in range(num_loops): new_positions = ( jax.random.normal(key, (max_num_atoms, 3), dtype=positions.dtype) + positions ) # for JIT compilation, max_neighbors and total cells **must** be specified to # accommoate for static array shapes neighbor_matrix, neighbor_ptr, neighbor_matrix_shifts = cell_list( new_positions, cutoff, cell * 1.5, pbc, max_neighbors=max_neighbors, max_total_cells=max_total_cells, ) # in this example we don't do any additional computation # other than neighborhoods; include your computation logic # within this scope all_neighbors = all_neighbors.at[i].set(neighbor_matrix) return all_neighbors # run the compute loop N times num_loops = 100 print(f"\nRun neighbor computation loop {num_loops} times.") all_neighbors = run_compute_loop(positions, cell, pbc) print(f"Returned neighbor matrix shape: {all_neighbors.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== JIT compilation example ====================================================================== Run neighbor computation loop 100 times. Returned neighbor matrix shape: (100, 200, 128) .. GENERATED FROM PYTHON SOURCE LINES 306-316 Summary ======= This example demonstrated the JAX neighbor list API in nvalchemiops: - **Unified API**: ``neighbor_list()`` automatically selects the best algorithm - **Matrix format**: Dense (N, max_neighbors) format for neighbor indices - **COO format**: Sparse (2, num_pairs) format for graph neural networks - **Algorithm choice**: O(N²) naive vs O(N) cell list for different system sizes - **Half-fill mode**: Store only unique pairs to save memory - **Distance validation**: Verify all pairs are within cutoff .. GENERATED FROM PYTHON SOURCE LINES 316-327 .. code-block:: Python print("\n" + "=" * 70) print("SUMMARY") print("=" * 70) print("\nKey takeaways:") print(" - Use neighbor_list() for automatic algorithm selection") print(" - Use return_neighbor_list=True for COO format (GNNs)") print(" - Use half_fill=True to store only unique pairs") print(" - naive_neighbor_list: O(N²), best for < 5000 atoms") print(" - cell_list: O(N), best for >= 5000 atoms") print("\nExample completed successfully!") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== SUMMARY ====================================================================== Key takeaways: - Use neighbor_list() for automatic algorithm selection - Use return_neighbor_list=True for COO format (GNNs) - Use half_fill=True to store only unique pairs - naive_neighbor_list: O(N²), best for < 5000 atoms - cell_list: O(N), best for >= 5000 atoms Example completed successfully! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.285 seconds) .. _sphx_glr_download_examples_neighbors_05_jax_neighbor_list.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 05_jax_neighbor_list.ipynb <05_jax_neighbor_list.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 05_jax_neighbor_list.py <05_jax_neighbor_list.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 05_jax_neighbor_list.zip <05_jax_neighbor_list.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_