nvalchemiops.jax.interactions.electrostatics: Electrostatics#

The electrostatics module provides GPU-accelerated implementations of long-range electrostatic interactions for molecular simulations with JAX bindings. These functions accept standard jax.Array inputs.

Tip

For the underlying framework-agnostic Warp kernels, see nvalchemiops.interactions.electrostatics: Electrostatic Interactions (Warp).

High-Level Interface#

These are the primary entry points for most users; these methods should be jax.jit compatible.

nvalchemiops.jax.interactions.electrostatics.ewald_summation(positions, charges, cell, alpha=None, k_vectors=None, k_cutoff=None, miller_bounds=None, batch_idx=None, max_atoms_per_system=None, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, compute_forces=False, compute_virial=False, accuracy=1e-6)[source]#

Compute complete Ewald summation (real + reciprocal space).

The Ewald method splits long-range Coulomb into:

E_total = E_real + E_reciprocal - E_self - E_background

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic partial charges.

  • cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.

  • alpha (float or jax.Array or None) – Ewald splitting parameter. If None, estimated automatically.

  • k_vectors (jax.Array | None) – Reciprocal lattice vectors. If None, generated automatically. Shape (K, 3) for single system, (B, K, 3) for batch.

  • k_cutoff (float | None) – K-space cutoff. Used only if k_vectors is None.

  • miller_bounds (tuple[int, int, int] | None, optional) – Precomputed maximum Miller indices (M_h, M_k, M_l). Forwarded to generate_k_vectors_ewald_summation() when k_vectors is None. When provided, makes k-vector generation compatible with jax.jit. Use generate_miller_indices() to precompute. Ignored when k_vectors is explicitly provided.

  • batch_idx (jax.Array | None, shape (N,)) – System index for each atom.

  • max_atoms_per_system (int | None, optional) – Maximum number of atoms in any single system in the batch. Required when using jax.jit with batched inputs. If None, inferred from data (fails under JIT).

  • neighbor_list (jax.Array | None, shape (2, M)) – Neighbor list in COO format.

  • neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list.

  • neighbor_shifts (jax.Array | None, shape (M, 3)) – Periodic image shifts for neighbor list.

  • neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Dense neighbor matrix format.

  • neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Periodic image shifts for neighbor_matrix.

  • mask_value (int | None) – Value indicating invalid entries in neighbor_matrix.

  • compute_forces (bool, default=False) – Whether to compute forces.

  • compute_virial (bool, default=False) – Whether to compute the virial tensor.

  • accuracy (float, default=1e-6) – Target accuracy for automatic parameter estimation.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom total Ewald energy.

  • forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True).

  • virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.

Return type:

Array | tuple[Array, …]

Examples

>>> # Complete Ewald summation with automatic parameters
>>> energies, forces = ewald_summation(
...     positions, charges, cell,
...     neighbor_list=nl, neighbor_ptr=neighbor_ptr, neighbor_shifts=shifts,
...     accuracy=1e-6,
...     compute_forces=True,
... )
nvalchemiops.jax.interactions.electrostatics.particle_mesh_ewald(positions, charges, cell, alpha=None, mesh_spacing=None, mesh_dimensions=None, spline_order=4, batch_idx=None, k_vectors=None, k_squared=None, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False, accuracy=1e-6)[source]#

Complete Particle Mesh Ewald (PME) calculation for long-range electrostatics.

Computes total Coulomb energy using the PME method, which achieves O(N log N) scaling through FFT-based reciprocal space calculations. Combines: 1. Real-space contribution (short-range, erfc-damped) 2. Reciprocal-space contribution (long-range, FFT + B-spline interpolation) 3. Self-energy and background corrections

Total Energy Formula:

E_total = E_real + E_reciprocal - E_self - E_background

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic partial charges in elementary charge units.

  • cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrices with lattice vectors as rows. Shape (3, 3) is automatically promoted to (1, 3, 3) for single-system mode.

  • alpha (float, jax.Array, or None, default=None) – Ewald splitting parameter controlling real/reciprocal space balance. - float: Same α for all systems - Array shape (B,): Per-system α values - None: Automatically estimated using Kolafa-Perram formula

  • mesh_spacing (float, optional) – Target mesh spacing. Mesh dimensions computed as ceil(cell_length / mesh_spacing).

  • mesh_dimensions (tuple[int, int, int], optional) – Explicit FFT mesh dimensions (nx, ny, nz). Power-of-2 values recommended.

  • spline_order (int, default=4) – B-spline interpolation order (4 = cubic B-splines, recommended).

  • batch_idx (jax.Array, shape (N,), dtype=int32, optional) – System index for each atom (0 to B-1). Determines execution mode: - None: Single-system optimized kernels - Provided: Batched kernels for multiple independent systems

  • k_vectors (jax.Array, optional) – Precomputed k-vectors from generate_k_vectors_pme. Providing this along with k_squared skips k-vector generation.

  • k_squared (jax.Array, optional) – Precomputed k² values from generate_k_vectors_pme.

  • neighbor_list (jax.Array, optional) – CSR-format neighbor list indices. See ewald_real_space.

  • neighbor_ptr (jax.Array, optional) – CSR-format neighbor list pointers. See ewald_real_space.

  • neighbor_shifts (jax.Array, optional) – Periodic image shifts for neighbor list. See ewald_real_space.

  • neighbor_matrix (jax.Array, optional) – Dense neighbor matrix. Alternative to CSR format.

  • neighbor_matrix_shifts (jax.Array, optional) – Shifts for dense neighbor matrix.

  • mask_value (int, optional) – Mask value for invalid neighbors in dense format.

  • compute_forces (bool, default=False) – If True, compute per-atom forces.

  • compute_charge_gradients (bool, default=False) – If True, compute per-atom charge gradients dE/dq.

  • compute_virial (bool, default=False) – If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume.

  • accuracy (float, default=1e-6) – Target accuracy for automatic parameter estimation.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom total electrostatic energies.

  • forces (jax.Array, shape (N, 3), optional) – Per-atom forces (only if compute_forces=True).

  • charge_gradients (jax.Array, shape (N,), optional) – Per-atom charge gradients (only if compute_charge_gradients=True).

  • virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (only if compute_virial=True). Always last in the return tuple.

Return type:

Array | tuple[Array, Array] | tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]

Notes

Automatic Parameter Estimation (when alpha is None):

Uses Kolafa-Perram formula for optimal α and mesh dimensions based on requested accuracy.

Examples

Basic usage:

>>> energies = particle_mesh_ewald(
...     positions, charges, cell, alpha=0.3,
...     mesh_dimensions=(32, 32, 32),
...     neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts,
... )

With forces and automatic parameters:

>>> energies, forces = particle_mesh_ewald(
...     positions, charges, cell,
...     mesh_spacing=1.0, accuracy=1e-5,
...     neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts,
...     compute_forces=True,
... )

Batched systems:

>>> energies = particle_mesh_ewald(
...     positions, charges, cell,
...     batch_idx=batch_idx,
...     neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts,
... )

See also

pme_reciprocal_space

Reciprocal-space component only

ewald_real_space

Real-space component

estimate_pme_parameters

Automatic parameter estimation

Coulomb Interactions#

Direct pairwise Coulomb interactions.

nvalchemiops.jax.interactions.electrostatics.coulomb_energy(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#

Compute Coulomb electrostatic energies.

Computes pairwise electrostatic energies using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic charges.

  • cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrix. Shape (B, 3, 3) for batched calculations.

  • cutoff (float) – Cutoff distance for interactions.

  • alpha (float, default=0.0) – Ewald splitting parameter. Use 0.0 for undamped Coulomb.

  • neighbor_list (jax.Array | None, shape (2, num_pairs)) – Neighbor pairs in COO format. Row 0 = source, Row 1 = target.

  • neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module.

  • neighbor_shifts (jax.Array | None, shape (num_pairs, 3)) – Integer unit cell shifts for neighbor list format.

  • neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Neighbor indices in matrix format.

  • neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Integer unit cell shifts for matrix format.

  • fill_value (int | None) – Fill value for neighbor matrix padding.

  • batch_idx (jax.Array | None, shape (N,)) – Batch indices for each atom.

Returns:

energies – Per-atom energies. Sum to get total energy.

Return type:

jax.Array, shape (N,)

Examples

>>> # Direct Coulomb (undamped)
>>> energies = coulomb_energy(
...     positions, charges, cell, cutoff=10.0, alpha=0.0,
...     neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr,
...     neighbor_shifts=neighbor_shifts
... )
>>> total_energy = energies.sum()
>>> # Ewald/PME real-space (damped)
>>> energies = coulomb_energy(
...     positions, charges, cell, cutoff=10.0, alpha=0.3,
...     neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr,
...     neighbor_shifts=neighbor_shifts
... )
nvalchemiops.jax.interactions.electrostatics.coulomb_forces(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#

Compute Coulomb electrostatic forces.

Convenience wrapper that returns only forces (no energies).

Parameters:
  • descriptions. (See coulomb_energy for parameter)

  • positions (Array)

  • charges (Array)

  • cell (Array)

  • cutoff (float)

  • alpha (float)

  • neighbor_list (Array | None)

  • neighbor_ptr (Array | None)

  • neighbor_shifts (Array | None)

  • neighbor_matrix (Array | None)

  • neighbor_matrix_shifts (Array | None)

  • fill_value (int | None)

  • batch_idx (Array | None)

Returns:

forces – Forces on each atom.

Return type:

jax.Array, shape (N, 3)

See also

coulomb_energy_forces

Compute both energies and forces

nvalchemiops.jax.interactions.electrostatics.coulomb_energy_forces(positions, charges, cell, cutoff, alpha=0.0, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, fill_value=None, batch_idx=None)[source]#

Compute Coulomb electrostatic energies and forces.

Computes pairwise electrostatic energies and forces using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic charges.

  • cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrix. Shape (B, 3, 3) for batched calculations.

  • cutoff (float) – Cutoff distance for interactions.

  • alpha (float, default=0.0) – Ewald splitting parameter. Use 0.0 for undamped Coulomb.

  • neighbor_list (jax.Array | None, shape (2, num_pairs)) – Neighbor pairs in COO format.

  • neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module.

  • neighbor_shifts (jax.Array | None, shape (num_pairs, 3)) – Integer unit cell shifts for neighbor list format.

  • neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Neighbor indices in matrix format.

  • neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Integer unit cell shifts for matrix format.

  • fill_value (int | None) – Fill value for neighbor matrix padding.

  • batch_idx (jax.Array | None, shape (N,)) – Batch indices for each atom.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom energies.

  • forces (jax.Array, shape (N, 3)) – Forces on each atom.

Return type:

tuple[Array, Array]

Examples

>>> # Direct Coulomb
>>> energies, forces = coulomb_energy_forces(
...     positions, charges, cell, cutoff=10.0, alpha=0.0,
...     neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr,
...     neighbor_shifts=neighbor_shifts
... )
>>> # Ewald/PME real-space
>>> energies, forces = coulomb_energy_forces(
...     positions, charges, cell, cutoff=10.0, alpha=0.3,
...     neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts,
...     fill_value=num_atoms
... )

Ewald Components#

Individual components of the Ewald summation method.

nvalchemiops.jax.interactions.electrostatics.ewald_real_space(positions, charges, cell, alpha, neighbor_list=None, neighbor_ptr=None, neighbor_shifts=None, neighbor_matrix=None, neighbor_matrix_shifts=None, mask_value=None, batch_idx=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#

Compute real-space Ewald energy and optionally forces, charge gradients, and virial.

Computes the damped Coulomb interactions for atom pairs within the real-space cutoff. The complementary error function (erfc) damping ensures rapid convergence in real space.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic partial charges.

  • cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.

  • alpha (float or jax.Array) – Ewald splitting parameter. Can be a float or array of shape (1,) or (B,).

  • neighbor_list (jax.Array | None, shape (2, M)) – Neighbor list in COO format.

  • neighbor_ptr (jax.Array | None, shape (N+1,)) – CSR row pointers for neighbor list.

  • neighbor_shifts (jax.Array | None, shape (M, 3)) – Periodic image shifts for neighbor list.

  • neighbor_matrix (jax.Array | None, shape (N, max_neighbors)) – Dense neighbor matrix format.

  • neighbor_matrix_shifts (jax.Array | None, shape (N, max_neighbors, 3)) – Periodic image shifts for neighbor_matrix.

  • mask_value (int | None, optional) – Value indicating invalid entries in neighbor_matrix. If None (default), uses num_atoms as the mask value.

  • batch_idx (jax.Array | None, shape (N,)) – System index for each atom.

  • compute_forces (bool, default=False) – Whether to compute explicit forces.

  • compute_charge_gradients (bool, default=False) – Whether to compute charge gradients.

  • compute_virial (bool, default=False) – Whether to compute the virial tensor.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom real-space energy.

  • forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True or compute_charge_gradients=True).

  • charge_gradients (jax.Array, shape (N,), optional) – Charge gradients (if compute_charge_gradients=True).

  • virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.

Return type:

Array | tuple[Array, …]

nvalchemiops.jax.interactions.electrostatics.ewald_reciprocal_space(positions, charges, cell, k_vectors, alpha, batch_idx=None, max_atoms_per_system=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#

Compute reciprocal-space Ewald energy and optionally forces, charge gradients, and virial.

Computes the smooth long-range electrostatic contribution using structure factors in reciprocal space. Includes self-energy and background corrections.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic partial charges.

  • cell (jax.Array, shape (1, 3, 3) or (B, 3, 3)) – Unit cell matrices.

  • k_vectors (jax.Array) – Reciprocal lattice vectors. Shape (K, 3) for single system, (B, K, 3) for batch.

  • alpha (float or jax.Array) – Ewald splitting parameter. Can be a float or array of shape (1,) or (B,).

  • batch_idx (jax.Array | None, shape (N,)) – System index for each atom.

  • max_atoms_per_system (int | None, optional) – Maximum number of atoms in any single system in the batch. Required when using jax.jit with batched inputs. If None, inferred from data (fails under JIT).

  • compute_forces (bool, default=False) – Whether to compute explicit forces.

  • compute_charge_gradients (bool, default=False) – Whether to compute charge gradients.

  • compute_virial (bool, default=False) – Whether to compute the virial tensor.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom reciprocal-space energy (with corrections applied).

  • forces (jax.Array, shape (N, 3), optional) – Forces (if compute_forces=True or compute_charge_gradients=True).

  • charge_gradients (jax.Array, shape (N,), optional) – Charge gradients (if compute_charge_gradients=True).

  • virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (if compute_virial=True). Always last in the return tuple.

Return type:

Array | tuple[Array, …]

PME Components#

Individual components of the Particle Mesh Ewald method.

nvalchemiops.jax.interactions.electrostatics.pme_reciprocal_space(positions, charges, cell, alpha, mesh_dimensions=None, mesh_spacing=None, spline_order=4, batch_idx=None, k_vectors=None, k_squared=None, compute_forces=False, compute_charge_gradients=False, compute_virial=False)[source]#

Compute PME reciprocal-space contribution.

Implements the FFT-based long-range component of PME using B-spline interpolation and convolution with the Green’s function.

Pipeline:
  1. Spread charges to mesh (spline_spread)

  2. FFT → frequency space

  3. Compute Green’s function and structure factor

  4. Convolve: mesh_fft * G(k) / C²(k)

  5. IFFT → potential mesh

  6. Gather potential at atoms (spline_gather)

  7. Apply self-energy and background corrections

  8. (Optional) Compute forces via Fourier gradient

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • charges (jax.Array, shape (N,)) – Atomic partial charges.

  • cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrices with lattice vectors as rows.

  • alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)

  • mesh_dimensions (tuple[int, int, int], optional) – FFT mesh dimensions (nx, ny, nz).

  • mesh_spacing (float, optional) – Target mesh spacing. Used to compute mesh_dimensions if not provided.

  • spline_order (int, default=4) – B-spline interpolation order (4 = cubic).

  • batch_idx (jax.Array | None, default=None) – System index for each atom.

  • k_vectors (jax.Array, optional) – Precomputed k-vectors from generate_k_vectors_pme.

  • k_squared (jax.Array, optional) – Precomputed k² values from generate_k_vectors_pme.

  • compute_forces (bool, default=False) – If True, compute forces via Fourier gradient.

  • compute_charge_gradients (bool, default=False) – If True, compute charge gradients dE/dq.

  • compute_virial (bool, default=False) – If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume.

Returns:

  • energies (jax.Array, shape (N,)) – Per-atom reciprocal-space energies.

  • forces (jax.Array, shape (N, 3), optional) – Per-atom forces (only if compute_forces=True).

  • charge_gradients (jax.Array, shape (N,), optional) – Per-atom charge gradients (only if compute_charge_gradients=True).

  • virial (jax.Array, shape (1, 3, 3) or (B, 3, 3), optional) – Virial tensor (only if compute_virial=True). Always last in the return tuple.

Return type:

Array | tuple[Array, Array] | tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]

Notes

  • Output dtype for energy/forces matches the input positions dtype

  • FFT/convolution and spline operations all respect the input dtype

  • Automatically determines mesh_dimensions if not provided

  • Virial is computed in k-space and uses the same dtype as k_squared

nvalchemiops.jax.interactions.electrostatics.pme_green_structure_factor(k_squared, mesh_dimensions, alpha, cell, spline_order=4, batch_idx=None)[source]#

Compute Green’s function and B-spline structure factor correction.

Computes the Coulomb Green’s function with volume normalization and the B-spline aliasing correction factor for PME.

Green’s function (volume-normalized):

G(k) = (2π/V) * exp(-k²/(4α²)) / k²

Structure factor correction (for B-spline deconvolution):

C²(k) = [sinc(m_x/N_x) · sinc(m_y/N_y) · sinc(m_z/N_z)]^(2p)

where p is the spline order.

Parameters:
  • k_squared (jax.Array) – |k|² values at each FFT grid point. - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft)

  • mesh_dimensions (tuple[int, int, int]) – Full mesh dimensions (Nx, Ny, Nz) before rfft.

  • alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)

  • cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)

  • spline_order (int, default=4) – B-spline interpolation order (typically 4 for cubic B-splines).

  • batch_idx (jax.Array | None, default=None) – If provided, dispatches to batch kernels.

Returns:

  • green_function (jax.Array) – Volume-normalized Green’s function G(k). - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft)

  • structure_factor_sq (jax.Array) – Squared structure factor C²(k) for B-spline deconvolution. Shape (Nx, Ny, Nz_rfft), shared across batch.

Return type:

tuple[Array, Array]

Notes

  • G(k=0) is set to zero to avoid singularity

  • The volume normalization in G(k) eliminates later divisions

  • Structure factor is mesh-dependent only, so shared across batch

nvalchemiops.jax.interactions.electrostatics.pme_energy_corrections(raw_energies, charges, cell, alpha, batch_idx=None)[source]#

Apply self-energy and background corrections to PME energies.

Converts raw interpolated potential to energy and subtracts corrections:

E_i = q_i φ_i - E_self,i - E_background,i

Self-energy correction (removes Gaussian self-interaction):

E_self,i = (α/√π) q_i²

Background correction (for non-neutral systems):

E_background,i = (π/(2α²V)) q_i Q_total

Parameters:
  • raw_energies (jax.Array, shape (N,) or (N_total,)) – Raw potential values φ_i from mesh interpolation.

  • charges (jax.Array, shape (N,) or (N_total,)) – Atomic charges.

  • cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)

  • alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)

  • batch_idx (jax.Array | None, default=None) – System index for each atom. If provided, uses batch kernels.

Returns:

corrected_energies – Final per-atom reciprocal-space energy with corrections applied.

Return type:

jax.Array, shape (N,) or (N_total,)

Notes

  • For neutral systems, background correction is zero

  • Supports both float32 and float64 dtypes

nvalchemiops.jax.interactions.electrostatics.pme_energy_corrections_with_charge_grad(raw_energies, charges, cell, alpha, batch_idx=None)[source]#

Apply energy corrections and compute charge gradients.

Same as pme_energy_corrections but also returns dE/dq for each atom.

Parameters:
  • raw_energies (jax.Array, shape (N,) or (N_total,)) – Raw potential values φ_i from mesh interpolation.

  • charges (jax.Array, shape (N,) or (N_total,)) – Atomic charges.

  • cell (jax.Array) – Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3)

  • alpha (jax.Array) – Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,)

  • batch_idx (jax.Array | None, default=None) – System index for each atom. If provided, uses batch kernels.

Returns:

  • corrected_energies (jax.Array, shape (N,) or (N_total,)) – Final per-atom reciprocal-space energy with corrections applied.

  • charge_gradients (jax.Array, shape (N,) or (N_total,)) – Per-atom charge gradients dE/dq.

Return type:

tuple[Array, Array]

Notes

  • Useful for training models that predict partial charges

  • Supports both float32 and float64 dtypes

K-Vector Generation#

nvalchemiops.jax.interactions.electrostatics.generate_miller_indices(cell, k_cutoff)[source]#

Generate Miller index bounds for Ewald summation.

Parameters:
  • cell (jax.Array, shape (N, 3, 3)) – Unit cell matrices with lattice vectors as rows.

  • k_cutoff (float | jax.Array) – Maximum magnitude of k-vectors to include in reciprocal summation.

Returns:

Array of shape (3,) containing the maximum Miller indices (M_h, M_k, M_l) for each lattice direction.

Return type:

jax.Array

Notes

For batch mode, one shared set of Miller bounds is used for all systems. If k_cutoff is provided per system, the maximum cutoff across the batch is used to build those shared bounds.

nvalchemiops.jax.interactions.electrostatics.generate_k_vectors_ewald_summation(cell, k_cutoff, miller_bounds=None)[source]#

Generate reciprocal lattice vectors for Ewald summation (half-space).

Creates k-vectors within the specified cutoff for the reciprocal space summation in the Ewald method. Uses half-space optimization to reduce computational cost by approximately 2x.

Half-Space Optimization#

This function generates k-vectors in the positive half-space only, exploiting the symmetry S(-k) = S*(k) where S(k) is the structure factor. For each pair of k-vectors (k, -k), only one is included.

The half-space condition selects k-vectors where:
  • h > 0, OR

  • (h == 0 AND k > 0), OR

  • (h == 0 AND k == 0 AND l > 0)

The kernels in ewald_kernels.py compensate by doubling the Green’s function (using \(8\pi\) instead of \(4\pi\)), so energies, forces, and charge gradients are computed correctly.

Mathematical Background#

For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are:

\[ \begin{align}\begin{aligned}\mathbf{a}^* &= \frac{2\pi (\mathbf{b} \times \mathbf{c})}{V}\\\mathbf{b}^* &= \frac{2\pi (\mathbf{c} \times \mathbf{a})}{V}\\\mathbf{c}^* &= \frac{2\pi (\mathbf{a} \times \mathbf{b})}{V}\end{aligned}\end{align} \]

where \(V = \mathbf{a} \cdot (\mathbf{b} \times \mathbf{c})\) is the cell volume.

In matrix form: \(\text{reciprocal_matrix} = 2\pi \cdot (\text{cell}^T)^{-1}\)

Each k-vector is: \(\mathbf{k} = h \mathbf{a}^* + k \mathbf{b}^* + l \mathbf{c}^*\) where (h, k, l) are Miller indices (integers).

param cell:

Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch.

type cell:

jax.Array

param k_cutoff:

Maximum magnitude of k-vectors to include (\(|\mathbf{k}| \leq k_{\text{cutoff}}\)). Typical values: 8-12 \(\text{\AA}^{-1}\) for molecular systems. Higher values increase accuracy but also computational cost.

type k_cutoff:

float or jax.Array

param miller_bounds:

Precomputed maximum Miller indices (M_h, M_k, M_l) for each lattice direction. When provided, the function skips the internal computation of bounds from cell and k_cutoff, making it compatible with jax.jit (which requires static array shapes). Use generate_miller_indices() to compute these bounds eagerly before entering a JIT context. When None (default), bounds are computed automatically from cell and k_cutoff.

type miller_bounds:

tuple[int, int, int] | None, optional

returns:

Reciprocal lattice vectors within the cutoff. Shape (K, 3) for single system or (B, K, 3) for batch. Excludes k=0 and includes only half-space vectors.

rtype:

jax.Array

Examples

Single system with explicit k_cutoff:

>>> cell = jnp.eye(3, dtype=jnp.float64) * 10.0
>>> k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff=8.0)
>>> k_vectors.shape
(...)  # Number depends on cell size and cutoff

With automatic parameter estimation:

>>> from nvalchemiops.jax.interactions.electrostatics import estimate_ewald_parameters
>>> params = estimate_ewald_parameters(positions, cell)
>>> k_vectors = generate_k_vectors_ewald_summation(cell, params.reciprocal_space_cutoff)

JIT-compatible usage with precomputed bounds:

>>> from nvalchemiops.jax.interactions.electrostatics import generate_miller_indices
>>> cell = jnp.eye(3, dtype=jnp.float64)[None, ...] * 10.0
>>> bounds = generate_miller_indices(cell, k_cutoff=8.0)
>>> miller_bounds = (int(bounds[0]), int(bounds[1]), int(bounds[2]))
>>> # This can now be called inside @jax.jit
>>> k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff=8.0, miller_bounds=miller_bounds)

Notes

  • The k=0 vector is always excluded (causes division by zero in Green’s function).

  • For batch mode, the same set of Miller indices is used for all systems but transformed using each system’s reciprocal cell. If k_cutoff is given per system, the maximum cutoff across the batch determines the shared Miller bounds.

  • The number of k-vectors K scales as O(k_cutoff³ · V) where V is the cell volume.

  • When using inside jax.jit, you must provide miller_bounds as a concrete tuple[int, int, int]. The bounds determine array shapes (via jnp.arange), which must be statically known at trace time.

See also

ewald_reciprocal_space

Uses these k-vectors for reciprocal space energy.

estimate_ewald_parameters

Automatic parameter estimation including k_cutoff.

generate_miller_indices

Compute Miller bounds for JIT-compatible usage.

Parameters:
Return type:

Array

nvalchemiops.jax.interactions.electrostatics.generate_k_vectors_pme(cell, mesh_dimensions, reciprocal_cell=None)[source]#

Generate reciprocal lattice vectors for Particle Mesh Ewald (PME).

Creates k-vectors on a regular grid compatible with FFT-based reciprocal space calculations in PME. Uses rfft conventions (half-size in z-dimension) to exploit Hermitian symmetry of real-valued charge densities.

Notes

For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are:

\[\begin{split}\begin{aligned} \mathbf{a}^* &= \frac{2\pi (\mathbf{b} \times \mathbf{c})}{V} \\ \mathbf{b}^* &= \frac{2\pi (\mathbf{c} \times \mathbf{a})}{V} \\ \mathbf{c}^* &= \frac{2\pi (\mathbf{a} \times \mathbf{b})}{V} \end{aligned}\end{split}\]

where \(V = \mathbf{a} \cdot (\mathbf{b} \times \mathbf{c})\) is the cell volume.

In matrix form:

\[\text{reciprocal_matrix} = 2\pi \cdot (\text{cell}^T)^{-1}\]

Each k-vector is then:

\[\mathbf{k} = h \mathbf{a}^* + k \mathbf{b}^* + l \mathbf{c}^*\]

where (h, k, l) are Miller indices (integers).

Parameters:
  • cell (jax.Array) – Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch.

  • mesh_dimensions (tuple[int, int, int]) – PME mesh grid dimensions (nx, ny, nz). Should typically be chosen such that mesh spacing is \(\sim 1 \text{\AA}\) or finer. Power-of-2 dimensions are optimal for FFT performance.

  • reciprocal_cell (jax.Array, optional) – Precomputed reciprocal cell matrix (\(2\pi \cdot \text{cell}^{-1}\)). If provided, skips the inverse computation. Shape (3, 3) or (B, 3, 3).

Returns:

  • k_vectors (jax.Array, shape (nx, ny, nz//2+1, 3)) – Cartesian k-vectors at each grid point. Uses rfft convention where z-dimension is halved due to Hermitian symmetry.

  • k_squared_safe (jax.Array, shape (nx, ny, nz//2+1)) – Squared magnitude \(|\mathbf{k}|^2\) for each k-vector, with k=0 set to a small positive value (1e-12) to avoid division by zero.

Return type:

tuple[Array, Array]

Examples

Basic usage:

>>> cell = jnp.eye(3, dtype=jnp.float64) * 10.0
>>> mesh_dims = (32, 32, 32)
>>> k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dims)
>>> k_vectors.shape
(32, 32, 17, 3)

With precomputed reciprocal cell:

>>> reciprocal_cell = 2 * jnp.pi * jnp.linalg.inv(cell)
>>> k_vectors, k_squared = generate_k_vectors_pme(
...     cell, mesh_dims, reciprocal_cell=reciprocal_cell
... )

Notes

  • The z-dimension output size is nz//2+1 due to rfft symmetry.

  • Miller indices follow jnp.fft.fftfreq convention (0, 1, 2, …, -2, -1).

  • k_squared_safe has k=0 replaced with 1e-12 to prevent division by zero in Green’s function calculations.

See also

pme_reciprocal_space

Uses these k-vectors for PME reciprocal space energy.

pme_green_structure_factor

Computes Green’s function using k_squared.

Parameter Estimation#

Functions for automatic parameter estimation based on desired accuracy tolerance.

nvalchemiops.jax.interactions.electrostatics.estimate_ewald_parameters(positions, cell, batch_idx=None, accuracy=1e-6)[source]#

Estimate optimal Ewald summation parameters for a given accuracy.

Uses the Kolafa-Perram formula to balance real-space and reciprocal-space contributions for optimal efficiency at the target accuracy.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrix.

  • batch_idx (jax.Array, shape (N,), dtype=int32, optional) – System index for each atom. If None, single-system mode.

  • accuracy (float, default=1e-6) – Target accuracy (relative error tolerance).

Returns:

Dataclass containing alpha, real_space_cutoff, reciprocal_space_cutoff as jax.Array objects.

Return type:

EwaldParameters

nvalchemiops.jax.interactions.electrostatics.estimate_pme_parameters(positions, cell, batch_idx=None, accuracy=1e-6)[source]#

Estimate optimal PME parameters for a given accuracy.

Parameters:
  • positions (jax.Array, shape (N, 3)) – Atomic coordinates.

  • cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrix.

  • batch_idx (jax.Array, shape (N,), dtype=int32, optional) – System index for each atom.

  • accuracy (float, default=1e-6) – Target accuracy.

Returns:

Dataclass containing alpha, mesh dimensions, spacing, and cutoffs. Tensor fields are jax.Array objects.

Return type:

PMEParameters

nvalchemiops.jax.interactions.electrostatics.estimate_pme_mesh_dimensions(cell, alpha, accuracy=1e-6)[source]#

Estimate optimal PME mesh dimensions for a given accuracy.

Parameters:
  • cell (jax.Array, shape (3, 3) or (B, 3, 3)) – Unit cell matrix.

  • alpha (jax.Array, shape (B,)) – Ewald splitting parameter.

  • accuracy (float, default=1e-6) – Target accuracy.

Returns:

Maximum mesh dimensions (nx, ny, nz) across all systems in batch.

Return type:

tuple[int, int, int]

nvalchemiops.jax.interactions.electrostatics.mesh_spacing_to_dimensions(cell, mesh_spacing)[source]#

Convert mesh spacing to mesh dimensions.

Parameters:
Returns:

Mesh dimensions, rounded up to powers of 2.

Return type:

tuple[int, int, int]

class nvalchemiops.jax.interactions.electrostatics.EwaldParameters(alpha, real_space_cutoff, reciprocal_space_cutoff)[source]#

Container for Ewald summation parameters.

All values are arrays of shape (B,), for single system calculations, the shape is (1,).

Parameters:
alpha#

Ewald splitting parameter (inverse length units).

Type:

jax.Array, shape (B,)

real_space_cutoff#

Real-space cutoff distance.

Type:

jax.Array, shape (B,)

reciprocal_space_cutoff#

Reciprocal-space cutoff (\(|k|\) in inverse length units).

Type:

jax.Array, shape (B,)

class nvalchemiops.jax.interactions.electrostatics.PMEParameters(alpha, mesh_dimensions, mesh_spacing, real_space_cutoff)[source]#

Container for PME parameters.

Parameters:
alpha#

Ewald splitting parameter.

Type:

jax.Array, shape (B,)

mesh_dimensions#

Mesh dimensions (nx, ny, nz).

Type:

tuple[int, int, int], shape (3,)

mesh_spacing#

Actual mesh spacing in each direction.

Type:

jax.Array, shape (B, 3)

real_space_cutoff#

Real-space cutoff distance.

Type:

jax.Array, shape (B,)