Source code for nvalchemi.models.ewald

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ewald summation electrostatics model wrapper.

Wraps the ``nvalchemiops`` Ewald summation interaction (real-space +
reciprocal-space) as a :class:`~nvalchemi.models.base.BaseModelMixin`-compatible
model, ready to drop into any :class:`~nvalchemi.dynamics.base.BaseDynamics`
engine.

Usage
-----
::

    from nvalchemi.models.ewald import EwaldModelWrapper
    from nvalchemi.hooks import NeighborListHook
    from nvalchemi.dynamics.base import DynamicsStage

    model = EwaldModelWrapper(cutoff=10.0)

    nl_hook = NeighborListHook(model.model_config.neighbor_config, stage=DynamicsStage.BEFORE_COMPUTE)
    dynamics.register_hook(nl_hook)
    dynamics.model = model

Notes
-----
* Forces are computed **analytically** inside the Warp kernel using
  ``hybrid_forces=True``.  Direct kernel forces represent ``dE/dR|_q``
  (derivative at fixed charges).  ``"forces"`` is in ``autograd_outputs``
  so that the pipeline can add the charge chain-rule term
  ``(dE/dq)(dq/dR)`` via autograd on the energy.
* Energy supports ``backward()`` through the charge pathway: when
  ``charges.requires_grad``, the kernel injects analytical ``dE/dq``
  into the energy tensor via ``_InjectChargeGrad``.
* Virial/stress is also computed analytically by the kernel and returned
  detached (no ``grad_fn``), representing ``dE/d(strain)|_q``.  In a
  pipeline with geometry-dependent charges, the total stress is the sum
  of the direct kernel virial and the autograd chain-rule term
  ``(dE/dq)(dq/d(strain))``.
* Periodic boundary conditions are **required** (``needs_pbc=True``).
* Input charges are read from ``data.charges`` (shape ``[N]``).
* The Coulomb constant defaults to ``14.3996`` eV·Å/e², which gives energies
  in eV when positions are in Å and charges are in elementary charge units.
* k-vectors and Ewald parameters are cached per unique unit cell.  Call
  :meth:`invalidate_cache` to force recomputation.
"""

from __future__ import annotations

from collections import OrderedDict
from pathlib import Path
from typing import Any

import torch
from torch import nn

from nvalchemi._typing import ModelOutputs
from nvalchemi.data import AtomicData, Batch
from nvalchemi.models.base import (
    BaseModelMixin,
    ModelConfig,
    NeighborConfig,
    NeighborListFormat,
)

__all__ = ["EwaldModelWrapper"]


[docs] class EwaldModelWrapper(nn.Module, BaseModelMixin): """Ewald summation electrostatics potential as a model wrapper. Computes long-range Coulomb interactions via the Ewald method, splitting contributions into real-space (erfc-damped, handled by a neighbor matrix) and reciprocal-space (structure factor summation) components. Parameters ---------- cutoff : float Real-space interaction cutoff in Å. accuracy : float, optional Target accuracy for automatic Ewald parameter estimation. Defaults to ``1e-6``. coulomb_constant : float, optional Coulomb prefactor :math:`k_e` in eV·Å/e². Defaults to ``14.3996`` (standard value for Å/e/eV unit system). Attributes ---------- model_config : ModelConfig Mutable configuration controlling which outputs are computed. ``model_config.autograd_outputs`` includes ``"forces"`` so the pipeline accumulates direct kernel forces with charge-path autograd forces in hybrid mode. Include ``"stress"`` in ``model_config.active_outputs`` to enable virial computation for NPT/NPH simulations. When ``charges.requires_grad=True``, ``energy.backward()`` propagates through the injected :math:`dE/dq` pathway while the wrapper returns detached direct kernel forces and detached virial/stress. """ def __init__( self, cutoff: float, accuracy: float = 1e-6, coulomb_constant: float = 14.3996, hybrid_forces: bool = True, ) -> None: super().__init__() self.cutoff = cutoff self.accuracy = accuracy self.coulomb_constant = coulomb_constant self.hybrid_forces = hybrid_forces self.model_config = ModelConfig( outputs=frozenset({"energy", "forces", "stress"}), active_outputs={"energy", "forces"}, autograd_outputs=frozenset({"forces"}) if hybrid_forces else frozenset({"forces", "stress"}), autograd_inputs=frozenset({"positions"}), required_inputs=frozenset({"charges"}), optional_inputs=frozenset(), supports_pbc=True, needs_pbc=True, neighbor_config=NeighborConfig( cutoff=self.cutoff, format=NeighborListFormat.MATRIX, half_list=False, ), ) # k-vector / parameter cache. # Invalidated automatically when cell changes, or manually via invalidate_cache(). self._cache_valid: bool = False self._cached_alpha: torch.Tensor | None = None self._cached_k_vectors: torch.Tensor | None = None # Cached cell for automatic invalidation detection (e.g. NPT). self._cached_cell: torch.Tensor | None = None # Pre-allocated energy accumulation buffer (shape [B]). self._energies_buf: torch.Tensor | None = None # Cached all-zero neighbor-shifts for non-PBC runs (shape [N, K, 3] int32). self._null_shifts: torch.Tensor | None = None self._null_shifts_shape: tuple[int, int] = (0, 0) # ------------------------------------------------------------------ # BaseModelMixin required properties # ------------------------------------------------------------------ @property def embedding_shapes(self) -> dict[str, tuple[int, ...]]: return {}
[docs] def compute_embeddings( self, data: AtomicData | Batch, **kwargs: Any ) -> AtomicData | Batch: """Compute embeddings is not meaningful for Ewald models.""" raise NotImplementedError("EwaldModelWrapper does not produce embeddings.")
[docs] def direct_derivative_keys(self) -> set[str]: """Analytical force/stress keys when ``hybrid_forces=True``.""" if not self.hybrid_forces: return set() keys: set[str] = set() if "forces" in self.model_config.outputs: keys.add("forces") if "stress" in self.model_config.outputs: keys.add("stress") return keys
# ------------------------------------------------------------------ # Input / output key declarations # ------------------------------------------------------------------
[docs] def input_data(self) -> set[str]: """Return required input keys (override to drop ``atomic_numbers``).""" return {"positions", "charges", "neighbor_matrix", "num_neighbors"}
# ------------------------------------------------------------------ # Cache management # ------------------------------------------------------------------ def _cache_is_stale(self) -> bool: """Return ``True`` when cached Ewald parameters need recomputation. The cache is marked stale by :meth:`invalidate_cache`. Callers that modify the unit cell (e.g. NPT integrators) must call ``invalidate_cache()`` so that k-vectors and alpha are recomputed on the next forward pass. """ return not self._cache_valid
[docs] def invalidate_cache(self) -> None: """Force recomputation of Ewald parameters and k-vectors.""" self._cache_valid = False self._cached_alpha = None self._cached_k_vectors = None
# Note: _cached_cell is intentionally NOT cleared here. # The cell reference is used for change-detection in forward(); clearing # it would cause every call to look like a cell change and invalidate # the cache again immediately after recomputation. def _update_cache( self, positions: torch.Tensor, cell: torch.Tensor, batch_idx: torch.Tensor, ) -> None: """Recompute Ewald parameters and k-vectors for the given cell.""" from nvalchemiops.torch.interactions.electrostatics.k_vectors import ( # lazy generate_k_vectors_ewald_summation, ) from nvalchemiops.torch.interactions.electrostatics.parameters import ( # lazy estimate_ewald_parameters, ) params = estimate_ewald_parameters( positions, cell, batch_idx=batch_idx, accuracy=self.accuracy ) k_vectors = generate_k_vectors_ewald_summation( cell, params.reciprocal_space_cutoff ) self._cache_valid = True self._cached_alpha = params.alpha self._cached_k_vectors = k_vectors # ------------------------------------------------------------------ # Input adaptation # ------------------------------------------------------------------
[docs] def adapt_input(self, data: AtomicData | Batch, **kwargs: Any) -> dict[str, Any]: """Collect required inputs from *data* without enabling gradients.""" if not isinstance(data, Batch): raise TypeError( "EwaldModelWrapper requires a Batch input; " "got AtomicData. Use Batch.from_data_list([data]) to wrap it." ) input_dict: dict[str, Any] = {} for key in self.input_data(): value = getattr(data, key, None) if value is None: raise KeyError(f"'{key}' required but not found in input data.") input_dict[key] = value input_dict["batch_idx"] = data.batch_idx.to(torch.int32) input_dict["ptr"] = data.batch_ptr.to(torch.int32) input_dict["num_graphs"] = data.num_graphs input_dict["fill_value"] = data.num_nodes # PBC cell (required for Ewald). try: input_dict["cell"] = data.cell # (B, 3, 3) except AttributeError: raise ValueError( "EwaldModelWrapper requires periodic boundary conditions " "(data.cell must be present)." ) # neighbor_matrix and num_neighbors are already collected by the # input_data() loop above. In a pipeline, the pipeline adapts them # to this model's cutoff/format before calling forward(). input_dict["neighbor_matrix_shifts"] = getattr( data, "neighbor_matrix_shifts", None ) return input_dict
# ------------------------------------------------------------------ # Output adaptation # ------------------------------------------------------------------
[docs] def adapt_output(self, model_output: Any, data: AtomicData | Batch) -> ModelOutputs: """Adapt the model output to the framework output format.""" output: ModelOutputs = OrderedDict() output["energy"] = model_output["energy"] if "forces" in self.model_config.active_outputs: output["forces"] = model_output["forces"] if "stress" in self.model_config.active_outputs: if "stress" in model_output: output["stress"] = model_output["stress"] else: raise RuntimeError( "'stress' is in active_outputs but missing from model output" ) return output
[docs] def output_data(self) -> set[str]: """Return the set of keys that the model produces.""" keys: set[str] = {"energy"} if "forces" in self.model_config.active_outputs: keys.add("forces") if "stress" in self.model_config.active_outputs: keys.add("stress") return keys
# ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------
[docs] def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs: """Run the Ewald summation and return a :class:`ModelOutputs` dict. Parameters ---------- data : Batch Batch containing ``positions``, ``charges``, ``cell``, ``neighbor_matrix``, and ``num_neighbors`` (populated by :class:`~nvalchemi.hooks.NeighborListHook`). Returns ------- ModelOutputs OrderedDict with keys ``"energy"`` (shape ``[B, 1]``, eV), ``"forces"`` (shape ``[N, 3]``, eV/Å), and optionally ``"stress"`` (shape ``[B, 3, 3]``, eV/ų — Cauchy stress ``W/V``). """ from nvalchemiops.torch.interactions.electrostatics.ewald import ( # lazy ewald_real_space, ewald_reciprocal_space, ) inp = self.adapt_input(data, **kwargs) positions = inp["positions"] # (N, 3) charges = inp["charges"].view( -1, ) # (N,) cell = inp["cell"] # (B, 3, 3) batch_idx = inp["batch_idx"] # (N,) int32 fill_value: int = inp["fill_value"] B: int = inp["num_graphs"] neighbor_matrix = inp["neighbor_matrix"].contiguous() neighbor_matrix_shifts = inp.get("neighbor_matrix_shifts") compute_forces = "forces" in self.model_config.active_outputs compute_stresses = "stress" in self.model_config.active_outputs # hybrid_forces=True: the kernel detaches positions and cell # internally and computes analytical forces/virial without a Warp # tape. Detach here too so that nvalchemiops' backward registration # (_register_runtime_state) does not expect a tape when inputs have # requires_grad=True (e.g. from prepare_strain in a pipeline). if self.hybrid_forces: positions = positions.detach() cell = cell.detach() # Automatically invalidate cache when cell changes (e.g. NPT simulation). if self._cached_cell is None or not torch.allclose( cell, self._cached_cell, rtol=1e-6, atol=1e-9 ): self._cached_cell = cell.detach().clone() self._cache_valid = False # Update cached parameters if invalidated. if self._cache_is_stale(): self._update_cache(positions, cell, batch_idx) alpha = self._cached_alpha # (B,) k_vectors = self._cached_k_vectors # Prepare neighbor_matrix_shifts: reuse cached zero buffer for non-PBC runs. if neighbor_matrix_shifts is None: K = neighbor_matrix.shape[1] N = positions.shape[0] if ( self._null_shifts is None or self._null_shifts_shape != (N, K) or self._null_shifts.device != positions.device ): self._null_shifts = torch.zeros( N, K, 3, dtype=torch.int32, device=positions.device ) self._null_shifts_shape = (N, K) neighbor_matrix_shifts = self._null_shifts # --- Real-space contribution --- real_result = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=alpha, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts.contiguous(), mask_value=fill_value, batch_idx=batch_idx, compute_forces=compute_forces, compute_virial=compute_stresses, hybrid_forces=self.hybrid_forces, ) # --- Reciprocal-space contribution --- recip_result = ewald_reciprocal_space( positions=positions, charges=charges, cell=cell, k_vectors=k_vectors, alpha=alpha, batch_idx=batch_idx, compute_forces=compute_forces, compute_virial=compute_stresses, hybrid_forces=self.hybrid_forces, ) # Unpack results (energies always first; forces and virial follow # in the order they were requested). def _unpack(result, compute_f: bool, compute_v: bool): """Extract (energies, forces_or_None, virial_or_None) from result.""" if isinstance(result, torch.Tensor): return result, None, None result_list = list(result) e = result_list[0] f: torch.Tensor | None = None v: torch.Tensor | None = None idx = 1 if compute_f and idx < len(result_list): f = result_list[idx] idx += 1 if compute_v and idx < len(result_list): v = result_list[idx] return e, f, v e_real, f_real, v_real = _unpack(real_result, compute_forces, compute_stresses) e_recip, f_recip, v_recip = _unpack( recip_result, compute_forces, compute_stresses ) # Sum real + reciprocal contributions. per_atom_energies = (e_real + e_recip).to( positions.dtype ) # (N,) float64->dtype forces: torch.Tensor | None = None if compute_forces and f_real is not None and f_recip is not None: forces = f_real + f_recip # (N, 3) virial: torch.Tensor | None = None if compute_stresses and v_real is not None and v_recip is not None: virial = v_real + v_recip # (B, 3, 3) # Scale by Coulomb constant. per_atom_energies = per_atom_energies * self.coulomb_constant if forces is not None: forces = forces * self.coulomb_constant if virial is not None: virial = virial * self.coulomb_constant # Scatter per-atom energies -> per-system totals using pre-allocated buffer. if ( self._energies_buf is None or self._energies_buf.shape[0] != B or self._energies_buf.dtype != positions.dtype or self._energies_buf.device != positions.device ): self._energies_buf = torch.empty( B, dtype=positions.dtype, device=positions.device ) self._energies_buf.zero_() self._energies_buf.scatter_add_(0, batch_idx, per_atom_energies) # Clone from pre-allocated buffer so the caller receives an independent tensor. # Without cloning, the next forward pass would overwrite this tensor in-place. model_output: dict[str, Any] = { "energy": self._energies_buf.unsqueeze(-1).clone() } if forces is not None: model_output["forces"] = forces if virial is not None: # Cauchy stress sigma = W/V (eV/A^3). volume = torch.det(data.cell).abs().view(-1, 1, 1) model_output["stress"] = virial / volume elif compute_stresses: raise RuntimeError( "stress was requested but the kernel did not return a virial" ) return self.adapt_output(model_output, data)
[docs] def export_model(self, path: Path, as_state_dict: bool = False) -> None: """Export model is not implemented for Ewald models.""" raise NotImplementedError