Source code for nvalchemiops.interactions.electrostatics.ewald_kernels

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unified Ewald Summation Kernels
===============================

This module provides GPU-accelerated Warp kernels for Ewald summation,
enabling efficient calculation of long-range Coulomb interactions. All kernels
support both single-system and batched calculations via the batch_idx parameter.

DTYPE FLEXIBILITY
=================

All kernels support both float32 and float64 input types via Warp's overload system:
- Input tensors (positions, charges, cell, alpha): float32 or float64
- Accumulators (energies, structure factors): Always float64 for numerical stability
- Forces: Match input positions dtype (float32 or float64)

Use the `_*_overload` dictionaries to select the appropriate kernel based on dtype.

MATHEMATICAL FORMULATION
========================

The Ewald method splits the Coulomb energy into tractable components:

.. math::

    E_{\\text{total}}(s) = E_{\\text{real}}(s) + E_{\\text{reciprocal}}(s) - E_{\\text{self}}(s) - E_{\\text{background}}(s)

Real-Space Component (damped short-range):

.. math::

    E_{\\text{real}}(s) = \\frac{1}{2} \\sum_{i \\neq j \\in s} q_i q_j \\frac{\\text{erfc}(\\alpha r_{ij})}{r_{ij}}

The erfc damping rapidly suppresses interactions beyond a cutoff distance.
Force:

.. math::

    F_{ij} = q_i q_j \\left[\\frac{\\text{erfc}(\\alpha r_{ij})}{r^2} + \\frac{2\\alpha}{\\sqrt{\\pi}} \\frac{\\exp(-\\alpha^2 r^2)}{r}\\right] \\hat{r}_{ij}

Reciprocal-Space Component (smooth long-range):

.. math::

    E_{\\text{reciprocal}}(s) = \\frac{1}{2} \\sum_{i \\in s} q_i \\phi_i

where :math:`\\phi_i = \\frac{1}{V} \\sum_{k \\neq 0} G(k) [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)]`

Green's function:

.. math::

    G(k) = \\frac{8\\pi}{k^2} \\exp\\left(-\\frac{k^2}{4\\alpha^2}\\right)

Structure factors:

.. math::

    S(k) = \\sum_j q_j \\exp(ik \\cdot r_j)

    Note: G(k) uses 8*pi (not 4*pi) because we use half-space k-vectors, exploiting
    the symmetry S(-k) = S*(k). This halves the number of k-vectors while
    maintaining correct energies/forces.

Self-Energy Correction (removes spurious self-interaction):

.. math::

    E_{\\text{self}}(s) = \\sum_{i \\in s} \\frac{\\alpha}{\\sqrt{\\pi}} q_i^2

Background Correction (for non-neutral systems):

.. math::

    E_{\\text{background}}(s) = \\sum_{i \\in s} \\frac{\\pi}{2\\alpha^2 V} q_i Q_{\\text{total}}

KERNEL ORGANIZATION
===================

Real-Space Kernels:
    - _ewald_real_space_energy_kernel: Single-system, neighbor list format
    - _ewald_real_space_energy_forces_kernel: Single-system with forces
    - _ewald_real_space_energy_neighbor_matrix_kernel: Neighbor matrix format
    - _ewald_real_space_energy_forces_neighbor_matrix_kernel: Matrix with forces
    - _batch_ewald_real_space_*: Batched versions of above

Reciprocal-Space Kernels:
    - _ewald_reciprocal_space_energy_kernel_fill_structure_factors: Compute S(k)
    - _ewald_reciprocal_space_energy_kernel_compute_energy: Energy from S(k)
    - _ewald_reciprocal_space_energy_forces_kernel: Energy + forces from S(k)
    - _ewald_subtract_self_energy_kernel: Apply self + background corrections
    - _batch_ewald_reciprocal_space_*: Batched versions of above

PERFORMANCE TUNING
==================

Environment variables for performance tuning:

ALCH_EWALD_BATCH_BLOCK_SIZE (default: 16)
    Block size for batched structure factor computation. Each thread processes
    a block of atoms, reducing atomic contention. Benchmark results show:
    - 16 is optimal for most scenarios (2-3x faster than atom-major)
    - Atom-major (no blocking) only wins for very large atom counts (>100K atoms)
    - Tune this if you have unusual workloads (many small or few large systems)

REFERENCES
==========

- Ewald, P. P. (1921). Ann. Phys. 369, 253-287 (Original Ewald method)
- Kolafa, J. & Perram, J. W. (1992). Mol. Sim. 9, 351-368 (Parameter optimization)
- Essmann et al. (1995). J. Chem. Phys. 103, 8577 (PME method)
"""

import math
import os
from typing import Any

import warp as wp

from nvalchemiops.math import wp_erfc, wp_exp_kernel

# Mathematical constants
PI = math.pi
TWOPI = 2.0 * PI
FOURPI = 4.0 * PI
EIGHTPI = 8.0 * PI  # Used for half-space k-vector optimization (2x FOURPI)

# Block size for batch structure factor accumulation
# Benchmark results show 16 is optimal for most cases (except very large atom counts)
BATCH_BLOCK_SIZE = int(os.environ.get("ALCH_EWALD_BATCH_BLOCK_SIZE", 16))
BATCH_BLOCK_SIZE = BATCH_BLOCK_SIZE if BATCH_BLOCK_SIZE > 0 else 16


###########################################################################################
########################### Helper Functions (always float64) #############################
###########################################################################################


@wp.func
def _ewald_real_space_energy_kernel_compute_energy(
    qi: wp.float64,
    qj: wp.float64,
    distance: wp.float64,
    alpha: wp.float64,
) -> wp.float64:
    """Compute damped Coulomb energy for a single pair.

    Formula:

    .. math::

        E_{ij} = \\frac{1}{2} q_i q_j \\frac{\\text{erfc}(\\alpha r)}{r}

    The 0.5 factor accounts for pair double-counting when iterating
    over all (i,j) pairs.

    Parameters
    ----------
    qi, qj : wp.float64
        Charges of atoms i and j.
    distance : wp.float64
        Distance |r_j - r_i|.
    alpha : wp.float64
        Ewald splitting parameter.

    Returns
    -------
    wp.float64
        Damped Coulomb energy contribution.
    """
    return wp.float64(0.5) * qi * qj * wp_erfc(alpha * distance) / distance


@wp.func
def _ewald_real_space_force_magnitude(
    qi: wp.float64,
    qj: wp.float64,
    distance: wp.float64,
    alpha: wp.float64,
) -> wp.float64:
    """Compute damped Coulomb force magnitude factor for a single pair.

    Returns the scalar part of the force:

    .. math::

        F = q_i q_j \\left[\\frac{\\text{erfc}(\\alpha r)}{r^3} + \\frac{2\\alpha}{\\sqrt{\\pi}} \\frac{\\exp(-\\alpha^2 r^2)}{r^2}\\right]

    To get the force vector, multiply by the separation vector.

    Parameters
    ----------
    qi, qj : wp.float64
        Charges of atoms i and j.
    distance : wp.float64
        Distance |r_j - r_i|.
    alpha : wp.float64
        Ewald splitting parameter.

    Returns
    -------
    wp.float64
        Force magnitude factor.
    """
    two_over_sqrt_pi = wp.float64(2.0 / 1.7724538509055159)

    prefactor = wp.float64(0.5) * qi * qj
    alpha_r = alpha * distance
    alpha_r_squared = alpha_r * alpha_r

    erfc_alpha_r = wp_erfc(alpha_r)
    exp_term = wp.exp(-alpha_r_squared)

    # Force magnitude / r^2
    force_mag_over_r = erfc_alpha_r / (
        distance * distance * distance
    ) + two_over_sqrt_pi * alpha * exp_term / (distance * distance)
    return prefactor * force_mag_over_r


@wp.func
def _ewald_real_space_charge_grad_potential(
    distance: wp.float64,
    alpha: wp.float64,
) -> wp.float64:
    """Compute the damped Coulomb potential for charge gradient.

    Returns (1/2) * erfc(α·r) / r, which when multiplied by q_j gives
    the charge gradient contribution to atom i.

    For pair (i,j) with energy E_ij = (1/2) * q_i * q_j * erfc(α·r) / r:
        ∂E_ij/∂q_i = (1/2) * q_j * erfc(α·r) / r = potential * q_j
        ∂E_ij/∂q_j = (1/2) * q_i * erfc(α·r) / r = potential * q_i

    Parameters
    ----------
    distance : wp.float64
        Distance |r_j - r_i|.
    alpha : wp.float64
        Ewald splitting parameter.

    Returns
    -------
    wp.float64
        Potential factor for charge gradient computation.
    """
    return wp.float64(0.5) * wp_erfc(alpha * distance) / distance


###########################################################################################
########################### Real-Space Kernels (dtype-flexible) ###########################
###########################################################################################


@wp.kernel
def _ewald_real_space_energy_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    pair_energies: wp.array(dtype=wp.float64),
):
    """Compute real-space Ewald energies using neighbor matrix format.

    Each thread processes one atom and loops over all its neighbors in the
    neighbor matrix. This 1D launch pattern is more efficient than 2D launch
    as it reduces thread divergence and improves memory access patterns.
    Invalid neighbors (marked with mask_value) are skipped. Pairs that are
    too close (less than 1e-8 distance) are also skipped.

    Launch Grid
    -----------
    dim = [N_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32
        Neighbor indices. Entry [i, k] = j means atom j is the k-th neighbor of i.
        Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.

    Notes
    -----
    Energy is accumulated in a local register then written once, reducing atomic
    contention. Internal computations use float64 for numerical stability.
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulate energy in local register
    energy_acc = wp.float64(0.0)
    max_neighbors = neighbor_matrix.shape[1]

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Compute periodic shift (in input precision, then cast)
        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )

        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

    # Write accumulated energy once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)


@wp.kernel
def _ewald_real_space_energy_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    pair_energies: wp.array(dtype=wp.float64),
):
    """Compute real-space Ewald energies using neighbor list (CSR) format.

    Each thread processes one atom and loops over its neighbors using CSR
    pointers. This 1D launch pattern is more efficient than one-thread-per-pair
    as it reduces atomic contention and allows local accumulation.
    Pairs too close (less than 1e-8 distance) are skipped.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32
        CSR row pointers. neighbor_ptr[i] to neighbor_ptr[i+1] gives the range
        of neighbors for atom i in idx_j.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.

    Notes
    -----
    Energy is accumulated in a local register then written once, reducing
    atomic contention. Internal computations use float64 for numerical stability.
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulate energy in local register
    energy_acc = wp.float64(0.0)

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]

    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Compute periodic shift
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )

        # Compute separation vector
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        # Compute real-space energy with erfc damping
        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

    # Write accumulated energy once
    wp.atomic_add(pair_energies, atom_i, energy_acc)


@wp.kernel
def _ewald_real_space_energy_forces_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy and forces using neighbor list (CSR) format.

    Each thread processes one atom and loops over its neighbors using CSR
    pointers. Energy and force on atom i are accumulated locally. Force on
    atom j uses atomic_add. Pairs too close (less than 1e-8 distance) are skipped.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32
        CSR row pointers. neighbor_ptr[i] to neighbor_ptr[i+1] gives the range
        of neighbors for atom i in idx_j.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom (matches positions dtype).
    virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).

    Notes
    -----
    Energy accumulated locally then written once. Force on atom i accumulated
    locally; force on atom j uses atomic_add (Newton's 3rd law).
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulators for energy and force on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]
    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]
        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Apply periodic shift
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            # Compute damped Coulomb energy
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

            # Compute force magnitude (in float64)
            force_mag = _ewald_real_space_force_magnitude(qi, qj, distance, alpha_)

            # Apply force in positions dtype
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_i, energy_acc)
    wp.atomic_add(atomic_forces, atom_i, force_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, 0, type(cell_t)(virial_acc))


@wp.kernel
def _ewald_real_space_energy_forces_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy and forces using neighbor matrix format.

    Each thread processes one atom and loops over all its neighbors. This 1D
    launch pattern is more efficient than 2D launch as it reduces thread
    divergence and improves memory access patterns. Energy is accumulated in
    a local register and written once. Force on atom j uses atomic_add.
    Pairs too close (less than 1e-8 distance) or invalid are skipped.

    Launch Grid
    -----------
    dim = [N_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32
        Neighbor indices. Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom (matches positions dtype).
    virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).

    Notes
    -----
    Energy accumulated locally then written once. Forces on atom i accumulated
    locally; forces on atom j use atomic_add (Newton's 3rd law).
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulators for energy and force on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    max_neighbors = neighbor_matrix.shape[1]

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

            force_mag = _ewald_real_space_force_magnitude(qi, qj, distance, alpha_)
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)
    wp.atomic_add(atomic_forces, atom_idx, force_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, 0, type(cell_t)(virial_acc))


###########################################################################################
#################### Real-Space Kernels with Charge Gradients #############################
###########################################################################################


@wp.kernel
def _ewald_real_space_energy_forces_charge_grad_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy, forces, AND charge gradients (neighbor list CSR).

    Each thread processes one atom and loops over its neighbors using CSR pointers.
    Energy, force, and charge gradient for atom i are accumulated locally. Forces
    and charge gradients on atom j use atomic_add. Pairs too close are skipped.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32
        CSR row pointers.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom (matches positions dtype).
    charge_gradients : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated charge gradients dE/dq per atom.
    virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).

    Notes
    -----
    Energy, force, charge gradient on atom i accumulated locally then written once.
    Forces and charge gradients on atom j use atomic_add.
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulators for energy, force, and charge gradient on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    cg_i_acc = wp.float64(0.0)

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]

    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Apply periodic shift
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            # Compute damped Coulomb energy
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

            # Compute force magnitude (in float64)
            force_mag = _ewald_real_space_force_magnitude(qi, qj, distance, alpha_)

            # Apply force in positions dtype
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Compute charge gradients
            potential = _ewald_real_space_charge_grad_potential(distance, alpha_)
            cg_i_acc += qj * potential
            cg_j = qi * potential
            wp.atomic_add(charge_gradients, j, cg_j)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_i, energy_acc)
    wp.atomic_add(atomic_forces, atom_i, force_i_acc)
    wp.atomic_add(charge_gradients, atom_i, cg_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, 0, type(cell_t)(virial_acc))


@wp.kernel
def _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy, forces, AND charge gradients (neighbor matrix).

    Each thread processes one atom and loops over all its neighbors. This 1D
    launch pattern is more efficient than 2D launch. Energy and charge gradient
    for atom i are accumulated locally and written once. Forces and charge
    gradients on atom j use atomic_add.

    Launch Grid
    -----------
    dim = [N_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix.
    neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32
        Neighbor indices. Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom (matches positions dtype).
    charge_gradients : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Accumulated charge gradients dE/dq per atom.
    virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    alpha_ = wp.float64(alpha[0])
    cell_t = wp.transpose(cell[0])

    # Accumulators for energy, force, and charge gradient on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    cg_i_acc = wp.float64(0.0)
    max_neighbors = neighbor_matrix.shape[1]

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, alpha_
            )

            force_mag = _ewald_real_space_force_magnitude(qi, qj, distance, alpha_)
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Compute charge gradients
            potential = _ewald_real_space_charge_grad_potential(distance, alpha_)
            cg_i_acc += qj * potential
            cg_j = qi * potential
            wp.atomic_add(charge_gradients, j, cg_j)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)
    wp.atomic_add(atomic_forces, atom_idx, force_i_acc)
    wp.atomic_add(charge_gradients, atom_idx, cg_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, 0, type(cell_t)(virial_acc))


###########################################################################################
########################### Batch Real-Space Kernels ######################################
###########################################################################################


@wp.kernel
def _batch_ewald_real_space_energy_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    pair_energies: wp.array(dtype=wp.float64),
):
    """Compute real-space Ewald energies for batched systems (neighbor matrix).

    Each thread processes one atom and loops over its neighbors. This 1D launch
    pattern is more efficient than 2D launch. Per-system cell and alpha are
    looked up using batch_id. Energy is accumulated locally and written once.

    Launch Grid
    -----------
    dim = [N_total]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32
        Neighbor indices. Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    system_id = batch_id[atom_idx]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulate energy in local register
    energy_acc = wp.float64(0.0)
    max_neighbors = neighbor_matrix.shape[1]

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

    # Write accumulated energy once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)


@wp.kernel
def _batch_ewald_real_space_energy_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    pair_energies: wp.array(dtype=wp.float64),
):
    """Compute real-space Ewald energies for batched systems (neighbor list CSR).

    Each thread processes one atom and loops over its neighbors using CSR
    pointers. Per-system cell and alpha are looked up using batch_id.
    Energy is accumulated locally and written once.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32
        CSR row pointers.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    system_id = batch_id[atom_i]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulate energy in local register
    energy_acc = wp.float64(0.0)

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]

    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Convert unit shifts to Cartesian using system cell
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )

        # Compute separation vector
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        # Compute real-space energy with erfc damping
        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

    # Write accumulated energy once
    wp.atomic_add(pair_energies, atom_i, energy_acc)


@wp.kernel
def _batch_ewald_real_space_energy_forces_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy and forces for batched systems (neighbor list CSR).

    Each thread processes one atom and loops over its neighbors using CSR
    pointers. Energy and force on atom i are accumulated locally. Forces on
    atom j use atomic_add.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32
        CSR row pointers.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom.
    virial : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    system_id = batch_id[atom_i]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulators for energy and force on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]

    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Apply periodic shift using system cell
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            # Compute damped Coulomb energy
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

            force_mag = _ewald_real_space_force_magnitude(
                qi, qj, distance, system_alpha
            )
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_i, energy_acc)
    wp.atomic_add(atomic_forces, atom_i, force_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, system_id, type(cell_t)(virial_acc))


@wp.kernel
def _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy and forces for batched systems (neighbor matrix).

    Each thread processes one atom and loops over its neighbors. This 1D launch
    pattern is more efficient than 2D launch. Energy and force on atom i are
    accumulated locally. Forces on atom j use atomic_add.

    Launch Grid
    -----------
    dim = [N_total]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32
        Neighbor indices. Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom.
    virial : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    system_id = batch_id[atom_idx]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulators for energy and force on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    max_neighbors = neighbor_matrix.shape[1]

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

            force_mag = _ewald_real_space_force_magnitude(
                qi, qj, distance, system_alpha
            )
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)
    wp.atomic_add(atomic_forces, atom_idx, force_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, system_id, type(cell_t)(virial_acc))


###########################################################################################
#################### Batch Real-Space Kernels with Charge Gradients #######################
###########################################################################################


@wp.kernel
def _batch_ewald_real_space_energy_forces_charge_grad_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    idx_j: wp.array(dtype=wp.int32),
    neighbor_ptr: wp.array(dtype=wp.int32),
    unit_shifts: wp.array(dtype=wp.vec3i),
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy, forces, AND charge gradients (batch, CSR).

    Each thread processes one atom and loops over its neighbors using CSR
    pointers. Energy, force, and charge gradient for atom i are accumulated
    locally. Forces and charge gradients on atom j use atomic_add.

    Launch Grid
    -----------
    dim = [num_atoms]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    idx_j : wp.array, shape (M,), dtype=wp.int32
        Target atom indices for each pair (flattened CSR data).
    neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32
        CSR row pointers.
    unit_shifts : wp.array, shape (M,), dtype=wp.vec3i
        Periodic image shifts for each pair.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom.
    charge_gradients : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated charge gradients dE/dq per atom.
    virial : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).
    """
    atom_i = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_i])
    pos_i = positions[atom_i]
    system_id = batch_id[atom_i]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulators for energy, force, and charge gradient on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    cg_i_acc = wp.float64(0.0)

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    # Iterate over neighbors using CSR pointers
    j_range_start = neighbor_ptr[atom_i]
    j_range_end = neighbor_ptr[atom_i + 1]

    for edge_idx in range(j_range_start, j_range_end):
        j = idx_j[edge_idx]

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        # Apply periodic shift using system cell
        shift_vec = unit_shifts[edge_idx]
        periodic_shift = cell_t * type(pos_i)(
            type(pos_i[0])(shift_vec[0]),
            type(pos_i[0])(shift_vec[1]),
            type(pos_i[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            # Compute damped Coulomb energy
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

            force_mag = _ewald_real_space_force_magnitude(
                qi, qj, distance, system_alpha
            )
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Compute charge gradients
            potential = _ewald_real_space_charge_grad_potential(distance, system_alpha)
            cg_i_acc += qj * potential
            cg_j = qi * potential
            wp.atomic_add(charge_gradients, j, cg_j)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_i, energy_acc)
    wp.atomic_add(atomic_forces, atom_i, force_i_acc)
    wp.atomic_add(charge_gradients, atom_i, cg_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, system_id, type(cell_t)(virial_acc))


@wp.kernel
def _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    neighbor_matrix: wp.array2d(dtype=wp.int32),
    unit_shifts_matrix: wp.array2d(dtype=wp.vec3i),
    mask_value: wp.int32,
    alpha: wp.array(dtype=Any),
    compute_virial: bool,
    pair_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute real-space Ewald energy, forces, AND charge gradients for batched systems.

    Each thread processes one atom and loops over its neighbors. This 1D launch
    pattern is more efficient than 2D launch. Energy, force, and charge gradient
    for atom i are accumulated locally. Forces and charge gradients on atom j
    use atomic_add.

    Launch Grid
    -----------
    dim = [N_total]

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrices for each system.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32
        Neighbor indices. Invalid entries contain mask_value.
    unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i
        Periodic image shifts for each neighbor pair.
    mask_value : wp.int32
        Value indicating invalid/padded neighbor entries.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    compute_virial : bool
        Whether to compute the virial tensor.
    pair_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated real-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Accumulated forces per atom.
    charge_gradients : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Accumulated charge gradients dE/dq per atom.
    virial : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor per system (if compute_virial=True).
    """
    atom_idx = wp.tid()

    # Load atom i data once
    qi = wp.float64(charges[atom_idx])
    pos_i = positions[atom_idx]
    system_id = batch_id[atom_idx]
    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])
    cell_t = wp.transpose(system_cell)

    # Accumulators for energy, force, and charge gradient on atom i
    energy_acc = wp.float64(0.0)
    force_i_acc = type(pos_i)(
        type(pos_i[0])(0.0), type(pos_i[0])(0.0), type(pos_i[0])(0.0)
    )
    cg_i_acc = wp.float64(0.0)
    max_neighbors = neighbor_matrix.shape[1]

    # Initialize virial accumulator
    if compute_virial:
        virial_acc = wp.mat33d()

    for neighbor_idx in range(max_neighbors):
        j = neighbor_matrix[atom_idx, neighbor_idx]
        if j == mask_value:
            continue

        qj = wp.float64(charges[j])
        pos_j = positions[j]

        shift_vec = unit_shifts_matrix[atom_idx, neighbor_idx]
        periodic_shift = cell_t * type(pos_j)(
            type(pos_j[0])(shift_vec[0]),
            type(pos_j[0])(shift_vec[1]),
            type(pos_j[0])(shift_vec[2]),
        )
        separation_vector = pos_j - pos_i + periodic_shift
        distance = wp.float64(wp.length(separation_vector))

        if distance > wp.float64(1e-8):
            energy_acc += _ewald_real_space_energy_kernel_compute_energy(
                qi, qj, distance, system_alpha
            )

            force_mag = _ewald_real_space_force_magnitude(
                qi, qj, distance, system_alpha
            )
            force = type(pos_i)(
                type(pos_i[0])(force_mag) * separation_vector[0],
                type(pos_i[0])(force_mag) * separation_vector[1],
                type(pos_i[0])(force_mag) * separation_vector[2],
            )
            force_i_acc -= force
            wp.atomic_add(atomic_forces, j, force)

            # Compute charge gradients
            potential = _ewald_real_space_charge_grad_potential(distance, system_alpha)
            cg_i_acc += qj * potential
            cg_j = qi * potential
            wp.atomic_add(charge_gradients, j, cg_j)

            # Accumulate virial if requested
            if compute_virial:
                virial_acc += wp.mat33d(
                    wp.outer(
                        wp.vec3d(
                            wp.float64(separation_vector[0]),
                            wp.float64(separation_vector[1]),
                            wp.float64(separation_vector[2]),
                        ),
                        wp.vec3d(
                            wp.float64(force[0]),
                            wp.float64(force[1]),
                            wp.float64(force[2]),
                        ),
                    )
                )

    # Write accumulated values once
    wp.atomic_add(pair_energies, atom_idx, energy_acc)
    wp.atomic_add(atomic_forces, atom_idx, force_i_acc)
    wp.atomic_add(charge_gradients, atom_idx, cg_i_acc)

    # Virial contribution: force already includes 0.5 factor for full-NL pair counting,
    # so virial_acc = sum_{i<j} outer(r_ij, F_pair) = W = -dE/dε (no extra scaling).
    if compute_virial:
        wp.atomic_add(virial, system_id, type(cell_t)(virial_acc))


###########################################################################################
########################### Reciprocal-Space Kernels ######################################
###########################################################################################


@wp.kernel
def _ewald_reciprocal_space_energy_kernel_fill_structure_factors(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    k_vectors: wp.array(dtype=Any),
    cell: wp.array(dtype=Any),
    alpha: wp.array(dtype=Any),
    total_charge: wp.array(dtype=wp.float64),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array(dtype=wp.float64),
    imag_structure_factors: wp.array(dtype=wp.float64),
):
    """Compute structure factors for reciprocal-space Ewald summation.

    This kernel uses K-major iteration: each thread processes one k-vector
    over all atoms. This avoids atomics entirely since each thread fully
    owns its k-vector's output.

    The weighted structure factors are:

    .. math::

        \\begin{aligned}
        S_{\\text{real}}(k) &= \\frac{G(k)}{V} \\sum_i q_i \\cos(k \\cdot r_i) \\\\
        S_{\\text{imag}}(k) &= \\frac{G(k)}{V} \\sum_i q_i \\sin(k \\cdot r_i)
        \\end{aligned}

    where :math:`G(k) = \\frac{4\\pi}{k^2} \\exp(-k^2/(4\\alpha^2))` is the Green's function.

    Launch Grid
    -----------
    dim = [K]

    Each thread processes one k-vector over all N atoms.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d
        Half-space reciprocal lattice vectors (excludes -k for each k).
    cell : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Unit cell matrix (for computing volume).
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    total_charge : wp.array, shape (1,), dtype=wp.float64
        OUTPUT: Accumulated total charge divided by volume (Q/V) for
        background correction. Only thread 1 accumulates this.
    cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        OUTPUT: :math:`\\cos(k \\cdot r_i)` for each (k, atom) pair.
    sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        OUTPUT: :math:`\\sin(k \\cdot r_i)` for each (k, atom) pair.
    real_structure_factors : wp.array, shape (K,), dtype=wp.float64
        OUTPUT: :math:`(G(k)/V) \\sum_i q_i \\cos(k \\cdot r_i)`.
    imag_structure_factors : wp.array, shape (K,), dtype=wp.float64
        OUTPUT: :math:`(G(k)/V) \\sum_i q_i \\sin(k \\cdot r_i)`.

    Notes
    -----
    - K-major iteration avoids atomics (each thread owns its k output).
    - k=0 is skipped (early return) to avoid division by zero in G(k).
    - Thread 1 accumulates total_charge as Q/V for background correction.
    - All internal computations use float64 for numerical stability.
    - cos_k_dot_r and sin_k_dot_r store unweighted phases for charge gradient computation.
    - Half-space k-vectors with 8π Green's function give ~2x speedup.
    """
    k_idx = wp.tid()
    num_atoms = positions.shape[0]

    alpha_ = wp.float64(alpha[0])
    exp_factor = wp.float64(0.25) / (alpha_ * alpha_)
    volume = wp.float64(wp.abs(wp.determinant(cell[0])))

    k_vector = k_vectors[k_idx]
    # Cast k_vector components to float64 for precision
    kx = wp.float64(k_vector[0])
    ky = wp.float64(k_vector[1])
    kz = wp.float64(k_vector[2])
    k_squared = kx * kx + ky * ky + kz * kz

    # Skip k=0 (would cause division by zero)
    if k_squared < wp.float64(1e-10):
        return

    # Compute Green's function: (8*pi/V) * exp(-k^2/(4*alpha^2)) / k^2
    green_function = wp_exp_kernel(k_squared, exp_factor) * wp.float64(EIGHTPI) / volume

    # Accumulate structure factors in registers (no atomics!)
    real_sum = wp.float64(0.0)
    imag_sum = wp.float64(0.0)

    for atom_idx in range(num_atoms):
        position = positions[atom_idx]
        charge = wp.float64(charges[atom_idx])

        # Thread 1 accumulates total charge for background correction
        if k_idx == 1:
            tc = charge / volume
            wp.atomic_add(total_charge, 0, tc)

        # Compute k*r in float64
        k_dot_r = (
            kx * wp.float64(position[0])
            + ky * wp.float64(position[1])
            + kz * wp.float64(position[2])
        )
        cos_kr = wp.cos(k_dot_r)
        sin_kr = wp.sin(k_dot_r)

        # Store per-(k, atom) UNWEIGHTED phase factors (for charge gradients)
        cos_k_dot_r[k_idx, atom_idx] = cos_kr
        sin_k_dot_r[k_idx, atom_idx] = sin_kr

        # Accumulate structure factors (charge-weighted) in registers
        real_sum += charge * cos_kr * green_function
        imag_sum += charge * sin_kr * green_function

    # Write final structure factors (no atomics needed)
    real_structure_factors[k_idx] = real_sum
    imag_structure_factors[k_idx] = imag_sum


@wp.kernel
def _ewald_reciprocal_space_energy_kernel_compute_energy(
    charges: wp.array(dtype=Any),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array(dtype=wp.float64),
    imag_structure_factors: wp.array(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
):
    """Compute per-atom reciprocal-space energies from structure factors.

    This kernel uses atom-major iteration: each thread processes one atom
    over all k-vectors. This avoids atomics since each thread fully owns
    its atom's output.

    For each atom i:

    .. math::

        E_i = \\frac{1}{2} \\sum_k [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)] q_i

    The 0.5 factor accounts for the pair energy sum: :math:`E = \\frac{1}{2} \\sum_i q_i \\phi_i`

    Launch Grid
    -----------
    dim = [N]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation.
    sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation.
    real_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed :math:`S_{\\text{real}}(k) = (G(k)/V) \\sum_j q_j \\cos(k \\cdot r_j)`.
    imag_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed :math:`S_{\\text{imag}}(k) = (G(k)/V) \\sum_j q_j \\sin(k \\cdot r_j)`.
    reciprocal_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.

    Notes
    -----
    - Atom-major iteration avoids atomics (each thread owns its atom output)
    - The 0.5 factor is applied here (not in structure factor computation)
    - cos_k_dot_r and sin_k_dot_r are unweighted; charge is multiplied here
    - All computations in float64
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[0]
    charge = wp.float64(charges[atom_idx])

    # Accumulate potential in register (no atomics!)
    local_potential = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = cos_k_dot_r[k_idx, atom_idx]
        sin_kr = sin_k_dot_r[k_idx, atom_idx]
        s_real = real_structure_factors[k_idx]
        s_imag = imag_structure_factors[k_idx]

        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += charge * phase_sum

    # Write final energy: E_i = (1/2) * q_i * phi_i (no atomics needed)
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential


@wp.kernel
def _ewald_subtract_self_energy_kernel(
    charges: wp.array(dtype=Any),
    alpha: wp.array(dtype=Any),
    total_charge: wp.array(dtype=wp.float64),
    energy_in: wp.array(dtype=wp.float64),
    energy_out: wp.array(dtype=wp.float64),
):
    """Apply self-energy and background corrections to reciprocal-space energies.

    For each atom i:

    .. math::

        E_{\\text{out},i} = E_{\\text{in},i} - E_{\\text{self},i} - E_{\\text{background},i}

    where:

    .. math::

        \\begin{aligned}
        E_{\\text{self},i} &= \\frac{\\alpha}{\\sqrt{\\pi}} q_i^2 \\\\
        E_{\\text{background},i} &= \\frac{\\pi}{2\\alpha^2} q_i \\frac{Q_{\\text{total}}}{V}
        \\end{aligned}

    The self-energy removes the spurious interaction of each Gaussian charge
    distribution with itself. The background correction accounts for the
    uniform neutralizing background charge for non-neutral systems.

    Launch Grid
    -----------
    dim = [N]

    Parameters
    ----------
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    total_charge : wp.array, shape (1,), dtype=wp.float64
        Total charge divided by volume (Q_total/V), precomputed in
        _ewald_reciprocal_space_energy_kernel_fill_structure_factors.
    energy_in : wp.array, shape (N,), dtype=wp.float64
        Raw reciprocal-space energy per atom (from potential interpolation).
    energy_out : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Corrected reciprocal-space energy per atom.

    Notes
    -----
    - Uses separate input/output arrays to avoid in-place modification,
      which would cause incorrect gradient accumulation in Warp's autodiff
    - For neutral systems, the background correction is zero
    - All computations in float64
    """
    atom_index = wp.tid()
    charge = wp.float64(charges[atom_index])
    alpha_ = wp.float64(alpha[0])
    # Compute self-energy: alpha * q^2 / sqrt(pi)
    self_energy = alpha_ * charge * charge / wp.sqrt(wp.float64(PI))

    # Background correction: pi / (2*alpha^2) * q * (Q_total/V)
    neutralization_energy = (
        wp.float64(PI) * charge * total_charge[0] / (wp.float64(2.0) * alpha_ * alpha_)
    )

    # Subtract self-energy (separate input/output to avoid autodiff issues)
    energy_out[atom_index] = energy_in[atom_index] - self_energy - neutralization_energy


@wp.kernel
def _ewald_reciprocal_space_energy_forces_kernel(
    charges: wp.array(dtype=Any),
    k_vectors: wp.array(dtype=Any),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array(dtype=wp.float64),
    imag_structure_factors: wp.array(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
):
    """Compute reciprocal-space Ewald energies and forces simultaneously.

    This kernel uses atom-major iteration: each thread processes one atom
    over all k-vectors. This avoids atomics since each thread fully owns
    its atom's output.

    For each atom i:

    .. math::

        \\begin{aligned}
        E_i &= \\frac{1}{2} \\sum_k [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)] q_i \\\\
        F_i &= \\sum_k k [S_{\\text{real}}(k) \\sin(k \\cdot r_i) - S_{\\text{imag}}(k) \\cos(k \\cdot r_i)] q_i
        \\end{aligned}

    The force formula comes from :math:`-\\nabla_i E`, where the gradient acts on the
    :math:`\\cos(k \\cdot r_i)` and :math:`\\sin(k \\cdot r_i)` terms:

    .. math::

        \\begin{aligned}
        \\frac{\\partial}{\\partial r_i} \\cos(k \\cdot r_i) &= -k \\sin(k \\cdot r_i) \\\\
        \\frac{\\partial}{\\partial r_i} \\sin(k \\cdot r_i) &= k \\cos(k \\cdot r_i)
        \\end{aligned}

    Launch Grid
    -----------
    dim = [N]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d
        Reciprocal lattice vectors.
    cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation.
    sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation.
    real_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed S_real(k) including Green's function.
    imag_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed S_imag(k) including Green's function.
    reciprocal_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Reciprocal-space forces per atom (matches k_vectors dtype).

    Notes
    -----
    - Atom-major iteration avoids atomics (each thread owns its atom output)
    - The 0.5 factor is applied to energy but not to forces
    - cos_k_dot_r and sin_k_dot_r are unweighted; charge is multiplied here
    - Energy computed in float64, forces in k_vectors dtype
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[0]
    charge = wp.float64(charges[atom_idx])

    # Get the zero vector in the correct type
    k0 = k_vectors[0]

    # Accumulate in registers (no atomics!)
    local_potential = wp.float64(0.0)
    local_force_x = wp.float64(0.0)
    local_force_y = wp.float64(0.0)
    local_force_z = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = charge * cos_k_dot_r[k_idx, atom_idx]
        sin_kr = charge * sin_k_dot_r[k_idx, atom_idx]

        # Load precomputed structure factors (already include green function)
        s_real = real_structure_factors[k_idx]
        s_imag = imag_structure_factors[k_idx]

        # Potential contribution
        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += phase_sum

        # Force contribution
        force_scalar = s_real * sin_kr - s_imag * cos_kr
        k_vec = k_vectors[k_idx]
        local_force_x += force_scalar * wp.float64(k_vec[0])
        local_force_y += force_scalar * wp.float64(k_vec[1])
        local_force_z += force_scalar * wp.float64(k_vec[2])

    # Write final results with charge multiplication (no atomics needed)
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential
    atomic_forces[atom_idx] = type(k0)(
        type(k0[0])(local_force_x),
        type(k0[0])(local_force_y),
        type(k0[0])(local_force_z),
    )


@wp.kernel
def _ewald_reciprocal_space_energy_forces_charge_grad_kernel(
    charges: wp.array(dtype=Any),
    k_vectors: wp.array(dtype=Any),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array(dtype=wp.float64),
    imag_structure_factors: wp.array(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
):
    """Compute reciprocal-space energies, forces, AND charge gradients.

    This kernel computes all three quantities in a single pass:

    .. math::

        \

        \\begin{aligned}
        E_i &= \\frac{1}{2} \\sum_k [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)] q_i \\\\
        F_i &= \\sum_k k [S_{\\text{real}}(k) \\sin(k \\cdot r_i) - S_{\\text{imag}}(k) \\cos(k \\cdot r_i)] q_i \\\\
        dE_i/dq_i &= \\sum_k [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)]
        \\end{aligned}

    where :math:`\\phi_i = \\sum_k [S_{\\text{real}}(k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(k) \\sin(k \\cdot r_i)]` is the
    electrostatic potential at atom i.

    Launch Grid
    -----------
    dim = [N]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d
        Reciprocal lattice vectors.
    cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation (unweighted).
    sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation (unweighted).
    real_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed :math:`S_{\\text{real}}(k)` including Green's function.
    imag_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed :math:`S_{\\text{imag}}(k)` including Green's function.
    reciprocal_energies : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.
    atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Reciprocal-space forces per atom.
    charge_gradients : wp.array, shape (N,), dtype=wp.float64
        OUTPUT: Electrostatic potential :math:`\\phi_i` per atom (reciprocal part of charge gradient).
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[0]
    charge = wp.float64(charges[atom_idx])

    # Get the zero vector in the correct type
    k0 = k_vectors[0]

    # Accumulate in registers (no atomics!)
    local_potential = wp.float64(0.0)
    local_potential_uncharged = wp.float64(0.0)
    local_force_x = wp.float64(0.0)
    local_force_y = wp.float64(0.0)
    local_force_z = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = cos_k_dot_r[k_idx, atom_idx]
        sin_kr = sin_k_dot_r[k_idx, atom_idx]

        # Load precomputed structure factors (already include green function)
        s_real = real_structure_factors[k_idx]
        s_imag = imag_structure_factors[k_idx]

        # Potential contribution
        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += charge * phase_sum
        local_potential_uncharged += phase_sum

        # Force contribution
        force_scalar = charge * (s_real * sin_kr - s_imag * cos_kr)
        k_vec = k_vectors[k_idx]
        local_force_x += force_scalar * wp.float64(k_vec[0])
        local_force_y += force_scalar * wp.float64(k_vec[1])
        local_force_z += force_scalar * wp.float64(k_vec[2])

    # Write final results (no atomics needed)
    # Energy: E_i = (1/2) * q_i * φ_i
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential

    # Forces
    atomic_forces[atom_idx] = type(k0)(
        type(k0[0])(local_force_x),
        type(k0[0])(local_force_y),
        type(k0[0])(local_force_z),
    )

    # Charge gradient
    # Self-energy and background corrections applied in higher-level code
    charge_gradients[atom_idx] = local_potential_uncharged


###########################################################################################
#################### Reciprocal-Space Virial Kernels ######################################
###########################################################################################


@wp.kernel
def _ewald_reciprocal_space_virial_kernel(
    k_vectors: wp.array(dtype=Any),
    alpha: wp.array(dtype=Any),
    volume: wp.array(dtype=wp.float64),
    real_structure_factors: wp.array(dtype=wp.float64),
    imag_structure_factors: wp.array(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute the reciprocal-space virial tensor from precomputed structure factors.

    For each k-vector, the virial contribution is:

    .. math::

        W_{ab}(k) = E(k) \\left[ \\frac{2 k_a k_b}{k^2} \\left(1 + \\frac{k^2}{4\\alpha^2}\\right) - \\delta_{ab} \\right]

    where the per-k energy is :math:`E(k) = \\frac{|S(k)|^2}{2 G(k)}` and
    :math:`G(k) = \\frac{8\\pi}{V} \\frac{\\exp(-k^2/(4\\alpha^2))}{k^2}`.

    Launch Grid
    -----------
    dim = [K]

    Parameters
    ----------
    k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d
        Reciprocal lattice vectors.
    alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64
        Ewald splitting parameter.
    volume : wp.array, shape (1,), dtype=wp.float64
        Unit cell volume |det(cell)|.
    real_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed S_real(k) = G(k) * sum_i q_i cos(k.r_i).
    imag_structure_factors : wp.array, shape (K,), dtype=wp.float64
        Precomputed S_imag(k) = G(k) * sum_i q_i sin(k.r_i).
    virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Accumulated virial tensor.
    """
    k_idx = wp.tid()

    k_vec = k_vectors[k_idx]
    kx = wp.float64(k_vec[0])
    ky = wp.float64(k_vec[1])
    kz = wp.float64(k_vec[2])
    k_sq = kx * kx + ky * ky + kz * kz

    if k_sq < wp.float64(1e-10):
        return

    alpha_ = wp.float64(alpha[0])
    vol = volume[0]
    s_real = real_structure_factors[k_idx]
    s_imag = imag_structure_factors[k_idx]

    # |S(k)|^2
    s_sq = s_real * s_real + s_imag * s_imag

    # Green's function G(k) = (8*pi/V) * exp(-k^2/(4*alpha^2)) / k^2
    exp_factor = wp.float64(0.25) / (alpha_ * alpha_)
    green = wp.float64(EIGHTPI) / vol * wp.exp(-k_sq * exp_factor) / k_sq

    # Per-k energy: E(k) = |S|^2 / (2*G)
    energy_k = wp.float64(0.5) * s_sq / green

    # Virial W = -dE/dε.  d ln G / dε_ab = -δ_ab + 2 k_a k_b / k² (1 + k²/(4α²)),
    # so W_ab(k) = E(k) * [δ_ab - 2 k_a k_b / k² (1 + k²/(4α²))].
    k_factor = wp.float64(2.0) * (wp.float64(1.0) + k_sq * exp_factor) / k_sq

    # Build 3x3 virial contribution: W_ab = E_k * (δ_ab - k_factor * k_a * k_b)
    w00 = energy_k * (wp.float64(1.0) - k_factor * kx * kx)
    w01 = energy_k * (-k_factor * kx * ky)
    w02 = energy_k * (-k_factor * kx * kz)
    w10 = energy_k * (-k_factor * ky * kx)
    w11 = energy_k * (wp.float64(1.0) - k_factor * ky * ky)
    w12 = energy_k * (-k_factor * ky * kz)
    w20 = energy_k * (-k_factor * kz * kx)
    w21 = energy_k * (-k_factor * kz * ky)
    w22 = energy_k * (wp.float64(1.0) - k_factor * kz * kz)

    # Cast to virial element type (mat33f or mat33d, matching input precision)
    # Use type() inline as constructor (Warp resolves at compile time)
    _virial_ref = virial[0]
    virial_k = type(_virial_ref)(
        type(k_vec[0])(w00),
        type(k_vec[0])(w01),
        type(k_vec[0])(w02),
        type(k_vec[0])(w10),
        type(k_vec[0])(w11),
        type(k_vec[0])(w12),
        type(k_vec[0])(w20),
        type(k_vec[0])(w21),
        type(k_vec[0])(w22),
    )
    wp.atomic_add(virial, 0, virial_k)


@wp.kernel
def _batch_ewald_reciprocal_space_virial_kernel(
    k_vectors: wp.array2d(dtype=Any),
    alpha: wp.array(dtype=Any),
    volume: wp.array(dtype=wp.float64),
    real_structure_factors: wp.array2d(dtype=wp.float64),
    imag_structure_factors: wp.array2d(dtype=wp.float64),
    virial: wp.array(dtype=Any),
):
    """Compute the reciprocal-space virial tensor for batched systems.

    Same formula as single-system version, but with per-system k-vectors,
    structure factors, alpha, and volume.

    Launch Grid
    -----------
    dim = [K, B]

    Parameters
    ----------
    k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d
        Per-system reciprocal lattice vectors.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    volume : wp.array, shape (B,), dtype=wp.float64
        Per-system unit cell volume |det(cell)|.
    real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system S_real(k).
    imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system S_imag(k).
    virial : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d
        OUTPUT: Per-system accumulated virial tensor.
    """
    k_idx, system_id = wp.tid()

    k_vec = k_vectors[system_id, k_idx]
    kx = wp.float64(k_vec[0])
    ky = wp.float64(k_vec[1])
    kz = wp.float64(k_vec[2])
    k_sq = kx * kx + ky * ky + kz * kz

    if k_sq < wp.float64(1e-10):
        return

    alpha_ = wp.float64(alpha[system_id])
    vol = volume[system_id]
    s_real = real_structure_factors[system_id, k_idx]
    s_imag = imag_structure_factors[system_id, k_idx]

    # |S(k)|^2
    s_sq = s_real * s_real + s_imag * s_imag

    # Green's function G(k) = (8*pi/V) * exp(-k^2/(4*alpha^2)) / k^2
    exp_factor = wp.float64(0.25) / (alpha_ * alpha_)
    green = wp.float64(EIGHTPI) / vol * wp.exp(-k_sq * exp_factor) / k_sq

    # Per-k energy: E(k) = |S|^2 / (2*G)
    energy_k = wp.float64(0.5) * s_sq / green

    # Virial W = -dE/dε.  d ln G / dε_ab = -δ_ab + 2 k_a k_b / k² (1 + k²/(4α²)),
    # so W_ab(k) = E(k) * [δ_ab - 2 k_a k_b / k² (1 + k²/(4α²))].
    k_factor = wp.float64(2.0) * (wp.float64(1.0) + k_sq * exp_factor) / k_sq

    # Build 3x3 virial contribution: W_ab = E_k * (δ_ab - k_factor * k_a * k_b)
    w00 = energy_k * (wp.float64(1.0) - k_factor * kx * kx)
    w01 = energy_k * (-k_factor * kx * ky)
    w02 = energy_k * (-k_factor * kx * kz)
    w10 = energy_k * (-k_factor * ky * kx)
    w11 = energy_k * (wp.float64(1.0) - k_factor * ky * ky)
    w12 = energy_k * (-k_factor * ky * kz)
    w20 = energy_k * (-k_factor * kz * kx)
    w21 = energy_k * (-k_factor * kz * ky)
    w22 = energy_k * (wp.float64(1.0) - k_factor * kz * kz)

    # Cast to virial element type (mat33f or mat33d, matching input precision)
    # Use type() inline as constructor (Warp resolves at compile time)
    _virial_ref = virial[system_id]
    virial_k = type(_virial_ref)(
        type(k_vec[0])(w00),
        type(k_vec[0])(w01),
        type(k_vec[0])(w02),
        type(k_vec[0])(w10),
        type(k_vec[0])(w11),
        type(k_vec[0])(w12),
        type(k_vec[0])(w20),
        type(k_vec[0])(w21),
        type(k_vec[0])(w22),
    )
    wp.atomic_add(virial, system_id, virial_k)


###########################################################################################
########################### Batch Reciprocal-Space Kernels ################################
###########################################################################################


@wp.kernel
def _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    k_vectors: wp.array2d(dtype=Any),
    cell: wp.array(dtype=Any),
    alpha: wp.array(dtype=Any),
    atom_start: wp.array(dtype=wp.int32),
    atom_end: wp.array(dtype=wp.int32),
    total_charges: wp.array(dtype=wp.float64),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array2d(dtype=wp.float64),
    imag_structure_factors: wp.array2d(dtype=wp.float64),
):
    """Compute structure factors for batched reciprocal-space Ewald summation.

    This kernel uses a blocked strategy: each thread handles one (k-vector, system,
    atom_block) triplet. This significantly reduces atomic contention compared to
    atom-major iteration while maintaining parallelism.

    The block size is controlled by ALCH_EWALD_BATCH_BLOCK_SIZE environment variable
    (default: 16, which benchmarks show is optimal for most scenarios).

    For each system s and atom i in that system:

    .. math::

        \\begin{aligned}
        S_{\\text{real}}(s, k) &+= \\frac{G_s(k)}{V_s} q_i \\cos(k \\cdot r_i) \\\\
        S_{\\text{imag}}(s, k) &+= \\frac{G_s(k)}{V_s} q_i \\sin(k \\cdot r_i)
        \\end{aligned}

    where :math:`G_s(k) = 8\\pi * \\exp(-k^2/(4\\alpha_s^2)) / k^2` uses half-space k-vectors.

    Launch Grid
    -----------
    dim = [K, B, max_blocks_per_system]

    where max_blocks_per_system = ceil(max_atoms_per_system / BATCH_BLOCK_SIZE)

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d
        Per-system half-space reciprocal lattice vectors.
    cell : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system unit cell matrices.
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    atom_start : wp.array, shape (B,), dtype=wp.int32
        First atom index for each system.
    atom_end : wp.array, shape (B,), dtype=wp.int32
        Last atom index (exclusive) for each system.
    total_charges : wp.array, shape (B,), dtype=wp.float64
        OUTPUT: Accumulated (Q_total/V) per system for background correction.
    cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        OUTPUT: :math:`\\cos(k \\cdot r_i)` for each (k, atom) pair.
    sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        OUTPUT: :math:`\\sin(k \\cdot r_i)` for each (k, atom) pair.
    real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        OUTPUT: Per-system :math:`(G(k)/V) \\sum_i q_i \\cos(k \\cdot r_i)`.
    imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        OUTPUT: Per-system :math:`(G(k)/V) \\sum_i q_i \\sin(k \\cdot r_i)`.

    Notes
    -----
    - Blocked iteration reduces atomic contention vs atom-major.
    - Each block computes partial sums in registers before one atomic add.
    - BATCH_BLOCK_SIZE=16 is optimal for most cases (set via environment variable ALCH_EWALD_BATCH_BLOCK_SIZE).
    - k=0 causes early return (would cause division by zero in G(k)).
    - Blocks beyond the system's atoms cause early return.
    - Thread 1 accumulates total_charges as Q/V for background correction.
    - All internal computations use float64 for numerical stability.
    - Half-space k-vectors with 8π Green's function give ~2x speedup.
    """
    k_idx, system_id, block_idx = wp.tid()

    system_cell = cell[system_id]
    system_alpha = wp.float64(alpha[system_id])

    a_start = atom_start[system_id]
    a_end = atom_end[system_id]

    # Compute atom range for this block
    block_start = a_start + block_idx * BATCH_BLOCK_SIZE
    block_end = wp.min(block_start + BATCH_BLOCK_SIZE, a_end)

    # Skip if this block is beyond the system's atoms
    if block_start >= a_end:
        return

    exp_factor = wp.float64(0.25) / (system_alpha * system_alpha)
    volume = wp.float64(wp.abs(wp.determinant(system_cell)))

    k_vector = k_vectors[system_id, k_idx]
    kx = wp.float64(k_vector[0])
    ky = wp.float64(k_vector[1])
    kz = wp.float64(k_vector[2])
    k_squared = kx * kx + ky * ky + kz * kz

    # Skip k=0 (would cause division by zero)
    if k_squared < wp.float64(1e-10):
        return

    # Compute Green's function: (4*pi/V) * exp(-k^2/(4*alpha^2)) / k^2
    green_function = wp_exp_kernel(k_squared, exp_factor) * wp.float64(EIGHTPI) / volume

    # Accumulate partial sums for this block in registers
    local_real = wp.float64(0.0)
    local_imag = wp.float64(0.0)
    local_charge = wp.float64(0.0)

    for atom_idx in range(block_start, block_end):
        position = positions[atom_idx]
        charge = wp.float64(charges[atom_idx])

        # Only first k-thread per block accumulates total charge
        if k_idx == 1:
            local_charge += charge / volume

        # Compute cos(k*r) and sin(k*r) weighted by charge
        k_dot_r = (
            kx * wp.float64(position[0])
            + ky * wp.float64(position[1])
            + kz * wp.float64(position[2])
        )
        cos_kr = wp.cos(k_dot_r)
        sin_kr = wp.sin(k_dot_r)

        # Store per-(k, atom) UNWEIGHTED phase factors (for charge gradients)
        cos_k_dot_r[k_idx, atom_idx] = cos_kr
        sin_k_dot_r[k_idx, atom_idx] = sin_kr

        # Accumulate structure factors (charge-weighted) in registers
        local_real += charge * cos_kr * green_function
        local_imag += charge * sin_kr * green_function

    # One atomic add per block (much fewer atomics than atom-major!)
    wp.atomic_add(real_structure_factors, system_id, k_idx, local_real)
    wp.atomic_add(imag_structure_factors, system_id, k_idx, local_imag)

    if k_idx == 1:
        wp.atomic_add(total_charges, system_id, local_charge)


@wp.kernel
def _batch_ewald_reciprocal_space_energy_kernel_compute_energy(
    charges: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array2d(dtype=wp.float64),
    imag_structure_factors: wp.array2d(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
):
    """Compute per-atom reciprocal-space energies for batched systems.

    This kernel uses atom-major iteration: each thread processes one atom
    over all k-vectors. This avoids atomics since each thread fully owns
    its atom's output.

    For each atom i in system s:

    .. math::

        E_i = \\frac{1}{2} \\sum_k [S_{\\text{real}}(s,k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(s,k) \\sin(k \\cdot r_i)] q_i

    Uses batch_id to look up the correct system's structure factors.

    Launch Grid
    -----------
    dim = [N_total]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation.
    sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation.
    real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{real}}(s, k)` including Green's function.
    imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{imag}}(s, k)` including Green's function.
    reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.

    Notes
    -----
    - Atom-major iteration avoids atomics (each thread owns its atom output)
    - cos_k_dot_r and sin_k_dot_r are unweighted; charge is multiplied here
    - All computations in float64
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[1]
    charge = wp.float64(charges[atom_idx])

    system_id = batch_id[atom_idx]

    # Accumulate potential in register (no atomics!)
    local_potential = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = cos_k_dot_r[k_idx, atom_idx]
        sin_kr = sin_k_dot_r[k_idx, atom_idx]
        s_real = real_structure_factors[system_id, k_idx]
        s_imag = imag_structure_factors[system_id, k_idx]

        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += charge * phase_sum

    # Write final energy: E_i = (1/2) * q_i * phi_i (no atomics needed)
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential


@wp.kernel
def _batch_ewald_subtract_self_energy_kernel(
    charges: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    alpha: wp.array(dtype=Any),
    total_charges: wp.array(dtype=wp.float64),
    energy_in: wp.array(dtype=wp.float64),
    energy_out: wp.array(dtype=wp.float64),
):
    """Apply self-energy and background corrections for batched systems.

    For each atom i in system s:

    .. math::

        E_{\\text{out},i} = E_{\\text{in},i} - E_{\\text{self},i} - E_{\\text{background},i}

    where:

    .. math::

        \\begin{aligned}
        E_{\\text{self},i} &= \\frac{\\alpha_s}{\\sqrt{\\pi}} q_i^2 \\\\
        E_{\\text{background},i} &= \\frac{\\pi}{2\\alpha_s^2} q_i \\frac{Q_{s,\\text{total}}}{V_s}
        \\end{aligned}

    Uses per-system alpha and total_charge values looked up via batch_idx.

    Launch Grid
    -----------
    dim = [N_total]

    Parameters
    ----------
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges for all systems concatenated.
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64
        Per-system Ewald splitting parameter.
    total_charges : wp.array, shape (B,), dtype=wp.float64
        Per-system (Q_total/V), precomputed in structure factor kernel.
    energy_in : wp.array, shape (N_total,), dtype=wp.float64
        Raw reciprocal-space energy per atom.
    energy_out : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Corrected reciprocal-space energy per atom.

    Notes
    -----
    - Uses separate input/output arrays for autodiff compatibility
    - Each system may have different alpha and total charge values
    - All computations in float64
    """
    atom_index = wp.tid()
    charge = wp.float64(charges[atom_index])
    system_id = batch_idx[atom_index]
    system_alpha = wp.float64(alpha[system_id])
    system_total_charge = total_charges[system_id]

    # Compute self-energy: alpha * q^2 / sqrt(pi)
    self_energy = system_alpha * charge * charge / wp.sqrt(wp.float64(PI))

    # Background correction: pi / (2*alpha^2) * q * (Q_total/V)
    neutralization_energy = (
        wp.float64(PI)
        * charge
        * system_total_charge
        / (wp.float64(2.0) * system_alpha * system_alpha)
    )

    # Subtract self-energy and background (separate input/output to avoid autodiff issues)
    energy_out[atom_index] = energy_in[atom_index] - self_energy - neutralization_energy


@wp.kernel
def _batch_ewald_reciprocal_space_energy_forces_kernel(
    charges: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    k_vectors: wp.array2d(dtype=Any),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array2d(dtype=wp.float64),
    imag_structure_factors: wp.array2d(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
):
    """Compute reciprocal-space energies and forces for batched systems.

    This kernel uses atom-major iteration: each thread processes one atom
    over all k-vectors. This avoids atomics since each thread fully owns
    its atom's output.

    For each atom i in system s:

    .. math::

        \\begin{aligned}
        E_i &= \\frac{1}{2} \\sum_k [S_{\\text{real}}(s,k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(s,k) \\sin(k \\cdot r_i)] q_i \\\\
        F_i &= \\sum_k k [S_{\\text{real}}(s,k) \\sin(k \\cdot r_i) - S_{\\text{imag}}(s,k) \\cos(k \\cdot r_i)] q_i
        \\end{aligned}

    Uses batch_id to look up the correct system's k-vectors and structure factors.

    Launch Grid
    -----------
    dim = [N_total]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d
        Per-system reciprocal lattice vectors.
    cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation.
    sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation.
    real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{real}}(s, k)` including Green's function.
    imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{imag}}(s, k)` including Green's function.
    reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Reciprocal-space forces per atom.

    Notes
    -----
    - Atom-major iteration avoids atomics (each thread owns its atom output)
    - cos_k_dot_r and sin_k_dot_r are unweighted; charge is multiplied here
    - Energy computed in float64, forces in k_vectors dtype
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[1]
    charge = wp.float64(charges[atom_idx])

    system_id = batch_id[atom_idx]

    # Get the zero vector in the correct type
    k0 = k_vectors[system_id, 0]

    # Accumulate in registers (no atomics!)
    local_potential = wp.float64(0.0)
    local_force_x = wp.float64(0.0)
    local_force_y = wp.float64(0.0)
    local_force_z = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = charge * cos_k_dot_r[k_idx, atom_idx]
        sin_kr = charge * sin_k_dot_r[k_idx, atom_idx]

        # Load precomputed structure factors (already include green function)
        s_real = real_structure_factors[system_id, k_idx]
        s_imag = imag_structure_factors[system_id, k_idx]

        # Potential contribution
        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += phase_sum

        # Force contribution
        force_scalar = s_real * sin_kr - s_imag * cos_kr
        k_vec = k_vectors[system_id, k_idx]
        local_force_x += force_scalar * wp.float64(k_vec[0])
        local_force_y += force_scalar * wp.float64(k_vec[1])
        local_force_z += force_scalar * wp.float64(k_vec[2])

    # Write final results with charge multiplication (no atomics needed)
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential
    atomic_forces[atom_idx] = type(k0)(
        type(k0[0])(local_force_x),
        type(k0[0])(local_force_y),
        type(k0[0])(local_force_z),
    )


@wp.kernel
def _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel(
    charges: wp.array(dtype=Any),
    batch_id: wp.array(dtype=wp.int32),
    k_vectors: wp.array2d(dtype=Any),
    cos_k_dot_r: wp.array2d(dtype=wp.float64),
    sin_k_dot_r: wp.array2d(dtype=wp.float64),
    real_structure_factors: wp.array2d(dtype=wp.float64),
    imag_structure_factors: wp.array2d(dtype=wp.float64),
    reciprocal_energies: wp.array(dtype=wp.float64),
    atomic_forces: wp.array(dtype=Any),
    charge_gradients: wp.array(dtype=wp.float64),
):
    """Compute reciprocal-space energies, forces, AND charge gradients for batched systems.

    This kernel computes all three quantities in a single pass:

    .. math::

        \\begin{aligned}
        E_i &= \\frac{1}{2} \\sum_k [S_{\\text{real}}(s,k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(s,k) \\sin(k \\cdot r_i)] q_i \\\\
        F_i &= \\sum_k k [S_{\\text{real}}(s,k) \\sin(k \\cdot r_i) - S_{\\text{imag}}(s,k) \\cos(k \\cdot r_i)] q_i \\\\
        dE_i/dq_i &= \\sum_k [S_{\\text{real}}(s,k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(s,k) \\sin(k \\cdot r_i)]
        \\end{aligned}

    where :math:`\\phi_i = \\sum_k [S_{\\text{real}}(s,k) \\cos(k \\cdot r_i) + S_{\\text{imag}}(s,k) \\sin(k \\cdot r_i)]` is the
    electrostatic potential at atom i from system s.

    Launch Grid
    -----------
    dim = [N_total]

    Each thread processes one atom over all K k-vectors.

    Parameters
    ----------
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges.
    batch_id : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d
        Per-system reciprocal lattice vectors.
    cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\cos(k \\cdot r_i)` from structure factor computation (unweighted).
    sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64
        :math:`\\sin(k \\cdot r_i)` from structure factor computation (unweighted).
    real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{real}}(s, k)` including Green's function.
    imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64
        Per-system :math:`S_{\\text{imag}}(s, k)` including Green's function.
    reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Reciprocal-space energy per atom.
    atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Reciprocal-space forces per atom.
    charge_gradients : wp.array, shape (N_total,), dtype=wp.float64
        OUTPUT: Electrostatic potential :math:`\\phi_i` per atom (reciprocal part of charge gradient).
    """
    atom_idx = wp.tid()
    num_k = real_structure_factors.shape[1]
    charge = wp.float64(charges[atom_idx])

    system_id = batch_id[atom_idx]

    # Get the zero vector in the correct type
    k0 = k_vectors[system_id, 0]

    # Accumulate in registers (no atomics!)
    local_potential = wp.float64(0.0)
    local_potential_uncharged = wp.float64(0.0)
    local_force_x = wp.float64(0.0)
    local_force_y = wp.float64(0.0)
    local_force_z = wp.float64(0.0)

    for k_idx in range(num_k):
        cos_kr = cos_k_dot_r[k_idx, atom_idx]
        sin_kr = sin_k_dot_r[k_idx, atom_idx]

        # Load precomputed structure factors (already include green function)
        s_real = real_structure_factors[system_id, k_idx]
        s_imag = imag_structure_factors[system_id, k_idx]

        # Potential contribution
        phase_sum = s_real * cos_kr + s_imag * sin_kr
        local_potential += charge * phase_sum
        local_potential_uncharged += phase_sum

        # Force contribution
        force_scalar = charge * (s_real * sin_kr - s_imag * cos_kr)
        k_vec = k_vectors[system_id, k_idx]
        local_force_x += force_scalar * wp.float64(k_vec[0])
        local_force_y += force_scalar * wp.float64(k_vec[1])
        local_force_z += force_scalar * wp.float64(k_vec[2])

    # Write final results (no atomics needed)
    # Energy
    reciprocal_energies[atom_idx] = wp.float64(0.5) * local_potential

    # Forces
    atomic_forces[atom_idx] = type(k0)(
        type(k0[0])(local_force_x),
        type(k0[0])(local_force_y),
        type(k0[0])(local_force_z),
    )

    # Charge gradient
    # Self-energy and background corrections applied in higher-level code
    charge_gradients[atom_idx] = local_potential_uncharged


###########################################################################################
########################### Warp Launchers (Framework-Agnostic) ############################
###########################################################################################


[docs] def ewald_real_space_energy( positions: wp.array, charges: wp.array, cell: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch Ewald real-space energy kernel using CSR neighbor list format. This is a framework-agnostic launcher that accepts warp arrays directly. Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices (CSR data). neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. Must be pre-allocated and zeroed. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. If None, inferred from positions. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, idx_j, neighbor_ptr, unit_shifts, alpha, pair_energies, ], device=device, )
[docs] def ewald_real_space_energy_forces( positions: wp.array, charges: wp.array, cell: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch Ewald real-space energy and forces kernel using CSR neighbor list. Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices (CSR data). neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d OUTPUT: Virial tensor. Only written when compute_virial=True. Must be pre-allocated and zeroed by caller. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_forces_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, idx_j, neighbor_ptr, unit_shifts, alpha, compute_virial, pair_energies, atomic_forces, virial, ], device=device, )
[docs] def ewald_real_space_energy_matrix( positions: wp.array, charges: wp.array, cell: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch Ewald real-space energy kernel using neighbor matrix format. Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid/padded entries. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_neighbor_matrix_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, pair_energies, ], device=device, )
[docs] def ewald_real_space_energy_forces_matrix( positions: wp.array, charges: wp.array, cell: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch Ewald real-space energy and forces kernel using neighbor matrix. Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid/padded entries. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor. Must be pre-allocated by caller. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, compute_virial, pair_energies, atomic_forces, virial, ], device=device, )
def ewald_real_space_energy_forces_charge_grad( positions: wp.array, charges: wp.array, cell: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch Ewald real-space energy, forces, and charge gradients kernel (CSR). Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices (CSR data). neighbor_ptr : wp.array, shape (N+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor. Must be pre-allocated by caller. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_forces_charge_grad_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, idx_j, neighbor_ptr, unit_shifts, alpha, compute_virial, pair_energies, atomic_forces, charge_gradients, virial, ], device=device, ) def ewald_real_space_energy_forces_charge_grad_matrix( positions: wp.array, charges: wp.array, cell: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch Ewald real-space energy, forces, and charge gradients kernel (matrix). Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. neighbor_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid/padded entries. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. pair_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. virial : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d OUTPUT: Virial tensor. Only written when compute_virial=True. Must be pre-allocated by caller. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[ wp_dtype ], dim=num_atoms, inputs=[ positions, charges, cell, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, compute_virial, pair_energies, atomic_forces, charge_gradients, virial, ], device=device, ) # ==================== Batch Real-Space Launchers ====================
[docs] def batch_ewald_real_space_energy( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched Ewald real-space energy kernel using CSR neighbor list. Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions (all systems concatenated). charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices for each system. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices (CSR data). neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. wp_dtype : type Warp scalar type (wp.float32 or wp.float64). device : str, optional Warp device. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, idx_j, neighbor_ptr, unit_shifts, alpha, pair_energies, ], device=device, )
[docs] def batch_ewald_real_space_energy_forces( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch batched Ewald real-space energy and forces kernel (CSR). Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices. neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. wp_dtype : type Warp scalar type. device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor, shape (B,). If None, a dummy array is created. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_forces_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, idx_j, neighbor_ptr, unit_shifts, alpha, compute_virial, pair_energies, atomic_forces, virial, ], device=device, )
[docs] def batch_ewald_real_space_energy_matrix( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched Ewald real-space energy kernel using neighbor matrix. Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid entries. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, pair_energies, ], device=device, )
[docs] def batch_ewald_real_space_energy_forces_matrix( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch batched Ewald real-space energy and forces kernel (matrix). Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid entries. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. wp_dtype : type Warp scalar type. device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor, shape (B,). If None, a dummy array is created. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, compute_virial, pair_energies, atomic_forces, virial, ], device=device, )
def batch_ewald_real_space_energy_forces_charge_grad( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, idx_j: wp.array, neighbor_ptr: wp.array, unit_shifts: wp.array, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch batched Ewald real-space energy, forces, charge gradients kernel (CSR). Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. idx_j : wp.array, shape (M,), dtype=wp.int32 Target atom indices. neighbor_ptr : wp.array, shape (N_total+1,), dtype=wp.int32 CSR row pointers. unit_shifts : wp.array, shape (M,), dtype=wp.vec3i Periodic image shifts. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. wp_dtype : type Warp scalar type. device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor, shape (B,). If None, a dummy array is created. """ num_atoms = positions.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, idx_j, neighbor_ptr, unit_shifts, alpha, compute_virial, pair_energies, atomic_forces, charge_gradients, virial, ], device=device, ) def batch_ewald_real_space_energy_forces_charge_grad_matrix( positions: wp.array, charges: wp.array, cell: wp.array, batch_id: wp.array, neighbor_matrix: wp.array, unit_shifts_matrix: wp.array, mask_value: int, alpha: wp.array, pair_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, virial: wp.array, wp_dtype: type, device: str | None = None, compute_virial: bool = False, ) -> None: """Launch batched Ewald real-space energy, forces, charge gradients kernel (matrix). Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Unit cell matrices. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. neighbor_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.int32 Neighbor indices. unit_shifts_matrix : wp.array2d, shape (N_total, max_neighbors), dtype=wp.vec3i Periodic image shifts. mask_value : int Value indicating invalid entries. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. pair_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. wp_dtype : type Warp scalar type. device : str, optional Warp device. compute_virial : bool, optional Whether to compute the virial tensor. Default False. virial : wp.array, optional OUTPUT: Virial tensor, shape (B,). If None, a dummy array is created. """ num_atoms = neighbor_matrix.shape[0] if device is None: device = str(positions.device) wp.launch( _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[ wp_dtype ], dim=num_atoms, inputs=[ positions, charges, cell, batch_id, neighbor_matrix, unit_shifts_matrix, wp.int32(mask_value), alpha, compute_virial, pair_energies, atomic_forces, charge_gradients, virial, ], device=device, ) # ==================== Reciprocal-Space Launchers ====================
[docs] def ewald_reciprocal_space_fill_structure_factors( positions: wp.array, charges: wp.array, k_vectors: wp.array, cell: wp.array, alpha: wp.array, total_charge: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch kernel to compute structure factors for reciprocal-space Ewald. Parameters ---------- positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d Half-space reciprocal lattice vectors. cell : wp.array, shape (1,), dtype=wp.mat33f or wp.mat33d Unit cell matrix. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. total_charge : wp.array, shape (1,), dtype=wp.float64 OUTPUT: Q_total/V for background correction. cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 OUTPUT: cos(k.r) for each (k, atom) pair. sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 OUTPUT: sin(k.r) for each (k, atom) pair. real_structure_factors : wp.array, shape (K,), dtype=wp.float64 OUTPUT: Real part of weighted structure factors. imag_structure_factors : wp.array, shape (K,), dtype=wp.float64 OUTPUT: Imaginary part of weighted structure factors. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_k = k_vectors.shape[0] if device is None: device = str(positions.device) wp.launch( _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[wp_dtype], dim=num_k, inputs=[ positions, charges, k_vectors, cell, alpha, total_charge, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, ], device=device, )
[docs] def ewald_reciprocal_space_compute_energy( charges: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch kernel to compute per-atom reciprocal-space energies. Parameters ---------- charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array, shape (K,), dtype=wp.float64 Real structure factors. imag_structure_factors : wp.array, shape (K,), dtype=wp.float64 Imaginary structure factors. reciprocal_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _ewald_reciprocal_space_energy_kernel_compute_energy_overload[wp_dtype], dim=num_atoms, inputs=[ charges, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, ], device=device, )
[docs] def ewald_subtract_self_energy( charges: wp.array, alpha: wp.array, total_charge: wp.array, energy_in: wp.array, energy_out: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch kernel to apply self-energy and background corrections. Parameters ---------- charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. alpha : wp.array, shape (1,), dtype=wp.float32 or wp.float64 Ewald splitting parameter. total_charge : wp.array, shape (1,), dtype=wp.float64 Q_total/V from structure factor computation. energy_in : wp.array, shape (N,), dtype=wp.float64 Raw reciprocal-space energies. energy_out : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Corrected energies. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _ewald_subtract_self_energy_kernel_overload[wp_dtype], dim=num_atoms, inputs=[charges, alpha, total_charge, energy_in, energy_out], device=device, )
[docs] def ewald_reciprocal_space_energy_forces( charges: wp.array, k_vectors: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, atomic_forces: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch kernel to compute reciprocal-space energies and forces. Parameters ---------- charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d Reciprocal lattice vectors. cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array, shape (K,), dtype=wp.float64 Real structure factors. imag_structure_factors : wp.array, shape (K,), dtype=wp.float64 Imaginary structure factors. reciprocal_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _ewald_reciprocal_space_energy_forces_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ charges, k_vectors, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, atomic_forces, ], device=device, )
def ewald_reciprocal_space_energy_forces_charge_grad( charges: wp.array, k_vectors: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch kernel to compute reciprocal-space energies, forces, and charge gradients. Parameters ---------- charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64 Atomic charges. k_vectors : wp.array, shape (K,), dtype=wp.vec3f or wp.vec3d Reciprocal lattice vectors. cos_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array, shape (K,), dtype=wp.float64 Real structure factors. imag_structure_factors : wp.array, shape (K,), dtype=wp.float64 Imaginary structure factors. reciprocal_energies : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ charges, k_vectors, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, atomic_forces, charge_gradients, ], device=device, ) # ==================== Batch Reciprocal-Space Launchers ====================
[docs] def batch_ewald_reciprocal_space_fill_structure_factors( positions: wp.array, charges: wp.array, k_vectors: wp.array, cell: wp.array, alpha: wp.array, atom_start: wp.array, atom_end: wp.array, total_charges: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, num_k: int, num_systems: int, max_blocks_per_system: int, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched kernel to compute structure factors for reciprocal-space Ewald. Parameters ---------- positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d Atomic positions. charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d Per-system reciprocal lattice vectors. cell : wp.array, shape (B,), dtype=wp.mat33f or wp.mat33d Per-system unit cell matrices. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. atom_start : wp.array, shape (B,), dtype=wp.int32 First atom index for each system. atom_end : wp.array, shape (B,), dtype=wp.int32 Last atom index (exclusive) for each system. total_charges : wp.array, shape (B,), dtype=wp.float64 OUTPUT: Per-system Q_total/V. cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 OUTPUT: cos(k.r) for each (k, atom) pair. sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 OUTPUT: sin(k.r) for each (k, atom) pair. real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 OUTPUT: Per-system real structure factors. imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 OUTPUT: Per-system imaginary structure factors. num_k : int Number of k-vectors per system. num_systems : int Number of systems in the batch. max_blocks_per_system : int Maximum atom blocks per system. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ if device is None: device = str(positions.device) wp.launch( _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[ wp_dtype ], dim=(num_k, num_systems, max_blocks_per_system), inputs=[ positions, charges, k_vectors, cell, alpha, atom_start, atom_end, total_charges, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, ], device=device, )
[docs] def batch_ewald_reciprocal_space_compute_energy( charges: wp.array, batch_id: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched kernel to compute per-atom reciprocal-space energies. Parameters ---------- charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system real structure factors. imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system imaginary structure factors. reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload[wp_dtype], dim=num_atoms, inputs=[ charges, batch_id, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, ], device=device, )
[docs] def batch_ewald_subtract_self_energy( charges: wp.array, batch_idx: wp.array, alpha: wp.array, total_charges: wp.array, energy_in: wp.array, energy_out: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched kernel to apply self-energy and background corrections. Parameters ---------- charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. batch_idx : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. alpha : wp.array, shape (B,), dtype=wp.float32 or wp.float64 Per-system Ewald splitting parameter. total_charges : wp.array, shape (B,), dtype=wp.float64 Per-system Q_total/V. energy_in : wp.array, shape (N_total,), dtype=wp.float64 Raw reciprocal-space energies. energy_out : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Corrected energies. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _batch_ewald_subtract_self_energy_kernel_overload[wp_dtype], dim=num_atoms, inputs=[charges, batch_idx, alpha, total_charges, energy_in, energy_out], device=device, )
[docs] def batch_ewald_reciprocal_space_energy_forces( charges: wp.array, batch_id: wp.array, k_vectors: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, atomic_forces: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched kernel to compute reciprocal-space energies and forces. Parameters ---------- charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d Per-system reciprocal lattice vectors. cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system real structure factors. imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system imaginary structure factors. reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _batch_ewald_reciprocal_space_energy_forces_kernel_overload[wp_dtype], dim=num_atoms, inputs=[ charges, batch_id, k_vectors, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, atomic_forces, ], device=device, )
def batch_ewald_reciprocal_space_energy_forces_charge_grad( charges: wp.array, batch_id: wp.array, k_vectors: wp.array, cos_k_dot_r: wp.array, sin_k_dot_r: wp.array, real_structure_factors: wp.array, imag_structure_factors: wp.array, reciprocal_energies: wp.array, atomic_forces: wp.array, charge_gradients: wp.array, wp_dtype: type, device: str | None = None, ) -> None: """Launch batched kernel for reciprocal-space energies, forces, charge gradients. Parameters ---------- charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64 Atomic charges. batch_id : wp.array, shape (N_total,), dtype=wp.int32 System index for each atom. k_vectors : wp.array2d, shape (B, K), dtype=wp.vec3f or wp.vec3d Per-system reciprocal lattice vectors. cos_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 cos(k.r) from structure factor computation. sin_k_dot_r : wp.array2d, shape (K, N_total), dtype=wp.float64 sin(k.r) from structure factor computation. real_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system real structure factors. imag_structure_factors : wp.array2d, shape (B, K), dtype=wp.float64 Per-system imaginary structure factors. reciprocal_energies : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom energies. atomic_forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d OUTPUT: Per-atom forces. charge_gradients : wp.array, shape (N_total,), dtype=wp.float64 OUTPUT: Per-atom charge gradients. wp_dtype : type Warp scalar type. device : str, optional Warp device. """ num_atoms = charges.shape[0] if device is None: device = str(charges.device) wp.launch( _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[ wp_dtype ], dim=num_atoms, inputs=[ charges, batch_id, k_vectors, cos_k_dot_r, sin_k_dot_r, real_structure_factors, imag_structure_factors, reciprocal_energies, atomic_forces, charge_gradients, ], device=device, ) ########################################################################################### ########################### Kernel Overloads (float32/float64) ############################ ########################################################################################### # Type aliases for clarity _T = [wp.float32, wp.float64] _V = [wp.vec3f, wp.vec3d] _M = [wp.mat33f, wp.mat33d] # Dictionaries to store overloads, keyed by scalar type (wp.float32 or wp.float64) # Real-space single-system kernels _ewald_real_space_energy_kernel_overload = {} _ewald_real_space_energy_forces_kernel_overload = {} _ewald_real_space_energy_neighbor_matrix_kernel_overload = {} _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload = {} # Real-space single-system kernels with charge gradients _ewald_real_space_energy_forces_charge_grad_kernel_overload = {} _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload = {} # Real-space batch kernels _batch_ewald_real_space_energy_kernel_overload = {} _batch_ewald_real_space_energy_forces_kernel_overload = {} _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload = {} _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload = {} # Real-space batch kernels with charge gradients _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload = {} _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload = {} # Reciprocal-space single-system kernels _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload = {} _ewald_reciprocal_space_energy_kernel_compute_energy_overload = {} _ewald_reciprocal_space_energy_forces_kernel_overload = {} _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload = {} _ewald_subtract_self_energy_kernel_overload = {} _ewald_reciprocal_space_virial_kernel_overload = {} # Reciprocal-space batch kernels _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload = {} _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload = {} _batch_ewald_reciprocal_space_energy_forces_kernel_overload = {} _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload = {} _batch_ewald_subtract_self_energy_kernel_overload = {} _batch_ewald_reciprocal_space_virial_kernel_overload = {} for t, v, m in zip(_T, _V, _M): # ==================== Real-space single-system kernels ==================== _ewald_real_space_energy_kernel_overload[t] = wp.overload( _ewald_real_space_energy_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # idx_i wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # pair_energies (always float64) ], ) _ewald_real_space_energy_forces_kernel_overload[t] = wp.overload( _ewald_real_space_energy_forces_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.int32), # neighbor_ptr wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces (matches positions dtype) wp.array(dtype=m), # virial ], ) _ewald_real_space_energy_neighbor_matrix_kernel_overload[t] = wp.overload( _ewald_real_space_energy_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # pair_energies ], ) _ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[t] = wp.overload( _ewald_real_space_energy_forces_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=m), # virial ], ) # ==================== Real-space single-system kernels with charge gradients ==================== _ewald_real_space_energy_forces_charge_grad_kernel_overload[t] = wp.overload( _ewald_real_space_energy_forces_charge_grad_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.int32), # neighbor_ptr wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients wp.array(dtype=m), # virial ], ) _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[t] = ( wp.overload( _ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients wp.array(dtype=m), # virial ], ) ) # ==================== Real-space batch kernels ==================== _batch_ewald_real_space_energy_kernel_overload[t] = wp.overload( _batch_ewald_real_space_energy_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array(dtype=wp.int32), # idx_i wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # pair_energies ], ) _batch_ewald_real_space_energy_forces_kernel_overload[t] = wp.overload( _batch_ewald_real_space_energy_forces_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.int32), # neighbor_ptr wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=m), # virial ], ) _batch_ewald_real_space_energy_neighbor_matrix_kernel_overload[t] = wp.overload( _batch_ewald_real_space_energy_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # pair_energies ], ) _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel_overload[t] = ( wp.overload( _batch_ewald_real_space_energy_forces_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=m), # virial ], ) ) # ==================== Real-space batch kernels with charge gradients ==================== _batch_ewald_real_space_energy_forces_charge_grad_kernel_overload[t] = wp.overload( _batch_ewald_real_space_energy_forces_charge_grad_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array(dtype=wp.int32), # idx_j wp.array(dtype=wp.int32), # neighbor_ptr wp.array(dtype=wp.vec3i), # unit_shifts wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients wp.array(dtype=m), # virial ], ) _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel_overload[ t ] = wp.overload( _batch_ewald_real_space_energy_forces_charge_grad_neighbor_matrix_kernel, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=m), # cell wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=wp.int32), # neighbor_matrix wp.array2d(dtype=wp.vec3i), # unit_shifts_matrix wp.int32, # mask_value wp.array(dtype=t), # alpha wp.bool, # compute_virial wp.array(dtype=wp.float64), # pair_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients wp.array(dtype=m), # virial ], ) # ==================== Reciprocal-space single-system kernels ==================== _ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[t] = ( wp.overload( _ewald_reciprocal_space_energy_kernel_fill_structure_factors, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array(dtype=v), # k_vectors wp.array(dtype=m), # cell wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # total_charge wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array(dtype=wp.float64), # real_structure_factors wp.array(dtype=wp.float64), # imag_structure_factors ], ) ) _ewald_reciprocal_space_energy_kernel_compute_energy_overload[t] = wp.overload( _ewald_reciprocal_space_energy_kernel_compute_energy, [ wp.array(dtype=t), # charges wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array(dtype=wp.float64), # real_structure_factors wp.array(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies ], ) _ewald_reciprocal_space_energy_forces_kernel_overload[t] = wp.overload( _ewald_reciprocal_space_energy_forces_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=v), # k_vectors wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array(dtype=wp.float64), # real_structure_factors wp.array(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies wp.array(dtype=v), # atomic_forces ], ) _ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[t] = wp.overload( _ewald_reciprocal_space_energy_forces_charge_grad_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=v), # k_vectors wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array(dtype=wp.float64), # real_structure_factors wp.array(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients ], ) _ewald_subtract_self_energy_kernel_overload[t] = wp.overload( _ewald_subtract_self_energy_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # total_charge wp.array(dtype=wp.float64), # energy_in wp.array(dtype=wp.float64), # energy_out ], ) _ewald_reciprocal_space_virial_kernel_overload[t] = wp.overload( _ewald_reciprocal_space_virial_kernel, [ wp.array(dtype=v), # k_vectors wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # volume wp.array(dtype=wp.float64), # real_structure_factors wp.array(dtype=wp.float64), # imag_structure_factors wp.array(dtype=m), # virial ], ) # ==================== Reciprocal-space batch kernels ==================== _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors_overload[t] = ( wp.overload( _batch_ewald_reciprocal_space_energy_kernel_fill_structure_factors, [ wp.array(dtype=v), # positions wp.array(dtype=t), # charges wp.array2d(dtype=v), # k_vectors (B, K) wp.array(dtype=m), # cell wp.array(dtype=t), # alpha wp.array(dtype=wp.int32), # atom_start wp.array(dtype=wp.int32), # atom_end wp.array(dtype=wp.float64), # total_charges wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array2d(dtype=wp.float64), # real_structure_factors wp.array2d(dtype=wp.float64), # imag_structure_factors ], ) ) _batch_ewald_reciprocal_space_energy_kernel_compute_energy_overload[t] = ( wp.overload( _batch_ewald_reciprocal_space_energy_kernel_compute_energy, [ wp.array(dtype=t), # charges wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array2d(dtype=wp.float64), # real_structure_factors wp.array2d(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies ], ) ) _batch_ewald_reciprocal_space_energy_forces_kernel_overload[t] = wp.overload( _batch_ewald_reciprocal_space_energy_forces_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=v), # k_vectors (B, K) wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array2d(dtype=wp.float64), # real_structure_factors wp.array2d(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies wp.array(dtype=v), # atomic_forces ], ) _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel_overload[t] = ( wp.overload( _batch_ewald_reciprocal_space_energy_forces_charge_grad_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=wp.int32), # batch_id wp.array2d(dtype=v), # k_vectors (B, K) wp.array2d(dtype=wp.float64), # cos_k_dot_r wp.array2d(dtype=wp.float64), # sin_k_dot_r wp.array2d(dtype=wp.float64), # real_structure_factors wp.array2d(dtype=wp.float64), # imag_structure_factors wp.array(dtype=wp.float64), # reciprocal_energies wp.array(dtype=v), # atomic_forces wp.array(dtype=wp.float64), # charge_gradients ], ) ) _batch_ewald_subtract_self_energy_kernel_overload[t] = wp.overload( _batch_ewald_subtract_self_energy_kernel, [ wp.array(dtype=t), # charges wp.array(dtype=wp.int32), # batch_idx wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # total_charges wp.array(dtype=wp.float64), # energy_in wp.array(dtype=wp.float64), # energy_out ], ) _batch_ewald_reciprocal_space_virial_kernel_overload[t] = wp.overload( _batch_ewald_reciprocal_space_virial_kernel, [ wp.array2d(dtype=v), # k_vectors (B, K) wp.array(dtype=t), # alpha wp.array(dtype=wp.float64), # volume wp.array2d(dtype=wp.float64), # real_structure_factors wp.array2d(dtype=wp.float64), # imag_structure_factors wp.array(dtype=m), # virial ], )