# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unified PME Kernels
===================
This module provides GPU-accelerated Warp kernels for Particle Mesh Ewald (PME)
calculations, specifically for Green's function and energy corrections.
Charge assignment and force interpolation are handled by the spline module.
MATHEMATICAL FORMULATION
========================
PME splits the Coulomb energy into components:
.. math::
E_{\\text{total}} = E_{\\text{real}} + E_{\\text{reciprocal}} - E_{\\text{self}} - E_{\\text{background}}
This module provides kernels for:
1. Green's Function and Structure Factor Correction:
.. math::
G(k) = \\frac{2\\pi}{V} \\frac{\\exp(-k^2/(4\\alpha^2))}{k^2}
The B-spline charge assignment introduces aliasing, corrected by:
.. math::
C(k) = \\left[\\text{sinc}(k_x/N_x) \\cdot \\text{sinc}(k_y/N_y) \\cdot \\text{sinc}(k_z/N_z)\\right]^{-2p}
where p is the spline order.
2. Energy Corrections:
- Self-energy: :math:`E_{\\text{self}} = \\frac{\\alpha}{\\sqrt{\\pi}} \\sum_i q_i^2`
- Background (for non-neutral systems): :math:`E_{\\text{background}} = \\frac{\\pi}{2\\alpha^2 V} \\sum_i q_i Q_{\\text{total}}`
DTYPE FLEXIBILITY
=================
All kernels support both float32 and float64 inputs via wp.Any type annotations
and explicit overloads. Use the overload dictionaries (e.g.,
_pme_green_structure_factor_kernel_overload) to select the appropriate kernel
based on input dtype.
KERNEL ORGANIZATION
===================
Green's Function Kernels:
_pme_green_structure_factor_kernel: Single-system G(k) and C(k)
_batch_pme_green_structure_factor_kernel: Batched version
Energy Correction Kernels:
_pme_energy_corrections_kernel: Single-system self + background correction
_batch_pme_energy_corrections_kernel: Batched version
.. warning
In contrast to the other electrostatic kernels that offer end-to-end
``warp`` launchers, PME requires FFT for the convolution step that is
currently not available in ``warp``. As a result, bindings must call
FFT within their own framework in between kernel launches. The sequence
of calls looks like the following:
1. Spread charges to mesh: ``spline_spread()``
2. Forward FFT: ``framework.fft.rfftn(mesh)``
3. Compute Green's function: ``pme_green_structure_factor()``
4. Convolution: ``mesh_fft * green_function / structure_factor_sq``
5. Inverse FFT: ``framework.fft.irfftn(...)``
6. Gather potential: ``spline_gather()``
7. Apply corrections: ``pme_energy_corrections()``
REFERENCES
==========
- Essmann et al. (1995). J. Chem. Phys. 103, 8577 (SPME paper)
- Darden et al. (1993). J. Chem. Phys. 98, 10089 (Original PME)
- torchpme: https://github.com/lab-cosmo/torch-pme (Reference implementation)
"""
import math
from typing import Any
import warp as wp
# Mathematical constants
PI = math.pi
TWOPI = 2.0 * PI
FOURPI = 4.0 * PI
###########################################################################################
########################### Helper Functions ##############################################
###########################################################################################
@wp.func
def compute_sinc(x: Any) -> Any:
"""Compute normalized sinc function: :math:`\\sin(\\pi x)/(\\pi x)`.
Uses Taylor expansion near zero for numerical stability.
"""
abs_x = wp.abs(x)
one = type(x)(1.0)
threshold = type(x)(1e-6)
if abs_x < threshold:
return one
pi_x = type(x)(PI) * x
return wp.sin(pi_x) / pi_x
@wp.func
def wp_exp_kernel(k_sq: Any, prefactor: Any) -> Any:
"""Compute exp(-prefactor * k_sq) / k_sq."""
return wp.exp(-prefactor * k_sq) / k_sq
###########################################################################################
########################### Green Function with Structure Factor ##########################
###########################################################################################
@wp.kernel
def _pme_green_structure_factor_kernel(
k_squared: wp.array3d(dtype=Any), # (Nx, Ny, Nz_rfft)
miller_x: wp.array(dtype=Any), # (Nx,)
miller_y: wp.array(dtype=Any), # (Ny,)
miller_z: wp.array(dtype=Any), # (Nz_rfft,)
alpha: wp.array(dtype=Any), # (1,)
volume: wp.array(dtype=Any), # (1,)
mesh_nx: wp.int32,
mesh_ny: wp.int32,
mesh_nz: wp.int32,
spline_order: wp.int32,
green_function: wp.array3d(dtype=Any), # (Nx, Ny, Nz_rfft)
structure_factor_sq: wp.array3d(dtype=Any), # (Nx, Ny, Nz_rfft)
):
"""Compute PME Green's function and B-spline structure factor correction.
Computes two arrays needed for PME reciprocal space:
1. Green's function: G(k) = (2π/V) * exp(-k²/(4α²)) / k²
2. Structure factor squared: :math:`|B(k)|^2` for B-spline dealiasing
The structure factor correction accounts for aliasing from B-spline
charge spreading: C(k) = [sinc(h/N_x) * sinc(k/N_y) * sinc(l/N_z)]^(2p)
Launch Grid
-----------
dim = [Nx, Ny, Nz_rfft]
Each thread processes one grid point in the FFT mesh (using rfft symmetry).
Parameters
----------
k_squared : wp.array3d, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
Squared magnitude of k-vectors at each grid point.
miller_x : wp.array, shape (Nx,), dtype=wp.float32 or wp.float64
Miller indices in x direction (from fftfreq).
miller_y : wp.array, shape (Ny,), dtype=wp.float32 or wp.float64
Miller indices in y direction (from fftfreq).
miller_z : wp.array, shape (Nz_rfft,), dtype=wp.float32 or wp.float64
Miller indices in z direction (from rfftfreq).
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
mesh_nx, mesh_ny, mesh_nz : wp.int32
Full mesh dimensions (Nz is the full size, not rfft size).
spline_order : wp.int32
B-spline order (1-4). Order 4 (cubic) recommended.
green_function : wp.array3d, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: Green's function G(k) at each grid point.
structure_factor_sq : wp.array3d, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: :math:`|B(k)|^2` structure factor squared at each grid point.
Notes
-----
- k=0 (grid point [0,0,0]) is explicitly set to zero (tin-foil boundary conditions).
- Near-zero k² values are set to zero to avoid division by zero.
- Structure factor is clamped to avoid division by zero in dealiasing.
- Uses rfft symmetry: only Nz_rfft = Nz//2 + 1 points in z.
"""
i, j, k = wp.tid()
k_sq = k_squared[i, j, k]
alpha_ = alpha[0]
volume_ = volume[0]
mi_x = miller_x[i]
mi_y = miller_y[j]
mi_z = miller_z[k]
# Get dtype-specific constants
zero = type(k_sq)(0.0)
one = type(k_sq)(1.0)
four = type(k_sq)(4.0)
threshold = type(k_sq)(1e-10)
clamp_threshold = type(k_sq)(1e-10)
twopi = type(k_sq)(TWOPI)
# Green's function: G(k) = 2*pi * exp(-k^2/(4*alpha^2)) / (k^2 * V)
if k_sq < threshold:
green_function[i, j, k] = zero
else:
exp_factor = wp_exp_kernel(k_sq, one / (four * alpha_ * alpha_))
green_function[i, j, k] = twopi * exp_factor / volume_
if i == 0 and j == 0 and k == 0:
green_function[i, j, k] = zero
# Structure factor: sinc(mi_x/Nx) * sinc(mi_y/Ny) * sinc(mi_z/Nz)
sinc_x = compute_sinc(mi_x / type(mi_x)(mesh_nx))
sinc_y = compute_sinc(mi_y / type(mi_y)(mesh_ny))
sinc_z = compute_sinc(mi_z / type(mi_z)(mesh_nz))
sinc_product = sinc_x * sinc_y * sinc_z
# Raise to spline_order power
sf = sinc_product
for _ in range(1, 4): # Max order 4
if _ < spline_order:
sf = sf * sinc_product
# Clamp to avoid division by zero
if sf < clamp_threshold:
sf = clamp_threshold
structure_factor_sq[i, j, k] = sf * sf
@wp.kernel
def _batch_pme_green_structure_factor_kernel(
k_squared: wp.array4d(dtype=Any), # (B, Nx, Ny, Nz_rfft)
miller_x: wp.array(dtype=Any), # (Nx,)
miller_y: wp.array(dtype=Any), # (Ny,)
miller_z: wp.array(dtype=Any), # (Nz_rfft,)
alpha: wp.array(dtype=Any), # (B,)
volumes: wp.array(dtype=Any), # (B,)
mesh_nx: wp.int32,
mesh_ny: wp.int32,
mesh_nz: wp.int32,
spline_order: wp.int32,
green_function: wp.array4d(dtype=Any), # (B, Nx, Ny, Nz_rfft)
structure_factor_sq: wp.array3d(dtype=Any), # (Nx, Ny, Nz_rfft)
):
"""Compute PME Green's function and B-spline structure factor for batched systems.
Batched version of _pme_green_structure_factor_kernel. Each system can have
different alpha and volume values, but shares the same mesh dimensions.
Green's function: G_s(k) = (2π/V_s) * exp(-k²/(4α_s²)) / k²
Structure factor: :math:`|B(k)|^2` (computed once, shared across systems)
Launch Grid
-----------
dim = [B, Nx, Ny, Nz_rfft]
Each thread processes one (system, grid_point) pair.
Parameters
----------
k_squared : wp.array4d, shape (B, Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
Per-system squared magnitude of k-vectors at each grid point.
miller_x : wp.array, shape (Nx,), dtype=wp.float32 or wp.float64
Miller indices in x direction (shared across systems).
miller_y : wp.array, shape (Ny,), dtype=wp.float32 or wp.float64
Miller indices in y direction (shared across systems).
miller_z : wp.array, shape (Nz_rfft,), dtype=wp.float32 or wp.float64
Miller indices in z direction (shared across systems).
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
mesh_nx, mesh_ny, mesh_nz : wp.int32
Full mesh dimensions (Nz is the full size, not rfft size).
spline_order : wp.int32
B-spline order (1-4). Order 4 (cubic) recommended.
green_function : wp.array4d, shape (B, Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: Per-system Green's function G_s(k) at each grid point.
structure_factor_sq : wp.array3d, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: :math:`|B(k)|^2` structure factor squared (computed only at batch_idx=0).
Notes
-----
- k=0 (grid point [0,0,0]) is explicitly set to zero for each system.
- Near-zero k² values are set to zero to avoid division by zero.
- Structure factor is computed only once (at batch_idx=0) since it depends
only on mesh dimensions and spline order, not on system parameters.
- Uses rfft symmetry: only Nz_rfft = Nz//2 + 1 points in z.
"""
batch_idx, i, j, k = wp.tid()
k_sq = k_squared[batch_idx, i, j, k]
system_alpha = alpha[batch_idx]
system_volume = volumes[batch_idx]
mi_x = miller_x[i]
mi_y = miller_y[j]
mi_z = miller_z[k]
# Get dtype-specific constants
zero = type(k_sq)(0.0)
one = type(k_sq)(1.0)
four = type(k_sq)(4.0)
threshold = type(k_sq)(1e-10)
clamp_threshold = type(k_sq)(1e-10)
twopi = type(k_sq)(TWOPI)
# Green's function: G(k) = 2*pi * exp(-k^2/(4*alpha^2)) / (k^2 * V)
if k_sq < threshold:
green_function[batch_idx, i, j, k] = zero
else:
exp_factor = wp_exp_kernel(k_sq, one / (four * system_alpha * system_alpha))
green_function[batch_idx, i, j, k] = twopi * exp_factor / system_volume
if i == 0 and j == 0 and k == 0:
green_function[batch_idx, i, j, k] = zero
# Structure factor (only compute once per k-point, at batch_idx=0)
if batch_idx == wp.int32(0):
sinc_x = compute_sinc(mi_x / type(mi_x)(mesh_nx))
sinc_y = compute_sinc(mi_y / type(mi_y)(mesh_ny))
sinc_z = compute_sinc(mi_z / type(mi_z)(mesh_nz))
sinc_product = sinc_x * sinc_y * sinc_z
sf = sinc_product
for _ in range(1, 4):
if _ < spline_order:
sf = sf * sinc_product
if sf < clamp_threshold:
sf = clamp_threshold
structure_factor_sq[i, j, k] = sf * sf
###########################################################################################
########################### PME Energy Corrections ########################################
###########################################################################################
@wp.kernel
def _pme_energy_corrections_kernel(
raw_energies: wp.array(dtype=Any),
charges: wp.array(dtype=Any),
volume: wp.array(dtype=Any),
alpha: wp.array(dtype=Any),
total_charge: wp.array(dtype=Any),
corrected_energies: wp.array(dtype=Any),
):
"""Apply self-energy and background corrections to PME energies.
Converts raw potential values (φ_i) to corrected per-atom energies by:
1. Multiplying potential by charge: E_pot = q_i * φ_i
2. Subtracting self-energy: E_self = (α/√π) * q_i²
3. Subtracting background: E_bg = (π/(2α²V)) * q_i * Q_total
Final: E_i = q_i * φ_i - (α/√π) * q_i² - (π/(2α²V)) * q_i * Q_total
Launch Grid
-----------
dim = [num_atoms]
Each thread processes one atom independently.
Parameters
----------
raw_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Atomic charges.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
total_charge : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Sum of all charges (Q_total = ∑_i q_i).
corrected_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
Notes
-----
- Self-energy removes spurious interaction of each Gaussian with itself.
- Background correction accounts for uniform neutralizing background.
- For neutral systems (Q_total = 0), background correction is zero.
"""
atom_idx = wp.tid()
charge = charges[atom_idx]
raw_energy = raw_energies[atom_idx]
alpha_ = alpha[0]
total_charge_ = total_charge[0]
volume_ = volume[0]
# Get dtype-specific constants
pi = type(charge)(PI)
two = type(charge)(2.0)
# Convert potential to energy: E = q * phi, where phi = raw_energy
potential_energy = charge * raw_energy
# Self-energy correction: -q^2 * alpha / sqrt(pi)
self_contrib = charge * charge * alpha_ / wp.sqrt(pi)
# Background correction: -q * pi * Q_tot / (2*alpha^2 * V)
background_contrib = charge * pi * total_charge_ / (two * alpha_ * alpha_ * volume_)
# Final corrected energy per atom
corrected_energies[atom_idx] = potential_energy - self_contrib - background_contrib
@wp.kernel
def _pme_energy_corrections_with_charge_grad_kernel(
raw_energies: wp.array(dtype=Any),
charges: wp.array(dtype=Any),
volume: wp.array(dtype=Any),
alpha: wp.array(dtype=Any),
total_charge: wp.array(dtype=Any),
corrected_energies: wp.array(dtype=Any),
charge_gradients: wp.array(dtype=Any),
):
"""Apply corrections and compute charge gradients for PME energies.
Computes both corrected energies and analytical charge gradients in a single pass:
Energy: E_i = q_i * φ_i - (α/√π) * q_i² - (π/(2α²V)) * q_i * Q_total
Charge gradient: ∂E_total/∂q_i = 2*φ_i - 2*(α/√π)*q_i - (π/(α²V))*Q_total
The factor of 2 on φ_i arises because changing q_i affects:
1. The direct term: ∂(q_i * φ_i)/∂q_i = φ_i
2. All potentials: ∑_j q_j * ∂φ_j/∂q_i = φ_i (since ∂φ_j/∂q_i = φ_i/q_i)
Total: 2*φ_i
Launch Grid
-----------
dim = [num_atoms]
Each thread processes one atom independently.
Parameters
----------
raw_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Atomic charges.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
total_charge : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Sum of all charges (Q_total = ∑_i q_i).
corrected_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
charge_gradients : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Analytical charge gradients ∂E_total/∂q_i.
Notes
-----
- Charge gradients are useful for second-derivative training in ML potentials.
- Combines energy and charge gradient computation for efficiency.
- Self-energy and background corrections are applied to both outputs.
"""
atom_idx = wp.tid()
charge = charges[atom_idx]
raw_energy = raw_energies[atom_idx] # This is φ_i (the potential)
alpha_ = alpha[0]
total_charge_ = total_charge[0]
volume_ = volume[0]
# Get dtype-specific constants
pi = type(charge)(PI)
two = type(charge)(2.0)
# === Energy calculation ===
# Convert potential to energy: E = q * φ
potential_energy = charge * raw_energy
# Self-energy correction: -q² * α / √π
self_contrib = charge * charge * alpha_ / wp.sqrt(pi)
# Background correction: -q * π * Q_tot / (2α² * V)
background_contrib = charge * pi * total_charge_ / (two * alpha_ * alpha_ * volume_)
corrected_energies[atom_idx] = potential_energy - self_contrib - background_contrib
# === Charge gradient calculation ===
# ∂E/∂q_i = 2*φ_i - 2*(α/√π)*q_i - (π/(α²V))*Q_total
# The 2*φ_i factor accounts for both direct contribution and induced potential changes
self_energy_grad = two * alpha_ * charge / wp.sqrt(pi)
background_grad = pi * total_charge_ / (alpha_ * alpha_ * volume_)
charge_gradients[atom_idx] = two * raw_energy - self_energy_grad - background_grad
@wp.kernel
def _batch_pme_energy_corrections_kernel(
raw_energies: wp.array(dtype=Any),
charges: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
volumes: wp.array(dtype=Any), # (B,)
alpha: wp.array(dtype=Any), # (B,)
total_charges: wp.array(dtype=Any), # (B,)
corrected_energies: wp.array(dtype=Any),
):
"""Apply self-energy and background corrections for batched PME.
Batched version of _pme_energy_corrections_kernel. Each atom looks up its
system's parameters (volume, alpha, total_charge) via batch_idx.
Final: E_i = q_i * φ_i - (α_s/√π) * q_i² - (π/(2α_s²V_s)) * q_i * Q_s
where s = batch_idx[i] is the system index for atom i.
Launch Grid
-----------
dim = [num_atoms_total]
Each thread processes one atom independently.
Parameters
----------
raw_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Atomic charges for all systems concatenated.
batch_idx : wp.array, shape (N_total,), dtype=wp.int32
System index for each atom (0 to B-1).
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
total_charges : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system sum of charges (Q_s = ∑_{i∈s} q_i).
corrected_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
Notes
-----
- Each system can have different alpha, volume, and total charge.
- Atoms are assigned to systems via batch_idx array.
"""
atom_idx = wp.tid()
system_id = batch_idx[atom_idx]
charge = charges[atom_idx]
raw_energy = raw_energies[atom_idx]
volume = volumes[system_id]
system_alpha = alpha[system_id]
total_charge = total_charges[system_id]
# Get dtype-specific constants
pi = type(charge)(PI)
two = type(charge)(2.0)
# Convert potential to energy: E = q * phi, where phi = raw_energy
potential_energy = charge * raw_energy
# Self-energy correction: -q^2 * alpha / sqrt(pi)
self_contrib = charge * charge * system_alpha / wp.sqrt(pi)
# Background correction: -q * pi * Q_tot / (2*alpha^2 * V)
background_contrib = (
charge * pi * total_charge / (two * system_alpha * system_alpha * volume)
)
# Final corrected energy per atom
corrected_energies[atom_idx] = potential_energy - self_contrib - background_contrib
@wp.kernel
def _batch_pme_energy_corrections_with_charge_grad_kernel(
raw_energies: wp.array(dtype=Any),
charges: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
volumes: wp.array(dtype=Any), # (B,)
alpha: wp.array(dtype=Any), # (B,)
total_charges: wp.array(dtype=Any), # (B,)
corrected_energies: wp.array(dtype=Any),
charge_gradients: wp.array(dtype=Any),
):
"""Apply corrections and compute charge gradients for batched PME.
Batched version of _pme_energy_corrections_with_charge_grad_kernel.
Computes both corrected energies and analytical charge gradients:
Energy: E_i = q_i * φ_i - (α_s/√π) * q_i² - (π/(2α_s²V_s)) * q_i * Q_s
Charge gradient: ∂E_total/∂q_i = 2*φ_i - 2*(α_s/√π)*q_i - (π/(α_s²V_s))*Q_s
where s = batch_idx[i] is the system index for atom i.
Launch Grid
-----------
dim = [num_atoms_total]
Each thread processes one atom independently.
Parameters
----------
raw_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Atomic charges for all systems concatenated.
batch_idx : wp.array, shape (N_total,), dtype=wp.int32
System index for each atom (0 to B-1).
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
total_charges : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system sum of charges (Q_s = ∑_{i∈s} q_i).
corrected_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
charge_gradients : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Analytical charge gradients ∂E_total/∂q_i.
Notes
-----
- Each system can have different alpha, volume, and total charge.
- Atoms are assigned to systems via batch_idx array.
- Charge gradients are useful for second-derivative training in ML potentials.
"""
atom_idx = wp.tid()
system_id = batch_idx[atom_idx]
charge = charges[atom_idx]
raw_energy = raw_energies[atom_idx] # This is φ_i (the potential)
volume = volumes[system_id]
system_alpha = alpha[system_id]
total_charge = total_charges[system_id]
# Get dtype-specific constants
pi = type(charge)(PI)
two = type(charge)(2.0)
# === Energy calculation ===
# Convert potential to energy: E = q * φ
potential_energy = charge * raw_energy
# Self-energy correction: -q² * α / √π
self_contrib = charge * charge * system_alpha / wp.sqrt(pi)
# Background correction: -q * π * Q_tot / (2α² * V)
background_contrib = (
charge * pi * total_charge / (two * system_alpha * system_alpha * volume)
)
corrected_energies[atom_idx] = potential_energy - self_contrib - background_contrib
# === Charge gradient calculation ===
# ∂E/∂q_i = 2*φ_i - 2*(α/√π)*q_i - (π/(α²V))*Q_total
# The 2*φ_i factor accounts for both direct contribution and induced potential changes
self_energy_grad = two * system_alpha * charge / wp.sqrt(pi)
background_grad = pi * total_charge / (system_alpha * system_alpha * volume)
charge_gradients[atom_idx] = two * raw_energy - self_energy_grad - background_grad
###########################################################################################
########################### Kernel Overloads for Dtype Flexibility ########################
###########################################################################################
# Type lists for creating overloads
_T = [wp.float32, wp.float64]
# Single-system kernel overloads
_pme_green_structure_factor_kernel_overload = {}
_pme_energy_corrections_kernel_overload = {}
_pme_energy_corrections_with_charge_grad_kernel_overload = {}
# Batch kernel overloads
_batch_pme_green_structure_factor_kernel_overload = {}
_batch_pme_energy_corrections_kernel_overload = {}
_batch_pme_energy_corrections_with_charge_grad_kernel_overload = {}
for t in _T:
# Green's function kernel overloads
_pme_green_structure_factor_kernel_overload[t] = wp.overload(
_pme_green_structure_factor_kernel,
[
wp.array3d(dtype=t), # k_squared
wp.array(dtype=t), # miller_x
wp.array(dtype=t), # miller_y
wp.array(dtype=t), # miller_z
wp.array(dtype=t), # alpha
wp.array(dtype=t), # volume
wp.int32, # mesh_nx
wp.int32, # mesh_ny
wp.int32, # mesh_nz
wp.int32, # spline_order
wp.array3d(dtype=t), # green_function
wp.array3d(dtype=t), # structure_factor_sq
],
)
_batch_pme_green_structure_factor_kernel_overload[t] = wp.overload(
_batch_pme_green_structure_factor_kernel,
[
wp.array4d(dtype=t), # k_squared
wp.array(dtype=t), # miller_x
wp.array(dtype=t), # miller_y
wp.array(dtype=t), # miller_z
wp.array(dtype=t), # alpha
wp.array(dtype=t), # volumes
wp.int32, # mesh_nx
wp.int32, # mesh_ny
wp.int32, # mesh_nz
wp.int32, # spline_order
wp.array4d(dtype=t), # green_function
wp.array3d(dtype=t), # structure_factor_sq
],
)
# Energy corrections kernel overloads
_pme_energy_corrections_kernel_overload[t] = wp.overload(
_pme_energy_corrections_kernel,
[
wp.array(dtype=t), # raw_energies
wp.array(dtype=t), # charges
wp.array(dtype=t), # volume
wp.array(dtype=t), # alpha
wp.array(dtype=t), # total_charge
wp.array(dtype=t), # corrected_energies
],
)
_batch_pme_energy_corrections_kernel_overload[t] = wp.overload(
_batch_pme_energy_corrections_kernel,
[
wp.array(dtype=t), # raw_energies
wp.array(dtype=t), # charges
wp.array(dtype=wp.int32), # batch_idx
wp.array(dtype=t), # volumes
wp.array(dtype=t), # alpha
wp.array(dtype=t), # total_charges
wp.array(dtype=t), # corrected_energies
],
)
# Energy corrections with charge gradient kernel overloads
_pme_energy_corrections_with_charge_grad_kernel_overload[t] = wp.overload(
_pme_energy_corrections_with_charge_grad_kernel,
[
wp.array(dtype=t), # raw_energies
wp.array(dtype=t), # charges
wp.array(dtype=t), # volume
wp.array(dtype=t), # alpha
wp.array(dtype=t), # total_charge
wp.array(dtype=t), # corrected_energies
wp.array(dtype=t), # charge_gradients
],
)
_batch_pme_energy_corrections_with_charge_grad_kernel_overload[t] = wp.overload(
_batch_pme_energy_corrections_with_charge_grad_kernel,
[
wp.array(dtype=t), # raw_energies
wp.array(dtype=t), # charges
wp.array(dtype=wp.int32), # batch_idx
wp.array(dtype=t), # volumes
wp.array(dtype=t), # alpha
wp.array(dtype=t), # total_charges
wp.array(dtype=t), # corrected_energies
wp.array(dtype=t), # charge_gradients
],
)
###########################################################################################
########################### Warp Launcher Functions (wp_*) ################################
###########################################################################################
[docs]
def pme_green_structure_factor(
k_squared: wp.array,
miller_x: wp.array,
miller_y: wp.array,
miller_z: wp.array,
alpha: wp.array,
volume: wp.array,
mesh_nx: int,
mesh_ny: int,
mesh_nz: int,
spline_order: int,
green_function: wp.array,
structure_factor_sq: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Compute PME Green's function and B-spline structure factor correction.
Framework-agnostic launcher for single-system Green's function computation.
Note: FFT Operations Offloaded to Framework
-------------------------------------------
This kernel computes the Green's function multipliers for PME.
The complete PME reciprocal-space workflow requires FFT operations
that are not available in Warp and must be performed by the calling
framework. The typical workflow is:
1. Spread charges to mesh: spline_spread()
2. Forward FFT: framework.fft.rfftn(mesh) <-- Framework-specific
3. Compute Green's function: pme_green_structure_factor()
4. Convolution: mesh_fft * green_function / structure_factor_sq
5. Inverse FFT: framework.fft.irfftn(...) <-- Framework-specific
6. Gather potential: spline_gather()
7. Apply corrections: pme_energy_corrections()
Parameters
----------
k_squared : wp.array, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
Squared magnitude of k-vectors at each grid point.
miller_x : wp.array, shape (Nx,), dtype=wp.float32 or wp.float64
Miller indices in x direction (from fftfreq).
miller_y : wp.array, shape (Ny,), dtype=wp.float32 or wp.float64
Miller indices in y direction (from fftfreq).
miller_z : wp.array, shape (Nz_rfft,), dtype=wp.float32 or wp.float64
Miller indices in z direction (from rfftfreq).
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
mesh_nx, mesh_ny, mesh_nz : int
Full mesh dimensions (Nz is the full size, not rfft size).
spline_order : int
B-spline order (1-4). Order 4 (cubic) recommended.
green_function : wp.array, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: Green's function G(k) at each grid point.
structure_factor_sq : wp.array, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: :math:`|B(k)|^2` structure factor squared at each grid point.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
See Also
--------
nvalchemiops.torch.interactions.electrostatics.pme : Complete PyTorch implementation
"""
nx, ny, nz_rfft = k_squared.shape[0], k_squared.shape[1], k_squared.shape[2]
kernel = _pme_green_structure_factor_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=(nx, ny, nz_rfft),
inputs=[
k_squared,
miller_x,
miller_y,
miller_z,
alpha,
volume,
wp.int32(mesh_nx),
wp.int32(mesh_ny),
wp.int32(mesh_nz),
wp.int32(spline_order),
],
outputs=[green_function, structure_factor_sq],
device=device,
)
[docs]
def batch_pme_green_structure_factor(
k_squared: wp.array,
miller_x: wp.array,
miller_y: wp.array,
miller_z: wp.array,
alpha: wp.array,
volumes: wp.array,
mesh_nx: int,
mesh_ny: int,
mesh_nz: int,
spline_order: int,
green_function: wp.array,
structure_factor_sq: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Compute PME Green's function and B-spline structure factor for batched systems.
Framework-agnostic launcher for batched Green's function computation.
Each system can have different alpha and volume values, but shares
the same mesh dimensions.
Parameters
----------
k_squared : wp.array, shape (B, Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
Per-system squared magnitude of k-vectors at each grid point.
miller_x : wp.array, shape (Nx,), dtype=wp.float32 or wp.float64
Miller indices in x direction (shared across systems).
miller_y : wp.array, shape (Ny,), dtype=wp.float32 or wp.float64
Miller indices in y direction (shared across systems).
miller_z : wp.array, shape (Nz_rfft,), dtype=wp.float32 or wp.float64
Miller indices in z direction (shared across systems).
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
mesh_nx, mesh_ny, mesh_nz : int
Full mesh dimensions (Nz is the full size, not rfft size).
spline_order : int
B-spline order (1-4). Order 4 (cubic) recommended.
green_function : wp.array, shape (B, Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: Per-system Green's function G_s(k) at each grid point.
structure_factor_sq : wp.array, shape (Nx, Ny, Nz_rfft), dtype=wp.float32 or wp.float64
OUTPUT: :math:`|B(k)|^2` structure factor squared (computed only at batch_idx=0).
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
See Also
--------
nvalchemiops.torch.interactions.electrostatics.pme : Complete PyTorch implementation
"""
num_systems = k_squared.shape[0]
nx, ny, nz_rfft = k_squared.shape[1], k_squared.shape[2], k_squared.shape[3]
kernel = _batch_pme_green_structure_factor_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=(num_systems, nx, ny, nz_rfft),
inputs=[
k_squared,
miller_x,
miller_y,
miller_z,
alpha,
volumes,
wp.int32(mesh_nx),
wp.int32(mesh_ny),
wp.int32(mesh_nz),
wp.int32(spline_order),
],
outputs=[green_function, structure_factor_sq],
device=device,
)
[docs]
def pme_energy_corrections(
raw_energies: wp.array,
charges: wp.array,
volume: wp.array,
alpha: wp.array,
total_charge: wp.array,
corrected_energies: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Apply self-energy and background corrections to PME energies.
Framework-agnostic launcher for single-system energy corrections.
Converts raw potential values (φ_i) to corrected per-atom energies by:
1. Multiplying potential by charge: E_pot = q_i * φ_i
2. Subtracting self-energy: E_self = (α/√π) * q_i²
3. Subtracting background: E_bg = (π/(2α²V)) * q_i * Q_total
Final: E_i = q_i * φ_i - (α/√π) * q_i² - (π/(2α²V)) * q_i * Q_total
Parameters
----------
raw_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Atomic charges.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
total_charge : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Sum of all charges (Q_total = ∑_i q_i).
corrected_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
"""
num_atoms = raw_energies.shape[0]
kernel = _pme_energy_corrections_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=num_atoms,
inputs=[raw_energies, charges, volume, alpha, total_charge],
outputs=[corrected_energies],
device=device,
)
[docs]
def batch_pme_energy_corrections(
raw_energies: wp.array,
charges: wp.array,
batch_idx: wp.array,
volumes: wp.array,
alpha: wp.array,
total_charges: wp.array,
corrected_energies: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Apply self-energy and background corrections for batched PME.
Framework-agnostic launcher for batched energy corrections.
Each atom looks up its system's parameters via batch_idx.
Parameters
----------
raw_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Atomic charges for all systems concatenated.
batch_idx : wp.array, shape (N_total,), dtype=wp.int32
System index for each atom (0 to B-1).
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
total_charges : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system sum of charges (Q_s = ∑_{i∈s} q_i).
corrected_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
"""
num_atoms = raw_energies.shape[0]
kernel = _batch_pme_energy_corrections_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=num_atoms,
inputs=[raw_energies, charges, batch_idx, volumes, alpha, total_charges],
outputs=[corrected_energies],
device=device,
)
[docs]
def pme_energy_corrections_with_charge_grad(
raw_energies: wp.array,
charges: wp.array,
volume: wp.array,
alpha: wp.array,
total_charge: wp.array,
corrected_energies: wp.array,
charge_gradients: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Apply corrections and compute charge gradients for PME energies.
Framework-agnostic launcher for single-system energy corrections
with analytical charge gradient computation.
Computes both corrected energies and analytical charge gradients:
- Energy: E_i = q_i * φ_i - (α/√π) * q_i² - (π/(2α²V)) * q_i * Q_total
- Charge gradient: ∂E_total/∂q_i = 2*φ_i - 2*(α/√π)*q_i - (π/(α²V))*Q_total
Parameters
----------
raw_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
Atomic charges.
volume : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Unit cell volume.
alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Ewald splitting parameter.
total_charge : wp.array, shape (1,), dtype=wp.float32 or wp.float64
Sum of all charges (Q_total = ∑_i q_i).
corrected_energies : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
charge_gradients : wp.array, shape (N,), dtype=wp.float32 or wp.float64
OUTPUT: Analytical charge gradients ∂E_total/∂q_i.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
"""
num_atoms = raw_energies.shape[0]
kernel = _pme_energy_corrections_with_charge_grad_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=num_atoms,
inputs=[raw_energies, charges, volume, alpha, total_charge],
outputs=[corrected_energies, charge_gradients],
device=device,
)
[docs]
def batch_pme_energy_corrections_with_charge_grad(
raw_energies: wp.array,
charges: wp.array,
batch_idx: wp.array,
volumes: wp.array,
alpha: wp.array,
total_charges: wp.array,
corrected_energies: wp.array,
charge_gradients: wp.array,
wp_dtype: type,
device: str | None = None,
) -> None:
"""Apply corrections and compute charge gradients for batched PME.
Framework-agnostic launcher for batched energy corrections
with analytical charge gradient computation.
Parameters
----------
raw_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Raw potential values φ_i from mesh interpolation.
charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
Atomic charges for all systems concatenated.
batch_idx : wp.array, shape (N_total,), dtype=wp.int32
System index for each atom (0 to B-1).
volumes : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system unit cell volume.
alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system Ewald splitting parameter.
total_charges : wp.array, shape (B,), dtype=wp.float32 or wp.float64
Per-system sum of charges (Q_s = ∑_{i∈s} q_i).
corrected_energies : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Corrected per-atom energies.
charge_gradients : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
OUTPUT: Analytical charge gradients ∂E_total/∂q_i.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64).
device : str | None
Warp device string. If None, inferred from arrays.
"""
num_atoms = raw_energies.shape[0]
kernel = _batch_pme_energy_corrections_with_charge_grad_kernel_overload[wp_dtype]
wp.launch(
kernel,
dim=num_atoms,
inputs=[raw_energies, charges, batch_idx, volumes, alpha, total_charges],
outputs=[corrected_energies, charge_gradients],
device=device,
)
###########################################################################################
########################### Module Exports #################################################
###########################################################################################
__all__ = [
# Kernel overloads
"_pme_green_structure_factor_kernel_overload",
"_batch_pme_green_structure_factor_kernel_overload",
"_pme_energy_corrections_kernel_overload",
"_batch_pme_energy_corrections_kernel_overload",
"_pme_energy_corrections_with_charge_grad_kernel_overload",
"_batch_pme_energy_corrections_with_charge_grad_kernel_overload",
# Warp launchers
"pme_green_structure_factor",
"batch_pme_green_structure_factor",
"pme_energy_corrections",
"batch_pme_energy_corrections",
"pme_energy_corrections_with_charge_grad",
"batch_pme_energy_corrections_with_charge_grad",
]