Source code for nvalchemiops.torch.interactions.electrostatics.ewald

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Ewald Summation - PyTorch Bindings
==================================

This module provides PyTorch bindings for Ewald summation calculations.
It wraps the framework-agnostic Warp launchers from
``nvalchemiops.interactions.electrostatics.ewald_kernels``.

Public API
----------
- ``ewald_real_space()``: Real-space component of Ewald summation
- ``ewald_reciprocal_space()``: Reciprocal-space component
- ``ewald_summation()``: Complete Ewald summation (real + reciprocal)

Mathematical Formulation
------------------------
The Ewald method splits long-range Coulomb interactions into components:

.. math::

    E_{\text{total}} = E_{\text{real}} + E_{\text{reciprocal}} - E_{\text{self}} - E_{\text{background}}

All functions support:
- Both neighbor list (CSR) and neighbor matrix formats
- Batched calculations
- Full autograd support
- Optional explicit forces and charge gradients

Examples
--------
>>> # Complete Ewald summation
>>> energies, forces = ewald_summation(
...     positions, charges, cell,
...     neighbor_list=nl, neighbor_ptr=neighbor_ptr, neighbor_shifts=shifts,
...     accuracy=1e-6,
...     compute_forces=True,
... )

>>> # Separate real and reciprocal components
>>> e_real, f_real = ewald_real_space(
...     positions, charges, cell, alpha,
...     neighbor_list=nl, neighbor_ptr=neighbor_ptr, neighbor_shifts=shifts,
...     compute_forces=True,
... )
>>> e_recip, f_recip = ewald_reciprocal_space(
...     positions, charges, cell, k_vectors, alpha,
...     compute_forces=True,
... )
"""

from __future__ import annotations

import math

import torch
import warp as wp

from nvalchemiops.interactions.electrostatics.ewald_kernels import (
    BATCH_BLOCK_SIZE,
    _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload,
    _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    _batch_ewald_real_space_energy_forces_kernel_overload,
    _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    _batch_ewald_real_space_energy_kernel_overload,
    _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload,
    _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    _batch_ewald_reciprocal_space_energy_forces_kernel_overload,
    _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    _batch_ewald_reciprocal_space_virial_kernel_overload,
    _batch_ewald_subtract_self_energy_kernel_overload,
    _ewald_real_space_energy_forces_charge_grad_kernel_overload,
    _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload,
    _ewald_real_space_energy_forces_kernel_overload,
    _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload,
    _ewald_real_space_energy_kernel_overload,
    _ewald_real_space_energy_neighbor_matrix_kernel_overload,
    _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload,
    _ewald_reciprocal_space_energy_forces_kernel_overload,
    _ewald_reciprocal_space_energy_kernel_compute_energy_overload,
    _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload,
    _ewald_reciprocal_space_virial_kernel_overload,
    _ewald_subtract_self_energy_kernel_overload,
)
from nvalchemiops.torch.autograd import (
    OutputSpec,
    WarpAutogradContextManager,
    attach_for_backward,
    needs_grad,
    warp_custom_op,
    warp_from_torch,
)
from nvalchemiops.torch.interactions.electrostatics.k_vectors import (
    generate_k_vectors_ewald_summation,
)
from nvalchemiops.torch.interactions.electrostatics.parameters import (
    estimate_ewald_parameters,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

__all__ = [
    "ewald_real_space",
    "ewald_reciprocal_space",
    "ewald_summation",
]


###########################################################################################
########################### Helper Functions ##############################################
###########################################################################################


def _prepare_alpha(
    alpha: float | torch.Tensor,
    num_systems: int,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor:
    """Convert alpha to a per-system tensor.

    Parameters
    ----------
    alpha : float or torch.Tensor
        Ewald splitting parameter. Can be:
        - A scalar float (broadcast to all systems)
        - A 0-d tensor (broadcast to all systems)
        - A 1-d tensor of shape (num_systems,) for per-system values
    num_systems : int
        Number of systems in the batch.
    dtype : torch.dtype
        Target dtype for the output tensor.
    device : torch.device
        Target device for the output tensor.

    Returns
    -------
    torch.Tensor, shape (num_systems,)
        Per-system alpha values.
    """
    if isinstance(alpha, (int, float)):
        return torch.full((num_systems,), float(alpha), dtype=dtype, device=device)
    elif isinstance(alpha, torch.Tensor):
        if alpha.dim() == 0:
            return alpha.expand(num_systems).to(dtype=dtype, device=device)
        elif alpha.shape[0] != num_systems:
            raise ValueError(
                f"alpha has {alpha.shape[0]} values but there are {num_systems} systems"
            )
        return alpha.to(dtype=dtype, device=device)
    else:
        raise TypeError(f"alpha must be float or torch.Tensor, got {type(alpha)}")


def _prepare_cell(cell: torch.Tensor) -> tuple[torch.Tensor, int]:
    """Ensure cell is 3D (B, 3, 3) and return number of systems.

    Parameters
    ----------
    cell : torch.Tensor
        Unit cell matrix. Shape (3, 3) for single system or (B, 3, 3) for batch.

    Returns
    -------
    cell : torch.Tensor, shape (B, 3, 3)
        Cell with batch dimension.
    num_systems : int
        Number of systems (B).
    """
    if cell.dim() == 2:
        cell = cell.unsqueeze(0)
    return cell, cell.shape[0]


###########################################################################################
########################### Real-Space Internal Custom Ops ################################
###########################################################################################

# Output dtype convention:
#   - Energies: always wp.float64 for numerical stability during accumulation.
#   - Forces: match input precision via get_wp_vec_dtype(pos.dtype) -- vec3f for
#     float32 inputs, vec3d for float64.  This was changed from the previous
#     hardcoded wp.vec3d to fix a dtype mismatch when positions are float32.
#   - Virial: match input precision via get_wp_mat_dtype(pos.dtype) -- mat33f for
#     float32 inputs, mat33d for float64.


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
    ],
    grad_arrays=["energies", "positions", "charges", "cell", "alpha"],
)
def _ewald_real_space_energy(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
) -> torch.Tensor:
    """Internal: Compute real-space Ewald energies (single system, neighbor list CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _ewald_real_space_energy_kernel_overload[wp_scalar],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    wp_energies,
                ],
                device=device,
            )

    if needs_grad_flag:
        attach_for_backward(
            energies,
            tape=tape,
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
    return energies


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy_forces",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _ewald_real_space_energy_forces(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald energies, forces, and optionally virial (single, CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _ewald_real_space_energy_forces_kernel_overload[wp_scalar],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy_matrix",
    outputs=[OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],))],
    grad_arrays=["energies", "positions", "charges", "cell", "alpha"],
)
def _ewald_real_space_energy_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
) -> torch.Tensor:
    """Internal: Compute real-space Ewald energies (single system, neighbor matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _ewald_real_space_energy_neighbor_matrix_kernel_overload[wp_scalar],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    wp_energies,
                ],
                device=device,
            )

    if needs_grad_flag:
        attach_for_backward(
            energies,
            tape=tape,
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
    return energies


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy_forces_matrix",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _ewald_real_space_energy_forces_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald energies, forces, and optionally virial (single, matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[
                    wp_scalar
                ],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


###########################################################################################
################## Real-Space with Charge Gradients Internal Custom Ops ###################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy_forces_charge_grad",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _ewald_real_space_energy_forces_charge_grad(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald E+F+charge_grad+virial (single, CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _ewald_real_space_energy_forces_charge_grad_kernel_overload[wp_scalar],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_charge_grads,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads, virial


@warp_custom_op(
    name="alchemiops::_ewald_real_space_energy_forces_charge_grad_matrix",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _ewald_real_space_energy_forces_charge_grad_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald E+F+charge_grad+virial (single, matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[
                    wp_scalar
                ],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_charge_grads,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads, virial


###########################################################################################
########################### Batch Real-Space Internal Custom Ops ##########################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy",
    outputs=[OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],))],
    grad_arrays=["energies", "positions", "charges", "cell", "alpha"],
)
def _batch_ewald_real_space_energy(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
) -> torch.Tensor:
    """Internal: Compute real-space Ewald energies (batch, neighbor list CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _batch_ewald_real_space_energy_kernel_overload[wp_scalar],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    wp_energies,
                ],
                device=device,
            )

    if needs_grad_flag:
        attach_for_backward(
            energies,
            tape=tape,
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
    return energies


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy_forces",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _batch_ewald_real_space_energy_forces(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald energies, forces, and optionally virial (batch, CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    num_systems = cell.shape[0]
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    virial = torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _batch_ewald_real_space_energy_forces_kernel_overload[wp_scalar],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy_matrix",
    outputs=[OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],))],
    grad_arrays=["energies", "positions", "charges", "cell", "alpha"],
)
def _batch_ewald_real_space_energy_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
) -> torch.Tensor:
    """Internal: Compute real-space Ewald energies (batch, neighbor matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload[
                    wp_scalar
                ],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    wp_energies,
                ],
                device=device,
            )

    if needs_grad_flag:
        attach_for_backward(
            energies,
            tape=tape,
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
    return energies


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy_forces_matrix",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _batch_ewald_real_space_energy_forces_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald energies, forces, and optionally virial (batch, matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    num_systems = cell.shape[0]
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    virial = torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[
                    wp_scalar
                ],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


###########################################################################################
################ Batch Real-Space with Charge Gradients Internal Custom Ops ###############
###########################################################################################


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy_forces_charge_grad",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _batch_ewald_real_space_energy_forces_charge_grad(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_list: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    neighbor_shifts: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald E+F+charge_grad+virial (batch, CSR)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nl = neighbor_list.shape[1] == 0

    idx_j = neighbor_list[1]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_idx_j = warp_from_torch(idx_j, wp.int32)
    wp_neighbor_ptr = warp_from_torch(neighbor_ptr, wp.int32)
    wp_unit_shifts = warp_from_torch(neighbor_shifts, wp.vec3i)

    num_systems = cell.shape[0]
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    virial = torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nl:
            wp.launch(
                _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload[
                    wp_scalar
                ],
                dim=[num_atoms],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_idx_j,
                    wp_neighbor_ptr,
                    wp_unit_shifts,
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_charge_grads,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads, virial


@warp_custom_op(
    name="alchemiops::_batch_ewald_real_space_energy_forces_charge_grad_matrix",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "alpha",
    ],
)
def _batch_ewald_real_space_energy_forces_charge_grad_matrix(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    neighbor_matrix_shifts: torch.Tensor,
    mask_value: int,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute real-space Ewald E+F+charge_grad+virial (batch, matrix)."""
    num_atoms = positions.shape[0]
    input_dtype = positions.dtype
    empty_nm = neighbor_matrix.shape[0] == 0

    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_neighbor_matrix = warp_from_torch(neighbor_matrix, wp.int32)
    wp_unit_shifts_matrix = warp_from_torch(neighbor_matrix_shifts, wp.vec3i)

    num_systems = cell.shape[0]
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    virial = torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )
    wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        if not empty_nm:
            wp.launch(
                _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[
                    wp_scalar
                ],
                dim=[neighbor_matrix.shape[0]],
                inputs=[
                    wp_positions,
                    wp_charges,
                    wp_cell,
                    wp_batch_idx,
                    wp_neighbor_matrix,
                    wp_unit_shifts_matrix,
                    wp.int32(mask_value),
                    wp_alpha,
                    compute_virial,
                    wp_energies,
                    wp_forces,
                    wp_charge_grads,
                    wp_virial,
                ],
                device=device,
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            alpha=wp_alpha,
        )
        if virial_grad:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads, virial


###########################################################################################
########################### Reciprocal-Space Internal Custom Ops ##########################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_ewald_reciprocal_space_energy",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _ewald_reciprocal_space_energy(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald energies and optionally virial (single)."""
    num_k = k_vectors.shape[0]
    num_atoms = positions.shape[0]
    device = wp.device_from_torch(positions.device)
    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    input_dtype = positions.dtype
    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)

    # Intermediate arrays
    wp_cos_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    wp_sin_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    real_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    imag_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    total_charge = torch.zeros(1, device=positions.device, dtype=torch.float64)
    wp_total_charge = warp_from_torch(
        total_charge, wp.float64, requires_grad=needs_grad_flag
    )
    raw_energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_raw_energies = warp_from_torch(
        raw_energies, wp.float64, requires_grad=needs_grad_flag
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=num_k,
            inputs=[wp_positions, wp_charges, wp_k_vectors, wp_cell, wp_alpha],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _ewald_reciprocal_space_energy_kernel_compute_energy_overload[wp_scalar],
            dim=num_atoms,
            inputs=[wp_charges, wp_cos_k_dot_r, wp_sin_k_dot_r, wp_real_sf, wp_imag_sf],
            outputs=[wp_raw_energies],
            device=device,
        )
        wp.launch(
            _ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[wp_charges, wp_alpha, wp_total_charge, wp_raw_energies],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell[0].to(torch.float64))).view(1)
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=num_k,
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, virial


@warp_custom_op(
    name="alchemiops::_ewald_reciprocal_space_energy_forces",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _ewald_reciprocal_space_energy_forces(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald energies, forces, and optionally virial (single)."""
    num_k = k_vectors.shape[0]
    num_atoms = positions.shape[0]
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype

    if num_k == 0 or num_atoms == 0:
        return (
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype),
            torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype),
        )

    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)

    # Intermediate arrays
    wp_cos_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    wp_sin_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    real_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    imag_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    total_charge = torch.zeros(1, device=positions.device, dtype=torch.float64)
    wp_total_charge = warp_from_torch(
        total_charge, wp.float64, requires_grad=needs_grad_flag
    )
    wp_raw_energies = warp_from_torch(
        torch.zeros(num_atoms, device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=num_k,
            inputs=[wp_positions, wp_charges, wp_k_vectors, wp_cell, wp_alpha],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _ewald_reciprocal_space_energy_forces_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_k_vectors,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            outputs=[wp_raw_energies, wp_forces],
            device=device,
        )
        wp.launch(
            _ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[wp_charges, wp_alpha, wp_total_charge, wp_raw_energies],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell[0].to(torch.float64))).view(1)
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=num_k,
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


@warp_custom_op(
    name="alchemiops::_ewald_reciprocal_space_energy_forces_charge_grad",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _ewald_reciprocal_space_energy_forces_charge_grad(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald E+F+charge_grad+virial (single)."""
    num_k = k_vectors.shape[0]
    num_atoms = positions.shape[0]
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype

    if num_k == 0 or num_atoms == 0:
        return (
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype),
        )

    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)

    # Intermediate arrays
    cos_k_dot_r = torch.zeros(
        num_k, num_atoms, device=positions.device, dtype=torch.float64
    )
    sin_k_dot_r = torch.zeros(
        num_k, num_atoms, device=positions.device, dtype=torch.float64
    )
    real_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    imag_sf = torch.zeros(num_k, device=positions.device, dtype=torch.float64)
    wp_cos_k_dot_r = warp_from_torch(
        cos_k_dot_r, wp.float64, requires_grad=needs_grad_flag
    )
    wp_sin_k_dot_r = warp_from_torch(
        sin_k_dot_r, wp.float64, requires_grad=needs_grad_flag
    )
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    total_charge = torch.zeros(1, device=positions.device, dtype=torch.float64)
    wp_total_charge = warp_from_torch(
        total_charge, wp.float64, requires_grad=needs_grad_flag
    )
    wp_raw_energies = warp_from_torch(
        torch.zeros(num_atoms, device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=num_k,
            inputs=[wp_positions, wp_charges, wp_k_vectors, wp_cell, wp_alpha],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[
                wp_scalar
            ],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_k_vectors,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            outputs=[wp_raw_energies, wp_forces, wp_charge_grads],
            device=device,
        )
        wp.launch(
            _ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[wp_charges, wp_alpha, wp_total_charge, wp_raw_energies],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell[0].to(torch.float64))).view(1)
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=num_k,
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(1, 3, 3, device=positions.device, dtype=input_dtype)

    # Apply self-energy and background corrections to charge gradients
    alpha_val = alpha[0].item()
    self_energy_grad = 2.0 * alpha_val / math.sqrt(math.pi) * charges
    background_grad = math.pi / (alpha_val * alpha_val) * total_charge[0]
    charge_grads = charge_grads - self_energy_grad - background_grad

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads.to(input_dtype), virial


###########################################################################################
########################### Batch Reciprocal-Space Internal Custom Ops ####################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_batch_ewald_reciprocal_space_energy",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _batch_ewald_reciprocal_space_energy(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald energies and optionally virial (batch)."""
    num_k = k_vectors.shape[1]
    num_atoms = positions.shape[0]
    num_systems = cell.shape[0]
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype

    if num_k == 0 or num_atoms == 0:
        return (
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype),
        )

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    atom_counts = torch.bincount(batch_idx, minlength=num_systems)
    atom_end = torch.cumsum(atom_counts, dim=0).to(torch.int32)
    atom_start = torch.cat(
        [torch.zeros(1, device=positions.device, dtype=torch.int32), atom_end[:-1]]
    )
    max_atoms_per_system = atom_counts.max().item()
    max_blocks_per_system = (
        max_atoms_per_system + BATCH_BLOCK_SIZE - 1
    ) // BATCH_BLOCK_SIZE

    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_atom_start = warp_from_torch(atom_start, wp.int32)
    wp_atom_end = warp_from_torch(atom_end, wp.int32)

    # Intermediate arrays
    wp_cos_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    wp_sin_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    real_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    imag_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_total_charge = warp_from_torch(
        torch.zeros(num_systems, device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    raw_energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_raw_energies = warp_from_torch(
        raw_energies, wp.float64, requires_grad=needs_grad_flag
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=(num_k, num_systems, max_blocks_per_system),
            inputs=[
                wp_positions,
                wp_charges,
                wp_k_vectors,
                wp_cell,
                wp_alpha,
                wp_atom_start,
                wp_atom_end,
            ],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload[
                wp_scalar
            ],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            outputs=[wp_raw_energies],
            device=device,
        )
        wp.launch(
            _batch_ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_alpha,
                wp_total_charge,
                wp_raw_energies,
            ],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell.to(torch.float64)))
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _batch_ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=(num_k, num_systems),
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, virial


@warp_custom_op(
    name="alchemiops::_batch_ewald_reciprocal_space_energy_forces",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _batch_ewald_reciprocal_space_energy_forces(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald energies, forces, and optionally virial (batch)."""
    num_k = k_vectors.shape[1]
    num_atoms = positions.shape[0]
    num_systems = cell.shape[0]
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype

    if num_k == 0 or num_atoms == 0:
        return (
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype),
            torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype),
        )

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    atom_counts = torch.bincount(batch_idx, minlength=num_systems)
    atom_end = torch.cumsum(atom_counts, dim=0).to(torch.int32)
    atom_start = torch.cat(
        [torch.zeros(1, device=positions.device, dtype=torch.int32), atom_end[:-1]]
    )
    max_atoms_per_system = atom_counts.max().item()
    max_blocks_per_system = (
        max_atoms_per_system + BATCH_BLOCK_SIZE - 1
    ) // BATCH_BLOCK_SIZE

    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_atom_start = warp_from_torch(atom_start, wp.int32)
    wp_atom_end = warp_from_torch(atom_end, wp.int32)

    # Intermediate arrays
    wp_cos_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    wp_sin_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    real_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    imag_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_total_charge = warp_from_torch(
        torch.zeros(num_systems, device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    raw_energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_raw_energies = warp_from_torch(
        raw_energies, wp.float64, requires_grad=needs_grad_flag
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=(num_k, num_systems, max_blocks_per_system),
            inputs=[
                wp_positions,
                wp_charges,
                wp_k_vectors,
                wp_cell,
                wp_alpha,
                wp_atom_start,
                wp_atom_end,
            ],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _batch_ewald_reciprocal_space_energy_forces_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_k_vectors,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            outputs=[wp_raw_energies, wp_forces],
            device=device,
        )
        wp.launch(
            _batch_ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_alpha,
                wp_total_charge,
                wp_raw_energies,
            ],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell.to(torch.float64)))
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _batch_ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=(num_k, num_systems),
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, virial


@warp_custom_op(
    name="alchemiops::_batch_ewald_reciprocal_space_energy_forces_charge_grad",
    outputs=[
        OutputSpec("energies", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "forces",
            lambda pos, *_: get_wp_vec_dtype(pos.dtype),
            lambda pos, *_: (pos.shape[0], 3),
        ),
        OutputSpec("charge_gradients", wp.float64, lambda pos, *_: (pos.shape[0],)),
        OutputSpec(
            "virial",
            lambda pos, *_: get_wp_mat_dtype(pos.dtype),
            lambda pos, charges, cell, *_: (cell.shape[0], 3, 3),
        ),
    ],
    grad_arrays=[
        "energies",
        "forces",
        "charge_gradients",
        "virial",
        "positions",
        "charges",
        "cell",
        "k_vectors",
        "alpha",
    ],
)
def _batch_ewald_reciprocal_space_energy_forces_charge_grad(
    positions: torch.Tensor,
    charges: torch.Tensor,
    cell: torch.Tensor,
    k_vectors: torch.Tensor,
    alpha: torch.Tensor,
    batch_idx: torch.Tensor,
    compute_virial: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Internal: Compute reciprocal-space Ewald E+F+charge_grad+virial (batch)."""
    num_k = k_vectors.shape[1]
    num_atoms = positions.shape[0]
    num_systems = cell.shape[0]
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype

    if num_k == 0 or num_atoms == 0:
        return (
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype),
            torch.zeros(num_atoms, device=positions.device, dtype=input_dtype),
            torch.zeros(num_systems, 3, 3, device=positions.device, dtype=input_dtype),
        )

    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    atom_counts = torch.bincount(batch_idx, minlength=num_systems)
    atom_end = torch.cumsum(atom_counts, dim=0).to(torch.int32)
    atom_start = torch.cat(
        [torch.zeros(1, device=positions.device, dtype=torch.int32), atom_end[:-1]]
    )
    max_atoms_per_system = atom_counts.max().item()
    max_blocks_per_system = (
        max_atoms_per_system + BATCH_BLOCK_SIZE - 1
    ) // BATCH_BLOCK_SIZE

    needs_grad_flag = needs_grad(positions, charges, cell)
    virial_grad = needs_grad_flag and compute_virial

    wp_positions = warp_from_torch(positions, wp_vec, requires_grad=needs_grad_flag)
    wp_charges = warp_from_torch(charges, wp_scalar, requires_grad=needs_grad_flag)
    wp_cell = warp_from_torch(cell, wp_mat, requires_grad=needs_grad_flag)
    k_vectors_typed = k_vectors.to(input_dtype)
    wp_k_vectors = warp_from_torch(
        k_vectors_typed, wp_vec, requires_grad=needs_grad_flag
    )
    wp_alpha = warp_from_torch(alpha, wp_scalar, requires_grad=needs_grad_flag)
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_atom_start = warp_from_torch(atom_start, wp.int32)
    wp_atom_end = warp_from_torch(atom_end, wp.int32)

    # Intermediate arrays
    wp_cos_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    wp_sin_k_dot_r = warp_from_torch(
        torch.zeros((num_k, num_atoms), device=positions.device, dtype=torch.float64),
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    real_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    imag_sf = torch.zeros(
        (num_systems, num_k), device=positions.device, dtype=torch.float64
    )
    wp_real_sf = warp_from_torch(real_sf, wp.float64, requires_grad=needs_grad_flag)
    wp_imag_sf = warp_from_torch(imag_sf, wp.float64, requires_grad=needs_grad_flag)
    total_charge_batch = torch.zeros(
        num_systems, device=positions.device, dtype=torch.float64
    )
    wp_total_charge = warp_from_torch(
        total_charge_batch,
        wp.float64,
        requires_grad=needs_grad_flag,
    )
    raw_energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_raw_energies = warp_from_torch(
        raw_energies, wp.float64, requires_grad=needs_grad_flag
    )
    energies = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    forces = torch.zeros(num_atoms, 3, device=positions.device, dtype=input_dtype)
    charge_grads = torch.zeros(num_atoms, device=positions.device, dtype=torch.float64)
    wp_energies = warp_from_torch(energies, wp.float64, requires_grad=needs_grad_flag)
    wp_forces = warp_from_torch(forces, wp_vec, requires_grad=needs_grad_flag)
    wp_charge_grads = warp_from_torch(
        charge_grads, wp.float64, requires_grad=needs_grad_flag
    )

    wp_virial = None
    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[
                wp_scalar
            ],
            dim=(num_k, num_systems, max_blocks_per_system),
            inputs=[
                wp_positions,
                wp_charges,
                wp_k_vectors,
                wp_cell,
                wp_alpha,
                wp_atom_start,
                wp_atom_end,
            ],
            outputs=[
                wp_total_charge,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            device=device,
        )
        wp.launch(
            _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[
                wp_scalar
            ],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_k_vectors,
                wp_cos_k_dot_r,
                wp_sin_k_dot_r,
                wp_real_sf,
                wp_imag_sf,
            ],
            outputs=[wp_raw_energies, wp_forces, wp_charge_grads],
            device=device,
        )
        wp.launch(
            _batch_ewald_subtract_self_energy_kernel_overload[wp_scalar],
            dim=num_atoms,
            inputs=[
                wp_charges,
                wp_batch_idx,
                wp_alpha,
                wp_total_charge,
                wp_raw_energies,
            ],
            outputs=[wp_energies],
            device=device,
        )
        if compute_virial:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )
            wp_virial = warp_from_torch(virial, wp_mat, requires_grad=virial_grad)
            volume = torch.abs(torch.det(cell.to(torch.float64)))
            wp_volume = warp_from_torch(volume, wp.float64)
            wp.launch(
                _batch_ewald_reciprocal_space_virial_kernel_overload[wp_scalar],
                dim=(num_k, num_systems),
                inputs=[
                    wp_k_vectors,
                    wp_alpha,
                    wp_volume,
                    wp_real_sf,
                    wp_imag_sf,
                    wp_virial,
                ],
                device=device,
            )
        else:
            virial = torch.zeros(
                num_systems, 3, 3, device=positions.device, dtype=input_dtype
            )

    # Apply self-energy and background corrections to charge gradients
    alpha_per_atom = alpha[batch_idx]
    total_charge_per_atom = total_charge_batch[batch_idx]

    self_energy_grad = 2.0 / math.sqrt(math.pi) * alpha_per_atom * charges
    background_grad = (
        math.pi / (alpha_per_atom * alpha_per_atom) * total_charge_per_atom
    )
    charge_grads = charge_grads - self_energy_grad - background_grad

    if needs_grad_flag:
        backward_kw = dict(
            energies=wp_energies,
            forces=wp_forces,
            charge_gradients=wp_charge_grads,
            positions=wp_positions,
            charges=wp_charges,
            cell=wp_cell,
            k_vectors=wp_k_vectors,
            alpha=wp_alpha,
        )
        if virial_grad and wp_virial is not None:
            backward_kw["virial"] = wp_virial
        attach_for_backward(energies, tape=tape, **backward_kw)
    return energies, forces, charge_grads.to(input_dtype), virial


###########################################################################################
########################### Public Wrapper APIs ###########################################
###########################################################################################


[docs] def ewald_real_space( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: torch.Tensor, neighbor_list: torch.Tensor | None = None, neighbor_ptr: torch.Tensor | None = None, neighbor_shifts: torch.Tensor | None = None, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, mask_value: int | None = None, batch_idx: torch.Tensor | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute real-space Ewald energy and optionally forces, charge gradients, and virial. Computes the damped Coulomb interactions for atom pairs within the real-space cutoff. The complementary error function (erfc) damping ensures rapid convergence in real space. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic coordinates. charges : torch.Tensor, shape (N,) Atomic partial charges. cell : torch.Tensor, shape (3, 3) or (B, 3, 3) Unit cell matrices. alpha : torch.Tensor, shape (1,) or (B,) Ewald splitting parameter(s). neighbor_list : torch.Tensor, shape (2, M), optional Neighbor list in COO format. neighbor_ptr : torch.Tensor, shape (N+1,), optional CSR row pointers for neighbor list. neighbor_shifts : torch.Tensor, shape (M, 3), optional Periodic image shifts for neighbor list. neighbor_matrix : torch.Tensor, shape (N, max_neighbors), optional Dense neighbor matrix format. neighbor_matrix_shifts : torch.Tensor, shape (N, max_neighbors, 3), optional Periodic image shifts for neighbor_matrix. mask_value : int, optional Value indicating invalid entries in neighbor_matrix. Defaults to N. batch_idx : torch.Tensor, shape (N,), optional System index for each atom. compute_forces : bool, default=False Whether to compute explicit forces. compute_charge_gradients : bool, default=False Whether to compute charge gradients. compute_virial : bool, default=False Whether to compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. Returns ------- energies : torch.Tensor, shape (N,) Per-atom real-space energy. forces : torch.Tensor, shape (N, 3), optional Forces (if compute_forces=True). charge_gradients : torch.Tensor, shape (N,), optional Charge gradients (if compute_charge_gradients=True). virial : torch.Tensor, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the tuple. Note ---- Energies are always float64 for numerical stability during accumulation. Forces and virial match the input dtype (float32 or float64). """ if mask_value is None: mask_value = positions.shape[0] is_batch = batch_idx is not None # The virial tensor is computed as the outer product of separation vectors and # pair forces (W += r_ij ⊗ F_ij), which is accumulated inside the force kernel. # Therefore, even when only virial is requested (compute_forces=False, # compute_virial=True), we must dispatch a force-capable kernel. need_force_kernel = compute_forces or compute_virial # Helper to build the return tuple from raw outputs using match dispatch. def _build_result(energies, forces=None, charge_grads=None, virial=None): match ( compute_forces and forces is not None, compute_charge_gradients and charge_grads is not None, compute_virial and virial is not None, ): case (True, True, True): return energies, forces, charge_grads, virial case (True, True, False): return energies, forces, charge_grads case (True, False, True): return energies, forces, virial case (True, False, False): return energies, forces case (False, True, True): return energies, charge_grads, virial case (False, True, False): return energies, charge_grads case (False, False, True): return energies, virial case _: return energies if compute_charge_gradients: if neighbor_list is not None: if neighbor_ptr is None: raise ValueError( "neighbor_ptr is required when using neighbor_list format" ) if is_batch: energies, forces, charge_grads, virial = ( _batch_ewald_real_space_energy_forces_charge_grad( positions, charges, cell, alpha, batch_idx, neighbor_list, neighbor_ptr, neighbor_shifts, compute_virial=compute_virial, ) ) else: energies, forces, charge_grads, virial = ( _ewald_real_space_energy_forces_charge_grad( positions, charges, cell, alpha, neighbor_list, neighbor_ptr, neighbor_shifts, compute_virial=compute_virial, ) ) elif neighbor_matrix is not None: if is_batch: energies, forces, charge_grads, virial = ( _batch_ewald_real_space_energy_forces_charge_grad_matrix( positions, charges, cell, alpha, batch_idx, neighbor_matrix, neighbor_matrix_shifts, mask_value, compute_virial=compute_virial, ) ) else: energies, forces, charge_grads, virial = ( _ewald_real_space_energy_forces_charge_grad_matrix( positions, charges, cell, alpha, neighbor_matrix, neighbor_matrix_shifts, mask_value, compute_virial=compute_virial, ) ) else: raise ValueError("Either neighbor_list or neighbor_matrix must be provided") return _build_result(energies, forces, charge_grads, virial) # No charge gradients requested if neighbor_list is not None: if neighbor_ptr is None: raise ValueError("neighbor_ptr is required when using neighbor_list format") if is_batch: if need_force_kernel: energies, forces, virial = _batch_ewald_real_space_energy_forces( positions, charges, cell, alpha, batch_idx, neighbor_list, neighbor_ptr, neighbor_shifts, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies = _batch_ewald_real_space_energy( positions, charges, cell, alpha, batch_idx, neighbor_list, neighbor_ptr, neighbor_shifts, ) return _build_result(energies) else: if need_force_kernel: energies, forces, virial = _ewald_real_space_energy_forces( positions, charges, cell, alpha, neighbor_list, neighbor_ptr, neighbor_shifts, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies = _ewald_real_space_energy( positions, charges, cell, alpha, neighbor_list, neighbor_ptr, neighbor_shifts, ) return _build_result(energies) elif neighbor_matrix is not None: if is_batch: if need_force_kernel: energies, forces, virial = _batch_ewald_real_space_energy_forces_matrix( positions, charges, cell, alpha, batch_idx, neighbor_matrix, neighbor_matrix_shifts, mask_value, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies = _batch_ewald_real_space_energy_matrix( positions, charges, cell, alpha, batch_idx, neighbor_matrix, neighbor_matrix_shifts, mask_value, ) return _build_result(energies) else: if need_force_kernel: energies, forces, virial = _ewald_real_space_energy_forces_matrix( positions, charges, cell, alpha, neighbor_matrix, neighbor_matrix_shifts, mask_value, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies = _ewald_real_space_energy_matrix( positions, charges, cell, alpha, neighbor_matrix, neighbor_matrix_shifts, mask_value, ) return _build_result(energies) else: raise ValueError("Either neighbor_list or neighbor_matrix must be provided")
[docs] def ewald_reciprocal_space( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, k_vectors: torch.Tensor, alpha: torch.Tensor, batch_idx: torch.Tensor | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: r"""Compute reciprocal-space Ewald energy and optionally forces, charge gradients, virial. Computes the smooth long-range electrostatic contribution using structure factors in reciprocal space. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic coordinates. charges : torch.Tensor, shape (N,) Atomic partial charges. cell : torch.Tensor, shape (3, 3) or (B, 3, 3) Unit cell matrices. k_vectors : torch.Tensor Reciprocal lattice vectors. Shape (K, 3) for single system, (B, K, 3) for batch. alpha : torch.Tensor, shape (1,) or (B,) Ewald splitting parameter(s). batch_idx : torch.Tensor, shape (N,), optional System index for each atom. compute_forces : bool, default=False Whether to compute explicit forces. compute_charge_gradients : bool, default=False Whether to compute charge gradients. compute_virial : bool, default=False Whether to compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. Returns ------- energies : torch.Tensor, shape (N,) Per-atom reciprocal-space energy. forces : torch.Tensor, shape (N, 3), optional Forces (if compute_forces=True). charge_gradients : torch.Tensor, shape (N,), optional Charge gradients (if compute_charge_gradients=True). virial : torch.Tensor, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the tuple. Note ---- Energies are always float64 for numerical stability during accumulation. Forces and virial match the input dtype (float32 or float64). """ is_batch = batch_idx is not None # Normalize k-vector rank based on dispatch mode. # Batch kernels expect (B, K, 3), single kernels expect (K, 3). if is_batch and k_vectors.dim() == 2: k_vectors = k_vectors.unsqueeze(0) elif not is_batch and k_vectors.dim() == 3 and k_vectors.shape[0] == 1: k_vectors = k_vectors.squeeze(0) # Helper to build the return tuple from raw outputs using match dispatch. def _build_result(energies, forces=None, charge_grads=None, virial=None): match ( compute_forces and forces is not None, compute_charge_gradients and charge_grads is not None, compute_virial and virial is not None, ): case (True, True, True): return energies, forces, charge_grads, virial case (True, True, False): return energies, forces, charge_grads case (True, False, True): return energies, forces, virial case (True, False, False): return energies, forces case (False, True, True): return energies, charge_grads, virial case (False, True, False): return energies, charge_grads case (False, False, True): return energies, virial case _: return energies if compute_charge_gradients: if is_batch: energies, forces, charge_grads, virial = ( _batch_ewald_reciprocal_space_energy_forces_charge_grad( positions, charges, cell, k_vectors, alpha, batch_idx, compute_virial=compute_virial, ) ) else: energies, forces, charge_grads, virial = ( _ewald_reciprocal_space_energy_forces_charge_grad( positions, charges, cell, k_vectors, alpha, compute_virial=compute_virial, ) ) return _build_result(energies, forces, charge_grads, virial) # No charge gradients if is_batch: if compute_forces: energies, forces, virial = _batch_ewald_reciprocal_space_energy_forces( positions, charges, cell, k_vectors, alpha, batch_idx, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies, virial = _batch_ewald_reciprocal_space_energy( positions, charges, cell, k_vectors, alpha, batch_idx, compute_virial=compute_virial, ) return _build_result(energies, virial=virial) else: if compute_forces: energies, forces, virial = _ewald_reciprocal_space_energy_forces( positions, charges, cell, k_vectors, alpha, compute_virial=compute_virial, ) return _build_result(energies, forces, virial=virial) else: energies, virial = _ewald_reciprocal_space_energy( positions, charges, cell, k_vectors, alpha, compute_virial=compute_virial, ) return _build_result(energies, virial=virial)
[docs] def ewald_summation( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, alpha: float | torch.Tensor | None = None, k_vectors: torch.Tensor | None = None, k_cutoff: float | None = None, batch_idx: torch.Tensor | None = None, neighbor_list: torch.Tensor | None = None, neighbor_ptr: torch.Tensor | None = None, neighbor_shifts: torch.Tensor | None = None, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, mask_value: int | None = None, compute_forces: bool = False, compute_charge_gradients: bool = False, compute_virial: bool = False, accuracy: float = 1e-6, ) -> tuple[torch.Tensor, ...] | torch.Tensor: """Complete Ewald summation for long-range electrostatics. Computes total Coulomb energy by combining real-space and reciprocal-space contributions with self-energy and background corrections. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic coordinates. charges : torch.Tensor, shape (N,) Atomic partial charges. cell : torch.Tensor, shape (3, 3) or (B, 3, 3) Unit cell matrices. alpha : float, torch.Tensor, or None, default=None Ewald splitting parameter. Auto-estimated if None. k_vectors : torch.Tensor, optional Pre-computed reciprocal lattice vectors. k_cutoff : float, optional K-space cutoff for generating k_vectors. batch_idx : torch.Tensor, shape (N,), optional System index for each atom. neighbor_list : torch.Tensor, shape (2, M), optional Neighbor pairs in COO format. neighbor_ptr : torch.Tensor, shape (N+1,), optional CSR row pointers. neighbor_shifts : torch.Tensor, shape (M, 3), optional Periodic image shifts for neighbor list. neighbor_matrix : torch.Tensor, shape (N, max_neighbors), optional Dense neighbor matrix. neighbor_matrix_shifts : torch.Tensor, shape (N, max_neighbors, 3), optional Periodic image shifts for neighbor_matrix. mask_value : int, optional Value indicating invalid entries. Defaults to N. compute_forces : bool, default=False Whether to compute explicit forces. compute_charge_gradients : bool, default=False Whether to compute charge gradients dE/dq_i. compute_virial : bool, default=False Whether to compute the virial tensor W = -dE/d(epsilon). Stress = virial / volume. accuracy : float, default=1e-6 Target accuracy for parameter estimation. Returns ------- energies : torch.Tensor, shape (N,) Per-atom total Ewald energy. forces : torch.Tensor, shape (N, 3), optional Forces (if compute_forces=True). charge_gradients : torch.Tensor, shape (N,), optional Charge gradients (if compute_charge_gradients=True). virial : torch.Tensor, shape (1, 3, 3) or (B, 3, 3), optional Virial tensor (if compute_virial=True). Always last in the tuple. Note ---- Energies are always float64 for numerical stability during accumulation. Forces, charge gradients, and virial match the input dtype (float32 or float64). """ device = positions.device dtype = positions.dtype num_atoms = positions.shape[0] cell, num_systems = _prepare_cell(cell) if alpha is None or (k_cutoff is None and k_vectors is None): params = estimate_ewald_parameters(positions, cell, batch_idx, accuracy) if alpha is None: alpha = params.alpha if k_cutoff is None: k_cutoff = params.reciprocal_space_cutoff alpha_tensor = _prepare_alpha(alpha, num_systems, dtype, device) if k_vectors is None: k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff) if mask_value is None: mask_value = num_atoms # Compute real-space rs = ewald_real_space( positions=positions, charges=charges, cell=cell, alpha=alpha_tensor, neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, neighbor_shifts=neighbor_shifts, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, mask_value=mask_value, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, ) # Compute reciprocal-space rec = ewald_reciprocal_space( positions=positions, charges=charges, cell=cell, k_vectors=k_vectors, alpha=alpha_tensor, batch_idx=batch_idx, compute_forces=compute_forces, compute_charge_gradients=compute_charge_gradients, compute_virial=compute_virial, ) # Normalize return tuples for element-wise combination rs_tuple = rs if isinstance(rs, tuple) else (rs,) rec_tuple = rec if isinstance(rec, tuple) else (rec,) results = tuple(r + s for r, s in zip(rs_tuple, rec_tuple)) if len(results) == 1: return results[0] return results