# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Particle Mesh Ewald (PME) electrostatics model wrapper.
Wraps the ``nvalchemiops`` PME interaction as a
:class:`~nvalchemi.models.base.BaseModelMixin`-compatible model, ready to
drop into any :class:`~nvalchemi.dynamics.base.BaseDynamics` engine.
Usage
-----
::
from nvalchemi.models.pme import PMEModelWrapper
from nvalchemi.dynamics.hooks import NeighborListHook
model = PMEModelWrapper(cutoff=10.0)
nl_hook = NeighborListHook(model.model_card.neighbor_config)
dynamics.register_hook(nl_hook)
dynamics.model = model
Notes
-----
* Forces are computed **analytically** inside the Warp kernel (not via
autograd), so :attr:`~ModelCard.forces_via_autograd` is ``False``.
* Periodic boundary conditions are **required** (``needs_pbc=True``).
* Input charges are read from ``data.node_charges`` (shape ``[N]``).
* The Coulomb constant defaults to ``14.3996`` eV·Å/e², which gives energies
in eV when positions are in Å and charges are in elementary charge units.
* PME achieves :math:`O(N \\log N)` scaling via FFT-based reciprocal space
calculations, making it more efficient than Ewald for large systems.
* Mesh k-vectors and Ewald parameters are cached per unique unit cell. Call
:meth:`invalidate_cache` to force recomputation.
"""
from __future__ import annotations
from collections import OrderedDict
from pathlib import Path
from typing import Any
import torch
from torch import nn
from nvalchemi._typing import ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models._ops.neighbor_filter import prepare_neighbors_for_model
from nvalchemi.models.base import (
BaseModelMixin,
ModelCard,
ModelConfig,
NeighborConfig,
NeighborListFormat,
)
__all__ = ["PMEModelWrapper"]
[docs]
class PMEModelWrapper(nn.Module, BaseModelMixin):
"""Particle Mesh Ewald electrostatics potential as a model wrapper.
Computes long-range Coulomb interactions via the PME method using
B-spline charge interpolation and FFT-based reciprocal space evaluation,
achieving :math:`O(N \\log N)` computational scaling.
Parameters
----------
cutoff : float
Real-space interaction cutoff in Å.
mesh_spacing : float, optional
Target mesh spacing in Å used to determine mesh dimensions when
``mesh_dimensions`` is not provided. Defaults to ``1.0`` Å.
mesh_dimensions : tuple[int, int, int] or None, optional
Explicit mesh dimensions ``(nx, ny, nz)``. When set, overrides
automatic estimation from ``mesh_spacing``. Defaults to ``None``
(auto-estimated).
spline_order : int, optional
B-spline interpolation order. Higher values give greater accuracy
at increased cost. Defaults to ``4``.
alpha : float or None, optional
Ewald splitting parameter (inverse Å). ``None`` causes automatic
estimation via the Kolafa-Perram formula each time the cell changes.
Defaults to ``None``.
accuracy : float, optional
Target accuracy for automatic parameter estimation. Defaults to
``1e-6``.
coulomb_constant : float, optional
Coulomb prefactor :math:`k_e` in eV·Å/e².
Defaults to ``14.3996`` (standard value for Å/e/eV unit system).
max_neighbors : int, optional
Maximum neighbors per atom for the dense neighbor matrix.
Defaults to 256.
Attributes
----------
model_config : ModelConfig
Mutable configuration controlling which outputs are computed.
Set ``model.model_config.compute_stresses = True`` to enable
virial computation for NPT/NPH simulations.
"""
def __init__(
self,
cutoff: float,
mesh_spacing: float = 1.0,
mesh_dimensions: tuple[int, int, int] | None = None,
spline_order: int = 4,
alpha: float | None = None,
accuracy: float = 1e-6,
coulomb_constant: float = 14.3996,
max_neighbors: int = 256,
) -> None:
super().__init__()
self.cutoff = cutoff
self.mesh_spacing = mesh_spacing
self.mesh_dimensions = mesh_dimensions
self.spline_order = spline_order
self.alpha = alpha
self.accuracy = accuracy
self.coulomb_constant = coulomb_constant
self.max_neighbors = max_neighbors
self.model_config = ModelConfig()
self._model_card: ModelCard = self._build_model_card()
# PME k-vector / parameter cache.
# Automatically invalidated when cell changes, or manually via invalidate_cache().
self._cache_valid: bool = False
self._cached_alpha: torch.Tensor | None = None
self._cached_k_vectors: torch.Tensor | None = None
self._cached_k_squared: torch.Tensor | None = None
self._cached_mesh_dims: tuple[int, int, int] | None = None
# Cached cell for automatic invalidation detection (e.g. NPT).
self._cached_cell: torch.Tensor | None = None
# Pre-allocated energy accumulation buffer (shape [B]).
self._energies_buf: torch.Tensor | None = None
# Cached all-zero neighbor-shifts for non-PBC runs (shape [N, K, 3] int32).
self._null_shifts: torch.Tensor | None = None
self._null_shifts_shape: tuple[int, int] = (0, 0)
# ------------------------------------------------------------------
# BaseModelMixin required properties
# ------------------------------------------------------------------
def _build_model_card(self) -> ModelCard:
return ModelCard(
forces_via_autograd=False,
supports_energies=True,
supports_forces=True,
supports_stresses=True,
supports_pbc=True,
needs_pbc=True,
supports_non_batch=False,
needs_node_charges=True,
neighbor_config=NeighborConfig(
cutoff=self.cutoff,
format=NeighborListFormat.MATRIX,
half_list=False,
max_neighbors=self.max_neighbors,
),
)
@property
def model_card(self) -> ModelCard:
return self._model_card
@property
def embedding_shapes(self) -> dict[str, tuple[int, ...]]:
return {}
[docs]
def compute_embeddings(
self, data: AtomicData | Batch, **kwargs: Any
) -> AtomicData | Batch:
raise NotImplementedError("PMEModelWrapper does not produce embeddings.")
# ------------------------------------------------------------------
# Input / output key declarations
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Cache management
# ------------------------------------------------------------------
def _cache_is_stale(self) -> bool:
"""Return ``True`` when cached PME parameters need recomputation.
The cache is marked stale by :meth:`invalidate_cache`. Callers that
modify the unit cell (e.g. NPT integrators) must call
``invalidate_cache()`` so that k-vectors and alpha are recomputed on
the next forward pass.
"""
return not self._cache_valid
[docs]
def invalidate_cache(self) -> None:
"""Force recomputation of PME parameters, k-vectors, and mesh."""
self._cache_valid = False
self._cached_alpha = None
self._cached_k_vectors = None
self._cached_k_squared = None
self._cached_mesh_dims = None
# Note: _cached_cell is intentionally NOT cleared here.
# The cell reference is used for change-detection in forward(); clearing
# it would cause every call to look like a cell change and invalidate
# the cache again immediately after recomputation.
def _update_cache(
self,
positions: torch.Tensor,
cell: torch.Tensor,
batch_idx: torch.Tensor,
) -> None:
"""Recompute PME parameters and k-vectors for the given cell."""
from nvalchemiops.torch.interactions.electrostatics.k_vectors import ( # lazy
generate_k_vectors_pme,
)
from nvalchemiops.torch.interactions.electrostatics.parameters import ( # lazy
estimate_pme_parameters,
)
B = cell.shape[0] if cell.dim() == 3 else 1
# Determine alpha and mesh_dimensions.
need_params = (self.alpha is None) or (self.mesh_dimensions is None)
params = None
if need_params:
params = estimate_pme_parameters(
positions, cell, batch_idx=batch_idx, accuracy=self.accuracy
)
if self.alpha is not None:
alpha_val = float(self.alpha)
else:
# Take the mean across batch systems.
alpha_val = float(params.alpha.mean().item())
if self.mesh_dimensions is not None:
dims = self.mesh_dimensions
else:
dims = params.mesh_dimensions
# Build per-system alpha tensor.
alpha_tensor = torch.full((B,), alpha_val, dtype=cell.dtype, device=cell.device)
k_vectors, k_squared = generate_k_vectors_pme(cell, dims)
self._cache_valid = True
self._cached_alpha = alpha_tensor
self._cached_k_vectors = k_vectors
self._cached_k_squared = k_squared
self._cached_mesh_dims = dims
# ------------------------------------------------------------------
# Input adaptation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Output adaptation
# ------------------------------------------------------------------
[docs]
def adapt_output(self, model_output: Any, data: AtomicData | Batch) -> ModelOutputs:
"""Adapt the model output to the framework output format."""
output: ModelOutputs = OrderedDict()
output["energies"] = model_output["energies"]
if self.model_config.compute_forces:
output["forces"] = model_output["forces"]
if self.model_config.compute_stresses:
if "stresses" in model_output:
output["stresses"] = model_output["stresses"]
return output
[docs]
def output_data(self) -> set[str]:
"""Return the set of keys that the model produces."""
keys: set[str] = {"energies"}
if self.model_config.compute_forces:
keys.add("forces")
if self.model_config.compute_stresses:
keys.add("stresses")
return keys
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
[docs]
def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs:
"""Run the PME kernel and return a :class:`ModelOutputs` dict.
Parameters
----------
data : Batch
Batch containing ``positions``, ``node_charges``, ``cell``,
``neighbor_matrix``, and ``num_neighbors`` (populated by
:class:`~nvalchemi.dynamics.hooks.NeighborListHook`).
Returns
-------
ModelOutputs
OrderedDict with keys ``"energies"`` (shape ``[B, 1]``, eV),
``"forces"`` (shape ``[N, 3]``, eV/Å), and optionally
``"stress"`` (shape ``[B, 3, 3]``, eV — the raw virial
:math:`W_{phys}`).
"""
from nvalchemiops.torch.interactions.electrostatics.pme import ( # lazy
particle_mesh_ewald,
)
inp = self.adapt_input(data, **kwargs)
positions = inp["positions"] # (N, 3)
charges = inp["node_charges"] # (N,)
cell = inp["cell"] # (B, 3, 3)
batch_idx = inp["batch_idx"] # (N,) int32
fill_value: int = inp["fill_value"]
B: int = inp["num_graphs"]
neighbor_matrix = inp["neighbor_matrix"].contiguous()
neighbor_shifts = inp.get("neighbor_shifts")
compute_forces = self.model_config.compute_forces
compute_stresses = self.model_config.compute_stresses
# Automatically invalidate cache when cell changes (e.g. NPT simulation).
if self._cached_cell is None or not torch.allclose(
cell, self._cached_cell, rtol=1e-6, atol=1e-9
):
self._cached_cell = cell.detach().clone()
self._cache_valid = False
# Warn when using a single mean α for heterogeneous batch cell volumes.
if self.alpha is None and data.num_graphs > 1:
vols = torch.linalg.det(cell).abs()
if vols.min() > 0 and (vols.max() / vols.min()) > 1.1:
import warnings as _warnings
_warnings.warn(
"PMEModelWrapper: using a single mean α for a batch of systems with "
"heterogeneous cell volumes (max/min volume ratio > 1.1). "
"This may introduce systematic errors in the Ewald real/reciprocal "
"balance. For accurate results, use a homogeneous batch.",
UserWarning,
stacklevel=2,
)
# Update cache if invalidated.
if self._cache_is_stale():
self._update_cache(positions, cell, batch_idx)
# Prepare neighbor_matrix_shifts: reuse cached zero buffer for non-PBC runs.
if neighbor_shifts is None:
K = neighbor_matrix.shape[1]
N = positions.shape[0]
if (
self._null_shifts is None
or self._null_shifts_shape != (N, K)
or self._null_shifts.device != positions.device
):
self._null_shifts = torch.zeros(
N, K, 3, dtype=torch.int32, device=positions.device
)
self._null_shifts_shape = (N, K)
neighbor_shifts = self._null_shifts
result = particle_mesh_ewald(
positions=positions,
charges=charges.view(
-1,
),
cell=cell,
alpha=self._cached_alpha,
mesh_dimensions=self._cached_mesh_dims,
spline_order=self.spline_order,
batch_idx=batch_idx,
k_vectors=self._cached_k_vectors,
k_squared=self._cached_k_squared,
neighbor_matrix=neighbor_matrix,
neighbor_matrix_shifts=neighbor_shifts.contiguous(),
mask_value=fill_value,
compute_forces=compute_forces,
compute_virial=compute_stresses,
accuracy=self.accuracy,
)
# Unpack tuple: (energies, [forces], [virial]).
def _unpack(res, compute_f: bool, compute_v: bool):
"""Extract (per_atom_energies, forces_or_None, virial_or_None)."""
if isinstance(res, torch.Tensor):
return res, None, None
res_list = list(res)
e = res_list[0]
f: torch.Tensor | None = None
v: torch.Tensor | None = None
idx = 1
if compute_f and idx < len(res_list):
f = res_list[idx]
idx += 1
if compute_v and idx < len(res_list):
v = res_list[idx]
return e, f, v
per_atom_energies, forces, virial = _unpack(
result, compute_forces, compute_stresses
)
# Scale by Coulomb constant.
per_atom_energies = (
per_atom_energies.to(positions.dtype) * self.coulomb_constant
)
if forces is not None:
forces = forces * self.coulomb_constant
if virial is not None:
virial = virial * self.coulomb_constant
# Scatter per-atom energies → per-system totals using pre-allocated buffer.
if (
self._energies_buf is None
or self._energies_buf.shape[0] != B
or self._energies_buf.dtype != positions.dtype
or self._energies_buf.device != positions.device
):
self._energies_buf = torch.empty(
B, dtype=positions.dtype, device=positions.device
)
self._energies_buf.zero_()
self._energies_buf.scatter_add_(0, batch_idx.long(), per_atom_energies)
# Clone from pre-allocated buffer so the caller receives an independent tensor.
# Without cloning, the next forward pass would overwrite this tensor in-place.
model_output: dict[str, Any] = {
"energies": self._energies_buf.unsqueeze(-1).clone()
}
if forces is not None:
model_output["forces"] = forces
if virial is not None:
# particle_mesh_ewald accumulates W = Σ r_ij ⊗ F_ij (positive convention).
# Store directly as stresses (W_phys) — the barostat divides by V.
model_output["stresses"] = virial
return self.adapt_output(model_output, data)
[docs]
def export_model(self, path: Path, as_state_dict: bool = False) -> None:
"""Export model is not implemented for PME models."""
raise NotImplementedError