.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/electrostatics/04_jax_pme_example.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_electrostatics_04_jax_pme_example.py: Particle Mesh Ewald (PME) with JAX ================================== This example demonstrates how to compute long-range electrostatic interactions using the Particle Mesh Ewald (PME) method with the JAX backend. PME achieves O(N log N) scaling through FFT-based mesh interpolation. In this example you will learn: - How to set up and run PME with automatic parameter estimation in JAX - Using neighbor list (COO) and neighbor matrix formats - Accessing real-space and reciprocal-space components separately - Computing charge gradients for ML potential training - ``jax.jit`` compilation of the full neighbor list + PME pipeline .. important:: This script is intended as an API demonstration. Do not use this script for performance benchmarking; refer to the `benchmarks` folder instead. .. GENERATED FROM PYTHON SOURCE LINES 38-41 Setup and Imports ----------------- Import JAX and the nvalchemiops electrostatics API. .. GENERATED FROM PYTHON SOURCE LINES 41-74 .. code-block:: Python from __future__ import annotations import sys import time try: import jax import jax.numpy as jnp except ImportError: print( "This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'" ) sys.exit(0) import numpy as np try: from nvalchemiops.jax.interactions.electrostatics import ( estimate_pme_parameters, ewald_real_space, particle_mesh_ewald, pme_reciprocal_space, ) from nvalchemiops.jax.neighbors import neighbor_list from nvalchemiops.jax.neighbors.naive import naive_neighbor_list from nvalchemiops.jax.neighbors.neighbor_utils import compute_naive_num_shifts except Exception as exc: print( f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime." ) sys.exit(0) .. GENERATED FROM PYTHON SOURCE LINES 75-77 Check Device ------------ .. GENERATED FROM PYTHON SOURCE LINES 77-86 .. code-block:: Python print("=" * 70) print("JAX PME ELECTROSTATICS EXAMPLE") print("=" * 70) devices = jax.devices() print(f"\nJAX devices: {devices}") print(f"Default device: {jax.default_backend()}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== JAX PME ELECTROSTATICS EXAMPLE ====================================================================== JAX devices: [CudaDevice(id=0)] Default device: gpu .. GENERATED FROM PYTHON SOURCE LINES 87-90 Create a NaCl Crystal System ---------------------------- We define a helper function to create NaCl rock salt crystal supercells. .. GENERATED FROM PYTHON SOURCE LINES 90-130 .. code-block:: Python def create_nacl_system(n_cells: int = 3, lattice_constant: float = 5.64): """Create a NaCl crystal supercell. Parameters ---------- n_cells : int Number of unit cells in each direction. lattice_constant : float NaCl lattice constant in Angstroms. Returns ------- positions, charges, cell, pbc : jax.Array System arrays with float64 dtype for electrostatics. """ base_positions = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) base_charges = np.array([1.0, -1.0]) positions_list = [] charges_list = [] for i in range(n_cells): for j in range(n_cells): for k in range(n_cells): offset = np.array([i, j, k]) for pos, charge in zip(base_positions, base_charges): positions_list.append((pos + offset) * lattice_constant) charges_list.append(charge) # Convert to JAX arrays with float64 for electrostatics accuracy positions = jnp.array(positions_list, dtype=jnp.float64) charges = jnp.array(charges_list, dtype=jnp.float64) cell = jnp.eye(3, dtype=jnp.float64) * lattice_constant * n_cells cell = cell[None, ...] # Add batch dimension: (1, 3, 3) pbc = jnp.array([[True, True, True]]) return positions, charges, cell, pbc .. GENERATED FROM PYTHON SOURCE LINES 131-134 Basic PME with Automatic Parameters ----------------------------------- The simplest way to use PME is with automatic parameter estimation. .. GENERATED FROM PYTHON SOURCE LINES 134-146 .. code-block:: Python print("\n" + "=" * 70) print("BASIC PME WITH AUTOMATIC PARAMETERS") print("=" * 70) # Create a NaCl crystal (3×3×3 unit cells = 54 atoms) positions, charges, cell, pbc = create_nacl_system(n_cells=3) print(f"\nSystem: {len(positions)} atoms NaCl crystal") print(f"Cell size: {float(cell[0, 0, 0]):.2f} Å") print(f"Total charge: {float(charges.sum()):.1f} (should be 0 for neutral)") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== BASIC PME WITH AUTOMATIC PARAMETERS ====================================================================== System: 54 atoms NaCl crystal Cell size: 16.92 Å Total charge: 0.0 (should be 0 for neutral) .. GENERATED FROM PYTHON SOURCE LINES 147-148 Estimate optimal PME parameters: .. GENERATED FROM PYTHON SOURCE LINES 148-160 .. code-block:: Python params = estimate_pme_parameters(positions, cell, accuracy=1e-6) print("\nEstimated parameters (accuracy=1e-6):") print(f" alpha = {float(params.alpha[0]):.4f}") print(f" mesh_dimensions = {params.mesh_dimensions}") print( f" mesh_spacing = ({float(params.mesh_spacing[0, 0]):.2f}, " f"{float(params.mesh_spacing[0, 1]):.2f}, {float(params.mesh_spacing[0, 2]):.2f}) Å" ) print(f" real_space_cutoff = {float(params.real_space_cutoff[0]):.2f} Å") .. rst-class:: sphx-glr-script-out .. code-block:: none Estimated parameters (accuracy=1e-6): alpha = 0.2037 mesh_dimensions = (64, 64, 64) mesh_spacing = (0.26, 0.26, 0.26) Å real_space_cutoff = 18.25 Å .. GENERATED FROM PYTHON SOURCE LINES 161-162 Build neighbor list and run PME: .. GENERATED FROM PYTHON SOURCE LINES 162-191 .. code-block:: Python cutoff = float(params.real_space_cutoff[0]) nl, nptr, ns = neighbor_list( positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True, ) energies, forces = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, neighbor_list=nl, neighbor_ptr=nptr, neighbor_shifts=ns, compute_forces=True, accuracy=1e-6, ) total_energy = float(energies.sum()) max_force = float(jnp.linalg.norm(forces, axis=1).max()) print("\nPME Results:") print(f" Total energy: {total_energy:.6f}") print(f" Energy per atom: {total_energy / len(positions):.6f}") print(f" Max force magnitude: {max_force:.6f}") .. rst-class:: sphx-glr-script-out .. code-block:: none PME Results: Total energy: -9.743761 Energy per atom: -0.180440 Max force magnitude: 0.000000 .. GENERATED FROM PYTHON SOURCE LINES 192-195 Neighbor Matrix vs COO Format Comparison ---------------------------------------- PME supports both neighbor formats, producing identical results. .. GENERATED FROM PYTHON SOURCE LINES 195-221 .. code-block:: Python print("\n" + "=" * 70) print("NEIGHBOR FORMAT COMPARISON") print("=" * 70) # Build both formats using the estimated real-space cutoff # COO format (neighbor list) nl_coo, nptr_coo, ns_coo = neighbor_list( positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=True, ) # Dense format (neighbor matrix) nm_dense, num_dense, ns_dense = neighbor_list( positions, cutoff, cell=cell, pbc=pbc, return_neighbor_list=False, ) print(f"\nUsing alpha={float(params.alpha[0]):.4f}, mesh_dims={params.mesh_dimensions}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== NEIGHBOR FORMAT COMPARISON ====================================================================== Using alpha=0.2037, mesh_dims=(64, 64, 64) .. GENERATED FROM PYTHON SOURCE LINES 222-223 Using neighbor list (COO) format: .. GENERATED FROM PYTHON SOURCE LINES 223-237 .. code-block:: Python energies_coo, forces_coo = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, neighbor_list=nl_coo, neighbor_ptr=nptr_coo, neighbor_shifts=ns_coo, compute_forces=True, accuracy=1e-6, ) print(f" COO format: E={float(energies_coo.sum()):.6f}") .. rst-class:: sphx-glr-script-out .. code-block:: none COO format: E=-9.743761 .. GENERATED FROM PYTHON SOURCE LINES 238-239 Using neighbor matrix (dense) format: .. GENERATED FROM PYTHON SOURCE LINES 239-259 .. code-block:: Python energies_dense, forces_dense = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, neighbor_matrix=nm_dense, neighbor_matrix_shifts=ns_dense, compute_forces=True, accuracy=1e-6, ) print(f" Dense format: E={float(energies_dense.sum()):.6f}") # Compare results energy_diff = abs(float(energies_coo.sum()) - float(energies_dense.sum())) force_diff = float(jnp.abs(forces_coo - forces_dense).max()) print(f"\nEnergy difference: {energy_diff:.2e}") print(f"Max force difference: {force_diff:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Dense format: E=-9.743761 Energy difference: 0.00e+00 Max force difference: 2.64e-17 .. GENERATED FROM PYTHON SOURCE LINES 260-263 Real-Space and Reciprocal-Space Components ------------------------------------------ You can compute the components separately if needed. .. GENERATED FROM PYTHON SOURCE LINES 263-280 .. code-block:: Python print("\n" + "=" * 70) print("ENERGY COMPONENTS") print("=" * 70) # Use lower accuracy for this demo to speed up parameter estimation params_comp = estimate_pme_parameters(positions, cell, accuracy=1e-4) cutoff_comp = float(params_comp.real_space_cutoff[0]) nl_comp, nptr_comp, ns_comp = neighbor_list( positions, cutoff_comp, cell=cell, pbc=pbc, return_neighbor_list=True, ) .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== ENERGY COMPONENTS ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 281-282 Real-space component (uses same kernel as Ewald): .. GENERATED FROM PYTHON SOURCE LINES 282-295 .. code-block:: Python real_energy = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=params_comp.alpha, neighbor_list=nl_comp, neighbor_ptr=nptr_comp, neighbor_shifts=ns_comp, ) print(f"\n Real-space: {float(real_energy.sum()):.6f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Real-space: -3.549335 .. GENERATED FROM PYTHON SOURCE LINES 296-297 PME reciprocal-space component (FFT-based): .. GENERATED FROM PYTHON SOURCE LINES 297-309 .. code-block:: Python recip_energy = pme_reciprocal_space( positions=positions, charges=charges, cell=cell, alpha=params_comp.alpha, mesh_dimensions=params_comp.mesh_dimensions, ) print(f" Reciprocal-space (PME): {float(recip_energy.sum()):.6f}") print(f" Total (sum): {float(real_energy.sum() + recip_energy.sum()):.6f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Reciprocal-space (PME): -6.194456 Total (sum): -9.743791 .. GENERATED FROM PYTHON SOURCE LINES 310-311 Compare with full PME: .. GENERATED FROM PYTHON SOURCE LINES 311-329 .. code-block:: Python full_pme_energy = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, neighbor_list=nl_comp, neighbor_ptr=nptr_comp, neighbor_shifts=ns_comp, accuracy=1e-4, ) print(f" Full PME: {float(full_pme_energy.sum()):.6f}") component_diff = abs( float(real_energy.sum() + recip_energy.sum()) - float(full_pme_energy.sum()) ) print(f"\n Component sum vs full PME difference: {component_diff:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Full PME: -9.743791 Component sum vs full PME difference: 0.00e+00 .. GENERATED FROM PYTHON SOURCE LINES 330-334 Charge Gradients for ML Potentials ---------------------------------- PME supports computing analytical charge gradients (∂E/∂q_i), which are useful for training machine learning potentials that predict atomic partial charges. .. GENERATED FROM PYTHON SOURCE LINES 334-363 .. code-block:: Python print("\n" + "=" * 70) print("CHARGE GRADIENTS") print("=" * 70) # Compute PME with charge gradients energies_cg, forces_cg, charge_grads = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, neighbor_list=nl_comp, neighbor_ptr=nptr_comp, neighbor_shifts=ns_comp, compute_forces=True, compute_charge_gradients=True, accuracy=1e-4, ) print(f"\n Charge gradients shape: {charge_grads.shape}") print( f" Charge gradients range: [{float(charge_grads.min()):.4f}, " f"{float(charge_grads.max()):.4f}]" ) print(f" Charge gradients mean: {float(charge_grads.mean()):.4f}") # The charge gradient represents dE/dq for each atom # For neutral systems, the sum should be close to zero due to symmetry print(f" Sum of charge gradients: {float(charge_grads.sum()):.4e}") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== CHARGE GRADIENTS ====================================================================== Charge gradients shape: (54,) Charge gradients range: [-0.3609, 0.3609] Charge gradients mean: 0.0000 Sum of charge gradients: 3.3307e-16 .. GENERATED FROM PYTHON SOURCE LINES 364-365 Verify by checking gradient symmetry for Na+ and Cl- ions: .. GENERATED FROM PYTHON SOURCE LINES 365-372 .. code-block:: Python na_grads = charge_grads[charges > 0] # Na+ ions cl_grads = charge_grads[charges < 0] # Cl- ions print(f"\n Na+ charge gradients mean: {float(na_grads.mean()):.4f}") print(f" Cl- charge gradients mean: {float(cl_grads.mean()):.4f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Na+ charge gradients mean: -0.3609 Cl- charge gradients mean: 0.3609 .. GENERATED FROM PYTHON SOURCE LINES 373-390 JIT Compilation --------------- Demonstrate combining the neighbor list build and PME calculation into a single ``jax.jit``-compiled function. This allows JAX to fuse the entire pipeline into one optimized computation. For JIT compatibility: - ``max_neighbors`` must be specified (static array shapes) - ``mesh_dimensions`` must be a concrete tuple (static FFT sizes) - ``alpha`` can be a traced JAX array - ``compute_forces`` and other boolean flags must be static - Parameter estimation (``estimate_pme_parameters``) should happen **outside** the jitted function since it determines array shapes - Periodic shift metadata (``shift_range``, ``num_shifts_per_system``, ``max_shifts_per_system``) must be pre-computed outside jit using ``compute_naive_num_shifts``, since the launch dimensions must be concrete .. GENERATED FROM PYTHON SOURCE LINES 390-454 .. code-block:: Python print("\n" + "=" * 70) print("JIT COMPILATION") print("=" * 70) # First, estimate parameters outside jit (determines static shapes) jit_positions, jit_charges, jit_cell, jit_pbc = create_nacl_system(n_cells=3) jit_params = estimate_pme_parameters(jit_positions, jit_cell, accuracy=1e-5) jit_cutoff = float(jit_params.real_space_cutoff[0]) jit_mesh_dims = tuple(int(x) for x in jit_params.mesh_dimensions) jit_alpha = jit_params.alpha # Pre-compute shift metadata outside jit (launch sizes must be concrete) shift_range, num_shifts_per_system, max_shifts_per_system = compute_naive_num_shifts( jit_cell, jit_cutoff, jit_pbc ) # Define a function that builds neighbors and computes PME # We will compare the performance of the jitted and non-jitted versions. def compute_pme_energy_forces( positions: jax.Array, charges: jax.Array, cell: jax.Array, pbc: jax.Array, alpha: jax.Array, shift_range: jax.Array = shift_range, num_shifts_per_system: jax.Array = num_shifts_per_system, cutoff: float = jit_cutoff, max_neighbors: int = 128, max_shifts_per_system: int = max_shifts_per_system, mesh_dimensions: tuple[int, int, int] = jit_mesh_dims, ) -> tuple[jax.Array, jax.Array]: """JIT-compiled neighbor list + PME pipeline.""" # Build neighbor matrix inside jit (max_neighbors must be static, # shift metadata pre-computed outside jit) neighbor_matrix, _, neighbor_matrix_shifts = naive_neighbor_list( positions, cutoff, cell=cell, pbc=pbc, max_neighbors=max_neighbors, shift_range_per_dimension=shift_range, num_shifts_per_system=num_shifts_per_system, max_shifts_per_system=max_shifts_per_system, ) # Compute PME (mesh_dimensions is static, alpha is traced) energies, forces = particle_mesh_ewald( positions=positions, charges=charges, cell=cell, alpha=alpha, mesh_dimensions=mesh_dimensions, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, compute_forces=True, ) return energies, forces jit_compute_pme_energy_forces = jax.jit(compute_pme_energy_forces) .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== JIT COMPILATION ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 455-456 Run the non-jitted function: .. GENERATED FROM PYTHON SOURCE LINES 456-484 .. code-block:: Python energies, forces = compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) total_energy = float(energies.sum()) max_force = float(jnp.linalg.norm(forces, axis=1).max()) print(f" Non-jitted total energy: {total_energy:.6f}") print(f" Non-jitted max force: {max_force:.6f}") # Calculate Performance # Warmup measurements for _ in range(10): energies, forces = compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) energies.block_until_ready() forces.block_until_ready() # Timed measurements start_time = time.time() for _ in range(50): energies, forces = compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) energies.block_until_ready() forces.block_until_ready() total_time = time.time() - start_time print(f" Non-jitted average time per call: {total_time / 50:.6f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Non-jitted total energy: -8.492044 Non-jitted max force: 0.059720 Non-jitted average time per call: 0.142695 seconds .. GENERATED FROM PYTHON SOURCE LINES 485-486 Run the jitted function: .. GENERATED FROM PYTHON SOURCE LINES 486-524 .. code-block:: Python print("\nCompiling and running jitted PME pipeline...") jit_energies, jit_forces = jit_compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) jit_total_energy = float(jit_energies.sum()) jit_max_force = float(jnp.linalg.norm(jit_forces, axis=1).max()) print(f" JIT total energy: {jit_total_energy:.6f}") print(f" JIT max force: {jit_max_force:.6f}") # Calculate Performance # Warmup measurements for _ in range(10): jit_energies, jit_forces = jit_compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) jit_energies.block_until_ready() jit_forces.block_until_ready() # Timed measurements start_time = time.time() for _ in range(50): jit_energies, jit_forces = jit_compute_pme_energy_forces( jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha ) jit_energies.block_until_ready() jit_forces.block_until_ready() total_time = time.time() - start_time print(f" JIT average time per call: {total_time / 50:.6f} seconds") # Compare with non-jitted result (note: may differ slightly due to different # accuracy settings or neighbor list truncation from max_neighbors) energy_diff_jit = abs(jit_total_energy - total_energy) print(f" Difference vs non-jitted (different accuracy): {energy_diff_jit:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none Compiling and running jitted PME pipeline... JIT total energy: -8.585994 JIT max force: 0.059910 JIT average time per call: 0.000817 seconds Difference vs non-jitted (different accuracy): 9.40e-02 .. GENERATED FROM PYTHON SOURCE LINES 525-544 Summary ------- This example demonstrated: 1. **Automatic parameter estimation** for alpha and mesh dimensions using ``estimate_pme_parameters`` with target accuracy 2. **Neighbor format flexibility** with COO (list) and dense (matrix) formats 3. **Component access** for real-space and reciprocal-space separately 4. **Charge gradients** (∂E/∂q_i) for ML potential training 5. **JIT compilation** of the full neighbor list + PME pipeline Key JAX-specific patterns: - Use ``jnp.float64`` for electrostatics calculations - Cell shape is ``(1, 3, 3)`` with batch dimension - Use ``float()`` to extract scalar values from JAX arrays for printing - Parameters from ``estimate_pme_parameters`` are JAX arrays - For ``jax.jit``: estimate parameters outside, pass ``max_neighbors`` and ``mesh_dimensions`` as static values .. GENERATED FROM PYTHON SOURCE LINES 544-555 .. code-block:: Python print("\n" + "=" * 70) print("SUMMARY") print("=" * 70) print("\nKey takeaways:") print(" - Use estimate_pme_parameters() for automatic parameter selection") print(" - Both COO and dense neighbor formats produce identical results") print(" - Real and reciprocal components can be computed separately") print(" - Charge gradients are available for ML potential training") print(" - Use jax.jit to fuse neighbor list + PME into one compiled function") print("\nJAX PME example completed successfully!") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== SUMMARY ====================================================================== Key takeaways: - Use estimate_pme_parameters() for automatic parameter selection - Both COO and dense neighbor formats produce identical results - Real and reciprocal components can be computed separately - Charge gradients are available for ML potential training - Use jax.jit to fuse neighbor list + PME into one compiled function JAX PME example completed successfully! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 20.259 seconds) .. _sphx_glr_download_examples_electrostatics_04_jax_pme_example.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 04_jax_pme_example.ipynb <04_jax_pme_example.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 04_jax_pme_example.py <04_jax_pme_example.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 04_jax_pme_example.zip <04_jax_pme_example.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_