Source code for nvalchemiops.jax.interactions.electrostatics.pme

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX Particle Mesh Ewald (PME) implementation.

This module provides JAX bindings for PME long-range electrostatics calculations.
PME achieves O(N log N) scaling through FFT-based reciprocal space computation
combined with real-space Ewald summation.

The implementation uses:
- JAX FFT operations (jnp.fft.rfftn/irfftn)
- B-spline interpolation from nvalchemiops.jax.spline
- Ewald real-space from nvalchemiops.jax.interactions.electrostatics.ewald
- Warp kernels for Green's function and energy corrections

Key Functions
-------------
particle_mesh_ewald : Complete PME calculation (real + reciprocal space)
pme_reciprocal_space : Reciprocal-space component only
pme_green_structure_factor : Green's function and structure factor
pme_energy_corrections : Self-energy and background corrections

See Also
--------
nvalchemiops.jax.interactions.electrostatics.ewald : Ewald real-space
nvalchemiops.jax.spline : B-spline interpolation
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel

from nvalchemiops.interactions.electrostatics.pme_kernels import (
    _batch_pme_energy_corrections_kernel_overload,
    _batch_pme_energy_corrections_with_charge_grad_kernel_overload,
    _batch_pme_green_structure_factor_kernel_overload,
    _pme_energy_corrections_kernel_overload,
    _pme_energy_corrections_with_charge_grad_kernel_overload,
    _pme_green_structure_factor_kernel_overload,
)
from nvalchemiops.jax.interactions.electrostatics.ewald import ewald_real_space
from nvalchemiops.jax.interactions.electrostatics.k_vectors import (
    generate_k_vectors_pme,
)
from nvalchemiops.jax.interactions.electrostatics.parameters import (
    estimate_pme_mesh_dimensions,
    estimate_pme_parameters,
    mesh_spacing_to_dimensions,
)
from nvalchemiops.jax.spline import (
    spline_gather,
    spline_gather_vec3,
    spline_spread,
)

__all__ = [
    "particle_mesh_ewald",
    "pme_reciprocal_space",
    "pme_green_structure_factor",
    "pme_energy_corrections",
    "pme_energy_corrections_with_charge_grad",
]


# ==============================================================================
# Helper Function for JAX Kernel Creation
# ==============================================================================


def _make_jax_kernels(
    wp_overload_dict: dict,
    num_outputs: int,
    in_out_argnames: list[str],
) -> dict:
    """Maps JAX data types to Warp kernel overloads.

    Parameters
    ----------
    wp_overload_dict : dict
        Warp kernel overload dictionary keyed by wp.float32/wp.float64.
    num_outputs : int
        Number of output arrays returned by the kernel.
    in_out_argnames : list of str
        Names of in-place output arguments.

    Returns
    -------
    dict
        Dictionary mapping jnp.float32/jnp.float64 to jax_kernel instances.
    """
    _JAX_TO_WP = {jnp.float32: wp.float32, jnp.float64: wp.float64}
    return {
        jax_dtype: jax_kernel(
            wp_overload_dict[wp_dtype],
            num_outputs=num_outputs,
            in_out_argnames=in_out_argnames,
            enable_backward=False,
        )
        for jax_dtype, wp_dtype in _JAX_TO_WP.items()
    }


def _normalize_dtype(dtype):
    """Normalize dtype for kernel dictionary lookup.

    Parameters
    ----------
    dtype : dtype-like
        Input dtype from a JAX array.

    Returns
    -------
    jnp.float32 or jnp.float64
        Normalized JAX dtype for kernel lookup.
    """
    if dtype == jnp.float32 or str(dtype) == "float32":
        return jnp.float32
    elif dtype == jnp.float64 or str(dtype) == "float64":
        return jnp.float64
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")


# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================

# Single-system kernels
_jax_pme_green_sf = _make_jax_kernels(
    _pme_green_structure_factor_kernel_overload,
    2,
    ["green_function", "structure_factor_sq"],
)

_jax_pme_energy_corrections = _make_jax_kernels(
    _pme_energy_corrections_kernel_overload,
    1,
    ["corrected_energies"],
)

_jax_pme_energy_corrections_charge_grad = _make_jax_kernels(
    _pme_energy_corrections_with_charge_grad_kernel_overload,
    2,
    ["corrected_energies", "charge_gradients"],
)

# Batch kernels
_jax_batch_pme_green_sf = _make_jax_kernels(
    _batch_pme_green_structure_factor_kernel_overload,
    2,
    ["green_function", "structure_factor_sq"],
)

_jax_batch_pme_energy_corrections = _make_jax_kernels(
    _batch_pme_energy_corrections_kernel_overload,
    1,
    ["corrected_energies"],
)

_jax_batch_pme_energy_corrections_charge_grad = _make_jax_kernels(
    _batch_pme_energy_corrections_with_charge_grad_kernel_overload,
    2,
    ["corrected_energies", "charge_gradients"],
)


# ==============================================================================
# Public API Functions
# ==============================================================================


[docs] def pme_green_structure_factor( k_squared: jax.Array, mesh_dimensions: tuple[int, int, int], alpha: jax.Array, cell: jax.Array, spline_order: int = 4, batch_idx: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Compute Green's function and B-spline structure factor correction. Computes the Coulomb Green's function with volume normalization and the B-spline aliasing correction factor for PME. Green's function (volume-normalized): G(k) = (2π/V) * exp(-k²/(4α²)) / k² Structure factor correction (for B-spline deconvolution): C²(k) = [sinc(m_x/N_x) · sinc(m_y/N_y) · sinc(m_z/N_z)]^(2p) where p is the spline order. Parameters ---------- k_squared : jax.Array |k|² values at each FFT grid point. - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft) mesh_dimensions : tuple[int, int, int] Full mesh dimensions (Nx, Ny, Nz) before rfft. alpha : jax.Array Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,) cell : jax.Array Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) spline_order : int, default=4 B-spline interpolation order (typically 4 for cubic B-splines). batch_idx : jax.Array | None, default=None If provided, dispatches to batch kernels. Returns ------- green_function : jax.Array Volume-normalized Green's function G(k). - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft) structure_factor_sq : jax.Array Squared structure factor C²(k) for B-spline deconvolution. Shape (Nx, Ny, Nz_rfft), shared across batch. Notes ----- - G(k=0) is set to zero to avoid singularity - The volume normalization in G(k) eliminates later divisions - Structure factor is mesh-dependent only, so shared across batch """ mesh_nx, mesh_ny, mesh_nz = mesh_dimensions input_dtype = _normalize_dtype(k_squared.dtype) # Ensure cell is correct shape if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] volume = jnp.abs(jnp.linalg.det(cell)).astype(input_dtype) # Generate Miller indices using JAX FFT frequency functions # Use d=1.0/n to get integer Miller indices miller_x = jnp.fft.fftfreq(mesh_nx, d=1.0 / mesh_nx).astype(input_dtype) miller_y = jnp.fft.fftfreq(mesh_ny, d=1.0 / mesh_ny).astype(input_dtype) miller_z = jnp.fft.rfftfreq(mesh_nz, d=1.0 / mesh_nz).astype(input_dtype) # Ensure alpha is 1D array if alpha.ndim == 0: alpha = alpha.reshape(1) alpha = alpha.astype(input_dtype) # Get kernel for input dtype if batch_idx is None: # Single system kernel = _jax_pme_green_sf[input_dtype] # Allocate outputs green_function = jnp.zeros( (mesh_nx, mesh_ny, mesh_nz // 2 + 1), dtype=input_dtype ) structure_factor_sq = jnp.zeros( (mesh_nx, mesh_ny, mesh_nz // 2 + 1), dtype=input_dtype ) # Launch kernel green_out, sf_out = kernel( k_squared.astype(input_dtype), miller_x, miller_y, miller_z, alpha, volume, int(mesh_nx), int(mesh_ny), int(mesh_nz), int(spline_order), green_function, structure_factor_sq, launch_dims=(mesh_nx, mesh_ny, mesh_nz // 2 + 1), ) return green_out, sf_out else: # Batch num_systems = cell.shape[0] kernel = _jax_batch_pme_green_sf[input_dtype] # Ensure k_squared has batch dimension for batch kernels k_sq = k_squared.astype(input_dtype) if k_sq.ndim == 3: k_sq = jnp.broadcast_to( k_sq[jnp.newaxis], (num_systems, mesh_nx, mesh_ny, mesh_nz // 2 + 1) ) # Allocate outputs green_function = jnp.zeros( (num_systems, mesh_nx, mesh_ny, mesh_nz // 2 + 1), dtype=input_dtype ) structure_factor_sq = jnp.zeros( (mesh_nx, mesh_ny, mesh_nz // 2 + 1), dtype=input_dtype ) # Launch kernel green_out, sf_out = kernel( k_sq, miller_x, miller_y, miller_z, alpha, volume, int(mesh_nx), int(mesh_ny), int(mesh_nz), int(spline_order), green_function, structure_factor_sq, launch_dims=(num_systems, mesh_nx, mesh_ny, mesh_nz // 2 + 1), ) return green_out, sf_out
[docs] def pme_energy_corrections( raw_energies: jax.Array, charges: jax.Array, cell: jax.Array, alpha: jax.Array, batch_idx: jax.Array | None = None, ) -> jax.Array: """Apply self-energy and background corrections to PME energies. Converts raw interpolated potential to energy and subtracts corrections: E_i = q_i φ_i - E_self,i - E_background,i Self-energy correction (removes Gaussian self-interaction): E_self,i = (α/√π) q_i² Background correction (for non-neutral systems): E_background,i = (π/(2α²V)) q_i Q_total Parameters ---------- raw_energies : jax.Array, shape (N,) or (N_total,) Raw potential values φ_i from mesh interpolation. charges : jax.Array, shape (N,) or (N_total,) Atomic charges. cell : jax.Array Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) alpha : jax.Array Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,) batch_idx : jax.Array | None, default=None System index for each atom. If provided, uses batch kernels. Returns ------- corrected_energies : jax.Array, shape (N,) or (N_total,) Final per-atom reciprocal-space energy with corrections applied. Notes ----- - For neutral systems, background correction is zero - Supports both float32 and float64 dtypes """ input_dtype = _normalize_dtype(raw_energies.dtype) num_atoms = raw_energies.shape[0] # Ensure alpha is 1D array if alpha.ndim == 0: alpha = alpha.reshape(1) alpha = alpha.astype(input_dtype) if batch_idx is None: # Single system kernel = _jax_pme_energy_corrections[input_dtype] # Ensure cell is correct shape if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] volume = jnp.abs(jnp.linalg.det(cell)).astype(input_dtype) total_charge = charges.sum().reshape(1).astype(input_dtype) # Allocate output corrected_energies = jnp.zeros(num_atoms, dtype=input_dtype) # Launch kernel (corrected_out,) = kernel( raw_energies.astype(input_dtype), charges.astype(input_dtype), volume, alpha, total_charge, corrected_energies, launch_dims=(num_atoms,), ) return corrected_out else: # Batch kernel = _jax_batch_pme_energy_corrections[input_dtype] num_systems = cell.shape[0] if cell.ndim == 3 else 1 if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] volumes = jnp.abs(jnp.linalg.det(cell)).astype(input_dtype) # Compute total charge per system total_charges = jnp.zeros(num_systems, dtype=input_dtype) total_charges = total_charges.at[batch_idx].add(charges.astype(input_dtype)) # Allocate output corrected_energies = jnp.zeros(num_atoms, dtype=input_dtype) # Launch kernel (corrected_out,) = kernel( raw_energies.astype(input_dtype), charges.astype(input_dtype), batch_idx.astype(jnp.int32), volumes, alpha, total_charges, corrected_energies, launch_dims=(num_atoms,), ) return corrected_out
[docs] def pme_energy_corrections_with_charge_grad( raw_energies: jax.Array, charges: jax.Array, cell: jax.Array, alpha: jax.Array, batch_idx: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Apply energy corrections and compute charge gradients. Same as pme_energy_corrections but also returns dE/dq for each atom. Parameters ---------- raw_energies : jax.Array, shape (N,) or (N_total,) Raw potential values φ_i from mesh interpolation. charges : jax.Array, shape (N,) or (N_total,) Atomic charges. cell : jax.Array Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) alpha : jax.Array Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,) batch_idx : jax.Array | None, default=None System index for each atom. If provided, uses batch kernels. Returns ------- corrected_energies : jax.Array, shape (N,) or (N_total,) Final per-atom reciprocal-space energy with corrections applied. charge_gradients : jax.Array, shape (N,) or (N_total,) Per-atom charge gradients dE/dq. Notes ----- - Useful for training models that predict partial charges - Supports both float32 and float64 dtypes """ input_dtype = _normalize_dtype(raw_energies.dtype) num_atoms = raw_energies.shape[0] # Ensure alpha is 1D array if alpha.ndim == 0: alpha = alpha.reshape(1) alpha = alpha.astype(input_dtype) if batch_idx is None: # Single system kernel = _jax_pme_energy_corrections_charge_grad[input_dtype] # Ensure cell is correct shape if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] volume = jnp.abs(jnp.linalg.det(cell)).astype(input_dtype) total_charge = charges.sum().reshape(1).astype(input_dtype) # Allocate outputs corrected_energies = jnp.zeros(num_atoms, dtype=input_dtype) charge_gradients = jnp.zeros(num_atoms, dtype=input_dtype) # Launch kernel corrected_out, charge_grad_out = kernel( raw_energies.astype(input_dtype), charges.astype(input_dtype), volume, alpha, total_charge, corrected_energies, charge_gradients, launch_dims=(num_atoms,), ) return corrected_out, charge_grad_out else: # Batch kernel = _jax_batch_pme_energy_corrections_charge_grad[input_dtype] num_systems = cell.shape[0] if cell.ndim == 3 else 1 if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] volumes = jnp.abs(jnp.linalg.det(cell)).astype(input_dtype) # Compute total charge per system total_charges = jnp.zeros(num_systems, dtype=input_dtype) total_charges = total_charges.at[batch_idx].add(charges.astype(input_dtype)) # Allocate outputs corrected_energies = jnp.zeros(num_atoms, dtype=input_dtype) charge_gradients = jnp.zeros(num_atoms, dtype=input_dtype) # Launch kernel corrected_out, charge_grad_out = kernel( raw_energies.astype(input_dtype), charges.astype(input_dtype), batch_idx.astype(jnp.int32), volumes, alpha, total_charges, corrected_energies, charge_gradients, launch_dims=(num_atoms,), ) return corrected_out, charge_grad_out
def _compute_pme_reciprocal_virial( mesh_fft_raw: jax.Array, convolved_mesh: jax.Array, k_vectors: jax.Array, k_squared: jax.Array, alpha: jax.Array, mesh_dimensions: tuple[int, int, int], is_batch: bool, ) -> jax.Array: """Compute PME reciprocal-space virial tensor in k-space. Uses the exact spectral pair from the pipeline (mesh_fft_raw before deconvolution, and convolved_mesh after Green's function multiplication) to compute the per-k energy density directly via Parseval's theorem. The virial per k-point is W_ab(k) = E_k * sigma_ab(k) where: - E_k = prefactor * weight(k) * Re(mesh_fft_raw(k) * convolved_mesh(k)*) - sigma_ab(k) = delta_ab - 2*k_a*k_b/k^2 * (1 + k^2/(4*alpha^2)) (sign reflects W = -dE/dε convention) Parameters ---------- mesh_fft_raw : jax.Array Raw rfftn output before B-spline deconvolution. Shape (nx, ny, nz//2+1) or (B, nx, ny, nz//2+1), complex. convolved_mesh : jax.Array Deconvolved mesh FFT multiplied by Green's function: (mesh_fft/B^2)*G. Shape matching mesh_fft_raw. k_vectors : jax.Array k-vectors on the mesh. Shape (..., nx, ny, nz//2+1, 3). k_squared : jax.Array |k|^2. Shape (..., nx, ny, nz//2+1). alpha : jax.Array Ewald splitting parameter. mesh_dimensions : tuple (nx, ny, nz). is_batch : bool Whether this is a batched calculation. Returns ------- virial : jax.Array, shape (B, 3, 3) or (1, 3, 3) Per-system virial tensor. """ mesh_nx, mesh_ny, mesh_nz = mesh_dimensions # Determine accumulation dtype from k_squared (float32 or float64) acc_dtype = _normalize_dtype(k_squared.dtype) complex_dtype = jnp.complex64 if acc_dtype == jnp.float32 else jnp.complex128 # Per-k energy density from exact pipeline spectral pair. # Re(mesh_fft_raw * convolved_mesh*) = |mesh_fft_raw|^2 * G / B^2 fft_raw_cast = mesh_fft_raw.astype(complex_dtype) conv_cast = convolved_mesh.astype(complex_dtype) energy_density = (fft_raw_cast * jnp.conj(conv_cast)).real # Weight for rfft symmetry: 2 for interior k_z, 1 for boundary weight = jnp.full_like(energy_density, 2.0) weight = weight.at[..., 0].set(1.0) # k_z = 0 if mesh_nz % 2 == 0: weight = weight.at[..., -1].set(1.0) # k_z = nz//2 (Nyquist) # Weighted energy density weighted_energy = weight * energy_density # Virial W = -dE/dε, so sigma_ab = delta_ab - 2*k_a*k_b/k^2 * (1 + k^2/(4*alpha^2)) k_sq_acc = k_squared.astype(acc_dtype) alpha_acc = alpha.astype(acc_dtype) # Handle alpha broadcasting: alpha may be (B,) for batch if is_batch and alpha_acc.ndim == 1: alpha_view = alpha_acc.reshape(-1, 1, 1, 1) else: alpha_view = alpha_acc.reshape(-1) if alpha_acc.ndim == 0 else alpha_acc exp_factor = 0.25 / (alpha_view**2) # Avoid division by zero at k=0 safe_k_sq = jnp.maximum(k_sq_acc, 1e-30) k_factor = 2.0 * (1.0 + k_sq_acc * exp_factor) / safe_k_sq # Zero out k=0 contribution (no virial from k=0) k_mask = k_sq_acc > 1e-10 # Vectorized virial computation using einsum # virial_ab = sum_k weighted_energy * (delta_ab - k_factor * k_a * k_b) * k_mask # = delta_ab * sum_k (weighted_energy * k_mask) - sum_k (weighted_energy * k_mask * k_factor) * k_a * k_b k_vecs_acc = k_vectors.astype(acc_dtype) # (..., nx, ny, nz//2+1, 3) masked_energy = weighted_energy * k_mask # (..., nx, ny, nz//2+1) masked_energy_kf = masked_energy * k_factor # (..., nx, ny, nz//2+1) # Sum dimensions depend on batch vs single if is_batch: sum_dims = (1, 2, 3) # sum over (nx, ny, nz//2+1) else: sum_dims = (0, 1, 2) # sum over (nx, ny, nz//2+1) # Trace term: delta_ab * sum_k masked_energy trace_term = masked_energy.sum(axis=sum_dims) # scalar or (B,) # kk term: sum_k masked_energy_kf * k_a * k_b # k_vecs_acc has shape (..., nx, ny, nz//2+1, 3) # masked_energy_kf has shape (..., nx, ny, nz//2+1) # Use einsum for vectorized outer product + reduction if is_batch: # k_vecs: (B, nx, ny, nz_half, 3), masked_energy_kf: (B, nx, ny, nz_half) kk_term = jnp.einsum( "b...i,b...j,b...->bij", k_vecs_acc, k_vecs_acc, masked_energy_kf ) # (B, 3, 3) eye = jnp.eye(3, dtype=acc_dtype) virial = eye * trace_term[:, jnp.newaxis, jnp.newaxis] - kk_term # (B, 3, 3) else: kk_term = jnp.einsum( "...i,...j,...->ij", k_vecs_acc, k_vecs_acc, masked_energy_kf ) # (3, 3) eye = jnp.eye(3, dtype=acc_dtype) virial = (eye * trace_term - kk_term)[jnp.newaxis, :, :] # (1, 3, 3) return virial.astype(acc_dtype)
[docs] def pme_reciprocal_space( positions: jax.Array, charges: jax.Array, cell: jax.Array, alpha: jax.Array, mesh_dimensions: tuple[int, int, int] | None = None, mesh_spacing: float | None = None, spline_order: int = 4, batch_idx: jax.Array | None = None, k_vectors: jax.Array | None = None, k_squared: jax.Array | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> ( jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array, jax.Array] ): """Compute PME reciprocal-space contribution. Implements the FFT-based long-range component of PME using B-spline interpolation and convolution with the Green's function. Pipeline: 1. Spread charges to mesh (spline_spread) 2. FFT → frequency space 3. Compute Green's function and structure factor 4. Convolve: mesh_fft * G(k) / C²(k) 5. IFFT → potential mesh 6. Gather potential at atoms (spline_gather) 7. Apply self-energy and background corrections 8. (Optional) Compute forces via Fourier gradient Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic partial charges. cell : jax.Array, shape (3, 3) or (B, 3, 3) Unit cell matrices with lattice vectors as rows. alpha : jax.Array Ewald splitting parameter. - Single-system: shape (1,) or scalar - Batch: shape (B,) mesh_dimensions : tuple[int, int, int], optional FFT mesh dimensions (nx, ny, nz). mesh_spacing : float, optional Target mesh spacing. Used to compute mesh_dimensions if not provided. spline_order : int, default=4 B-spline interpolation order (4 = cubic). batch_idx : jax.Array | None, default=None System index for each atom. k_vectors : jax.Array, optional Precomputed k-vectors from generate_k_vectors_pme. k_squared : jax.Array, optional Precomputed k² values from generate_k_vectors_pme. compute_forces : bool, default=False If True, compute forces via Fourier gradient. compute_charge_gradients : bool, default=False If True, compute charge gradients dE/dq. compute_virial : bool, default=False If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. Returns ------- energies : jax.Array, shape (N,) Per-atom reciprocal-space energies. forces : jax.Array, shape (N, 3), optional Per-atom forces (only if compute_forces=True). charge_gradients : jax.Array, shape (N,), optional Per-atom charge gradients (only if compute_charge_gradients=True). virial : jax.Array, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (only if compute_virial=True). Always last in the return tuple. Notes ----- - Output dtype for energy/forces matches the input positions dtype - FFT/convolution and spline operations all respect the input dtype - Automatically determines mesh_dimensions if not provided - Virial is computed in k-space and uses the same dtype as k_squared """ num_atoms = positions.shape[0] input_dtype = _normalize_dtype(positions.dtype) is_batch = batch_idx is not None fft_dims = (1, 2, 3) if is_batch else (0, 1, 2) # Ensure cell is correct shape for num_systems calculation if cell.ndim == 2: num_systems = 1 else: num_systems = cell.shape[0] # Handle empty systems if num_atoms == 0: energies = jnp.zeros(num_atoms, dtype=input_dtype) forces = ( jnp.zeros((num_atoms, 3), dtype=input_dtype) if compute_forces else None ) charge_grads = ( jnp.zeros(num_atoms, dtype=input_dtype) if compute_charge_gradients else None ) virial = ( jnp.zeros((num_systems, 3, 3), dtype=input_dtype) if compute_virial else None ) # Build return tuple based on flags result = [energies] if compute_forces: result.append(forces) if compute_charge_gradients: result.append(charge_grads) if compute_virial: result.append(virial) if len(result) == 1: return result[0] return tuple(result) # Determine mesh dimensions if mesh_dimensions is None: if mesh_spacing is not None: mesh_dimensions = mesh_spacing_to_dimensions(cell, mesh_spacing) else: # Default estimation mesh_dimensions = estimate_pme_mesh_dimensions(cell, alpha, accuracy=1e-6) mesh_nx, mesh_ny, mesh_nz = mesh_dimensions # Step 1: Spread charges to mesh mesh_grid = spline_spread( positions, charges, cell, mesh_dims=mesh_dimensions, spline_order=spline_order, batch_idx=batch_idx, ) # Step 2: FFT of charge mesh mesh_fft = jnp.fft.rfftn(mesh_grid, axes=fft_dims, norm="backward") # Step 3: Generate k-space grid and compute Green's function + structure factor if k_vectors is None or k_squared is None: k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dimensions) green_function, structure_factor_sq = pme_green_structure_factor( k_squared, mesh_dimensions, alpha, cell, spline_order, batch_idx, ) # Save reference to raw FFT before deconvolution (needed for virial). # No copy needed: the reassignment below creates a new array. mesh_fft_raw = mesh_fft if compute_virial else None # Step 4: Apply B-spline deconvolution and convolve with Green's function # Upcast to the complex equivalent of input_dtype to preserve imaginary part. # spline_spread now returns the same dtype as input positions. # rfftn then produces complex64 (float32 input) or complex128 (float64 input). # (casting complex to real silently drops the imaginary component). complex_dtype = jnp.complex64 if input_dtype == jnp.float32 else jnp.complex128 mesh_fft = mesh_fft.astype(complex_dtype) / structure_factor_sq convolved_mesh = mesh_fft * green_function # Step 5: Compute virial before forces to allow early release of mesh_fft_raw # (virial needs mesh_fft_raw; forces only need convolved_mesh) virial = None if compute_virial: virial = _compute_pme_reciprocal_virial( mesh_fft_raw=mesh_fft_raw, convolved_mesh=convolved_mesh, k_vectors=k_vectors, k_squared=k_squared, alpha=alpha, mesh_dimensions=mesh_dimensions, is_batch=is_batch, ) del mesh_fft_raw # Free before force field meshes are allocated # Step 6: Inverse FFT to get potential mesh potential_mesh = jnp.fft.irfftn( convolved_mesh, s=mesh_dimensions, axes=fft_dims, norm="forward" ) # Step 6: Interpolate potential to atomic positions (dtype matches positions) raw_energies = spline_gather( positions, potential_mesh, cell, spline_order=spline_order, batch_idx=batch_idx, ) # Step 7: Apply corrections if compute_charge_gradients: energies, charge_grads = pme_energy_corrections_with_charge_grad( raw_energies, charges, cell, alpha, batch_idx ) else: energies = pme_energy_corrections(raw_energies, charges, cell, alpha, batch_idx) charge_grads = None # Step 8: Compute forces if needed forces = None if compute_forces: # Compute electric field by taking gradient in Fourier space Ex_fft = -1j * k_vectors[..., 0] * convolved_mesh Ey_fft = -1j * k_vectors[..., 1] * convolved_mesh Ez_fft = -1j * k_vectors[..., 2] * convolved_mesh Ex = jnp.fft.irfftn(Ex_fft, s=mesh_dimensions, axes=fft_dims, norm="forward") Ey = jnp.fft.irfftn(Ey_fft, s=mesh_dimensions, axes=fft_dims, norm="forward") Ez = jnp.fft.irfftn(Ez_fft, s=mesh_dimensions, axes=fft_dims, norm="forward") electric_field_mesh = jnp.stack([Ex, Ey, Ez], axis=-1) # Interpolate electric field to atomic positions (dtype matches positions) interpolated_field = spline_gather_vec3( positions, charges, electric_field_mesh, cell, spline_order=spline_order, batch_idx=batch_idx, ) # Compute forces: F = 2 * q * E forces = 2.0 * interpolated_field # Build return tuple based on flags # Order: energies, [forces], [charge_grads], [virial] (virial always last) if compute_forces and compute_charge_gradients and compute_virial: return energies, forces, charge_grads, virial elif compute_forces and compute_charge_gradients: return energies, forces, charge_grads elif compute_forces and compute_virial: return energies, forces, virial elif compute_charge_gradients and compute_virial: return energies, charge_grads, virial elif compute_forces: return energies, forces elif compute_charge_gradients: return energies, charge_grads elif compute_virial: return energies, virial else: return energies
[docs] def particle_mesh_ewald( positions: jax.Array, charges: jax.Array, cell: jax.Array, alpha: float | jax.Array | None = None, mesh_spacing: float | None = None, mesh_dimensions: tuple[int, int, int] | None = None, spline_order: int = 4, batch_idx: jax.Array | None = None, k_vectors: jax.Array | None = None, k_squared: jax.Array | None = None, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, mask_value: int | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, accuracy: float = 1e-6, ) -> ( jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array] | tuple[jax.Array, jax.Array, jax.Array, jax.Array] ): """Complete Particle Mesh Ewald (PME) calculation for long-range electrostatics. Computes total Coulomb energy using the PME method, which achieves O(N log N) scaling through FFT-based reciprocal space calculations. Combines: 1. Real-space contribution (short-range, erfc-damped) 2. Reciprocal-space contribution (long-range, FFT + B-spline interpolation) 3. Self-energy and background corrections Total Energy Formula: E_total = E_real + E_reciprocal - E_self - E_background Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic partial charges in elementary charge units. cell : jax.Array, shape (3, 3) or (B, 3, 3) Unit cell matrices with lattice vectors as rows. Shape (3, 3) is automatically promoted to (1, 3, 3) for single-system mode. alpha : float, jax.Array, or None, default=None Ewald splitting parameter controlling real/reciprocal space balance. - float: Same α for all systems - Array shape (B,): Per-system α values - None: Automatically estimated using Kolafa-Perram formula mesh_spacing : float, optional Target mesh spacing. Mesh dimensions computed as ceil(cell_length / mesh_spacing). mesh_dimensions : tuple[int, int, int], optional Explicit FFT mesh dimensions (nx, ny, nz). Power-of-2 values recommended. spline_order : int, default=4 B-spline interpolation order (4 = cubic B-splines, recommended). batch_idx : jax.Array, shape (N,), dtype=int32, optional System index for each atom (0 to B-1). Determines execution mode: - None: Single-system optimized kernels - Provided: Batched kernels for multiple independent systems k_vectors : jax.Array, optional Precomputed k-vectors from generate_k_vectors_pme. Providing this along with k_squared skips k-vector generation. k_squared : jax.Array, optional Precomputed k² values from generate_k_vectors_pme. neighbor_list : jax.Array, optional CSR-format neighbor list indices. See ewald_real_space. neighbor_ptr : jax.Array, optional CSR-format neighbor list pointers. See ewald_real_space. neighbor_shifts : jax.Array, optional Periodic image shifts for neighbor list. See ewald_real_space. neighbor_matrix : jax.Array, optional Dense neighbor matrix. Alternative to CSR format. neighbor_matrix_shifts : jax.Array, optional Shifts for dense neighbor matrix. mask_value : int, optional Mask value for invalid neighbors in dense format. compute_forces : bool, default=False If True, compute per-atom forces. compute_charge_gradients : bool, default=False If True, compute per-atom charge gradients dE/dq. compute_virial : bool, default=False If True, compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. accuracy : float, default=1e-6 Target accuracy for automatic parameter estimation. Returns ------- energies : jax.Array, shape (N,) Per-atom total electrostatic energies. forces : jax.Array, shape (N, 3), optional Per-atom forces (only if compute_forces=True). charge_gradients : jax.Array, shape (N,), optional Per-atom charge gradients (only if compute_charge_gradients=True). virial : jax.Array, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (only if compute_virial=True). Always last in the return tuple. Notes ----- Automatic Parameter Estimation (when alpha is None): Uses Kolafa-Perram formula for optimal α and mesh dimensions based on requested accuracy. Examples -------- Basic usage: >>> energies = particle_mesh_ewald( ... positions, charges, cell, alpha=0.3, ... mesh_dimensions=(32, 32, 32), ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... ) With forces and automatic parameters: >>> energies, forces = particle_mesh_ewald( ... positions, charges, cell, ... mesh_spacing=1.0, accuracy=1e-5, ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... compute_forces=True, ... ) Batched systems: >>> energies = particle_mesh_ewald( ... positions, charges, cell, ... batch_idx=batch_idx, ... neighbor_list=nl, neighbor_ptr=ptr, neighbor_shifts=shifts, ... ) See Also -------- pme_reciprocal_space : Reciprocal-space component only ewald_real_space : Real-space component estimate_pme_parameters : Automatic parameter estimation """ num_atoms = positions.shape[0] # Prepare cell if cell.ndim == 2: cell = cell[jnp.newaxis, :, :] num_systems = cell.shape[0] # Estimate parameters if not provided if alpha is None: params = estimate_pme_parameters(positions, cell, batch_idx, accuracy) alpha = params.alpha if mesh_dimensions is None and mesh_spacing is None: # Convert to explicit tuple[int, int, int] md = params.mesh_dimensions mesh_dimensions = (int(md[0]), int(md[1]), int(md[2])) # Prepare alpha if isinstance(alpha, (int, float)): alpha = jnp.array([alpha] * num_systems, dtype=positions.dtype) elif alpha.ndim == 0: alpha = alpha.reshape(1) if mask_value is None: mask_value = num_atoms # Determine mesh dimensions if mesh_dimensions is None: if mesh_spacing is not None: mesh_dimensions = mesh_spacing_to_dimensions(cell, mesh_spacing) else: mesh_dimensions = estimate_pme_mesh_dimensions(cell, alpha, accuracy) # Compute real-space contribution rs = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=alpha, neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, neighbor_shifts=neighbor_shifts, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, mask_value=mask_value, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, ) # Compute reciprocal-space contribution rec = pme_reciprocal_space( positions=positions, charges=charges, cell=cell, alpha=alpha, mesh_dimensions=mesh_dimensions, spline_order=spline_order, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, k_vectors=k_vectors, k_squared=k_squared, ) # Normalize return tuples for easy combination # Both rs and rec return: energies, [forces], [charge_grads], [virial] # where virial is always last if present rs_tuple = rs if isinstance(rs, tuple) else (rs,) rec_tuple = rec if isinstance(rec, tuple) else (rec,) # The number of outputs should match between rs and rec # Combine element-wise results = tuple(r + s for r, s in zip(rs_tuple, rec_tuple)) if len(results) == 1: return results[0] return results