Note
Go to the end to download the full example code.
Particle Mesh Ewald (PME) with JAX#
This example demonstrates how to compute long-range electrostatic interactions using the Particle Mesh Ewald (PME) method with the JAX backend. PME achieves O(N log N) scaling through FFT-based mesh interpolation.
In this example you will learn:
How to set up and run PME with automatic parameter estimation in JAX
Using neighbor list (COO) and neighbor matrix formats
Accessing real-space and reciprocal-space components separately
Computing charge gradients for ML potential training
jax.jitcompilation of the full neighbor list + PME pipeline
Important
This script is intended as an API demonstration. Do not use this script for performance benchmarking; refer to the benchmarks folder instead.
Setup and Imports#
Import JAX and the nvalchemiops electrostatics API.
from __future__ import annotations
import sys
import time
try:
import jax
import jax.numpy as jnp
except ImportError:
print(
"This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'"
)
sys.exit(0)
import numpy as np
try:
from nvalchemiops.jax.interactions.electrostatics import (
estimate_pme_parameters,
ewald_real_space,
particle_mesh_ewald,
pme_reciprocal_space,
)
from nvalchemiops.jax.neighbors import neighbor_list
from nvalchemiops.jax.neighbors.naive import naive_neighbor_list
from nvalchemiops.jax.neighbors.neighbor_utils import compute_naive_num_shifts
except Exception as exc:
print(
f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime."
)
sys.exit(0)
Check Device#
print("=" * 70)
print("JAX PME ELECTROSTATICS EXAMPLE")
print("=" * 70)
devices = jax.devices()
print(f"\nJAX devices: {devices}")
print(f"Default device: {jax.default_backend()}")
======================================================================
JAX PME ELECTROSTATICS EXAMPLE
======================================================================
JAX devices: [CudaDevice(id=0)]
Default device: gpu
Create a NaCl Crystal System#
We define a helper function to create NaCl rock salt crystal supercells.
def create_nacl_system(n_cells: int = 3, lattice_constant: float = 5.64):
"""Create a NaCl crystal supercell.
Parameters
----------
n_cells : int
Number of unit cells in each direction.
lattice_constant : float
NaCl lattice constant in Angstroms.
Returns
-------
positions, charges, cell, pbc : jax.Array
System arrays with float64 dtype for electrostatics.
"""
base_positions = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])
base_charges = np.array([1.0, -1.0])
positions_list = []
charges_list = []
for i in range(n_cells):
for j in range(n_cells):
for k in range(n_cells):
offset = np.array([i, j, k])
for pos, charge in zip(base_positions, base_charges):
positions_list.append((pos + offset) * lattice_constant)
charges_list.append(charge)
# Convert to JAX arrays with float64 for electrostatics accuracy
positions = jnp.array(positions_list, dtype=jnp.float64)
charges = jnp.array(charges_list, dtype=jnp.float64)
cell = jnp.eye(3, dtype=jnp.float64) * lattice_constant * n_cells
cell = cell[None, ...] # Add batch dimension: (1, 3, 3)
pbc = jnp.array([[True, True, True]])
return positions, charges, cell, pbc
Basic PME with Automatic Parameters#
The simplest way to use PME is with automatic parameter estimation.
print("\n" + "=" * 70)
print("BASIC PME WITH AUTOMATIC PARAMETERS")
print("=" * 70)
# Create a NaCl crystal (3×3×3 unit cells = 54 atoms)
positions, charges, cell, pbc = create_nacl_system(n_cells=3)
print(f"\nSystem: {len(positions)} atoms NaCl crystal")
print(f"Cell size: {float(cell[0, 0, 0]):.2f} Å")
print(f"Total charge: {float(charges.sum()):.1f} (should be 0 for neutral)")
======================================================================
BASIC PME WITH AUTOMATIC PARAMETERS
======================================================================
System: 54 atoms NaCl crystal
Cell size: 16.92 Å
Total charge: 0.0 (should be 0 for neutral)
Estimate optimal PME parameters:
params = estimate_pme_parameters(positions, cell, accuracy=1e-6)
print("\nEstimated parameters (accuracy=1e-6):")
print(f" alpha = {float(params.alpha[0]):.4f}")
print(f" mesh_dimensions = {params.mesh_dimensions}")
print(
f" mesh_spacing = ({float(params.mesh_spacing[0, 0]):.2f}, "
f"{float(params.mesh_spacing[0, 1]):.2f}, {float(params.mesh_spacing[0, 2]):.2f}) Å"
)
print(f" real_space_cutoff = {float(params.real_space_cutoff[0]):.2f} Å")
Estimated parameters (accuracy=1e-6):
alpha = 0.2037
mesh_dimensions = (64, 64, 64)
mesh_spacing = (0.26, 0.26, 0.26) Å
real_space_cutoff = 18.25 Å
Build neighbor list and run PME:
cutoff = float(params.real_space_cutoff[0])
nl, nptr, ns = neighbor_list(
positions,
cutoff,
cell=cell,
pbc=pbc,
return_neighbor_list=True,
)
energies, forces = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
neighbor_list=nl,
neighbor_ptr=nptr,
neighbor_shifts=ns,
compute_forces=True,
accuracy=1e-6,
)
total_energy = float(energies.sum())
max_force = float(jnp.linalg.norm(forces, axis=1).max())
print("\nPME Results:")
print(f" Total energy: {total_energy:.6f}")
print(f" Energy per atom: {total_energy / len(positions):.6f}")
print(f" Max force magnitude: {max_force:.6f}")
PME Results:
Total energy: -9.743761
Energy per atom: -0.180440
Max force magnitude: 0.000000
Neighbor Matrix vs COO Format Comparison#
PME supports both neighbor formats, producing identical results.
print("\n" + "=" * 70)
print("NEIGHBOR FORMAT COMPARISON")
print("=" * 70)
# Build both formats using the estimated real-space cutoff
# COO format (neighbor list)
nl_coo, nptr_coo, ns_coo = neighbor_list(
positions,
cutoff,
cell=cell,
pbc=pbc,
return_neighbor_list=True,
)
# Dense format (neighbor matrix)
nm_dense, num_dense, ns_dense = neighbor_list(
positions,
cutoff,
cell=cell,
pbc=pbc,
return_neighbor_list=False,
)
print(f"\nUsing alpha={float(params.alpha[0]):.4f}, mesh_dims={params.mesh_dimensions}")
======================================================================
NEIGHBOR FORMAT COMPARISON
======================================================================
Using alpha=0.2037, mesh_dims=(64, 64, 64)
Using neighbor list (COO) format:
energies_coo, forces_coo = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
neighbor_list=nl_coo,
neighbor_ptr=nptr_coo,
neighbor_shifts=ns_coo,
compute_forces=True,
accuracy=1e-6,
)
print(f" COO format: E={float(energies_coo.sum()):.6f}")
COO format: E=-9.743761
Using neighbor matrix (dense) format:
energies_dense, forces_dense = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
neighbor_matrix=nm_dense,
neighbor_matrix_shifts=ns_dense,
compute_forces=True,
accuracy=1e-6,
)
print(f" Dense format: E={float(energies_dense.sum()):.6f}")
# Compare results
energy_diff = abs(float(energies_coo.sum()) - float(energies_dense.sum()))
force_diff = float(jnp.abs(forces_coo - forces_dense).max())
print(f"\nEnergy difference: {energy_diff:.2e}")
print(f"Max force difference: {force_diff:.2e}")
Dense format: E=-9.743761
Energy difference: 0.00e+00
Max force difference: 2.64e-17
Real-Space and Reciprocal-Space Components#
You can compute the components separately if needed.
print("\n" + "=" * 70)
print("ENERGY COMPONENTS")
print("=" * 70)
# Use lower accuracy for this demo to speed up parameter estimation
params_comp = estimate_pme_parameters(positions, cell, accuracy=1e-4)
cutoff_comp = float(params_comp.real_space_cutoff[0])
nl_comp, nptr_comp, ns_comp = neighbor_list(
positions,
cutoff_comp,
cell=cell,
pbc=pbc,
return_neighbor_list=True,
)
======================================================================
ENERGY COMPONENTS
======================================================================
Real-space component (uses same kernel as Ewald):
real_energy = ewald_real_space(
positions=positions,
charges=charges,
cell=cell,
alpha=params_comp.alpha,
neighbor_list=nl_comp,
neighbor_ptr=nptr_comp,
neighbor_shifts=ns_comp,
)
print(f"\n Real-space: {float(real_energy.sum()):.6f}")
Real-space: -3.549335
PME reciprocal-space component (FFT-based):
recip_energy = pme_reciprocal_space(
positions=positions,
charges=charges,
cell=cell,
alpha=params_comp.alpha,
mesh_dimensions=params_comp.mesh_dimensions,
)
print(f" Reciprocal-space (PME): {float(recip_energy.sum()):.6f}")
print(f" Total (sum): {float(real_energy.sum() + recip_energy.sum()):.6f}")
Reciprocal-space (PME): -6.194456
Total (sum): -9.743791
Compare with full PME:
full_pme_energy = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
neighbor_list=nl_comp,
neighbor_ptr=nptr_comp,
neighbor_shifts=ns_comp,
accuracy=1e-4,
)
print(f" Full PME: {float(full_pme_energy.sum()):.6f}")
component_diff = abs(
float(real_energy.sum() + recip_energy.sum()) - float(full_pme_energy.sum())
)
print(f"\n Component sum vs full PME difference: {component_diff:.2e}")
Full PME: -9.743791
Component sum vs full PME difference: 0.00e+00
Charge Gradients for ML Potentials#
PME supports computing analytical charge gradients (∂E/∂q_i), which are useful for training machine learning potentials that predict atomic partial charges.
print("\n" + "=" * 70)
print("CHARGE GRADIENTS")
print("=" * 70)
# Compute PME with charge gradients
energies_cg, forces_cg, charge_grads = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
neighbor_list=nl_comp,
neighbor_ptr=nptr_comp,
neighbor_shifts=ns_comp,
compute_forces=True,
compute_charge_gradients=True,
accuracy=1e-4,
)
print(f"\n Charge gradients shape: {charge_grads.shape}")
print(
f" Charge gradients range: [{float(charge_grads.min()):.4f}, "
f"{float(charge_grads.max()):.4f}]"
)
print(f" Charge gradients mean: {float(charge_grads.mean()):.4f}")
# The charge gradient represents dE/dq for each atom
# For neutral systems, the sum should be close to zero due to symmetry
print(f" Sum of charge gradients: {float(charge_grads.sum()):.4e}")
======================================================================
CHARGE GRADIENTS
======================================================================
Charge gradients shape: (54,)
Charge gradients range: [-0.3609, 0.3609]
Charge gradients mean: 0.0000
Sum of charge gradients: 3.3307e-16
Verify by checking gradient symmetry for Na+ and Cl- ions:
na_grads = charge_grads[charges > 0] # Na+ ions
cl_grads = charge_grads[charges < 0] # Cl- ions
print(f"\n Na+ charge gradients mean: {float(na_grads.mean()):.4f}")
print(f" Cl- charge gradients mean: {float(cl_grads.mean()):.4f}")
Na+ charge gradients mean: -0.3609
Cl- charge gradients mean: 0.3609
JIT Compilation#
Demonstrate combining the neighbor list build and PME calculation into a
single jax.jit-compiled function. This allows JAX to fuse the entire
pipeline into one optimized computation.
For JIT compatibility:
max_neighborsmust be specified (static array shapes)mesh_dimensionsmust be a concrete tuple (static FFT sizes)alphacan be a traced JAX arraycompute_forcesand other boolean flags must be staticParameter estimation (
estimate_pme_parameters) should happen outside the jitted function since it determines array shapesPeriodic shift metadata (
shift_range,num_shifts_per_system,max_shifts_per_system) must be pre-computed outside jit usingcompute_naive_num_shifts, since the launch dimensions must be concrete
print("\n" + "=" * 70)
print("JIT COMPILATION")
print("=" * 70)
# First, estimate parameters outside jit (determines static shapes)
jit_positions, jit_charges, jit_cell, jit_pbc = create_nacl_system(n_cells=3)
jit_params = estimate_pme_parameters(jit_positions, jit_cell, accuracy=1e-5)
jit_cutoff = float(jit_params.real_space_cutoff[0])
jit_mesh_dims = tuple(int(x) for x in jit_params.mesh_dimensions)
jit_alpha = jit_params.alpha
# Pre-compute shift metadata outside jit (launch sizes must be concrete)
shift_range, num_shifts_per_system, max_shifts_per_system = compute_naive_num_shifts(
jit_cell, jit_cutoff, jit_pbc
)
# Define a function that builds neighbors and computes PME
# We will compare the performance of the jitted and non-jitted versions.
def compute_pme_energy_forces(
positions: jax.Array,
charges: jax.Array,
cell: jax.Array,
pbc: jax.Array,
alpha: jax.Array,
shift_range: jax.Array = shift_range,
num_shifts_per_system: jax.Array = num_shifts_per_system,
cutoff: float = jit_cutoff,
max_neighbors: int = 128,
max_shifts_per_system: int = max_shifts_per_system,
mesh_dimensions: tuple[int, int, int] = jit_mesh_dims,
) -> tuple[jax.Array, jax.Array]:
"""JIT-compiled neighbor list + PME pipeline."""
# Build neighbor matrix inside jit (max_neighbors must be static,
# shift metadata pre-computed outside jit)
neighbor_matrix, _, neighbor_matrix_shifts = naive_neighbor_list(
positions,
cutoff,
cell=cell,
pbc=pbc,
max_neighbors=max_neighbors,
shift_range_per_dimension=shift_range,
num_shifts_per_system=num_shifts_per_system,
max_shifts_per_system=max_shifts_per_system,
)
# Compute PME (mesh_dimensions is static, alpha is traced)
energies, forces = particle_mesh_ewald(
positions=positions,
charges=charges,
cell=cell,
alpha=alpha,
mesh_dimensions=mesh_dimensions,
neighbor_matrix=neighbor_matrix,
neighbor_matrix_shifts=neighbor_matrix_shifts,
compute_forces=True,
)
return energies, forces
jit_compute_pme_energy_forces = jax.jit(compute_pme_energy_forces)
======================================================================
JIT COMPILATION
======================================================================
Run the non-jitted function:
energies, forces = compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
total_energy = float(energies.sum())
max_force = float(jnp.linalg.norm(forces, axis=1).max())
print(f" Non-jitted total energy: {total_energy:.6f}")
print(f" Non-jitted max force: {max_force:.6f}")
# Calculate Performance
# Warmup measurements
for _ in range(10):
energies, forces = compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
energies.block_until_ready()
forces.block_until_ready()
# Timed measurements
start_time = time.time()
for _ in range(50):
energies, forces = compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
energies.block_until_ready()
forces.block_until_ready()
total_time = time.time() - start_time
print(f" Non-jitted average time per call: {total_time / 50:.6f} seconds")
Non-jitted total energy: -8.492044
Non-jitted max force: 0.059720
Non-jitted average time per call: 0.142695 seconds
Run the jitted function:
print("\nCompiling and running jitted PME pipeline...")
jit_energies, jit_forces = jit_compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
jit_total_energy = float(jit_energies.sum())
jit_max_force = float(jnp.linalg.norm(jit_forces, axis=1).max())
print(f" JIT total energy: {jit_total_energy:.6f}")
print(f" JIT max force: {jit_max_force:.6f}")
# Calculate Performance
# Warmup measurements
for _ in range(10):
jit_energies, jit_forces = jit_compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
jit_energies.block_until_ready()
jit_forces.block_until_ready()
# Timed measurements
start_time = time.time()
for _ in range(50):
jit_energies, jit_forces = jit_compute_pme_energy_forces(
jit_positions, jit_charges, jit_cell, jit_pbc, jit_alpha
)
jit_energies.block_until_ready()
jit_forces.block_until_ready()
total_time = time.time() - start_time
print(f" JIT average time per call: {total_time / 50:.6f} seconds")
# Compare with non-jitted result (note: may differ slightly due to different
# accuracy settings or neighbor list truncation from max_neighbors)
energy_diff_jit = abs(jit_total_energy - total_energy)
print(f" Difference vs non-jitted (different accuracy): {energy_diff_jit:.2e}")
Compiling and running jitted PME pipeline...
JIT total energy: -8.585994
JIT max force: 0.059910
JIT average time per call: 0.000817 seconds
Difference vs non-jitted (different accuracy): 9.40e-02
Summary#
This example demonstrated:
Automatic parameter estimation for alpha and mesh dimensions using
estimate_pme_parameterswith target accuracyNeighbor format flexibility with COO (list) and dense (matrix) formats
Component access for real-space and reciprocal-space separately
Charge gradients (∂E/∂q_i) for ML potential training
JIT compilation of the full neighbor list + PME pipeline
Key JAX-specific patterns:
Use
jnp.float64for electrostatics calculationsCell shape is
(1, 3, 3)with batch dimensionUse
float()to extract scalar values from JAX arrays for printingParameters from
estimate_pme_parametersare JAX arraysFor
jax.jit: estimate parameters outside, passmax_neighborsandmesh_dimensionsas static values
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print("\nKey takeaways:")
print(" - Use estimate_pme_parameters() for automatic parameter selection")
print(" - Both COO and dense neighbor formats produce identical results")
print(" - Real and reciprocal components can be computed separately")
print(" - Charge gradients are available for ML potential training")
print(" - Use jax.jit to fuse neighbor list + PME into one compiled function")
print("\nJAX PME example completed successfully!")
======================================================================
SUMMARY
======================================================================
Key takeaways:
- Use estimate_pme_parameters() for automatic parameter selection
- Both COO and dense neighbor formats produce identical results
- Real and reciprocal components can be computed separately
- Charge gradients are available for ML potential training
- Use jax.jit to fuse neighbor list + PME into one compiled function
JAX PME example completed successfully!
Total running time of the script: (0 minutes 20.259 seconds)