Source code for nvalchemiops.torch.interactions.electrostatics.pme

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PyTorch Bindings for Particle Mesh Ewald (PME)
==============================================

This module provides PyTorch bindings for the Particle Mesh Ewald algorithm,
wrapping Warp kernels with PyTorch custom operators for autograd support.

The PME module has unique challenges - it requires FFT operations that Warp
doesn't support. The Warp layer provides building blocks (Green's function,
energy corrections), but the complete PME workflow must remain in framework
bindings due to FFT dependency on PyTorch.

This module provides a unified GPU-accelerated API for Particle Mesh Ewald that
handles both single-system and batched calculations transparently. PME achieves
:math:`O(N \\log N)` scaling compared to :math:`O(N^2)` for direct summation, making it efficient
for large systems.

The output dtype convention follows ewald.py: energies in float64, forces/virial
match input precision.

API STRUCTURE
=============

Primary APIs (public, with autograd support):
    particle_mesh_ewald(): Complete PME calculation (real + reciprocal)
    pme_reciprocal_space(): Reciprocal-space FFT-based component only

Helper APIs:
    pme_green_structure_factor(): Green's function and B-spline correction
    pme_energy_corrections(): Self-energy and background corrections

The batch_idx parameter determines kernel dispatch:
    batch_idx=None → Single-system kernels
    batch_idx provided → Batch kernels (multiple independent systems)

MATHEMATICAL FORMULATION
========================

PME uses B-spline interpolation to assign charges to a mesh, computes the
convolution with the Coulomb kernel efficiently via FFT, then interpolates
back to get energies and forces.

.. math::

    E_{\\text{total}} = E_{\\text{real}} + E_{\\text{reciprocal}} - E_{\\text{self}} - E_{\\text{background}}

Reciprocal-Space Steps:

1. Charge assignment:

.. math::

    Q(x) = \\sum_i q_i M_p(x - r_i)

where :math:`M_p` is the pth-order cardinal B-spline

2. FFT:

.. math::

    \\tilde{Q}(k) = \\text{FFT}[Q(x)]

3. Convolution in k-space:

.. math::

    \\tilde{\\Phi}(k) = \\frac{G(k)}{C^2(k)} \\tilde{Q}(k)

where :math:`G(k) = \\frac{2\\pi}{V} \\frac{\\exp(-k^2/(4\\alpha^2))}{k^2}` and :math:`C(k) = [\\text{sinc products}]^p` is the B-spline correction

4. Inverse FFT for potential and field:

.. math::

    \\begin{aligned}
    \\Phi(x) &= \\text{IFFT}[\\tilde{\\Phi}(k)] \\\\
    E(x) &= \\text{IFFT}[-ik \\tilde{\\Phi}(k)]
    \\end{aligned}

5. Energy and force interpolation:

.. math::

    \\begin{aligned}
    E_i &= q_i \\cdot \\text{interpolate}(\\Phi, r_i) \\\\
    F_i &= q_i \\cdot \\text{interpolate}(E, r_i)
    \\end{aligned}

Corrections:

.. math::

    \\begin{aligned}
    E_{\\text{self}} &= \\sum_i \\frac{\\alpha}{\\sqrt{\\pi}} q_i^2 \\\\
    E_{\\text{background}} &= \\sum_i \\frac{\\pi}{2\\alpha^2 V} q_i Q_{\\text{total}}
    \\end{aligned}

USAGE EXAMPLES
==============

Automatic parameter estimation::

    >>> from nvalchemiops.torch.interactions.electrostatics import particle_mesh_ewald
    >>> energies, forces = particle_mesh_ewald(
    ...     positions, charges, cell,
    ...     neighbor_list=nl, neighbor_shifts=shifts,
    ...     accuracy=1e-6,  # alpha and mesh estimated automatically
    ... )

Explicit parameters::

    >>> energies, forces = particle_mesh_ewald(
    ...     positions, charges, cell,
    ...     alpha=0.3,
    ...     mesh_dimensions=(32, 32, 32),
    ...     spline_order=4,
    ...     neighbor_list=nl, neighbor_shifts=shifts,
    ... )

Batched systems::

    >>> energies, forces = particle_mesh_ewald(
    ...     positions, charges, cells,  # cells shape (B, 3, 3)
    ...     alpha=torch.tensor([0.3, 0.35]),
    ...     batch_idx=batch_idx,
    ...     mesh_dimensions=(32, 32, 32),
    ...     neighbor_list=nl, neighbor_shifts=shifts,
    ... )

Reciprocal-space only (no real-space)::

    >>> energies = pme_reciprocal_space(
    ...     positions, charges, cell,
    ...     alpha=0.3, mesh_dimensions=(32, 32, 32),
    ... )
REFERENCES

==========

- Essmann et al. (1995). J. Chem. Phys. 103, 8577 (SPME paper)
- Darden et al. (1993). J. Chem. Phys. 98, 10089 (Original PME)
- torchpme: https://github.com/lab-cosmo/torch-pme (Reference implementation)
"""

import math
from typing import Any

import torch
import warp as wp

from nvalchemiops.interactions.electrostatics.pme_kernels import (
    _batch_pme_energy_corrections_kernel_overload,
    _batch_pme_energy_corrections_with_charge_grad_kernel_overload,
    _batch_pme_green_structure_factor_kernel_overload,
    _pme_energy_corrections_kernel_overload,
    _pme_energy_corrections_with_charge_grad_kernel_overload,
    _pme_green_structure_factor_kernel_overload,
)
from nvalchemiops.torch.autograd import (
    OutputSpec,
    WarpAutogradContextManager,
    attach_for_backward,
    needs_grad,
    warp_custom_op,
    warp_from_torch,
)
from nvalchemiops.torch.interactions.electrostatics.ewald import ewald_real_space
from nvalchemiops.torch.interactions.electrostatics.k_vectors import (
    generate_k_vectors_pme,
)
from nvalchemiops.torch.interactions.electrostatics.parameters import (
    estimate_pme_mesh_dimensions,
    estimate_pme_parameters,
    mesh_spacing_to_dimensions,
)
from nvalchemiops.torch.spline import (
    spline_gather,
    spline_gather_vec3,
    spline_spread,
)
from nvalchemiops.torch.types import get_wp_dtype

# Mathematical constants
PI = math.pi
TWOPI = 2.0 * PI
FOURPI = 4.0 * PI


###########################################################################################
########################### Helper Functions ##############################################
###########################################################################################


def _prepare_alpha(
    alpha: float | torch.Tensor,
    num_systems: int,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor:
    """Convert alpha to a per-system tensor.

    Parameters
    ----------
    alpha : float or torch.Tensor
        Ewald splitting parameter. Can be:
        - A scalar float (broadcast to all systems)
        - A 0-d tensor (broadcast to all systems)
        - A 1-d tensor of shape (num_systems,) for per-system values
    num_systems : int
        Number of systems in the batch.
    dtype : torch.dtype
        Target dtype (typically float64).
    device : torch.device
        Target device.

    Returns
    -------
    torch.Tensor, shape (num_systems,)
        Per-system alpha values.
    """
    if isinstance(alpha, (int, float)):
        return torch.full((num_systems,), float(alpha), dtype=dtype, device=device)
    elif isinstance(alpha, torch.Tensor):
        if alpha.dim() == 0:
            return alpha.expand(num_systems).to(dtype=dtype, device=device)
        elif alpha.shape[0] != num_systems:
            raise ValueError(
                f"alpha has {alpha.shape[0]} values but there are {num_systems} systems"
            )
        return alpha.to(dtype=dtype, device=device)
    else:
        raise TypeError(f"alpha must be float or torch.Tensor, got {type(alpha)}")


def _prepare_cell(cell: torch.Tensor) -> tuple[torch.Tensor, int]:
    """Ensure cell is 3D (B, 3, 3) and return number of systems.

    Parameters
    ----------
    cell : torch.Tensor
        Unit cell matrix. Shape (3, 3) for single system or (B, 3, 3) for batch.

    Returns
    -------
    cell : torch.Tensor, shape (B, 3, 3)
        Cell with batch dimension.
    num_systems : int
        Number of systems (B).
    """
    if cell.dim() == 2:
        cell = cell.unsqueeze(0)
    return cell, cell.shape[0]


###########################################################################################
########################### Green Function & Structure Factor Custom Ops ##################
###########################################################################################


def _green_output_shape(k_squared, *_):
    """Helper to compute output shape for Green's function.

    Uses k_squared.shape directly since it already has shape (Nx, Ny, Nz_rfft).
    """
    return k_squared.shape


def _struct_output_shape(k_squared, *_):
    """Helper to compute output shape for structure factor.

    Uses k_squared.shape directly since it already has shape (Nx, Ny, Nz_rfft).
    """
    return k_squared.shape


@warp_custom_op(
    name="alchemiops::_pme_green_structure_factor",
    outputs=[
        OutputSpec("green_function", wp.array(dtype=Any, ndim=3), _green_output_shape),
        OutputSpec(
            "structure_factor_sq", wp.array(dtype=Any, ndim=3), _struct_output_shape
        ),
    ],
    grad_arrays=["green_function", "k_squared", "alpha", "volume"],
)
def _pme_green_structure_factor(
    k_squared: torch.Tensor,
    miller_x: torch.Tensor,
    miller_y: torch.Tensor,
    miller_z: torch.Tensor,
    alpha: torch.Tensor,
    volume: torch.Tensor,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""Compute Green's function and structure factor for single-system PME.

    The Green's function includes volume normalization:

    .. math::

        G(k) = \frac{4\pi \exp(-k^2/(4\alpha^2))}{V k^2}

    Supports both float32 and float64 dtypes via kernel overloads.

    Parameters
    ----------
    k_squared : torch.Tensor, shape (Nx, Ny, Nz_rfft)
        :math:`|k|^2` for each grid point.
    miller_x : torch.Tensor, shape (Nx,)
        Miller indices in x direction.
    miller_y : torch.Tensor, shape (Ny,)
        Miller indices in y direction.
    miller_z : torch.Tensor, shape (Nz_rfft,)
        Miller indices in z direction.
    alpha : torch.Tensor, shape (1,)
        Ewald splitting parameter.
    volume : torch.Tensor, shape (1,)
        Cell volume.
    mesh_nx, mesh_ny, mesh_nz : int
        Full mesh dimensions.
    spline_order : int
        B-spline order.

    Returns
    -------
    green_function : torch.Tensor, shape (Nx, Ny, Nz_rfft)
        Green's function values (volume-normalized).
    structure_factor_sq : torch.Tensor, shape (Nx, Ny, Nz_rfft)
        Structure factor squared.
    """
    device = wp.device_from_torch(k_squared.device)
    input_dtype = k_squared.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    nx, ny, nz_rfft = k_squared.shape
    needs_grad_flag = needs_grad(k_squared, alpha, volume)

    # Prepare inputs with appropriate dtype
    wp_k_squared = warp_from_torch(
        k_squared.contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_miller_x = warp_from_torch(
        miller_x.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_miller_y = warp_from_torch(
        miller_y.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_miller_z = warp_from_torch(
        miller_z.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_alpha = warp_from_torch(
        alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_volume = warp_from_torch(
        volume.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )

    # Allocate outputs with input dtype
    green_function = torch.zeros(
        (nx, ny, nz_rfft), dtype=input_dtype, device=k_squared.device
    )
    structure_factor_sq = torch.zeros(
        (nx, ny, nz_rfft), dtype=input_dtype, device=k_squared.device
    )

    wp_green = warp_from_torch(green_function, wp_dtype, requires_grad=needs_grad_flag)
    wp_struct = warp_from_torch(structure_factor_sq, wp_dtype, requires_grad=False)

    # Select kernel based on dtype
    kernel = _pme_green_structure_factor_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(nx, ny, nz_rfft),
            inputs=[
                wp_k_squared,
                wp_miller_x,
                wp_miller_y,
                wp_miller_z,
                wp_alpha,
                wp_volume,
                wp.int32(mesh_nx),
                wp.int32(mesh_ny),
                wp.int32(mesh_nz),
                wp.int32(spline_order),
            ],
            outputs=[wp_green, wp_struct],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            green_function,
            tape=tape,
            green_function=wp_green,
            k_squared=wp_k_squared,
            alpha=wp_alpha,
            volume=wp_volume,
        )

    return green_function, structure_factor_sq


def _batch_green_output_shape(
    k_squared, miller_x, miller_y, miller_z, alpha, volumes, *_
):
    """Helper to compute output shapes for batch Green's function."""
    if k_squared.dim() == 3:
        _, nx, ny, nz_rfft = (1,) + k_squared.shape
    else:
        _, nx, ny, nz_rfft = k_squared.shape
    num_systems = volumes.shape[0]
    return (num_systems, nx, ny, nz_rfft)


def _batch_struct_output_shape(k_squared, *_):
    """Helper to compute output shape for structure factor in batch case."""
    if k_squared.dim() == 3:
        return k_squared.shape
    else:
        return k_squared.shape[1:]  # Remove batch dim


@warp_custom_op(
    name="alchemiops::_batch_pme_green_structure_factor",
    outputs=[
        OutputSpec(
            "green_function", wp.array(dtype=Any, ndim=4), _batch_green_output_shape
        ),
        OutputSpec(
            "structure_factor_sq",
            wp.array(dtype=Any, ndim=3),
            _batch_struct_output_shape,
        ),
    ],
    grad_arrays=["green_function", "k_squared", "alpha", "volumes"],
)
def _batch_pme_green_structure_factor(
    k_squared: torch.Tensor,
    miller_x: torch.Tensor,
    miller_y: torch.Tensor,
    miller_z: torch.Tensor,
    alpha: torch.Tensor,
    volumes: torch.Tensor,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
    num_systems: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Green's function and structure factor for batch PME.

    The Green's function includes volume normalization:

    .. math::

        G(k) = \\frac{4\\pi \\exp(-k^2/(4\\alpha^2))}{V k^2}

    Supports both float32 and float64 dtypes via kernel overloads.

    Parameters
    ----------
    k_squared : torch.Tensor, shape (B, Nx, Ny, Nz_rfft)
        :math:`|k|^2` for each grid point per system.
    miller_x, miller_y, miller_z : torch.Tensor
        Miller indices for each dimension.
    alpha : torch.Tensor, shape (B,)
        Per-system Ewald splitting parameter.
    volumes : torch.Tensor, shape (B,)
        Per-system cell volumes.
    num_systems : int
        Number of systems.

    Returns
    -------
    green_function : torch.Tensor, shape (B, Nx, Ny, Nz_rfft)
        Green's function values per system (volume-normalized).
    structure_factor_sq : torch.Tensor, shape (Nx, Ny, Nz_rfft)
        Structure factor :math:`C^2(k)` squared (same for all systems).
    """
    device = wp.device_from_torch(k_squared.device)
    if k_squared.dim() == 3:
        k_squared = k_squared.unsqueeze(0)
    input_dtype = k_squared.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    _, nx, ny, nz_rfft = k_squared.shape
    needs_grad_flag = needs_grad(k_squared, alpha, volumes)

    # Prepare inputs with appropriate dtype
    wp_k_squared = warp_from_torch(
        k_squared.contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_miller_x = warp_from_torch(
        miller_x.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_miller_y = warp_from_torch(
        miller_y.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_miller_z = warp_from_torch(
        miller_z.to(input_dtype).contiguous(), wp_dtype, requires_grad=False
    )
    wp_alpha = warp_from_torch(
        alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_volumes = warp_from_torch(
        volumes.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag
    )

    # Allocate outputs with input dtype
    green_function = torch.zeros(
        (num_systems, nx, ny, nz_rfft), dtype=input_dtype, device=k_squared.device
    )
    structure_factor_sq = torch.zeros(
        (nx, ny, nz_rfft), dtype=input_dtype, device=k_squared.device
    )

    wp_green = warp_from_torch(green_function, wp_dtype, requires_grad=needs_grad_flag)
    wp_struct = warp_from_torch(structure_factor_sq, wp_dtype, requires_grad=False)

    # Select kernel based on dtype
    kernel = _batch_pme_green_structure_factor_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_systems, nx, ny, nz_rfft),
            inputs=[
                wp_k_squared,
                wp_miller_x,
                wp_miller_y,
                wp_miller_z,
                wp_alpha,
                wp_volumes,
                wp.int32(mesh_nx),
                wp.int32(mesh_ny),
                wp.int32(mesh_nz),
                wp.int32(spline_order),
            ],
            outputs=[wp_green, wp_struct],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            green_function,
            tape=tape,
            green_function=wp_green,
            k_squared=wp_k_squared,
            alpha=wp_alpha,
            volumes=wp_volumes,
        )
    return green_function, structure_factor_sq


[docs] def pme_green_structure_factor( k_squared: torch.Tensor, mesh_dimensions: tuple[int, int, int], alpha: torch.Tensor, cell: torch.Tensor, spline_order: int = 4, batch_idx: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Green's function and B-spline structure factor correction. Computes the Coulomb Green's function with volume normalization and the B-spline aliasing correction factor for PME. Green's function (volume-normalized): .. math:: G(k) = \\frac{2\\pi}{V} \\frac{\\exp(-k^2/(4\\alpha^2))}{k^2} Structure factor correction (for B-spline deconvolution): .. math:: C^2(k) = \\left[\\text{sinc}(m_x/N_x) \\cdot \\text{sinc}(m_y/N_y) \\cdot \\text{sinc}(m_z/N_z)\\right]^{2p} where p is the spline order. Supports both float32 and float64 dtypes. Parameters ---------- k_squared : torch.Tensor :math:`|k|^2` values at each FFT grid point. - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft) mesh_dimensions : tuple[int, int, int] Full mesh dimensions (Nx, Ny, Nz) before rfft. alpha : torch.Tensor Ewald splitting parameter. - Single-system: shape (1,) - Batch: shape (B,) cell : torch.Tensor Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) spline_order : int, default=4 B-spline interpolation order (typically 4 for cubic B-splines). batch_idx : torch.Tensor | None, default=None If provided, dispatches to batch kernels. Returns ------- green_function : torch.Tensor Volume-normalized Green's function :math:`G(k)`. - Single-system: shape (Nx, Ny, Nz_rfft) - Batch: shape (B, Nx, Ny, Nz_rfft) structure_factor_sq : torch.Tensor Squared structure factor :math:`C^2(k)` for B-spline deconvolution. Shape (Nx, Ny, Nz_rfft), shared across batch. Notes ----- - :math:`G(k=0)` is set to zero to avoid singularity - The volume normalization in :math:`G(k)` eliminates later divisions - Structure factor is mesh-dependent only, so shared across batch """ mesh_nx, mesh_ny, mesh_nz = mesh_dimensions device = k_squared.device input_dtype = k_squared.dtype # Ensure cell is correct shape cell = cell if cell.dim() == 3 else cell.unsqueeze(0) volume = torch.abs(torch.det(cell)).to(input_dtype) # Generate Miller indices in input dtype miller_x = torch.fft.fftfreq( mesh_nx, d=1.0 / mesh_nx, device=device, dtype=input_dtype ) miller_y = torch.fft.fftfreq( mesh_ny, d=1.0 / mesh_ny, device=device, dtype=input_dtype ) miller_z = torch.fft.rfftfreq( mesh_nz, d=1.0 / mesh_nz, device=device, dtype=input_dtype ) if batch_idx is None: # Single system result = _pme_green_structure_factor( k_squared, miller_x, miller_y, miller_z, alpha.to(input_dtype), volume, mesh_nx, mesh_ny, mesh_nz, spline_order, ) else: # Batch - num_systems from k_squared shape num_systems = cell.shape[0] result = _batch_pme_green_structure_factor( k_squared, miller_x, miller_y, miller_z, alpha.to(input_dtype), volume, mesh_nx, mesh_ny, mesh_nz, spline_order, num_systems, ) return result
########################################################################################### ########################### PME Energy Corrections Custom Ops ############################# ########################################################################################### @warp_custom_op( name="alchemiops::_pme_energy_corrections", outputs=[ OutputSpec( "corrected_energies", wp.array(dtype=Any, ndim=1), lambda raw_energies, *_: (raw_energies.shape[0],), ), ], grad_arrays=[ "corrected_energies", "raw_energies", "charges", "volume", "alpha", "total_charge", ], ) def _pme_energy_corrections( raw_energies: torch.Tensor, charges: torch.Tensor, volume: torch.Tensor, alpha: torch.Tensor, total_charge: torch.Tensor, ) -> torch.Tensor: """Apply self-energy and background corrections to PME energies. Applies per-atom corrections: E_i = q_i * phi_i - self - background. The 1/2 pair-counting factor is already included in the Green's function (G = 2*pi/(V*k^2), not 4*pi/(V*k^2)), so no extra 0.5 is needed. Supports both float32 and float64 dtypes via kernel overloads. Parameters ---------- raw_energies : torch.Tensor, shape (N,) Raw interpolated energies from potential mesh. charges : torch.Tensor, shape (N,) Atomic charges. volume : torch.Tensor, shape (1,) Cell volume. alpha : torch.Tensor, shape (1,) Ewald splitting parameter. total_charge : torch.Tensor, shape (1,) Total system charge. Returns ------- corrected_energies : torch.Tensor, shape (N,) Corrected energies per atom. """ device = wp.device_from_torch(raw_energies.device) input_dtype = raw_energies.dtype wp_dtype = get_wp_dtype(input_dtype) num_atoms = raw_energies.shape[0] needs_grad_flag = needs_grad(raw_energies, charges, volume, alpha, total_charge) wp_raw = warp_from_torch( raw_energies.contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_charges = warp_from_torch( charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_volume = warp_from_torch( volume.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_alpha = warp_from_torch( alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_total_charge = warp_from_torch( total_charge.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag, ) corrected_energies = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) wp_corrected = warp_from_torch( corrected_energies, wp_dtype, requires_grad=needs_grad_flag ) # Select kernel based on dtype kernel = _pme_energy_corrections_kernel_overload[wp_dtype] with WarpAutogradContextManager(needs_grad_flag) as tape: wp.launch( kernel, dim=num_atoms, inputs=[ wp_raw, wp_charges, wp_volume, wp_alpha, wp_total_charge, ], outputs=[wp_corrected], device=device, ) if needs_grad_flag: attach_for_backward( corrected_energies, tape=tape, corrected_energies=wp_corrected, raw_energies=wp_raw, charges=wp_charges, volume=wp_volume, alpha=wp_alpha, total_charge=wp_total_charge, ) return corrected_energies @warp_custom_op( name="alchemiops::_batch_pme_energy_corrections", outputs=[ OutputSpec( "corrected_energies", wp.array(dtype=Any, ndim=1), lambda raw_energies, *_: (raw_energies.shape[0],), ), ], grad_arrays=[ "corrected_energies", "raw_energies", "charges", "volumes", "alpha", "total_charges", ], ) def _batch_pme_energy_corrections( raw_energies: torch.Tensor, charges: torch.Tensor, batch_idx: torch.Tensor, volumes: torch.Tensor, alpha: torch.Tensor, total_charges: torch.Tensor, ) -> torch.Tensor: """Apply corrections for batch PME. Uses unified prefactors. For energy-only calculations, the caller should multiply by 0.5 at the end. Supports both float32 and float64 dtypes via kernel overloads. Parameters ---------- raw_energies : torch.Tensor, shape (N_total,) Raw interpolated energies. charges : torch.Tensor, shape (N_total,) Atomic charges. batch_idx : torch.Tensor, shape (N_total,) System index for each atom. volumes : torch.Tensor, shape (B,) Cell volumes per system. alpha : torch.Tensor, shape (B,) Per-system Ewald splitting parameter. total_charges : torch.Tensor, shape (B,) Total charge per system. Returns ------- corrected_energies : torch.Tensor, shape (N_total,) Corrected energies per atom. """ device = wp.device_from_torch(raw_energies.device) input_dtype = raw_energies.dtype wp_dtype = get_wp_dtype(input_dtype) num_atoms = raw_energies.shape[0] needs_grad_flag = needs_grad(raw_energies, charges, volumes, alpha, total_charges) wp_raw = warp_from_torch( raw_energies.contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_charges = warp_from_torch( charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_batch_idx = warp_from_torch( batch_idx.contiguous(), wp.int32, requires_grad=False ) wp_volumes = warp_from_torch( volumes.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_alpha = warp_from_torch( alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_total_charges = warp_from_torch( total_charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag, ) corrected_energies = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) wp_corrected = warp_from_torch( corrected_energies, wp_dtype, requires_grad=needs_grad_flag ) # Select kernel based on dtype kernel = _batch_pme_energy_corrections_kernel_overload[wp_dtype] with WarpAutogradContextManager(needs_grad_flag) as tape: wp.launch( kernel, dim=num_atoms, inputs=[ wp_raw, wp_charges, wp_batch_idx, wp_volumes, wp_alpha, wp_total_charges, ], outputs=[wp_corrected], device=device, ) if needs_grad_flag: attach_for_backward( corrected_energies, tape=tape, corrected_energies=wp_corrected, raw_energies=wp_raw, charges=wp_charges, volumes=wp_volumes, alpha=wp_alpha, total_charges=wp_total_charges, ) return corrected_energies @warp_custom_op( name="alchemiops::_pme_energy_corrections_with_charge_grad", outputs=[ OutputSpec( "corrected_energies", wp.float64, lambda raw_energies, *_: (raw_energies.shape[0],), ), OutputSpec( "charge_gradients", wp.float64, lambda raw_energies, *_: (raw_energies.shape[0],), ), ], grad_arrays=[ "corrected_energies", "charge_gradients", "raw_energies", "charges", "volume", "alpha", "total_charge", ], ) def _pme_energy_corrections_with_charge_grad( raw_energies: torch.Tensor, charges: torch.Tensor, volume: torch.Tensor, alpha: torch.Tensor, total_charge: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply self-energy and background corrections and compute charge gradients. Computes both corrected energies and analytical charge gradients for PME. Parameters ---------- raw_energies : torch.Tensor, shape (N,) Raw interpolated potential φ_i from mesh. charges : torch.Tensor, shape (N,) Atomic charges. volume : torch.Tensor, shape (1,) Cell volume. alpha : torch.Tensor, shape (1,) Ewald splitting parameter. total_charge : torch.Tensor, shape (1,) Total system charge. Returns ------- corrected_energies : torch.Tensor, shape (N,) Corrected energies per atom. charge_gradients : torch.Tensor, shape (N,) Analytical charge gradients ∂E/∂q_i. """ device = wp.device_from_torch(raw_energies.device) input_dtype = raw_energies.dtype wp_dtype = get_wp_dtype(input_dtype) num_atoms = raw_energies.shape[0] needs_grad_flag = needs_grad(raw_energies, charges, volume, alpha, total_charge) wp_raw = warp_from_torch( raw_energies.contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_charges = warp_from_torch( charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_volume = warp_from_torch( volume.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_alpha = warp_from_torch( alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_total_charge = warp_from_torch( total_charge.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag, ) corrected_energies = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) charge_gradients = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) wp_corrected = warp_from_torch( corrected_energies, wp_dtype, requires_grad=needs_grad_flag ) wp_charge_grads = warp_from_torch( charge_gradients, wp_dtype, requires_grad=needs_grad_flag ) kernel = _pme_energy_corrections_with_charge_grad_kernel_overload[wp_dtype] with WarpAutogradContextManager(needs_grad_flag) as tape: wp.launch( kernel, dim=num_atoms, inputs=[ wp_raw, wp_charges, wp_volume, wp_alpha, wp_total_charge, ], outputs=[wp_corrected, wp_charge_grads], device=device, ) if needs_grad_flag: attach_for_backward( corrected_energies, tape=tape, corrected_energies=wp_corrected, charge_gradients=wp_charge_grads, raw_energies=wp_raw, charges=wp_charges, volume=wp_volume, alpha=wp_alpha, total_charge=wp_total_charge, ) return corrected_energies, charge_gradients @warp_custom_op( name="alchemiops::_batch_pme_energy_corrections_with_charge_grad", outputs=[ OutputSpec( "corrected_energies", wp.float64, lambda raw_energies, *_: (raw_energies.shape[0],), ), OutputSpec( "charge_gradients", wp.float64, lambda raw_energies, *_: (raw_energies.shape[0],), ), ], grad_arrays=[ "corrected_energies", "charge_gradients", "raw_energies", "charges", "volumes", "alpha", "total_charges", ], ) def _batch_pme_energy_corrections_with_charge_grad( raw_energies: torch.Tensor, charges: torch.Tensor, batch_idx: torch.Tensor, volumes: torch.Tensor, alpha: torch.Tensor, total_charges: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply corrections and compute charge gradients for batch PME. Parameters ---------- raw_energies : torch.Tensor, shape (N_total,) Raw interpolated potential. charges : torch.Tensor, shape (N_total,) Atomic charges. batch_idx : torch.Tensor, shape (N_total,) System index for each atom. volumes : torch.Tensor, shape (B,) Cell volumes per system. alpha : torch.Tensor, shape (B,) Per-system Ewald splitting parameter. total_charges : torch.Tensor, shape (B,) Total charge per system. Returns ------- corrected_energies : torch.Tensor, shape (N_total,) Corrected energies per atom. charge_gradients : torch.Tensor, shape (N_total,) Analytical charge gradients ∂E/∂q_i. """ device = wp.device_from_torch(raw_energies.device) input_dtype = raw_energies.dtype wp_dtype = get_wp_dtype(input_dtype) num_atoms = raw_energies.shape[0] needs_grad_flag = needs_grad(raw_energies, charges, volumes, alpha, total_charges) wp_raw = warp_from_torch( raw_energies.contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_charges = warp_from_torch( charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_batch_idx = warp_from_torch( batch_idx.contiguous(), wp.int32, requires_grad=False ) wp_volumes = warp_from_torch( volumes.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_alpha = warp_from_torch( alpha.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag ) wp_total_charges = warp_from_torch( total_charges.to(input_dtype).contiguous(), wp_dtype, requires_grad=needs_grad_flag, ) corrected_energies = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) charge_gradients = torch.zeros( num_atoms, dtype=input_dtype, device=raw_energies.device ) wp_corrected = warp_from_torch( corrected_energies, wp_dtype, requires_grad=needs_grad_flag ) wp_charge_grads = warp_from_torch( charge_gradients, wp_dtype, requires_grad=needs_grad_flag ) kernel = _batch_pme_energy_corrections_with_charge_grad_kernel_overload[wp_dtype] with WarpAutogradContextManager(needs_grad_flag) as tape: wp.launch( kernel, dim=num_atoms, inputs=[ wp_raw, wp_charges, wp_batch_idx, wp_volumes, wp_alpha, wp_total_charges, ], outputs=[wp_corrected, wp_charge_grads], device=device, ) if needs_grad_flag: attach_for_backward( corrected_energies, tape=tape, corrected_energies=wp_corrected, charge_gradients=wp_charge_grads, raw_energies=wp_raw, charges=wp_charges, volumes=wp_volumes, alpha=wp_alpha, total_charges=wp_total_charges, ) return corrected_energies, charge_gradients
[docs] def pme_energy_corrections( raw_energies: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: torch.Tensor, batch_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Apply self-energy and background corrections to PME energies. Converts raw interpolated potential to energy and subtracts corrections: .. math:: E_i = q_i \\phi_i - E_{\\text{self},i} - E_{\\text{background},i} Self-energy correction (removes Gaussian self-interaction): .. math:: E_{\\text{self},i} = \\frac{\\alpha}{\\sqrt{\\pi}} q_i^2 Background correction (for non-neutral systems): .. math:: E_{\\text{background},i} = \\frac{\\pi}{2\\alpha^2 V} q_i Q_{\\text{total}} Parameters ---------- raw_energies : torch.Tensor, shape (N,) or (N_total,) Raw potential values :math:`\\phi_i` from mesh interpolation. charges : torch.Tensor, shape (N,) or (N_total,) Atomic charges. cell : torch.Tensor Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) alpha : torch.Tensor Ewald splitting parameter. - Single-system: shape (1,) - Batch: shape (B,) batch_idx : torch.Tensor | None, default=None System index for each atom. If provided, uses batch kernels. Returns ------- corrected_energies : torch.Tensor, shape (N,) or (N_total,) Final per-atom reciprocal-space energy with corrections applied. Notes ----- - For neutral systems, background correction is zero - Matches torchpme's self_contribution and background_correction formulas - Supports both float32 and float64 dtypes """ input_dtype = raw_energies.dtype if batch_idx is None: # Single system - ensure tensors are 1D for kernel indexing total_charge = charges.sum().reshape(1) volume = torch.abs(torch.det(cell)).reshape(1) result = _pme_energy_corrections( raw_energies, charges.to(input_dtype), volume.to(input_dtype), alpha.to(input_dtype), total_charge.to(input_dtype), ) else: # Batch num_systems = cell.shape[0] volumes = torch.abs(torch.linalg.det(cell)).to(input_dtype) # Compute total charge per system total_charges = torch.zeros( num_systems, dtype=input_dtype, device=raw_energies.device ) total_charges.scatter_add_(0, batch_idx, charges.to(input_dtype)) result = _batch_pme_energy_corrections( raw_energies, charges.to(input_dtype), batch_idx, volumes, alpha.to(input_dtype), total_charges, ) return result
[docs] def pme_energy_corrections_with_charge_grad( raw_energies: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: torch.Tensor, batch_idx: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply corrections and compute charge gradients for PME energies. Computes both corrected energies and analytical charge gradients: E_i = q_i * φ_i - E_self_i - E_background_i ∂E/∂q_i = 2*φ_i - 2*(α/√π)*q_i - (π/(α²V))*Q_total The factor of 2 on φ_i arises because changing q_i affects both the direct energy term (q_i * φ_i) and all other potentials through the structure factor (∑_j q_j * ∂φ_j/∂q_i = φ_i). Parameters ---------- raw_energies : torch.Tensor, shape (N,) or (N_total,) Raw potential values φ_i from mesh interpolation. charges : torch.Tensor, shape (N,) or (N_total,) Atomic charges. cell : torch.Tensor Unit cell matrices. - Single-system: shape (3, 3) or (1, 3, 3) - Batch: shape (B, 3, 3) alpha : torch.Tensor Ewald splitting parameter. - Single-system: shape (1,) - Batch: shape (B,) batch_idx : torch.Tensor | None, default=None System index for each atom. If provided, uses batch kernels. Returns ------- corrected_energies : torch.Tensor, shape (N,) or (N_total,) Final per-atom reciprocal-space energy with corrections applied. charge_gradients : torch.Tensor, shape (N,) or (N_total,) Analytical charge gradients ∂E/∂q_i. """ input_dtype = raw_energies.dtype if batch_idx is None: # Single system total_charge = charges.sum().reshape(1) volume = torch.abs(torch.det(cell)).reshape(1) return _pme_energy_corrections_with_charge_grad( raw_energies, charges.to(input_dtype), volume.to(input_dtype), alpha.to(input_dtype), total_charge.to(input_dtype), ) else: # Batch num_systems = cell.shape[0] volumes = torch.abs(torch.linalg.det(cell)).to(input_dtype) # Compute total charge per system total_charges = torch.zeros( num_systems, dtype=input_dtype, device=raw_energies.device ) total_charges.scatter_add_(0, batch_idx, charges.to(input_dtype)) return _batch_pme_energy_corrections_with_charge_grad( raw_energies, charges.to(input_dtype), batch_idx, volumes, alpha.to(input_dtype), total_charges, )
########################################################################################### ########################### Unified PME Reciprocal Space ################################## ########################################################################################### def _compute_pme_reciprocal_virial( mesh_fft_raw: torch.Tensor, convolved_mesh: torch.Tensor, k_vectors: torch.Tensor, k_squared: torch.Tensor, alpha: torch.Tensor, mesh_dimensions: tuple[int, int, int], is_batch: bool, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """Compute PME reciprocal-space virial tensor in k-space. Uses the exact spectral pair from the pipeline (mesh_fft_raw before deconvolution, and convolved_mesh after Green's function multiplication) to compute the per-k energy density directly via Parseval's theorem. The virial per k-point is W_ab(k) = E_k * sigma_ab(k) where: - E_k = prefactor * weight(k) * Re(mesh_fft_raw(k) * convolved_mesh(k)*) - sigma_ab(k) = delta_ab - 2*k_a*k_b/k^2 * (1 + k^2/(4*alpha^2)) (sign reflects W = -dE/dε convention) Parameters ---------- mesh_fft_raw : torch.Tensor Raw rfftn output before B-spline deconvolution. Shape (nx, ny, nz//2+1) or (B, nx, ny, nz//2+1), complex. convolved_mesh : torch.Tensor Deconvolved mesh FFT multiplied by Green's function: (mesh_fft/B^2)*G. Shape matching mesh_fft_raw. k_vectors : torch.Tensor k-vectors on the mesh. Shape (..., nx, ny, nz//2+1, 3). k_squared : torch.Tensor |k|^2. Shape (..., nx, ny, nz//2+1). alpha : torch.Tensor Ewald splitting parameter. mesh_dimensions : tuple (nx, ny, nz). is_batch : bool Whether this is a batched calculation. device : torch.device Computation device. dtype : torch.dtype Output dtype. Returns ------- virial : torch.Tensor, shape (B, 3, 3) or (1, 3, 3) Per-system virial tensor. """ mesh_nx, mesh_ny, mesh_nz = mesh_dimensions # Per-k energy density from exact pipeline spectral pair. # Re(mesh_fft_raw * convolved_mesh*) = |mesh_fft_raw|^2 * G / B^2 # # Explicit complex/real dtype mapping is needed because `dtype` is a # real-valued dtype (float32 or float64) but the FFT mesh data is complex. # PyTorch has no implicit real-to-complex dtype promotion, so we map # float32 -> complex64 and float64 -> complex128 explicitly. complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 acc_dtype = dtype # real accumulation dtype matches input precision fft_raw_cast = mesh_fft_raw.to(complex_dtype) conv_cast = convolved_mesh.to(complex_dtype) energy_density = (fft_raw_cast * conv_cast.conj()).real # Weight for rfft symmetry: 2 for interior k_z, 1 for boundary weight = torch.full_like(energy_density, 2.0) weight[..., 0] = 1.0 # k_z = 0 if mesh_nz % 2 == 0: weight[..., -1] = 1.0 # k_z = nz//2 (Nyquist) # Weighted energy density weighted_energy = weight * energy_density # Virial W = -dE/dε, so sigma_ab = delta_ab - 2*k_a*k_b/k^2 * (1 + k^2/(4*alpha^2)) k_sq_acc = k_squared.to(acc_dtype) alpha_acc = alpha.to(acc_dtype) # Handle alpha broadcasting: alpha may be (B,) for batch if is_batch and alpha_acc.dim() == 1: alpha_view = alpha_acc.view(-1, 1, 1, 1) else: alpha_view = alpha_acc.view(-1) if alpha_acc.dim() == 0 else alpha_acc exp_factor = 0.25 / (alpha_view**2) # Avoid division by zero at k=0 safe_k_sq = k_sq_acc.clamp(min=1e-30) k_factor = 2.0 * (1.0 + k_sq_acc * exp_factor) / safe_k_sq # Zero out k=0 contribution (no virial from k=0) k_mask = k_sq_acc > 1e-10 # Vectorized virial computation using einsum # virial_ab = sum_k weighted_energy * (delta_ab - k_factor * k_a * k_b) * k_mask # = delta_ab * sum_k (weighted_energy * k_mask) - sum_k (weighted_energy * k_mask * k_factor) * k_a * k_b k_vecs_acc = k_vectors.to(acc_dtype) # (..., nx, ny, nz//2+1, 3) masked_energy = weighted_energy * k_mask # (..., nx, ny, nz//2+1) masked_energy_kf = masked_energy * k_factor # (..., nx, ny, nz//2+1) # Sum dimensions depend on batch vs single if is_batch: sum_dims = (1, 2, 3) # sum over (nx, ny, nz//2+1) else: sum_dims = (0, 1, 2) # sum over (nx, ny, nz//2+1) # Trace term: delta_ab * sum_k masked_energy trace_term = masked_energy.sum(dim=sum_dims) # scalar or (B,) # kk term: sum_k masked_energy_kf * k_a * k_b # k_vecs_acc has shape (..., nx, ny, nz//2+1, 3) # masked_energy_kf has shape (..., nx, ny, nz//2+1) # Use einsum for vectorized outer product + reduction if is_batch: # k_vecs: (B, nx, ny, nz_half, 3), masked_energy_kf: (B, nx, ny, nz_half) kk_term = torch.einsum( "b...i,b...j,b...->bij", k_vecs_acc, k_vecs_acc, masked_energy_kf ) # (B, 3, 3) eye = torch.eye(3, device=device, dtype=acc_dtype) virial = eye * trace_term[:, None, None] - kk_term # (B, 3, 3) else: kk_term = torch.einsum( "...i,...j,...->ij", k_vecs_acc, k_vecs_acc, masked_energy_kf ) # (3, 3) eye = torch.eye(3, device=device, dtype=acc_dtype) virial = (eye * trace_term - kk_term).unsqueeze(0) # (1, 3, 3) return virial.to(dtype) def _pme_reciprocal_space_impl( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: torch.Tensor, mesh_dimensions: tuple[int, int, int], spline_order: int, batch_idx: torch.Tensor | None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, k_vectors: torch.Tensor | None = None, k_squared: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: """Internal implementation of PME reciprocal space calculation. Uses unified spline functions from nvalchemiops.spline for charge assignment and potential interpolation, and Warp kernels for Green's function and corrections. Supports both float32 and float64 dtypes - all operations are performed in the input dtype without conversion. """ device = positions.device input_dtype = positions.dtype num_atoms = positions.shape[0] is_batch = batch_idx is not None fft_dims = (1, 2, 3) if is_batch else (0, 1, 2) if num_atoms == 0: energies = torch.zeros(num_atoms, device=device, dtype=input_dtype) forces = ( torch.zeros(num_atoms, 3, device=device, dtype=input_dtype) if compute_forces else None ) charge_grads = ( torch.zeros(num_atoms, device=device, dtype=input_dtype) if compute_charge_gradients else None ) num_systems = cell.shape[0] if is_batch else 1 virial = ( torch.zeros(num_systems, 3, 3, device=device, dtype=input_dtype) if compute_virial else None ) return energies, forces, charge_grads, virial mesh_nx, mesh_ny, mesh_nz = mesh_dimensions # Precompute cell inverse ONCE and derive what we need for all operations cell_inv = torch.linalg.inv_ex(cell)[0] cell_inv_t = cell_inv.transpose(-1, -2).contiguous() reciprocal_cell = TWOPI * cell_inv # Step 1: Charge assignment using unified spline_spread API mesh_grid = spline_spread( positions, charges, cell, mesh_dims=(mesh_nx, mesh_ny, mesh_nz), spline_order=spline_order, batch_idx=batch_idx, cell_inv_t=cell_inv_t, ) # Step 2: FFT of charge mesh mesh_fft = torch.fft.rfftn(mesh_grid, norm="backward", dim=fft_dims) # Step 3: Generate k-space grid and compute Green's function + structure factor # Green's function: G(k) = 2*pi * exp(-k^2/(4*alpha^2)) / (V * k^2) # (includes 1/2 pair-counting factor; see pme_kernels.py) # Use precomputed k_vectors/k_squared if provided, otherwise generate them if k_vectors is None or k_squared is None: k_vectors, k_squared = generate_k_vectors_pme( cell, mesh_dimensions=mesh_dimensions, reciprocal_cell=reciprocal_cell ) green_function, structure_factor_sq = pme_green_structure_factor( k_squared, mesh_dimensions, alpha, cell, spline_order, batch_idx=batch_idx, ) # Save reference to raw FFT before deconvolution (needed for virial). # No clone needed: the reassignment below creates a new tensor. mesh_fft_raw = mesh_fft if compute_virial else None # Step 4: Apply B-spline deconvolution and convolve with Green's function mesh_fft = mesh_fft / structure_factor_sq convolved_mesh = mesh_fft * green_function # Step 5: Inverse FFT to get potential mesh potential_mesh = torch.fft.irfftn( convolved_mesh, norm="forward", s=mesh_dimensions, dim=fft_dims ) potential_mesh = potential_mesh.to(input_dtype) # Step 6: Interpolate potential to atomic positions using unified spline_gather API # Note: raw_energies are already volume-normalized from Green's function raw_energies = spline_gather( positions, potential_mesh, cell, spline_order=spline_order, batch_idx=batch_idx, cell_inv_t=cell_inv_t, ) # Step 7: Apply corrections using Warp kernel # Use charge gradient version if requested charge_grads = None if compute_charge_gradients: reciprocal_energies, charge_grads = pme_energy_corrections_with_charge_grad( raw_energies, charges, cell, alpha, batch_idx ) else: reciprocal_energies = pme_energy_corrections( raw_energies, charges, cell, alpha, batch_idx ) # Step 8: Compute virial before forces to allow early release of mesh_fft_raw # (virial needs mesh_fft_raw; forces only need convolved_mesh) virial = None if compute_virial: virial = _compute_pme_reciprocal_virial( mesh_fft_raw=mesh_fft_raw, convolved_mesh=convolved_mesh, k_vectors=k_vectors, k_squared=k_squared, alpha=alpha, mesh_dimensions=mesh_dimensions, is_batch=is_batch, device=device, dtype=input_dtype, ) del mesh_fft_raw # Free before force field meshes are allocated # Step 9: Compute forces if needed forces = None if compute_forces: # Compute electric field by taking gradient in Fourier space # Note: convolved_mesh is already volume-normalized from Green's function Ex_fft = -1j * k_vectors[..., 0] * convolved_mesh Ey_fft = -1j * k_vectors[..., 1] * convolved_mesh Ez_fft = -1j * k_vectors[..., 2] * convolved_mesh Ex = torch.fft.irfftn(Ex_fft, norm="forward", s=mesh_dimensions, dim=fft_dims) Ey = torch.fft.irfftn(Ey_fft, norm="forward", s=mesh_dimensions, dim=fft_dims) Ez = torch.fft.irfftn(Ez_fft, norm="forward", s=mesh_dimensions, dim=fft_dims) electric_field_mesh = torch.stack([Ex, Ey, Ez], dim=-1).to(input_dtype) # Use unified spline_gather_vec3 API to interpolate electric field interpolated_field = spline_gather_vec3( positions, charges, electric_field_mesh, cell, spline_order=spline_order, batch_idx=batch_idx, cell_inv_t=cell_inv_t, ) # Compute forces: F = 2 * q * E / V forces = 2.0 * interpolated_field return reciprocal_energies, forces, charge_grads, virial
[docs] def pme_reciprocal_space( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: float | torch.Tensor, mesh_dimensions: tuple[int, int, int] | None = None, mesh_spacing: float | None = None, spline_order: int = 4, batch_idx: torch.Tensor | None = None, k_vectors: torch.Tensor | None = None, k_squared: torch.Tensor | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute PME reciprocal-space energy and optionally forces and/or charge gradients. Performs the FFT-based reciprocal-space calculation using the Particle Mesh Ewald algorithm. This achieves O(N log N) scaling through: 1. B-spline charge interpolation to mesh (spreading) 2. FFT of charge mesh to reciprocal space 3. Convolution with Green's function (multiply by G(k)) 4. Inverse FFT back to real space (potential mesh) 5. B-spline interpolation of potential to atoms (gathering) 6. Self-energy and background corrections Formula ------- The reciprocal-space energy is computed via the mesh potential: .. math:: \\varphi_{\\text{mesh}}(k) = G(k) \\times B^2(k) \\times \\rho_{\\text{mesh}}(k) where: - :math:`G(k) = (4\\pi/k^2) \\times \\exp(-k^2/(4\\alpha^2))` is the Green's function - :math:`B(k)` is the B-spline structure factor (interpolation correction) - :math:`\\rho_{\\text{mesh}}(k)` is the FFT of interpolated charges Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic coordinates. Supports float32 or float64 dtype. charges : torch.Tensor, shape (N,) Atomic partial charges in elementary charge units. cell : torch.Tensor, shape (3, 3) or (B, 3, 3) Unit cell matrices with lattice vectors as rows. Shape (3, 3) is automatically promoted to (1, 3, 3). alpha : float or torch.Tensor Ewald splitting parameter controlling real/reciprocal space balance. - float: Same α for all systems - Tensor shape (B,): Per-system α values mesh_dimensions : tuple[int, int, int], optional Explicit FFT mesh dimensions (nx, ny, nz). Power-of-2 values are optimal for FFT performance. Either mesh_dimensions or mesh_spacing must be provided. mesh_spacing : float, optional Target mesh spacing in same units as cell. Mesh dimensions computed as ceil(cell_length / mesh_spacing). Typical value: ~1 Å. spline_order : int, default=4 B-spline interpolation order. Higher orders are more accurate but slower. - 4: Cubic B-splines (good balance, most common) - 5-6: Higher accuracy for demanding applications - Must be ≥ 3 for smooth interpolation batch_idx : torch.Tensor, shape (N,), dtype=int32, optional System index for each atom (0 to B-1). Determines kernel dispatch: - None: Single-system optimized kernels - Provided: Batched kernels for multiple independent systems k_vectors : torch.Tensor, shape (nx, ny, nz//2+1, 3), optional Precomputed k-vectors from ``generate_k_vectors_pme``. Providing this along with k_squared skips k-vector generation (~15% speedup). Can be precomputed once and reused when cell and mesh are unchanged. k_squared : torch.Tensor, shape (nx, ny, nz//2+1), optional Precomputed :math:`|k|^2` values. Must be provided together with k_vectors. compute_forces : bool, default=False Whether to compute explicit reciprocal-space forces. compute_charge_gradients : bool, default=False Whether to compute analytical charge gradients ∂E/∂q_i. Useful for computing charge Hessians in ML potential training. compute_virial : bool, default=False Whether to compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. Returns ------- energies : torch.Tensor, shape (N,) Per-atom reciprocal-space energy (includes self and background corrections). forces : torch.Tensor, shape (N, 3), optional Reciprocal-space forces. Only returned if compute_forces=True. charge_gradients : torch.Tensor, shape (N,), optional Charge gradients ∂E_recip/∂q_i. Only returned if compute_charge_gradients=True. virial : torch.Tensor, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor. Only returned if compute_virial=True. Always last in tuple. Note ---- Energies are always float64 for numerical stability during accumulation. Forces and virial match the input dtype (float32 or float64). Return Patterns --------------- Enabled flags are appended in order: energies, [forces], [charge_gradients], [virial]. A single output is returned unwrapped; multiple outputs as a tuple. Raises ------ ValueError If neither mesh_dimensions nor mesh_spacing is provided. Examples -------- Energy only with explicit mesh dimensions:: >>> energies = pme_reciprocal_space( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=(32, 32, 32), ... ) >>> total_recip_energy = energies.sum() With forces using mesh spacing:: >>> energies, forces = pme_reciprocal_space( ... positions, charges, cell, ... alpha=0.3, mesh_spacing=1.0, ... compute_forces=True, ... ) Precomputed k-vectors for MD loop (fixed cell):: >>> from nvalchemiops.torch.interactions.electrostatics import generate_k_vectors_pme >>> mesh_dims = (32, 32, 32) >>> k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dims) >>> for step in range(num_steps): ... energies = pme_reciprocal_space( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=mesh_dims, ... k_vectors=k_vectors, k_squared=k_squared, ... ) With charge gradients for ML training:: >>> energies, charge_grads = pme_reciprocal_space( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=(32, 32, 32), ... compute_charge_gradients=True, ... ) See Also -------- particle_mesh_ewald : Complete PME calculation (real + reciprocal). generate_k_vectors_pme : Generate k-vectors for this function. pme_green_structure_factor : Compute Green's function on mesh. """ cell, num_systems = _prepare_cell(cell) alpha_tensor = _prepare_alpha(alpha, num_systems, torch.float64, positions.device) # Determine mesh dimensions if mesh_dimensions is None: if mesh_spacing is None: raise ValueError("Either mesh_dimensions or mesh_spacing must be provided") cell_lengths = torch.norm(cell[0], dim=1) mesh_dimensions = tuple( int(torch.ceil(length / mesh_spacing).item()) for length in cell_lengths ) energies, forces, charge_grads, virial = _pme_reciprocal_space_impl( positions, charges, cell, alpha_tensor, mesh_dimensions, spline_order, batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, k_vectors=k_vectors, k_squared=k_squared, ) # Build return tuple based on flags match (compute_forces, compute_charge_gradients, compute_virial): case (True, True, True): return energies, forces, charge_grads, virial case (True, True, False): return energies, forces, charge_grads case (True, False, True): return energies, forces, virial case (True, False, False): return energies, forces case (False, True, True): return energies, charge_grads, virial case (False, True, False): return energies, charge_grads case (False, False, True): return energies, virial case _: return energies
########################################################################################### ########################### Unified PME API ############################################### ###########################################################################################
[docs] def particle_mesh_ewald( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: float | torch.Tensor | None = None, mesh_spacing: float | None = None, mesh_dimensions: tuple[int, int, int] | None = None, spline_order: int = 4, batch_idx: torch.Tensor | None = None, k_vectors: torch.Tensor | None = None, k_squared: torch.Tensor | None = None, neighbor_list: torch.Tensor | None = None, neighbor_ptr: torch.Tensor | None = None, neighbor_shifts: torch.Tensor | None = None, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, mask_value: int | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, accuracy: float = 1e-6, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Complete Particle Mesh Ewald (PME) calculation for long-range electrostatics. Computes total Coulomb energy using the PME method, which achieves :math:`O(N \\log N)` scaling through FFT-based reciprocal space calculations. Combines: 1. Real-space contribution (short-range, erfc-damped) 2. Reciprocal-space contribution (long-range, FFT + B-spline interpolation) 3. Self-energy and background corrections Total Energy Formula: .. math:: E_{\\text{total}} = E_{\\text{real}} + E_{\\text{reciprocal}} - E_{\\text{self}} - E_{\\text{background}} where: .. math:: E_{\\text{real}} = \\frac{1}{2} \\sum_{i \\neq j} q_i q_j \\frac{\\text{erfc}(\\alpha r_{ij}/\\sqrt{2})}{r_{ij}} E_{\\text{reciprocal}} = FFT-based smooth long-range contribution E_{\\text{self}} = \\sum_i \\frac{\\alpha}{\\sqrt{2\\pi}} q_i^2 E_{\\text{background}} = \\frac{\\pi}{2\\alpha^2 V} Q_{\\text{total}}^2 Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic coordinates. Supports float32 or float64 dtype. charges : torch.Tensor, shape (N,) Atomic partial charges in elementary charge units. cell : torch.Tensor, shape (3, 3) or (B, 3, 3) Unit cell matrices with lattice vectors as rows. Shape (3, 3) is automatically promoted to (1, 3, 3) for single-system mode. alpha : float, torch.Tensor, or None, default=None Ewald splitting parameter controlling real/reciprocal space balance. - float: Same α for all systems - Tensor shape (B,): Per-system α values - None: Automatically estimated using Kolafa-Perram formula Larger α shifts more computation to reciprocal space. mesh_spacing : float, optional Target mesh spacing in same units as cell (typically Å). Mesh dimensions computed as ceil(cell_length / mesh_spacing). Typical value: 0.8-1.2 Å. mesh_dimensions : tuple[int, int, int], optional Explicit FFT mesh dimensions (nx, ny, nz). Power-of-2 values recommended for optimal FFT performance. If None and mesh_spacing is None, computed from accuracy parameter. spline_order : int, default=4 B-spline interpolation order. Higher orders are more accurate but slower. - 4: Cubic B-splines (standard, good accuracy/speed balance) - 5-6: Higher accuracy for demanding applications batch_idx : torch.Tensor, shape (N,), dtype=int32, optional System index for each atom (0 to B-1). Determines execution mode: - None: Single-system optimized kernels - Provided: Batched kernels for multiple independent systems k_vectors : torch.Tensor, shape (nx, ny, nz//2+1, 3), optional Precomputed k-vectors from ``generate_k_vectors_pme``. Providing this along with k_squared skips k-vector generation (~15% speedup). Useful for fixed-cell MD simulations (NVT/NVE). k_squared : torch.Tensor, shape (nx, ny, nz//2+1), optional Precomputed :math:`|k|^2` values. Must be provided together with k_vectors. neighbor_list : torch.Tensor, shape (2, M), dtype=int32, optional Neighbor pairs for real-space in COO format. Row 0 = source indices, row 1 = target indices. Mutually exclusive with neighbor_matrix. neighbor_ptr : torch.Tensor, shape (N+1,), dtype=int32, optional CSR row pointers for neighbor_list. neighbor_ptr[i] gives the starting index in neighbor_list for atom i's neighbors. Required with neighbor_list. neighbor_shifts : torch.Tensor, shape (M, 3), dtype=int32, optional Periodic image shifts for neighbor_list. Required with neighbor_list. neighbor_matrix : torch.Tensor, shape (N, max_neighbors), dtype=int32, optional Dense neighbor matrix format. Entry [i, k] = j means j is k-th neighbor of i. Invalid entries should be set to mask_value. Mutually exclusive with neighbor_list. neighbor_matrix_shifts : torch.Tensor, shape (N, max_neighbors, 3), dtype=int32, optional Periodic image shifts for neighbor_matrix. Required with neighbor_matrix. mask_value : int, optional Value indicating invalid entries in neighbor_matrix. Defaults to N. compute_forces : bool, default=False Whether to compute explicit analytical forces. compute_charge_gradients : bool, default=False Whether to compute analytical charge gradients ∂E/∂q_i. Useful for training ML potentials that require second derivatives (charge Hessians). compute_virial : bool, default=False Whether to compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. accuracy : float, default=1e-6 Target relative accuracy for automatic parameter estimation (α, mesh dims). Only used when alpha or mesh_dimensions is None. Smaller values increase accuracy but also computational cost. Returns ------- energies : torch.Tensor, shape (N,) Per-atom contribution to total PME energy. Sum gives total energy. forces : torch.Tensor, shape (N, 3), optional Forces on each atom. Only returned if compute_forces=True. charge_gradients : torch.Tensor, shape (N,), optional Charge gradients ∂E/∂q_i. Only returned if compute_charge_gradients=True. virial : torch.Tensor, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor. Only returned if compute_virial=True. Always last in tuple. Note ---- Energies are always float64 for numerical stability during accumulation. Forces and virial match the input dtype (float32 or float64). Return Patterns --------------- Enabled flags are appended in order: energies, [forces], [charge_gradients], [virial]. A single output is returned unwrapped; multiple outputs as a tuple. Raises ------ ValueError If neither neighbor_list nor neighbor_matrix is provided for real-space. TypeError If alpha has an unsupported type. Examples -------- Automatic parameter estimation (recommended for most cases):: >>> energies = particle_mesh_ewald( ... positions, charges, cell, ... neighbor_list=nl, neighbor_shifts=shifts, ... neighbor_ptr=nptr, accuracy=1e-6, ... ) >>> total_energy = energies.sum() Explicit parameters for reproducibility:: >>> energies, forces = particle_mesh_ewald( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=(32, 32, 32), ... spline_order=4, neighbor_list=nl, ... neighbor_shifts=shifts, neighbor_ptr=nptr, ... compute_forces=True, ... ) Using mesh spacing for automatic mesh sizing:: >>> energies, forces = particle_mesh_ewald( ... positions, charges, cell, ... alpha=0.3, mesh_spacing=1.0, # ~1 Å spacing ... neighbor_list=nl, neighbor_shifts=shifts, ... neighbor_ptr=nptr, compute_forces=True, ... ) Batched systems (multiple independent structures):: >>> # positions: concatenated atoms from all systems >>> # batch_idx: [0,0,0,0, 1,1,1,1, 2,2,2,2] for 4 atoms × 3 systems >>> energies, forces = particle_mesh_ewald( ... positions, charges, cells, # cells shape (3, 3, 3) ... alpha=torch.tensor([0.3, 0.35, 0.3]), ... batch_idx=batch_idx, ... mesh_dimensions=(32, 32, 32), ... neighbor_list=nl, ... neighbor_shifts=shifts, neighbor_ptr=nptr, ... compute_forces=True, ... ) Precomputed k-vectors for MD loop (fixed cell):: >>> from nvalchemiops.torch.interactions.electrostatics import generate_k_vectors_pme >>> mesh_dims = (32, 32, 32) >>> k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dims) >>> for step in range(num_steps): ... energies, forces = particle_mesh_ewald( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=mesh_dims, ... k_vectors=k_vectors, k_squared=k_squared, ... neighbor_list=nl, neighbor_shifts=shifts, ... neighbor_ptr=nptr, ... compute_forces=True, ... ) With charge gradients for ML training:: >>> energies, forces, charge_grads = particle_mesh_ewald( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=(32, 32, 32), ... neighbor_list=nl, neighbor_shifts=shifts, ... neighbor_ptr=nptr, ... compute_forces=True, compute_charge_gradients=True, ... ) >>> # Use charge_grads for training on ∂E/∂q Using PyTorch autograd:: >>> positions.requires_grad_(True) >>> energies = particle_mesh_ewald( ... positions, charges, cell, ... alpha=0.3, mesh_dimensions=(32, 32, 32), ... neighbor_list=nl, neighbor_shifts=shifts, ... neighbor_ptr=nptr, ... ) >>> total_energy = energies.sum() >>> total_energy.backward() >>> autograd_forces = -positions.grad # Should match explicit forces Notes ----- Automatic Parameter Estimation (when alpha is None): Uses Kolafa-Perram formula: .. math:: \\begin{aligned} \\eta &= \\frac{(V^2 / N)^{1/6}}{\\sqrt{2\\pi}} \\\\ \\alpha &= \\frac{1}{2\\eta} \\end{aligned} Mesh dimensions (when mesh_dimensions is None): .. math:: n_x = \\left\\lceil \\frac{2 \\alpha L_x}{3 \\varepsilon^{1/5}} \\right\\rceil Autograd Support: All inputs (positions, charges, cell) support gradient computation. See Also -------- pme_reciprocal_space : Reciprocal-space component only ewald_real_space : Real-space component (used internally) estimate_pme_parameters : Automatic parameter estimation PMEParameters : Container for PME parameters """ num_atoms = positions.shape[0] # Prepare cell cell, num_systems = _prepare_cell(cell) # Estimate parameters if not provided if alpha is None: params = estimate_pme_parameters(positions, cell, batch_idx, accuracy) alpha = params.alpha if mesh_dimensions is None and mesh_spacing is None: mesh_dimensions = tuple(params.mesh_dimensions) # Unpack the tuple # Prepare alpha tensor alpha = _prepare_alpha(alpha, num_systems, positions.dtype, positions.device) if mask_value is None: mask_value = num_atoms # Determine mesh dimensions if mesh_dimensions is None: if mesh_spacing is not None: mesh_dimensions = mesh_spacing_to_dimensions(cell, mesh_spacing) else: # Use accuracy-based estimation mesh_dimensions = estimate_pme_mesh_dimensions(cell, alpha, accuracy) # Compute real-space contribution rs = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=alpha, neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, neighbor_shifts=neighbor_shifts, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, mask_value=mask_value, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, ) # Compute reciprocal-space contribution rec = pme_reciprocal_space( positions=positions, charges=charges, cell=cell, alpha=alpha, mesh_dimensions=mesh_dimensions, spline_order=spline_order, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, k_vectors=k_vectors, k_squared=k_squared, ) # Normalize return tuples for easy combination # Both rs and rec return: energies, [forces], [charge_grads], [virial] # where virial is always last if present rs_tuple = rs if isinstance(rs, tuple) else (rs,) rec_tuple = rec if isinstance(rec, tuple) else (rec,) # The number of outputs should match between rs and rec # Combine element-wise results = [] for r, s in zip(rs_tuple, rec_tuple): results.append(r + s) if len(results) == 1: return results[0] return tuple(results)
__all__ = [ # Public APIs "particle_mesh_ewald", "pme_reciprocal_space", "pme_green_structure_factor", "pme_energy_corrections", "pme_energy_corrections_with_charge_grad", ]