nvalchemiops.jax.interactions.electrostatics: Electrostatics#
The electrostatics module provides GPU-accelerated implementations of
long-range electrostatic interactions for molecular simulations with JAX bindings.
These functions accept standard jax.Array inputs.
Tip
For the underlying framework-agnostic Warp kernels, see nvalchemiops.interactions.electrostatics: Electrostatic Interactions (Warp).
High-Level Interface#
These are the primary entry points for most users; these methods should
be jax.jit compatible.
- nvalchemiops.jax.interactions.electrostatics.ewald_summation(positions, charges, cell, alpha=None, k_vectors=None, k_cutoff=None, miller_bounds=None, batch_idx=None, max_atoms_per_system=None, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, compute_forces=False, compute_virial=False, accuracy=1e-6)[source]#
Compute complete Ewald summation (real + reciprocal space).
- The Ewald method splits long-range Coulomb into:
E_total = E_real + E_reciprocal - E_self - E_background
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic partial charges.
cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.
alpha (float or jax.Array or None) – Ewald splitting parameter. If None, estimated automatically.
k_vectors (jax.Array | None) – Reciprocal lattice vectors. If None, generated automatically. Shape (K, 3) for single system, (B, K, 3) for batch.
k_cutoff (float | None) – K-space cutoff. Used only if k_vectors is None.
miller_bounds (tuple[int, int, int] | None, optional) – Precomputed maximum Miller indices (M_h, M_k, M_l). Forwarded to
generate_k_vectors_ewald_summation()whenk_vectorsisNone. When provided, makes k-vector generation compatible withjax.jit. Usegenerate_miller_indices()to precompute. Ignored whenk_vectorsis explicitly provided.batch_idx (jax.Array | None, shape (N,)) – System index for each atom.
max_atoms_per_system (int | None, optional) – Maximum number of atoms in any single system in the batch. Required when using
jax.jitwith batched inputs. If None, inferred from data (fails under JIT).neighbor_list (jax.Array | None, shape (2, M)) – Neighbor list in COO format.
neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list.
neighbor_shifts (jax.Array | None, shape (M, 3)) – Periodic image shifts for neighbor list.
neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Dense neighbor matrix format.
neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Periodic image shifts for neighbor_matrix.
mask_value (int | None) – Value indicating invalid entries in neighbor_matrix.
compute_forces (bool, default=False) – Whether to compute forces.
compute_virial (bool, default=False) – Whether to compute the virial tensor.
accuracy (float, default=1e-6) – Target accuracy for automatic parameter estimation.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom total Ewald energy.
forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True).
virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.
- Return type:
Examples
>>> # Complete Ewald summation with automatic parameters >>> energies, forces = ewald_summation( ... positions, charges, cell, ... neighbor_list=nl, neighbor_ptr=neighbor_ptr, neighbor_shifts=shifts, ... accuracy=1e-6, ... compute_forces=True, ... )
- nvalchemiops.jax.interactions.electrostatics.particle_mesh_ewald(positions, charges, cell, alpha=None, mesh_spacing=None, mesh_dimensions=None, spline_order=4, batch_idx=None, k_vectors=None, k_squared=None, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False, accuracy=1e-6)[source]#
Complete Particle Mesh Ewald (PME) calculation for long-range electrostatics.
Computes total Coulomb energy using the PME method, which achieves O(N log N) scaling through FFT-based reciprocal space calculations. Combines: 1. Real-space contribution (short-range, erfc-damped) 2. Reciprocal-space contribution (long-range, FFT + B-spline interpolation) 3. Self-energy and background corrections
- Total Energy Formula:
E_total = E_real + E_reciprocal - E_self - E_background
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic partial charges in elementary charge units.
cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrices with lattice vectors as rows. Shape (3, 3) is automatically promoted to (1, 3, 3) for single-system mode.
alpha (float, jax.Array, or None, default=None) – Ewald splitting parameter controlling real/reciprocal space balance. - float: Same α for all systems - Array shape (B,): Per-system α values - None: Automatically estimated using Kolafa-Perram formula
mesh_spacing (float, optional) – Target mesh spacing. Mesh dimensions computed as ceil(cell_length / mesh_spacing).
mesh_dimensions (tuple[int, int, int], optional) – Explicit FFT mesh dimensions (nx, ny, nz). Power-of-2 values recommended.
spline_order (int, default=4) – B-spline interpolation order (4 = cubic B-splines, recommended).
batch_idx (jax.Array, shape (N,), dtype=int32, optional) – System index for each atom (0 to B-1). Determines execution mode: - None: Single-system optimized kernels - Provided: Batched kernels for multiple independent systems
k_vectors (jax.Array, optional) – Precomputed k-vectors from generate_k_vectors_pme. Providing this along with k_squared skips k-vector generation.
k_squared (jax.Array, optional) – Precomputed k² values from generate_k_vectors_pme.
neighbor_list (jax.Array, optional) – CSR-format neighbor list indices. See ewald_real_space.
neighbor_ptr (jax.Array, optional) – CSR-format neighbor list pointers. See ewald_real_space.
neighbor_shifts (jax.Array, optional) – Periodic image shifts for neighbor list. See ewald_real_space.
neighbor_matrix (jax.Array, optional) – Dense neighbor matrix. Alternative to CSR format.
neighbor_matrix_shifts (jax.Array, optional) – Shifts for dense neighbor matrix.
mask_value (int, optional) – Mask value for invalid neighbors in dense format.
compute_forces (bool, default=False) – If True, compute per-atom forces.
compute_charge_gradients (bool, default=False) – If True, compute per-atom charge gradients dE/dq.
compute_virial (bool, default=False) – If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume.
accuracy (float, default=1e-6) – Target accuracy for automatic parameter estimation.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom total electrostatic energies.
forces (jax.Array, shape (N, 3), optional) – Per-atom forces (only if compute_forces=True).
charge_gradients (jax.Array, shape (N,), optional) – Per-atom charge gradients (only if compute_charge_gradients=True).
virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (only if compute_virial=True). Always last in the return tuple.
- Return type:
Array | tuple[Array, Array] | tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]
Notes
- Automatic Parameter Estimation (when alpha is None):
Uses Kolafa-Perram formula for optimal α and mesh dimensions based on requested accuracy.
Examples
Basic usage:
>>> energies = particle_mesh_ewald( ... positions, charges, cell, alpha=0.3, ... mesh_dimensions=(32, 32, 32), ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... )
With forces and automatic parameters:
>>> energies, forces = particle_mesh_ewald( ... positions, charges, cell, ... mesh_spacing=1.0, accuracy=1e-5, ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... compute_forces=True, ... )
Batched systems:
>>> energies = particle_mesh_ewald( ... positions, charges, cell, ... batch_idx=batch_idx, ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... )
See also
pme_reciprocal_spaceReciprocal-space component only
ewald_real_spaceReal-space component
estimate_pme_parametersAutomatic parameter estimation
Coulomb Interactions#
Direct pairwise Coulomb interactions.
- nvalchemiops.jax.interactions.electrostatics.coulomb_energy(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#
Compute Coulomb electrostatic energies.
Computes pairwise electrostatic energies using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations.
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic charges.
cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrix. Shape (B, 3, 3) for batched calculations.
cutoff (float) – Cutoff distance for interactions.
alpha (float, default=0.0) – Ewald splitting parameter. Use 0.0 for undamped Coulomb.
neighbor_list (jax.Array | None, shape (2, num_pairs)) – Neighbor pairs in COO format. Row 0 = source, Row 1 = target.
neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module.
neighbor_shifts (jax.Array | None, shape (num_pairs, 3)) – Integer unit cell shifts for neighbor list format.
neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Neighbor indices in matrix format.
neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Integer unit cell shifts for matrix format.
fill_value (int | None) – Fill value for neighbor matrix padding.
batch_idx (jax.Array | None, shape (N,)) – Batch indices for each atom.
- Returns:
energies – Per-atom energies. Sum to get total energy.
- Return type:
jax.Array, shape (N,)
Examples
>>> # Direct Coulomb (undamped) >>> energies = coulomb_energy( ... positions, charges, cell, cutoff=10.0, alpha=0.0, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... ) >>> total_energy = energies.sum()
>>> # Ewald/PME real-space (damped) >>> energies = coulomb_energy( ... positions, charges, cell, cutoff=10.0, alpha=0.3, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... )
- nvalchemiops.jax.interactions.electrostatics.coulomb_forces(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#
Compute Coulomb electrostatic forces.
Convenience wrapper that returns only forces (no energies).
- Parameters:
descriptions. (See coulomb_energy for parameter)
positions (Array)
charges (Array)
cell (Array)
cutoff (float)
alpha (float)
neighbor_list (Array | None)
neighbor_ptr (Array | None)
neighbor_shifts (Array | None)
neighbor_matrix (Array | None)
neighbor_matrix_shifts (Array | None)
fill_value (int | None)
batch_idx (Array | None)
- Returns:
forces – Forces on each atom.
- Return type:
jax.Array, shape (N, 3)
See also
coulomb_energy_forcesCompute both energies and forces
- nvalchemiops.jax.interactions.electrostatics.coulomb_energy_forces(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#
Compute Coulomb electrostatic energies and forces.
Computes pairwise electrostatic energies and forces using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations.
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic charges.
cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrix. Shape (B, 3, 3) for batched calculations.
cutoff (float) – Cutoff distance for interactions.
alpha (float, default=0.0) – Ewald splitting parameter. Use 0.0 for undamped Coulomb.
neighbor_list (jax.Array | None, shape (2, num_pairs)) – Neighbor pairs in COO format.
neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module.
neighbor_shifts (jax.Array | None, shape (num_pairs, 3)) – Integer unit cell shifts for neighbor list format.
neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Neighbor indices in matrix format.
neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Integer unit cell shifts for matrix format.
fill_value (int | None) – Fill value for neighbor matrix padding.
batch_idx (jax.Array | None, shape (N,)) – Batch indices for each atom.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom energies.
forces (jax.Array, shape (N, 3)) – Forces on each atom.
- Return type:
Examples
>>> # Direct Coulomb >>> energies, forces = coulomb_energy_forces( ... positions, charges, cell, cutoff=10.0, alpha=0.0, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... )
>>> # Ewald/PME real-space >>> energies, forces = coulomb_energy_forces( ... positions, charges, cell, cutoff=10.0, alpha=0.3, ... neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, ... fill_value=num_atoms ... )
Ewald Components#
Individual components of the Ewald summation method.
- nvalchemiops.jax.interactions.electrostatics.ewald_real_space(positions, charges, cell, alpha, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, batch_idx=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#
Compute real-space Ewald energy and optionally forces, charge gradients, and virial.
Computes the damped Coulomb interactions for atom pairs within the real-space cutoff. The complementary error function (erfc) damping ensures rapid convergence in real space.
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic partial charges.
cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.
alpha (float or jax.Array) – Ewald splitting parameter. Can be a float or array of shape (1,) or (B,).
neighbor_list (jax.Array | None, shape (2, M)) – Neighbor list in COO format.
neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list.
neighbor_shifts (jax.Array | None, shape (M, 3)) – Periodic image shifts for neighbor list.
neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Dense neighbor matrix format.
neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Periodic image shifts for neighbor_matrix.
mask_value (int | None, optional) – Value indicating invalid entries in neighbor_matrix. If None (default), uses num_atoms as the mask value.
batch_idx (jax.Array | None, shape (N,)) – System index for each atom.
compute_forces (bool, default=False) – Whether to compute explicit forces.
compute_charge_gradients (bool, default=False) – Whether to compute charge gradients.
compute_virial (bool, default=False) – Whether to compute the virial tensor.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom real-space energy.
forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True or compute_charge_gradients=True).
charge_gradients (jax.Array, shape (N,), optional) – Charge gradients (if compute_charge_gradients=True).
virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.
- Return type:
- nvalchemiops.jax.interactions.electrostatics.ewald_reciprocal_space(positions, charges, cell, k_vectors, alpha, batch_idx=None, max_atoms_per_system=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#
Compute reciprocal-space Ewald energy and optionally forces, charge gradients, and virial.
Computes the smooth long-range electrostatic contribution using structure factors in reciprocal space. Includes self-energy and background corrections.
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic partial charges.
cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.
k_vectors (jax.Array) – Reciprocal lattice vectors. Shape (K, 3) for single system, (B, K, 3) for batch.
alpha (float or jax.Array) – Ewald splitting parameter. Can be a float or array of shape (1,) or (B,).
batch_idx (jax.Array | None, shape (N,)) – System index for each atom.
max_atoms_per_system (int | None, optional) – Maximum number of atoms in any single system in the batch. Required when using
jax.jitwith batched inputs. If None, inferred from data (fails under JIT).compute_forces (bool, default=False) – Whether to compute explicit forces.
compute_charge_gradients (bool, default=False) – Whether to compute charge gradients.
compute_virial (bool, default=False) – Whether to compute the virial tensor.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom reciprocal-space energy (with corrections applied).
forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True or compute_charge_gradients=True).
charge_gradients (jax.Array, shape (N,), optional) – Charge gradients (if compute_charge_gradients=True).
virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.
- Return type:
PME Components#
Individual components of the Particle Mesh Ewald method.
- nvalchemiops.jax.interactions.electrostatics.pme_reciprocal_space(positions, charges, cell, alpha, mesh_dimensions=None, mesh_spacing=None, spline_order=4, batch_idx=None, k_vectors=None, k_squared=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#
Compute PME reciprocal-space contribution.
Implements the FFT-based long-range component of PME using B-spline interpolation and convolution with the Green’s function.
- Pipeline:
Spread charges to mesh (spline_spread)
FFT → frequency space
Compute Green’s function and structure factor
Convolve: mesh_fft * G(k) / C²(k)
IFFT → potential mesh
Gather potential at atoms (spline_gather)
Apply self-energy and background corrections
(Optional) Compute forces via Fourier gradient
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
charges (jax.Array, shape (N,)) – Atomic partial charges.
cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrices with lattice vectors as rows.
alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)
mesh_dimensions (tuple[int, int, int], optional) – FFT mesh dimensions (nx, ny, nz).
mesh_spacing (float, optional) – Target mesh spacing. Used to compute mesh_dimensions if not provided.
spline_order (int, default=4) – B-spline interpolation order (4 = cubic).
batch_idx (jax.Array | None, default=None) – System index for each atom.
k_vectors (jax.Array, optional) – Precomputed k-vectors from generate_k_vectors_pme.
k_squared (jax.Array, optional) – Precomputed k² values from generate_k_vectors_pme.
compute_forces (bool, default=False) – If True, compute forces via Fourier gradient.
compute_charge_gradients (bool, default=False) – If True, compute charge gradients dE/dq.
compute_virial (bool, default=False) – If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume.
- Returns:
energies (jax.Array, shape (N,)) – Per-atom reciprocal-space energies.
forces (jax.Array, shape (N, 3), optional) – Per-atom forces (only if compute_forces=True).
charge_gradients (jax.Array, shape (N,), optional) – Per-atom charge gradients (only if compute_charge_gradients=True).
virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (only if compute_virial=True). Always last in the return tuple.
- Return type:
Array | tuple[Array, Array] | tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]
Notes
Output dtype for energy/forces matches the input positions dtype
FFT/convolution and spline operations all respect the input dtype
Automatically determines mesh_dimensions if not provided
Virial is computed in k-space and uses the same dtype as k_squared
- nvalchemiops.jax.interactions.electrostatics.pme_green_structure_factor(k_squared, mesh_dimensions, alpha, cell, spline_order=4, batch_idx=None)[source]#
Compute Green’s function and B-spline structure factor correction.
Computes the Coulomb Green’s function with volume normalization and the B-spline aliasing correction factor for PME.
- Green’s function (volume-normalized):
G(k) = (2π/V) * exp(-k²/(4α²)) / k²
- Structure factor correction (for B-spline deconvolution):
C²(k) = [sinc(m_x/N_x) · sinc(m_y/N_y) · sinc(m_z/N_z)]^(2p)
where p is the spline order.
- Parameters:
k_squared (jax.Array) – |k|² values at each FFT grid point. - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft)
mesh_dimensions (tuple[int, int, int]) – Full mesh dimensions (Nx, Ny, Nz) before rfft.
alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)
cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)
spline_order (int, default=4) – B-spline interpolation order (typically 4 for cubic B-splines).
batch_idx (jax.Array | None, default=None) – If provided, dispatches to batch kernels.
- Returns:
green_function (jax.Array) – Volume-normalized Green’s function G(k). - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft)
structure_factor_sq (jax.Array) – Squared structure factor C²(k) for B-spline deconvolution. Shape (Nx, Ny, Nz_rfft), shared across batch.
- Return type:
Notes
G(k=0) is set to zero to avoid singularity
The volume normalization in G(k) eliminates later divisions
Structure factor is mesh-dependent only, so shared across batch
- nvalchemiops.jax.interactions.electrostatics.pme_energy_corrections(raw_energies, charges, cell, alpha, batch_idx=None)[source]#
Apply self-energy and background corrections to PME energies.
Converts raw interpolated potential to energy and subtracts corrections:
E_i = q_i φ_i - E_self,i - E_background,i
- Self-energy correction (removes Gaussian self-interaction):
E_self,i = (α/√π) q_i²
- Background correction (for non-neutral systems):
E_background,i = (π/(2α²V)) q_i Q_total
- Parameters:
raw_energies (jax.Array, shape (N,) or (N_total,)) – Raw potential values φ_i from mesh interpolation.
charges (jax.Array, shape (N,) or (N_total,)) – Atomic charges.
cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)
alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)
batch_idx (jax.Array | None, default=None) – System index for each atom. If provided, uses batch kernels.
- Returns:
corrected_energies – Final per-atom reciprocal-space energy with corrections applied.
- Return type:
jax.Array, shape (N,) or (N_total,)
Notes
For neutral systems, background correction is zero
Supports both float32 and float64 dtypes
- nvalchemiops.jax.interactions.electrostatics.pme_energy_corrections_with_charge_grad(raw_energies, charges, cell, alpha, batch_idx=None)[source]#
Apply energy corrections and compute charge gradients.
Same as pme_energy_corrections but also returns dE/dq for each atom.
- Parameters:
raw_energies (jax.Array, shape (N,) or (N_total,)) – Raw potential values φ_i from mesh interpolation.
charges (jax.Array, shape (N,) or (N_total,)) – Atomic charges.
cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)
alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)
batch_idx (jax.Array | None, default=None) – System index for each atom. If provided, uses batch kernels.
- Returns:
corrected_energies (jax.Array, shape (N,) or (N_total,)) – Final per-atom reciprocal-space energy with corrections applied.
charge_gradients (jax.Array, shape (N,) or (N_total,)) – Per-atom charge gradients dE/dq.
- Return type:
Notes
Useful for training models that predict partial charges
Supports both float32 and float64 dtypes
K-Vector Generation#
- nvalchemiops.jax.interactions.electrostatics.generate_miller_indices(cell, k_cutoff)[source]#
Generate Miller index bounds for Ewald summation.
- Parameters:
- Returns:
Array of shape (3,) containing the maximum Miller indices (M_h, M_k, M_l) for each lattice direction.
- Return type:
Notes
For batch mode, one shared set of Miller bounds is used for all systems. If
k_cutoffis provided per system, the maximum cutoff across the batch is used to build those shared bounds.
- nvalchemiops.jax.interactions.electrostatics.generate_k_vectors_ewald_summation(cell, k_cutoff, miller_bounds=None)[source]#
Generate reciprocal lattice vectors for Ewald summation (half-space).
Creates k-vectors within the specified cutoff for the reciprocal space summation in the Ewald method. Uses half-space optimization to reduce computational cost by approximately 2x.
Half-Space Optimization#
This function generates k-vectors in the positive half-space only, exploiting the symmetry S(-k) = S*(k) where S(k) is the structure factor. For each pair of k-vectors (k, -k), only one is included.
- The half-space condition selects k-vectors where:
h > 0, OR
(h == 0 AND k > 0), OR
(h == 0 AND k == 0 AND l > 0)
The kernels in ewald_kernels.py compensate by doubling the Green’s function (using \(8\pi\) instead of \(4\pi\)), so energies, forces, and charge gradients are computed correctly.
Mathematical Background#
For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are:
\[ \begin{align}\begin{aligned}\mathbf{a}^* &= \frac{2\pi (\mathbf{b} \times \mathbf{c})}{V}\\\mathbf{b}^* &= \frac{2\pi (\mathbf{c} \times \mathbf{a})}{V}\\\mathbf{c}^* &= \frac{2\pi (\mathbf{a} \times \mathbf{b})}{V}\end{aligned}\end{align} \]where \(V = \mathbf{a} \cdot (\mathbf{b} \times \mathbf{c})\) is the cell volume.
In matrix form: \(\text{reciprocal_matrix} = 2\pi \cdot (\text{cell}^T)^{-1}\)
Each k-vector is: \(\mathbf{k} = h \mathbf{a}^* + k \mathbf{b}^* + l \mathbf{c}^*\) where (h, k, l) are Miller indices (integers).
- param cell:
Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch.
- type cell:
jax.Array
- param k_cutoff:
Maximum magnitude of k-vectors to include (\(|\mathbf{k}| \leq k_{\text{cutoff}}\)). Typical values: 8-12 \(\text{\AA}^{-1}\) for molecular systems. Higher values increase accuracy but also computational cost.
- type k_cutoff:
float or jax.Array
- param miller_bounds:
Precomputed maximum Miller indices (M_h, M_k, M_l) for each lattice direction. When provided, the function skips the internal computation of bounds from
cellandk_cutoff, making it compatible withjax.jit(which requires static array shapes). Usegenerate_miller_indices()to compute these bounds eagerly before entering a JIT context. WhenNone(default), bounds are computed automatically fromcellandk_cutoff.- type miller_bounds:
tuple[int, int, int] | None, optional
- returns:
Reciprocal lattice vectors within the cutoff. Shape (K, 3) for single system or (B, K, 3) for batch. Excludes k=0 and includes only half-space vectors.
- rtype:
jax.Array
Examples
Single system with explicit k_cutoff:
>>> cell = jnp.eye(3, dtype=jnp.float64) * 10.0 >>> k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff=8.0) >>> k_vectors.shape (...) # Number depends on cell size and cutoff
With automatic parameter estimation:
>>> from nvalchemiops.jax.interactions.electrostatics import estimate_ewald_parameters >>> params = estimate_ewald_parameters(positions, cell) >>> k_vectors = generate_k_vectors_ewald_summation(cell, params.reciprocal_space_cutoff)
JIT-compatible usage with precomputed bounds:
>>> from nvalchemiops.jax.interactions.electrostatics import generate_miller_indices >>> cell = jnp.eye(3, dtype=jnp.float64)[None, ...] * 10.0 >>> bounds = generate_miller_indices(cell, k_cutoff=8.0) >>> miller_bounds = (int(bounds[0]), int(bounds[1]), int(bounds[2])) >>> # This can now be called inside @jax.jit >>> k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff=8.0, miller_bounds=miller_bounds)
Notes
The k=0 vector is always excluded (causes division by zero in Green’s function).
For batch mode, the same set of Miller indices is used for all systems but transformed using each system’s reciprocal cell. If
k_cutoffis given per system, the maximum cutoff across the batch determines the shared Miller bounds.The number of k-vectors K scales as O(k_cutoff³ · V) where V is the cell volume.
When using inside
jax.jit, you must providemiller_boundsas a concretetuple[int, int, int]. The bounds determine array shapes (viajnp.arange), which must be statically known at trace time.
See also
ewald_reciprocal_spaceUses these k-vectors for reciprocal space energy.
estimate_ewald_parametersAutomatic parameter estimation including k_cutoff.
generate_miller_indicesCompute Miller bounds for JIT-compatible usage.
- nvalchemiops.jax.interactions.electrostatics.generate_k_vectors_pme(cell, mesh_dimensions, reciprocal_cell=None)[source]#
Generate reciprocal lattice vectors for Particle Mesh Ewald (PME).
Creates k-vectors on a regular grid compatible with FFT-based reciprocal space calculations in PME. Uses rfft conventions (half-size in z-dimension) to exploit Hermitian symmetry of real-valued charge densities.
Notes
For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are:
\[\begin{split}\begin{aligned} \mathbf{a}^* &= \frac{2\pi (\mathbf{b} \times \mathbf{c})}{V} \\ \mathbf{b}^* &= \frac{2\pi (\mathbf{c} \times \mathbf{a})}{V} \\ \mathbf{c}^* &= \frac{2\pi (\mathbf{a} \times \mathbf{b})}{V} \end{aligned}\end{split}\]where \(V = \mathbf{a} \cdot (\mathbf{b} \times \mathbf{c})\) is the cell volume.
In matrix form:
\[\text{reciprocal_matrix} = 2\pi \cdot (\text{cell}^T)^{-1}\]Each k-vector is then:
\[\mathbf{k} = h \mathbf{a}^* + k \mathbf{b}^* + l \mathbf{c}^*\]where (h, k, l) are Miller indices (integers).
- Parameters:
cell (jax.Array) – Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch.
mesh_dimensions (tuple[int, int, int]) – PME mesh grid dimensions (nx, ny, nz). Should typically be chosen such that mesh spacing is \(\sim 1 \text{\AA}\) or finer. Power-of-2 dimensions are optimal for FFT performance.
reciprocal_cell (jax.Array, optional) – Precomputed reciprocal cell matrix (\(2\pi \cdot \text{cell}^{-1}\)). If provided, skips the inverse computation. Shape (3, 3) or (B, 3, 3).
- Returns:
k_vectors (jax.Array, shape (nx, ny, nz//2+1, 3)) – Cartesian k-vectors at each grid point. Uses rfft convention where z-dimension is halved due to Hermitian symmetry.
k_squared_safe (jax.Array, shape (nx, ny, nz//2+1)) – Squared magnitude \(|\mathbf{k}|^2\) for each k-vector, with k=0 set to a small positive value (1e-12) to avoid division by zero.
- Return type:
Examples
Basic usage:
>>> cell = jnp.eye(3, dtype=jnp.float64) * 10.0 >>> mesh_dims = (32, 32, 32) >>> k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dims) >>> k_vectors.shape (32, 32, 17, 3)
With precomputed reciprocal cell:
>>> reciprocal_cell = 2 * jnp.pi * jnp.linalg.inv(cell) >>> k_vectors, k_squared = generate_k_vectors_pme( ... cell, mesh_dims, reciprocal_cell=reciprocal_cell ... )
Notes
The z-dimension output size is nz//2+1 due to rfft symmetry.
Miller indices follow jnp.fft.fftfreq convention (0, 1, 2, …, -2, -1).
k_squared_safe has k=0 replaced with 1e-12 to prevent division by zero in Green’s function calculations.
See also
pme_reciprocal_spaceUses these k-vectors for PME reciprocal space energy.
pme_green_structure_factorComputes Green’s function using k_squared.
Parameter Estimation#
Functions for automatic parameter estimation based on desired accuracy tolerance.
- nvalchemiops.jax.interactions.electrostatics.estimate_ewald_parameters(positions, cell, batch_idx=None, accuracy=1e-6)[source]#
Estimate optimal Ewald summation parameters for a given accuracy.
Uses the Kolafa-Perram formula to balance real-space and reciprocal-space contributions for optimal efficiency at the target accuracy.
- Parameters:
positions (jax.Array, shape (N, 3)) – Atomic coordinates.
cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrix.
batch_idx (jax.Array, shape (N,), dtype=int32, optional) – System index for each atom. If None, single-system mode.
accuracy (float, default=1e-6) – Target accuracy (relative error tolerance).
- Returns:
Dataclass containing alpha, real_space_cutoff, reciprocal_space_cutoff as
jax.Arrayobjects.- Return type:
- nvalchemiops.jax.interactions.electrostatics.estimate_pme_parameters(positions, cell, batch_idx=None, accuracy=1e-6)[source]#
Estimate optimal PME parameters for a given accuracy.
- Parameters:
- Returns:
Dataclass containing alpha, mesh dimensions, spacing, and cutoffs. Tensor fields are
jax.Arrayobjects.- Return type:
- nvalchemiops.jax.interactions.electrostatics.estimate_pme_mesh_dimensions(cell, alpha, accuracy=1e-6)[source]#
Estimate optimal PME mesh dimensions for a given accuracy.
- nvalchemiops.jax.interactions.electrostatics.mesh_spacing_to_dimensions(cell, mesh_spacing)[source]#
Convert mesh spacing to mesh dimensions.
- class nvalchemiops.jax.interactions.electrostatics.EwaldParameters(alpha, real_space_cutoff, reciprocal_space_cutoff)[source]#
Container for Ewald summation parameters.
All values are arrays of shape (B,), for single system calculations, the shape is (1,).