Source code for nvalchemiops.jax.interactions.electrostatics.ewald

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX Ewald summation implementation.

Wraps the framework-agnostic Warp kernels from
``nvalchemiops.interactions.electrostatics.ewald_kernels`` with JAX bindings.

The Ewald method splits long-range Coulomb interactions into:
    E_total = E_real + E_reciprocal - E_self - E_background
"""

from __future__ import annotations

import math

import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental import jax_kernel

from nvalchemiops.interactions.electrostatics.ewald_kernels import (
    BATCH_BLOCK_SIZE,
    _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload,
    _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    _batch_ewald_real_space_energy_forces_kernel_overload,
    _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    _batch_ewald_real_space_energy_kernel_overload,
    _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload,
    _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    _batch_ewald_reciprocal_space_energy_forces_kernel_overload,
    _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    _batch_ewald_reciprocal_space_virial_kernel_overload,
    _batch_ewald_subtract_self_energy_kernel_overload,
    _ewald_real_space_energy_forces_charge_grad_kernel_overload,
    _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    _ewald_real_space_energy_forces_kernel_overload,
    _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    _ewald_real_space_energy_kernel_overload,
    _ewald_real_space_energy_neighbor_matrix_kernel_overload,
    _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    _ewald_reciprocal_space_energy_forces_kernel_overload,
    _ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    _ewald_reciprocal_space_virial_kernel_overload,
    _ewald_subtract_self_energy_kernel_overload,
)
from nvalchemiops.jax.interactions.electrostatics.k_vectors import (
    generate_k_vectors_ewald_summation,
)
from nvalchemiops.jax.interactions.electrostatics.parameters import (
    estimate_ewald_parameters,
)

__all__ = [
    "ewald_real_space",
    "ewald_reciprocal_space",
    "ewald_summation",
]

PI = math.pi

# ==============================================================================
# Helper for Creating JAX Kernel Dictionaries
# ==============================================================================

# Dtype normalization for kernel lookup
_DTYPE_MAP = {
    jnp.float32: jnp.float32,
    jnp.float64: jnp.float64,
}


def _normalize_dtype(dtype):
    """Normalize dtype for kernel dictionary lookup.

    Parameters
    ----------
    dtype : dtype-like
        Input dtype from a JAX array.

    Returns
    -------
    jnp.float32 or jnp.float64
        Normalized JAX dtype for kernel lookup.
    """
    # Convert to JAX dtype if needed
    if dtype == jnp.float32 or str(dtype) == "float32":
        return jnp.float32
    elif dtype == jnp.float64 or str(dtype) == "float64":
        return jnp.float64
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")


def _make_jax_kernels(
    wp_overload_dict: dict,
    num_outputs: int,
    in_out_argnames: list[str],
) -> dict:
    """Maps a ``jax`` data type to ``warp``.

    Parameters
    ----------
    wp_overload_dict : dict
        Warp kernel overload dictionary keyed by wp.float32/wp.float64.
    num_outputs : int
        Number of output arrays returned by the kernel.
    in_out_argnames : list of str
        Names of in-place output arguments.

    Returns
    -------
    dict
        Dictionary mapping jnp.float32/jnp.float64 to jax_kernel instances.
    """
    _JAX_TO_WP = {jnp.float32: wp.float32, jnp.float64: wp.float64}
    return {
        jax_dtype: jax_kernel(
            wp_overload_dict[wp_dtype],
            num_outputs=num_outputs,
            in_out_argnames=in_out_argnames,
            enable_backward=False,
        )
        for jax_dtype, wp_dtype in _JAX_TO_WP.items()
    }


# ==============================================================================
# JAX Kernel Wrappers - Real Space
# ==============================================================================

# --- Neighbor List (CSR) Format ---

_jax_ewald_real_space_energy_list = _make_jax_kernels(
    _ewald_real_space_energy_kernel_overload, 1, ["pair_energies"]
)

_jax_ewald_real_space_energy_forces_list = _make_jax_kernels(
    _ewald_real_space_energy_forces_kernel_overload,
    3,
    ["pair_energies", "atomic_forces", "virial"],
)

_jax_ewald_real_space_energy_forces_charge_grad_list = _make_jax_kernels(
    _ewald_real_space_energy_forces_charge_grad_kernel_overload,
    4,
    ["pair_energies", "atomic_forces", "charge_gradients", "virial"],
)

_jax_batch_ewald_real_space_energy_list = _make_jax_kernels(
    _batch_ewald_real_space_energy_kernel_overload, 1, ["pair_energies"]
)

_jax_batch_ewald_real_space_energy_forces_list = _make_jax_kernels(
    _batch_ewald_real_space_energy_forces_kernel_overload,
    3,
    ["pair_energies", "atomic_forces", "virial"],
)

_jax_batch_ewald_real_space_energy_forces_charge_grad_list = _make_jax_kernels(
    _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload,
    4,
    ["pair_energies", "atomic_forces", "charge_gradients", "virial"],
)

# --- Neighbor Matrix Format ---

_jax_ewald_real_space_energy_matrix = _make_jax_kernels(
    _ewald_real_space_energy_neighbor_matrix_kernel_overload, 1, ["pair_energies"]
)

_jax_ewald_real_space_energy_forces_matrix = _make_jax_kernels(
    _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    3,
    ["pair_energies", "atomic_forces", "virial"],
)

_jax_ewald_real_space_energy_forces_charge_grad_matrix = _make_jax_kernels(
    _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    4,
    ["pair_energies", "atomic_forces", "charge_gradients", "virial"],
)

_jax_batch_ewald_real_space_energy_matrix = _make_jax_kernels(
    _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload, 1, ["pair_energies"]
)

_jax_batch_ewald_real_space_energy_forces_matrix = _make_jax_kernels(
    _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    3,
    ["pair_energies", "atomic_forces", "virial"],
)

_jax_batch_ewald_real_space_energy_forces_charge_grad_matrix = _make_jax_kernels(
    _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    4,
    ["pair_energies", "atomic_forces", "charge_gradients", "virial"],
)

# ==============================================================================
# JAX Kernel Wrappers - Reciprocal Space
# ==============================================================================

# --- Structure Factor Computation ---

_jax_ewald_reciprocal_fill_structure_factors = _make_jax_kernels(
    _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    5,
    [
        "total_charge",
        "cos_k_dot_r",
        "sin_k_dot_r",
        "real_structure_factors",
        "imag_structure_factors",
    ],
)

_jax_batch_ewald_reciprocal_fill_structure_factors = _make_jax_kernels(
    _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    5,
    [
        "total_charges",
        "cos_k_dot_r",
        "sin_k_dot_r",
        "real_structure_factors",
        "imag_structure_factors",
    ],
)

# --- Energy Computation ---

_jax_ewald_reciprocal_compute_energy = _make_jax_kernels(
    _ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    1,
    ["reciprocal_energies"],
)

_jax_batch_ewald_reciprocal_compute_energy = _make_jax_kernels(
    _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    1,
    ["reciprocal_energies"],
)

# --- Energy + Forces ---

_jax_ewald_reciprocal_energy_forces = _make_jax_kernels(
    _ewald_reciprocal_space_energy_forces_kernel_overload,
    2,
    ["reciprocal_energies", "atomic_forces"],
)

_jax_batch_ewald_reciprocal_energy_forces = _make_jax_kernels(
    _batch_ewald_reciprocal_space_energy_forces_kernel_overload,
    2,
    ["reciprocal_energies", "atomic_forces"],
)

# --- Energy + Forces + Charge Gradients ---

_jax_ewald_reciprocal_energy_forces_charge_grad = _make_jax_kernels(
    _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    3,
    ["reciprocal_energies", "atomic_forces", "charge_gradients"],
)

_jax_batch_ewald_reciprocal_energy_forces_charge_grad = _make_jax_kernels(
    _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    3,
    ["reciprocal_energies", "atomic_forces", "charge_gradients"],
)

# --- Self-Energy Correction ---

_jax_ewald_subtract_self_energy = _make_jax_kernels(
    _ewald_subtract_self_energy_kernel_overload, 1, ["energy_out"]
)

_jax_batch_ewald_subtract_self_energy = _make_jax_kernels(
    _batch_ewald_subtract_self_energy_kernel_overload, 1, ["energy_out"]
)

# --- Reciprocal-Space Virial ---

_jax_ewald_reciprocal_virial = _make_jax_kernels(
    _ewald_reciprocal_space_virial_kernel_overload,
    1,
    ["virial"],
)

_jax_batch_ewald_reciprocal_virial = _make_jax_kernels(
    _batch_ewald_reciprocal_space_virial_kernel_overload,
    1,
    ["virial"],
)


# ==============================================================================
# Helper Functions
# ==============================================================================


def _prepare_alpha_array(
    alpha: float | jax.Array,
    num_systems: int,
    dtype: jnp.dtype = jnp.float64,
) -> jax.Array:
    """Convert alpha to a per-system array of shape (B,) or (1,).

    Parameters
    ----------
    alpha : float or jax.Array
        Ewald splitting parameter.
    num_systems : int
        Number of systems.
    dtype : jnp.dtype, optional
        Data type for the output array. Defaults to jnp.float64.

    Returns
    -------
    jax.Array
        Alpha array of shape (B,) or (1,).
    """
    if isinstance(alpha, (int, float)):
        return jnp.full(num_systems, float(alpha), dtype=dtype)
    elif isinstance(alpha, jax.Array):
        # generate elements from scalar
        if alpha.ndim == 0:
            return jnp.full(num_systems, alpha[0], dtype=dtype)
        elif len(alpha) != num_systems:
            raise ValueError(
                f"alpha has {alpha.shape[0]} values but there are {num_systems} systems"
            )
        else:
            return alpha.astype(dtype)
    else:
        raise TypeError(f"alpha must be float or jax.Array, got {type(alpha)}")


def _compute_total_charge(
    charges: jax.Array, batch_idx: jax.Array | None, num_systems: int = 1
) -> jax.Array:
    """Compute total charge (per system if batched).

    Parameters
    ----------
    charges : jax.Array, shape (N,)
        Atomic charges.
    batch_idx : jax.Array | None, shape (N,)
        Batch indices.
    num_systems : int, optional
        Number of systems in the batch. Only used when batch_idx is not None.
        Default is 1.

    Returns
    -------
    jax.Array
        Total charge, shape (1,) for single system or (B,) for batch.
    """
    if batch_idx is None:
        return jnp.array([charges.sum()], dtype=jnp.float64)
    else:
        total_charges = jnp.zeros(num_systems, dtype=jnp.float64)
        total_charges = total_charges.at[batch_idx].add(charges)
        return total_charges


# ==============================================================================
# Public API
# ==============================================================================


[docs] def ewald_real_space( positions: jax.Array, charges: jax.Array, cell: jax.Array, alpha: float | jax.Array, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, mask_value: int | None = None, batch_idx: jax.Array | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> jax.Array | tuple[jax.Array, ...]: """Compute real-space Ewald energy and optionally forces, charge gradients, and virial. Computes the damped Coulomb interactions for atom pairs within the real-space cutoff. The complementary error function (erfc) damping ensures rapid convergence in real space. Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic partial charges. cell : jax.Array, shape (1, 3, 3) or (B, 3, 3) Unit cell matrices. alpha : float or jax.Array Ewald splitting parameter. Can be a float or array of shape (1,) or (B,). neighbor_list : jax.Array | None, shape (2, M) Neighbor list in COO format. neighbor_ptr : jax.Array | None, shape (N+1,) CSR row pointers for neighbor list. neighbor_shifts : jax.Array | None, shape (M, 3) Periodic image shifts for neighbor list. neighbor_matrix : jax.Array | None, shape (N, max_neighbors) Dense neighbor matrix format. neighbor_matrix_shifts : jax.Array | None, shape (N, max_neighbors, 3) Periodic image shifts for neighbor_matrix. mask_value : int | None, optional Value indicating invalid entries in neighbor_matrix. If None (default), uses num_atoms as the mask value. batch_idx : jax.Array | None, shape (N,) System index for each atom. compute_forces : bool, default=False Whether to compute explicit forces. compute_charge_gradients : bool, default=False Whether to compute charge gradients. compute_virial : bool, default=False Whether to compute the virial tensor. Returns ------- energies : jax.Array, shape (N,) Per-atom real-space energy. forces : jax.Array, shape (N, 3), optional Forces (if compute_forces=True or compute_charge_gradients=True). charge_gradients : jax.Array, shape (N,), optional Charge gradients (if compute_charge_gradients=True). virial : jax.Array, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the return tuple. """ # Validate inputs use_list = neighbor_list is not None and neighbor_shifts is not None use_matrix = neighbor_matrix is not None and neighbor_matrix_shifts is not None if not use_list and not use_matrix: raise ValueError( "Must provide either neighbor_list/neighbor_shifts or " "neighbor_matrix/neighbor_matrix_shifts" ) if use_list and use_matrix: raise ValueError( "Cannot provide both neighbor list and neighbor matrix formats" ) # Store input dtype for kernel dispatch and outputs dtype = _normalize_dtype(positions.dtype) # Cast inputs to consistent dtype positions_cast = positions.astype(dtype) charges_cast = charges.astype(dtype) cell_cast = cell.astype(dtype) # Ensure cell is (B, 3, 3) if cell_cast.ndim == 2: cell_cast = cell_cast[jnp.newaxis, :, :] num_atoms = positions_cast.shape[0] is_batched = batch_idx is not None # Default mask_value to num_atoms (matches cell_list fill_value convention) if mask_value is None: mask_value = num_atoms # Derive num_systems from cell shape (cell is always (B, 3, 3) by caller convention) if is_batched: num_systems = cell_cast.shape[0] # Prepare alpha alpha_arr = _prepare_alpha_array(alpha, cell_cast.shape[0], dtype=dtype) # Allocate outputs (energies always float64, forces match input dtype) energies = jnp.zeros(num_atoms, dtype=jnp.float64) if use_list: if neighbor_ptr is None: raise ValueError("neighbor_ptr is required when using neighbor_list format") if neighbor_list is None or neighbor_shifts is None: raise ValueError("neighbor_list and neighbor_shifts are required") # Extract idx_j from neighbor_list idx_j = neighbor_list[1].astype(jnp.int32) neighbor_ptr_i32 = neighbor_ptr.astype(jnp.int32) neighbor_shifts_i32 = neighbor_shifts.astype(jnp.int32) if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) # Determine if we need the force kernel (for forces or virial) need_force_kernel = compute_forces or compute_virial if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) virial = jnp.zeros((num_systems, 3, 3), dtype=dtype) (energies, forces, charge_grads, virial) = ( _jax_batch_ewald_real_space_energy_forces_charge_grad_list[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, int(compute_virial), energies, forces, charge_grads, virial, launch_dims=(num_atoms,), ) ) elif need_force_kernel: forces = jnp.zeros((num_atoms, 3), dtype=dtype) virial = jnp.zeros((num_systems, 3, 3), dtype=dtype) (energies, forces, virial) = ( _jax_batch_ewald_real_space_energy_forces_list[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, int(compute_virial), energies, forces, virial, launch_dims=(num_atoms,), ) ) else: (energies,) = _jax_batch_ewald_real_space_energy_list[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, energies, launch_dims=(num_atoms,), ) else: # Determine if we need the force kernel (for forces or virial) need_force_kernel = compute_forces or compute_virial if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) virial = jnp.zeros((1, 3, 3), dtype=dtype) (energies, forces, charge_grads, virial) = ( _jax_ewald_real_space_energy_forces_charge_grad_list[dtype]( positions_cast, charges_cast, cell_cast, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, int(compute_virial), energies, forces, charge_grads, virial, launch_dims=(num_atoms,), ) ) elif need_force_kernel: forces = jnp.zeros((num_atoms, 3), dtype=dtype) virial = jnp.zeros((1, 3, 3), dtype=dtype) (energies, forces, virial) = _jax_ewald_real_space_energy_forces_list[ dtype ]( positions_cast, charges_cast, cell_cast, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, int(compute_virial), energies, forces, virial, launch_dims=(num_atoms,), ) else: (energies,) = _jax_ewald_real_space_energy_list[dtype]( positions_cast, charges_cast, cell_cast, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, alpha_arr, energies, launch_dims=(num_atoms,), ) else: # Matrix format if neighbor_matrix is None or neighbor_matrix_shifts is None: raise ValueError("neighbor_matrix and neighbor_matrix_shifts are required") neighbor_matrix_i32 = neighbor_matrix.astype(jnp.int32) neighbor_matrix_shifts_i32 = neighbor_matrix_shifts.astype(jnp.int32) if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) # Determine if we need the force kernel (for forces or virial) need_force_kernel = compute_forces or compute_virial if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) virial = jnp.zeros((num_systems, 3, 3), dtype=dtype) (energies, forces, charge_grads, virial) = ( _jax_batch_ewald_real_space_energy_forces_charge_grad_matrix[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, int(compute_virial), energies, forces, charge_grads, virial, launch_dims=(num_atoms,), ) ) elif need_force_kernel: forces = jnp.zeros((num_atoms, 3), dtype=dtype) virial = jnp.zeros((num_systems, 3, 3), dtype=dtype) (energies, forces, virial) = ( _jax_batch_ewald_real_space_energy_forces_matrix[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, int(compute_virial), energies, forces, virial, launch_dims=(num_atoms,), ) ) else: (energies,) = _jax_batch_ewald_real_space_energy_matrix[dtype]( positions_cast, charges_cast, cell_cast, batch_idx_i32, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, energies, launch_dims=(num_atoms,), ) else: # Determine if we need the force kernel (for forces or virial) need_force_kernel = compute_forces or compute_virial if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) virial = jnp.zeros((1, 3, 3), dtype=dtype) (energies, forces, charge_grads, virial) = ( _jax_ewald_real_space_energy_forces_charge_grad_matrix[dtype]( positions_cast, charges_cast, cell_cast, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, int(compute_virial), energies, forces, charge_grads, virial, launch_dims=(num_atoms,), ) ) elif need_force_kernel: forces = jnp.zeros((num_atoms, 3), dtype=dtype) virial = jnp.zeros((1, 3, 3), dtype=dtype) (energies, forces, virial) = _jax_ewald_real_space_energy_forces_matrix[ dtype ]( positions_cast, charges_cast, cell_cast, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, int(compute_virial), energies, forces, virial, launch_dims=(num_atoms,), ) else: (energies,) = _jax_ewald_real_space_energy_matrix[dtype]( positions_cast, charges_cast, cell_cast, neighbor_matrix_i32, neighbor_matrix_shifts_i32, int(mask_value), alpha_arr, energies, launch_dims=(num_atoms,), ) # Return results (energies and charge_grads are float64, forces match input dtype) # Virial is always last in the return tuple when requested def _build_result(): result = [energies] if compute_forces and forces is not None: result.append(forces) if compute_charge_gradients and charge_grads is not None: result.append(charge_grads) if compute_virial and virial is not None: result.append(virial) return tuple(result) if len(result) > 1 else result[0] # Initialize optional variables to None if not computed if not (compute_forces or compute_virial or compute_charge_gradients): forces = None charge_grads = None virial = None elif not compute_charge_gradients: charge_grads = None if not compute_virial: virial = None return _build_result()
[docs] def ewald_reciprocal_space( positions: jax.Array, charges: jax.Array, cell: jax.Array, k_vectors: jax.Array, alpha: float | jax.Array, batch_idx: jax.Array | None = None, max_atoms_per_system: int | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> jax.Array | tuple[jax.Array, ...]: """Compute reciprocal-space Ewald energy and optionally forces, charge gradients, and virial. Computes the smooth long-range electrostatic contribution using structure factors in reciprocal space. Includes self-energy and background corrections. Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic partial charges. cell : jax.Array, shape (1, 3, 3) or (B, 3, 3) Unit cell matrices. k_vectors : jax.Array Reciprocal lattice vectors. Shape (K, 3) for single system, (B, K, 3) for batch. alpha : float or jax.Array Ewald splitting parameter. Can be a float or array of shape (1,) or (B,). batch_idx : jax.Array | None, shape (N,) System index for each atom. max_atoms_per_system : int | None, optional Maximum number of atoms in any single system in the batch. Required when using ``jax.jit`` with batched inputs. If None, inferred from data (fails under JIT). compute_forces : bool, default=False Whether to compute explicit forces. compute_charge_gradients : bool, default=False Whether to compute charge gradients. compute_virial : bool, default=False Whether to compute the virial tensor. Returns ------- energies : jax.Array, shape (N,) Per-atom reciprocal-space energy (with corrections applied). forces : jax.Array, shape (N, 3), optional Forces (if compute_forces=True or compute_charge_gradients=True). charge_gradients : jax.Array, shape (N,), optional Charge gradients (if compute_charge_gradients=True). virial : jax.Array, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the return tuple. """ # Store input dtype for kernel dispatch and outputs dtype = _normalize_dtype(positions.dtype) # Cast inputs to consistent dtype positions_cast = positions.astype(dtype) charges_cast = charges.astype(dtype) cell_cast = cell.astype(dtype) k_vectors_cast = k_vectors.astype(dtype) # Ensure cell is (B, 3, 3) if cell_cast.ndim == 2: cell_cast = cell_cast[jnp.newaxis, :, :] num_atoms = positions_cast.shape[0] is_batched = batch_idx is not None # Prepare alpha alpha_arr = _prepare_alpha_array(alpha, cell_cast.shape[0], dtype=dtype) # Compute total charge (always float64) # num_systems is derived from cell shape (cell is always (B, 3, 3) by caller convention) total_charge = _compute_total_charge( charges_cast, batch_idx, num_systems=cell_cast.shape[0] ) # Determine k-vector dimensions if is_batched: # k_vectors should be (B, K, 3); expand from (K, 3) if necessary if k_vectors_cast.ndim == 2: k_vectors_cast = jnp.tile( k_vectors_cast[jnp.newaxis, :, :], (cell_cast.shape[0], 1, 1), ) num_k = k_vectors_cast.shape[1] num_systems = k_vectors_cast.shape[0] else: # k_vectors: (K, 3) num_k = k_vectors_cast.shape[0] num_systems = 1 # Allocate intermediate arrays for structure factors (always float64) if is_batched: cos_k_dot_r = jnp.zeros((num_k, num_atoms), dtype=jnp.float64) sin_k_dot_r = jnp.zeros((num_k, num_atoms), dtype=jnp.float64) real_sf = jnp.zeros((num_systems, num_k), dtype=jnp.float64) imag_sf = jnp.zeros((num_systems, num_k), dtype=jnp.float64) else: cos_k_dot_r = jnp.zeros((num_k, num_atoms), dtype=jnp.float64) sin_k_dot_r = jnp.zeros((num_k, num_atoms), dtype=jnp.float64) real_sf = jnp.zeros(num_k, dtype=jnp.float64) imag_sf = jnp.zeros(num_k, dtype=jnp.float64) # Allocate output arrays (energies always float64) raw_energies = jnp.zeros(num_atoms, dtype=jnp.float64) energies = jnp.zeros(num_atoms, dtype=jnp.float64) # Step 1: Fill structure factors if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) # Compute atom_start, atom_end, and max_blocks_per_system for batch kernels atom_counts = jnp.bincount(batch_idx_i32, length=num_systems) atom_end = jnp.cumsum(atom_counts).astype(jnp.int32) atom_start = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), atom_end[:-1]]) if max_atoms_per_system is None: try: max_atoms_per_system = int(atom_counts.max()) except ( jax.errors.ConcretizationTypeError, jax.errors.TracerIntegerConversionError, ): raise ValueError( "Cannot infer max_atoms_per_system inside jax.jit. " "Please provide max_atoms_per_system explicitly when " "using jax.jit." ) from None max_blocks_per_system = ( max_atoms_per_system + BATCH_BLOCK_SIZE - 1 ) // BATCH_BLOCK_SIZE (total_charge, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf) = ( _jax_batch_ewald_reciprocal_fill_structure_factors[dtype]( positions_cast, charges_cast, k_vectors_cast, cell_cast, alpha_arr, atom_start, atom_end, total_charge, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, launch_dims=(num_k, num_systems, max_blocks_per_system), ) ) else: (total_charge, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf) = ( _jax_ewald_reciprocal_fill_structure_factors[dtype]( positions_cast, charges_cast, k_vectors_cast, cell_cast, alpha_arr, total_charge, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, launch_dims=(num_k,), ) ) # Step 2: Compute energy (and forces/charge_grads if requested) if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) (raw_energies, forces, charge_grads) = ( _jax_batch_ewald_reciprocal_energy_forces_charge_grad[dtype]( charges_cast, batch_idx_i32, k_vectors_cast, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, forces, charge_grads, launch_dims=(num_atoms,), ) ) elif compute_forces: forces = jnp.zeros((num_atoms, 3), dtype=dtype) (raw_energies, forces) = _jax_batch_ewald_reciprocal_energy_forces[dtype]( charges_cast, batch_idx_i32, k_vectors_cast, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, forces, launch_dims=(num_atoms,), ) else: (raw_energies,) = _jax_batch_ewald_reciprocal_compute_energy[dtype]( charges_cast, batch_idx_i32, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, launch_dims=(num_atoms,), ) else: if compute_charge_gradients: forces = jnp.zeros((num_atoms, 3), dtype=dtype) charge_grads = jnp.zeros(num_atoms, dtype=jnp.float64) (raw_energies, forces, charge_grads) = ( _jax_ewald_reciprocal_energy_forces_charge_grad[dtype]( charges_cast, k_vectors_cast, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, forces, charge_grads, launch_dims=(num_atoms,), ) ) elif compute_forces: forces = jnp.zeros((num_atoms, 3), dtype=dtype) (raw_energies, forces) = _jax_ewald_reciprocal_energy_forces[dtype]( charges_cast, k_vectors_cast, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, forces, launch_dims=(num_atoms,), ) else: (raw_energies,) = _jax_ewald_reciprocal_compute_energy[dtype]( charges_cast, cos_k_dot_r, sin_k_dot_r, real_sf, imag_sf, raw_energies, launch_dims=(num_atoms,), ) # Step 3: Apply self-energy and background corrections if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) (energies,) = _jax_batch_ewald_subtract_self_energy[dtype]( charges_cast, batch_idx_i32, alpha_arr, total_charge, raw_energies, energies, launch_dims=(num_atoms,), ) else: (energies,) = _jax_ewald_subtract_self_energy[dtype]( charges_cast, alpha_arr, total_charge, raw_energies, energies, launch_dims=(num_atoms,), ) # Step 4: Compute virial if requested virial = None if compute_virial: volume = jnp.abs(jnp.linalg.det(cell_cast)).astype(jnp.float64) if is_batched: virial = jnp.zeros((num_systems, 3, 3), dtype=dtype) (virial,) = _jax_batch_ewald_reciprocal_virial[dtype]( k_vectors_cast, # (B, K, 3) alpha_arr, volume, real_sf, # (B, K) imag_sf, # (B, K) virial, launch_dims=(num_k, num_systems), ) else: virial = jnp.zeros((1, 3, 3), dtype=dtype) (virial,) = _jax_ewald_reciprocal_virial[dtype]( k_vectors_cast, # (K, 3) alpha_arr, volume, real_sf, # (K,) imag_sf, # (K,) virial, launch_dims=(num_k,), ) # Apply corrections to charge gradients if requested if compute_charge_gradients: # Self-energy gradient: 2 * alpha / sqrt(pi) * q alpha_val = alpha_arr[0] if not is_batched else alpha_arr[batch_idx] self_energy_grad = 2.0 * alpha_val / jnp.sqrt(PI) * charges_cast # Background gradient: pi / (2 * alpha^2 * V) * Q_total volume = jnp.abs(jnp.linalg.det(cell_cast)).astype(jnp.float64) if is_batched: total_charge_per_atom = total_charge[batch_idx] volume_per_atom = volume[batch_idx] else: total_charge_per_atom = total_charge[0] volume_per_atom = volume[0] background_grad = ( PI / (2.0 * alpha_val * alpha_val * volume_per_atom) * total_charge_per_atom ) charge_grads = charge_grads - self_energy_grad - background_grad # Return results (energies and charge_grads are float64, forces match input dtype) # Virial is always last in the return tuple when requested def _build_result(): result = [energies] if compute_forces and forces is not None: result.append(forces) if compute_charge_gradients and charge_grads is not None: result.append(charge_grads) if compute_virial and virial is not None: result.append(virial) return tuple(result) if len(result) > 1 else result[0] # Initialize optional variables to None if not computed if not (compute_forces or compute_charge_gradients): forces = None charge_grads = None elif not compute_charge_gradients: charge_grads = None return _build_result()
[docs] def ewald_summation( positions: jax.Array, charges: jax.Array, cell: jax.Array, alpha: float | jax.Array | None = None, k_vectors: jax.Array | None = None, k_cutoff: float | jax.Array | None = None, miller_bounds: tuple[int, int, int] | None = None, batch_idx: jax.Array | None = None, max_atoms_per_system: int | None = None, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, mask_value: int | None = None, compute_forces: bool = False, compute_virial: bool = False, accuracy: float = 1e-6, ) -> jax.Array | tuple[jax.Array, ...]: """Compute complete Ewald summation (real + reciprocal space). The Ewald method splits long-range Coulomb into: E_total = E_real + E_reciprocal - E_self - E_background Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic partial charges. cell : jax.Array, shape (1, 3, 3) or (B, 3, 3) Unit cell matrices. alpha : float or jax.Array or None Ewald splitting parameter. If None, estimated automatically. k_vectors : jax.Array | None Reciprocal lattice vectors. If None, generated automatically. Shape (K, 3) for single system, (B, K, 3) for batch. k_cutoff : float | None K-space cutoff. Used only if k_vectors is None. miller_bounds : tuple[int, int, int] | None, optional Precomputed maximum Miller indices (M_h, M_k, M_l). Forwarded to :func:`generate_k_vectors_ewald_summation` when ``k_vectors`` is ``None``. When provided, makes k-vector generation compatible with ``jax.jit``. Use :func:`generate_miller_indices` to precompute. Ignored when ``k_vectors`` is explicitly provided. batch_idx : jax.Array | None, shape (N,) System index for each atom. max_atoms_per_system : int | None, optional Maximum number of atoms in any single system in the batch. Required when using ``jax.jit`` with batched inputs. If None, inferred from data (fails under JIT). neighbor_list : jax.Array | None, shape (2, M) Neighbor list in COO format. neighbor_ptr : jax.Array | None, shape (N+1,) CSR row pointers for neighbor list. neighbor_shifts : jax.Array | None, shape (M, 3) Periodic image shifts for neighbor list. neighbor_matrix : jax.Array | None, shape (N, max_neighbors) Dense neighbor matrix format. neighbor_matrix_shifts : jax.Array | None, shape (N, max_neighbors, 3) Periodic image shifts for neighbor_matrix. mask_value : int | None Value indicating invalid entries in neighbor_matrix. compute_forces : bool, default=False Whether to compute forces. compute_virial : bool, default=False Whether to compute the virial tensor. accuracy : float, default=1e-6 Target accuracy for automatic parameter estimation. Returns ------- energies : jax.Array, shape (N,) Per-atom total Ewald energy. forces : jax.Array, shape (N, 3), optional Forces (if compute_forces=True). virial : jax.Array, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the return tuple. Examples -------- >>> # Complete Ewald summation with automatic parameters >>> energies, forces = ewald_summation( ... positions, charges, cell, ... neighbor_list=nl, neighbor_ptr=neighbor_ptr, neighbor_shifts=shifts, ... accuracy=1e-6, ... compute_forces=True, ... ) """ # Auto-estimate alpha and k_cutoff if not provided if alpha is None or k_cutoff is None: # Ensure cell is (B, 3, 3) for parameter estimation cell_3d = cell if cell.ndim == 3 else cell[jnp.newaxis, :, :] params = estimate_ewald_parameters( positions=positions, cell=cell_3d, batch_idx=batch_idx, accuracy=accuracy, ) if alpha is None: alpha = params.alpha if k_cutoff is None: k_cutoff = params.reciprocal_space_cutoff # Generate k_vectors if not provided if k_vectors is None: # Ensure cell is (B, 3, 3) cell_3d = cell if cell.ndim == 3 else cell[jnp.newaxis, :, :] # Ensure k_cutoff is defined if k_cutoff is None: raise ValueError("k_cutoff must be provided if k_vectors is None") k_vectors = generate_k_vectors_ewald_summation( cell=cell_3d, k_cutoff=k_cutoff, miller_bounds=miller_bounds, ) # Compute real-space component real_result = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=alpha, neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, neighbor_shifts=neighbor_shifts, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, mask_value=mask_value, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=False, compute_virial=compute_virial, ) # Compute reciprocal-space component recip_result = ewald_reciprocal_space( positions=positions, charges=charges, cell=cell, k_vectors=k_vectors, alpha=alpha, batch_idx=batch_idx, max_atoms_per_system=max_atoms_per_system, compute_forces=compute_forces, compute_charge_gradients=False, compute_virial=compute_virial, ) # Sum contributions # Both real_result and recip_result have matching tuple structure based on flags # The order is: (energies, [forces], [virial]) - virial always last when present if compute_forces and compute_virial: real_energies, real_forces, real_virial = real_result # type: ignore[misc] recip_energies, recip_forces, recip_virial = recip_result # type: ignore[misc] total_energies = real_energies + recip_energies total_forces = real_forces + recip_forces total_virial = real_virial + recip_virial return total_energies, total_forces, total_virial elif compute_forces: real_energies, real_forces = real_result # type: ignore[misc] recip_energies, recip_forces = recip_result # type: ignore[misc] total_energies = real_energies + recip_energies total_forces = real_forces + recip_forces return total_energies, total_forces elif compute_virial: real_energies, real_virial = real_result # type: ignore[misc] recip_energies, recip_virial = recip_result # type: ignore[misc] total_energies = real_energies + recip_energies total_virial = real_virial + recip_virial return total_energies, total_virial else: real_energies = real_result # type: ignore[assignment] recip_energies = recip_result # type: ignore[assignment] total_energies = real_energies + recip_energies return total_energies