Source code for nvalchemiops.spline

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""
B-Spline Interpolation Module
=============================

This module provides B-spline interpolation functions for mesh-based calculations,
commonly used in Particle Mesh Ewald (PME) and similar methods.

SUPPORTED ORDERS
================

- Order 1: Constant (Nearest Grid Point)
- Order 2: Linear
- Order 3: Quadratic
- Order 4: Cubic (recommended for PME)

OPERATIONS
==========

1. SPREAD: Scatter atom values to mesh grid
   mesh[g] += value[atom] * weight(atom, g)

2. GATHER: Collect mesh values at atom positions
   value[atom] = Σ_g mesh[g] * weight(atom, g)

3. GATHER_VEC3: Collect 3D vector field values at atom positions
   vector[atom] = Σ_g mesh[g] * weight(atom, g)

4. GATHER_GRADIENT: Collect mesh values with weight gradients (forces)
   grad[atom] = sum_g mesh[g] * grad_weight(atom, g)

5. SPREAD_CHANNELS: Scatter multi-channel values (e.g., multipoles) to mesh
   mesh[c, g] += values[atom, c] * weight(atom, g)

6. GATHER_CHANNELS: Collect multi-channel values from mesh
   values[atom, c] = Σ_g mesh[c, g] * weight(atom, g)

7. DECONVOLUTION: Correct B-spline approximation in Fourier space
   Used in FFT-based methods to remove B-spline smoothing artifacts.

USAGE
=====

Single-system:
    from nvalchemiops.spline import spline_spread, spline_gather, spline_gather_gradient

    # Spread charges to mesh
    mesh = spline_spread(positions, charges, cell, mesh_dims, spline_order=4)

    # Gather potential from mesh
    potentials = spline_gather(positions, potential_mesh, cell, spline_order=4)

    # Gather forces
    forces = spline_gather_gradient(positions, charges, potential_mesh, cell, spline_order=4)

Multi-channel (multipoles):
    from nvalchemiops.spline import spline_spread_channels, spline_gather_channels

    # multipoles has shape (N, num_channels) e.g. (N, 9) for L_max=2
    mesh = spline_spread_channels(positions, multipoles, cell, mesh_dims, spline_order=4)

    # Gather multi-channel potential from mesh
    potentials = spline_gather_channels(positions, potential_mesh, cell, spline_order=4)

Batched (multiple systems):
    # Spread charges to batched mesh
    mesh = spline_spread(positions, charges, cell, mesh_dims, spline_order=4, batch_idx=batch_idx)

    # Gather potential from batched mesh
    potentials = spline_gather(positions, potential_mesh, cell, spline_order=4, batch_idx=batch_idx)

Deconvolution:
    from nvalchemiops.spline import compute_bspline_deconvolution

    # Get deconvolution factors for mesh
    deconv = compute_bspline_deconvolution(mesh_dims, spline_order=4, device=device)

    # Apply in Fourier space: mesh_corrected_k = mesh_k * deconv
    mesh_fft = torch.fft.fftn(mesh)
    mesh_corrected_fft = mesh_fft * deconv
    mesh_corrected = torch.fft.ifftn(mesh_corrected_fft).real

REFERENCES
==========

- Essmann et al. (1995). J. Chem. Phys. 103, 8577 (PME B-splines)
"""

from __future__ import annotations

import math
from typing import Any

import torch
import warp as wp

from nvalchemiops.autograd import (
    OutputSpec,
    WarpAutogradContextManager,
    attach_for_backward,
    needs_grad,
    warp_custom_op,
    warp_from_torch,
)
from nvalchemiops.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

# Mathematical constants
PI = math.pi
TWOPI = 2.0 * PI


###########################################################################################
########################### B-Spline Weight Functions #####################################
###########################################################################################


@wp.func
def bspline_weight(u: Any, order: wp.int32) -> Any:
    """Compute B-spline basis function M_n(u).

    Parameters
    ----------
    u : float (Any)
        Parameter in [0, order). Type-generic (float32 or float64).
    order : wp.int32
        Spline order (1=constant, 2=linear, 3=quadratic, 4=cubic).

    Returns
    -------
    float (Any)
        Weight value M_n(u). Same type as input.
    """
    # Type-generic constants
    zero = type(u)(0.0)
    one = type(u)(1.0)
    two = type(u)(2.0)
    three = type(u)(3.0)
    four = type(u)(4.0)
    six = type(u)(6.0)

    if order == 4:
        if u >= zero and u < one:
            return u * u * u / six
        elif u >= one and u < two:
            u2 = u * u
            u3 = u2 * u
            return (
                type(u)(-3.0) * u3 + type(u)(12.0) * u2 - type(u)(12.0) * u + four
            ) / six
        elif u >= two and u < three:
            u2 = u * u
            u3 = u2 * u
            return (
                three * u3 - type(u)(24.0) * u2 + type(u)(60.0) * u - type(u)(44.0)
            ) / six
        elif u >= three and u < four:
            v = four - u
            return v * v * v / six
        else:
            return zero
    elif order == 3:
        if u >= zero and u < one:
            return u * u / two
        elif u >= one and u < two:
            return type(u)(0.75) - (u - type(u)(1.5)) * (u - type(u)(1.5))
        elif u >= two and u < three:
            v = three - u
            return v * v / two
        else:
            return zero
    elif order == 2:
        if u >= zero and u < one:
            return u
        elif u >= one and u < two:
            return two - u
        else:
            return zero
    elif order == 1:
        if u >= zero and u < one:
            return one
        else:
            return zero
    else:
        return zero


@wp.func
def bspline_derivative(u: Any, order: wp.int32) -> Any:
    """Compute B-spline derivative dM_n(u)/du.

    Parameters
    ----------
    u : float (Any)
        Parameter in [0, order). Type-generic (float32 or float64).
    order : wp.int32
        Spline order.

    Returns
    -------
    float (Any)
        Derivative value. Same type as input.
    """
    # Type-generic constants
    zero = type(u)(0.0)
    one = type(u)(1.0)
    two = type(u)(2.0)
    three = type(u)(3.0)
    four = type(u)(4.0)
    six = type(u)(6.0)

    if order == 4:
        if u >= zero and u < one:
            return u * u / two
        elif u >= one and u < two:
            return (type(u)(-9.0) * u * u + type(u)(24.0) * u - type(u)(12.0)) / six
        elif u >= two and u < three:
            return (type(u)(9.0) * u * u - type(u)(48.0) * u + type(u)(60.0)) / six
        elif u >= three and u < four:
            v = four - u
            return -three * v * v / six
        else:
            return zero
    elif order == 3:
        if u >= zero and u < one:
            return u
        elif u >= one and u < two:
            return -two * (u - type(u)(1.5))
        elif u >= two and u < three:
            return -(three - u)
        else:
            return zero
    elif order == 2:
        if u >= zero and u < one:
            return one
        elif u >= one and u < two:
            return -one
        else:
            return zero
    else:
        return zero


###########################################################################################
########################### Grid Utility Functions ########################################
###########################################################################################


@wp.func
def compute_fractional_coords(
    position: Any,
    cell_inv_t: Any,
    mesh_dims: wp.vec3i,
) -> Any:
    """Convert Cartesian position to mesh coordinates.

    Parameters
    ----------
    position : vec3 (Any)
        Atomic position. Type-generic (vec3f or vec3d).
    cell_inv_t : mat33 (Any)
        Transpose of inverse cell. Type-generic (mat33f or mat33d).
    mesh_dims : wp.vec3i
        Mesh dimensions.

    Returns
    -------
    base_grid : wp.vec3i
        Base grid point (floor of mesh coords).
    theta : vec3 (Any)
        Fractional part [0, 1) in each dimension. Same type as position.

    Note: Returns (base_grid, theta) as a tuple via multiple return values.
    """
    # Convert to fractional coordinates
    frac = cell_inv_t * position
    p0 = position[0]
    # Scale to mesh coordinates
    mesh_x = frac[0] * type(p0)(mesh_dims[0])
    mesh_y = frac[1] * type(p0)(mesh_dims[1])
    mesh_z = frac[2] * type(p0)(mesh_dims[2])

    # Base grid point
    mx = wp.int32(wp.floor(mesh_x))
    my = wp.int32(wp.floor(mesh_y))
    mz = wp.int32(wp.floor(mesh_z))

    # Fractional part
    theta_x = mesh_x - type(p0)(mx)
    theta_y = mesh_y - type(p0)(my)
    theta_z = mesh_z - type(p0)(mz)

    return wp.vec3i(mx, my, mz), type(position)(theta_x, theta_y, theta_z)


@wp.func
def bspline_grid_offset(
    point_idx: wp.int32,
    order: wp.int32,
    theta: Any,
) -> wp.vec3i:
    """Compute grid offset for B-spline point index.

    For B-splines, points are indexed 0 to order^3-1 and arranged in a cube.
    The offset is computed such that the B-spline parameter u is always in [0, n).

    The offset_start for each dimension is floor(theta - (n-2)/2), which ensures
    that for any theta in [0, 1), all n grid points have valid u values.

    Parameters
    ----------
    point_idx : wp.int32
        Linear point index (0 to order^3-1).
    order : wp.int32
        Spline order.
    theta : vec3 (Any)
        Fractional position within the base grid cell [0, 1) in each dimension.
        Type-generic (vec3f or vec3d).

    Returns
    -------
    wp.vec3i
        Grid offset (relative to base grid point).
    """
    order2 = order * order
    i = point_idx // order2
    j = (point_idx % order2) // order
    k = point_idx % order

    t0 = theta[0]

    # Compute offset_start = floor(theta - (n-2)/2) for each dimension
    # This ensures u = n/2 + theta - offset is always in [0, n)
    half_n_minus_1 = type(t0)(order - 2) * type(t0)(0.5)
    offset_start_x = wp.int32(wp.floor(t0 - half_n_minus_1))
    offset_start_y = wp.int32(wp.floor(theta[1] - half_n_minus_1))
    offset_start_z = wp.int32(wp.floor(theta[2] - half_n_minus_1))

    return wp.vec3i(i + offset_start_x, j + offset_start_y, k + offset_start_z)


@wp.func
def bspline_weight_3d(
    theta: Any,
    offset: wp.vec3i,
    order: wp.int32,
) -> Any:
    """Compute 3D B-spline weight (separable product).

    The B-spline parameter u is computed as:

    .. math::

        u = \\text{order}/2 + \\theta - \\text{offset}

    When offset = i + offset_start (from bspline_grid_offset), this gives
    u values in [0, n) that sum to 1 and are centered at the atom position.

    Parameters
    ----------
    theta : vec3 (Any)
        Fractional position within the base grid cell [0, 1).
        Type-generic (vec3f or vec3d).
    offset : wp.vec3i
        Grid offset from base grid point (includes offset_start adjustment).
    order : wp.int32
        Spline order.

    Returns
    -------
    float (Any)
        Weight = M(u_x) * M(u_y) * M(u_z). Same scalar type as theta.
    """
    # Get scalar type from theta vector
    t0 = theta[0]
    half_order = type(t0)(order) * type(t0)(0.5)
    zero = type(t0)(0.0)
    order_f = type(t0)(order)

    # u = n/2 + theta - offset
    u_x = half_order + t0 - type(t0)(offset[0])
    u_y = half_order + theta[1] - type(t0)(offset[1])
    u_z = half_order + theta[2] - type(t0)(offset[2])

    if (
        u_x < zero
        or u_x >= order_f
        or u_y < zero
        or u_y >= order_f
        or u_z < zero
        or u_z >= order_f
    ):
        return zero

    return (
        bspline_weight(u_x, order)
        * bspline_weight(u_y, order)
        * bspline_weight(u_z, order)
    )


@wp.func
def bspline_weight_gradient_3d(
    theta: Any,
    offset: wp.vec3i,
    order: wp.int32,
    mesh_dims: wp.vec3i,
) -> Any:
    """Compute gradient of 3D B-spline weight.

    The B-spline parameter u is computed as:

    .. math::

        u = \\text{order}/2 + \\theta - \\text{offset}

    The gradient with respect to theta is:

    .. math::

        \\begin{aligned}
        \\frac{\\partial u}{\\partial \\theta} &= +1 \\\\
        \\frac{\\partial \\text{weight}}{\\partial \\theta} &= \\frac{\\partial M}{\\partial u} \\cdot \\frac{\\partial u}{\\partial \\theta} = \\frac{\\partial M}{\\partial u}
        \\end{aligned}

    Parameters
    ----------
    theta : vec3 (Any)
        Fractional position within the base grid cell [0, 1).
        Type-generic (vec3f or vec3d).
    offset : wp.vec3i
        Grid offset from base grid point (includes offset_start adjustment).
    order : wp.int32
        Spline order.
    mesh_dims : wp.vec3i
        Mesh dimensions (for scaling to Cartesian coordinates).

    Returns
    -------
    vec3 (Any)
        Gradient :math:`\\nabla` weight in fractional coordinates (scaled by mesh_dims).
        Same type as theta.
    """
    # Get scalar type from theta vector
    t0 = theta[0]
    half_order = type(t0)(order) * type(t0)(0.5)
    zero = type(t0)(0.0)
    order_f = type(t0)(order)

    # u = n/2 + theta - offset
    u_x = half_order + t0 - type(t0)(offset[0])
    u_y = half_order + theta[1] - type(t0)(offset[1])
    u_z = half_order + theta[2] - type(t0)(offset[2])

    if (
        u_x < zero
        or u_x >= order_f
        or u_y < zero
        or u_y >= order_f
        or u_z < zero
        or u_z >= order_f
    ):
        return type(theta)(zero, zero, zero)

    w_x = bspline_weight(u_x, order)
    w_y = bspline_weight(u_y, order)
    w_z = bspline_weight(u_z, order)

    # Positive sign because u = half_order + theta - offset, so ∂u/∂theta = +1
    dw_x = bspline_derivative(u_x, order) * type(t0)(mesh_dims[0])
    dw_y = bspline_derivative(u_y, order) * type(t0)(mesh_dims[1])
    dw_z = bspline_derivative(u_z, order) * type(t0)(mesh_dims[2])

    return type(theta)(dw_x * w_y * w_z, w_x * dw_y * w_z, w_x * w_y * dw_z)


@wp.func
def wrap_grid_index(idx: wp.int32, dim: wp.int32) -> wp.int32:
    """Wrap grid index for periodic boundaries."""
    return ((idx % dim) + dim) % dim


###########################################################################################
########################### Single-System Warp Kernels ####################################
###########################################################################################


@wp.kernel
def _bspline_spread_kernel(
    positions: wp.array(dtype=Any),
    values: wp.array(dtype=Any),
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array3d(dtype=Any),
):
    """Spread (scatter) values from atoms to a 3D mesh using B-spline interpolation.

    For each atom, distributes its value to nearby grid points weighted by the
    B-spline basis function. This is the adjoint operation to gathering.

    Formula: mesh[g] += value[atom] * w(atom, g)

    where w(atom, g) is the product of 1D B-spline weights in each dimension.

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    values : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Values to spread (e.g., charges).
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array3d, shape (nx, ny, nz), dtype=wp.float32 or wp.float64
        OUTPUT: 3D mesh to accumulate values into. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds for thread-safe accumulation to shared grid points.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    mesh_dims = wp.vec3i(mesh.shape[0], mesh.shape[1], mesh.shape[2])
    position = positions[atom_idx]
    value = values[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    if weight > type(value)(0.0):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        wp.atomic_add(mesh, gx, gy, gz, value * weight)


@wp.kernel
def _bspline_gather_kernel(
    positions: wp.array(dtype=Any),
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array3d(dtype=Any),
    output: wp.array(dtype=Any),
):
    """Gather (interpolate) values from a 3D mesh to atom positions using B-splines.

    For each atom, interpolates the mesh value at its position by summing nearby
    grid points weighted by the B-spline basis function.

    Formula: output[atom] = Σ_g mesh[g] * w(atom, g)

    where the sum is over the order^3 grid points in the atom's stencil.

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array3d, shape (nx, ny, nz), dtype=wp.float32 or wp.float64
        3D mesh containing values to interpolate (e.g., electrostatic potential).
    output : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        OUTPUT: Interpolated values per atom. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    mesh_dims = wp.vec3i(mesh.shape[0], mesh.shape[1], mesh.shape[2])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    mesh_val = mesh[0, 0, 0]  # Get type reference
    if weight > type(mesh_val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[gx, gy, gz]
        wp.atomic_add(output, atom_idx, mesh_val * weight)


@wp.kernel
def _bspline_gather_vec3_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array3d(dtype=Any),
    output: wp.array(dtype=Any),
):
    """Gather charge-weighted 3D vector values from mesh to atoms using B-splines.

    Similar to _bspline_gather_kernel but multiplies by the atom's charge and
    outputs to a 3D vector array (for use with vector-valued mesh fields).

    Formula: output[atom] = q[atom] * Σ_g mesh[g] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges (or other scalar weights).
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array3d, shape (nx, ny, nz), dtype=wp.vec3f or wp.vec3d
        3D mesh containing vector values to interpolate.
    output : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Charge-weighted interpolated vectors per atom. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    mesh_dims = wp.vec3i(mesh.shape[0], mesh.shape[1], mesh.shape[2])
    position = positions[atom_idx]
    charge = charges[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    if weight > type(charge)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[gx, gy, gz]
        wp.atomic_add(output, atom_idx, charge * mesh_val * weight)


@wp.kernel
def _bspline_gather_gradient_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array3d(dtype=Any),
    forces: wp.array(dtype=Any),
):
    """Compute forces by gathering mesh gradients using B-spline derivatives.

    Computes:

    .. math::

        F_i = -q_i \\sum_g \\phi(g) \\nabla w(r_i, g)

    The gradient ∇w is computed in fractional coordinates and then transformed
    to Cartesian coordinates via the cell matrix.

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    charges : wp.array, shape (N,), dtype=wp.float32 or wp.float64
        Atomic charges.
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array3d, shape (nx, ny, nz), dtype=wp.float32 or wp.float64
        3D mesh containing potential values (e.g., electrostatic potential φ).
    forces : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Forces per atom in Cartesian coordinates. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's force.
    - The gradient is computed in fractional coordinates, then transformed:
      F_cart = cell_inv_t^T * F_frac
    - Threads with zero gradient magnitude skip the atomic add for efficiency.
    - Grid indices are wrapped using periodic boundary conditions.
    """
    atom_idx, point_idx = wp.tid()

    mesh_dims = wp.vec3i(mesh.shape[0], mesh.shape[1], mesh.shape[2])
    position = positions[atom_idx]
    charge = charges[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    grad_frac = bspline_weight_gradient_3d(theta, offset, order, mesh_dims)

    grad_mag = wp.abs(grad_frac[0]) + wp.abs(grad_frac[1]) + wp.abs(grad_frac[2])

    if grad_mag > type(charge)(0.0):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[gx, gy, gz]

        force_frac = type(position)(
            -charge * mesh_val * grad_frac[0],
            -charge * mesh_val * grad_frac[1],
            -charge * mesh_val * grad_frac[2],
        )
        force = wp.transpose(cell_inv_t[0]) * force_frac

        wp.atomic_add(forces, atom_idx, force)


###########################################################################################
########################### Batch Warp Kernels #############################################
###########################################################################################


@wp.kernel
def _batch_bspline_spread_kernel(
    positions: wp.array(dtype=Any),
    values: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (B, nx, ny, nz)
):
    """Spread values from atoms to a batched 4D mesh using B-splines.

    Batched version of _bspline_spread_kernel for multiple systems. Each atom
    is assigned to a system via batch_idx, and values are spread to that
    system's mesh slice.

    Formula: mesh[sys, g] += value[atom] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    values : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Values to spread (e.g., charges) for all systems.
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (B, nx, ny, nz), dtype=wp.float32 or wp.float64
        OUTPUT: 4D mesh (batch × spatial) to accumulate values. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds for thread-safe accumulation to shared grid points.
    - Each system uses its own cell matrix for fractional coordinate conversion.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]
    value = values[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    if weight > type(value)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        wp.atomic_add(mesh, sys_idx, gx, gy, gz, value * weight)


@wp.kernel
def _batch_bspline_gather_kernel(
    positions: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (B, nx, ny, nz)
    output: wp.array(dtype=Any),
):
    """Gather values from a batched 4D mesh to atom positions using B-splines.

    Batched version of _bspline_gather_kernel for multiple systems. Each atom
    reads from its assigned system's mesh slice via batch_idx.

    Formula: output[atom] = Σ_g mesh[sys, g] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (B, nx, ny, nz), dtype=wp.float32 or wp.float64
        4D mesh (batch × spatial) containing values to interpolate.
    output : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        OUTPUT: Interpolated values per atom. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Each system uses its own cell matrix for fractional coordinate conversion.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    mesh_val = mesh[0, 0, 0, 0]  # Get type reference
    if weight > type(mesh_val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[sys_idx, gx, gy, gz]
        wp.atomic_add(output, atom_idx, mesh_val * weight)


@wp.kernel
def _batch_bspline_gather_vec3_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (B, nx, ny, nz)
    output: wp.array(dtype=Any),
):
    """Gather charge-weighted 3D vector values from batched mesh using B-splines.

    Batched version of _bspline_gather_vec3_kernel for multiple systems.

    Formula: output[atom] = q[atom] * Σ_g mesh[sys, g] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges (or other scalar weights).
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (B, nx, ny, nz), dtype=wp.vec3f or wp.vec3d
        4D mesh (batch × spatial) containing vector values.
    output : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Charge-weighted interpolated vectors per atom. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Each system uses its own cell matrix for fractional coordinate conversion.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic add for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]
    charge = charges[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    if weight > type(charge)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[sys_idx, gx, gy, gz]
        wp.atomic_add(output, atom_idx, charge * mesh_val * weight)


@wp.kernel
def _batch_bspline_gather_gradient_kernel(
    positions: wp.array(dtype=Any),
    charges: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (B, nx, ny, nz)
    forces: wp.array(dtype=Any),
):
    """Compute forces by gathering mesh gradients from batched mesh using B-spline derivatives.

    Computes:

    .. math::

        F_i = -q_i \\sum_g \\phi(g) \\nabla w(r_i, g)

    The gradient ∇w is computed in fractional coordinates and then transformed
    to Cartesian coordinates via each system's cell matrix.

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair within the atom's stencil.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    charges : wp.array, shape (N_total,), dtype=wp.float32 or wp.float64
        Atomic charges.
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (B, nx, ny, nz), dtype=wp.float32 or wp.float64
        4D mesh (batch × spatial) containing potential values.
    forces : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        OUTPUT: Forces per atom in Cartesian coordinates. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's force.
    - The gradient is computed in fractional coordinates, then transformed:
      F_cart = cell_inv_t[sys]^T * F_frac
    - Each system uses its own cell matrix for the transformation.
    - Threads with zero gradient magnitude skip the atomic add for efficiency.
    - Grid indices are wrapped using periodic boundary conditions.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]
    charge = charges[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    grad_frac = bspline_weight_gradient_3d(theta, offset, order, mesh_dims)

    grad_mag = wp.abs(grad_frac[0]) + wp.abs(grad_frac[1]) + wp.abs(grad_frac[2])

    if grad_mag > type(charge)(0.0):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        mesh_val = mesh[sys_idx, gx, gy, gz]

        force_frac = type(position)(
            -charge * mesh_val * grad_frac[0],
            -charge * mesh_val * grad_frac[1],
            -charge * mesh_val * grad_frac[2],
        )
        force = wp.transpose(cell_inv_t[sys_idx]) * force_frac

        wp.atomic_add(forces, atom_idx, force)


###########################################################################################
########################### Multi-Channel Warp Kernels ####################################
###########################################################################################


@wp.kernel
def _bspline_spread_channels_kernel(
    positions: wp.array(dtype=Any),
    values: wp.array2d(dtype=Any),  # (N, C)
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (C, nx, ny, nz)
):
    """Spread multi-channel values from atoms to mesh using B-splines.

    Similar to _bspline_spread_kernel but handles multiple channels per atom,
    useful for multipole moments (e.g., monopole + dipole + quadrupole).

    Formula: mesh[c, g] += values[atom, c] * w(atom, g)

    for each channel c = 0, 1, ..., C-1.

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair and iterates over all channels.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    values : wp.array2d, shape (N, C), dtype=wp.float32 or wp.float64
        Multi-channel values to spread (e.g., multipole moments).
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (C, nx, ny, nz), dtype=wp.float32 or wp.float64
        OUTPUT: 4D mesh (channels × spatial) to accumulate values. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds for thread-safe accumulation to shared grid points.
    - Each channel is spread independently to its own mesh slice.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic adds for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    num_channels = values.shape[1]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    val = values[0, 0]  # Get type reference
    if weight > type(val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        # Spread each channel
        for c in range(num_channels):
            val = values[atom_idx, c]
            wp.atomic_add(mesh, c, gx, gy, gz, val * weight)


@wp.kernel
def _bspline_gather_channels_kernel(
    positions: wp.array(dtype=Any),
    cell_inv_t: wp.array(dtype=Any),
    order: wp.int32,
    mesh: wp.array(dtype=Any, ndim=4),  # (C, nx, ny, nz)
    output: wp.array2d(dtype=Any),  # (N, C)
):
    """Gather multi-channel values from mesh to atoms using B-splines.

    Similar to _bspline_gather_kernel but handles multiple channels,
    useful for multipole-based methods.

    Formula: output[atom, c] = Σ_g mesh[c, g] * w(atom, g)

    for each channel c = 0, 1, ..., C-1.

    Launch Grid
    -----------
    dim = [num_atoms, order^3]

    Each thread handles one (atom, grid_point) pair and iterates over all channels.

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates in Cartesian space.
    cell_inv_t : wp.array, shape (1, 3, 3), dtype=wp.mat33f or wp.mat33d
        Transpose of inverse cell matrix for fractional coordinate conversion.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    mesh : wp.array4d, shape (C, nx, ny, nz), dtype=wp.float32 or wp.float64
        4D mesh (channels × spatial) containing values to interpolate.
    output : wp.array2d, shape (N, C), dtype=wp.float32 or wp.float64
        OUTPUT: Interpolated multi-channel values per atom. Must be zero-initialized.

    Notes
    -----
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Each channel is gathered independently from its own mesh slice.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic adds for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    num_channels = mesh.shape[0]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(position, cell_inv_t[0], mesh_dims)
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    mesh_val = mesh[0, 0, 0, 0]  # Get type reference
    if weight > type(mesh_val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        # Gather each channel
        for c in range(num_channels):
            mesh_val = mesh[c, gx, gy, gz]
            wp.atomic_add(output, atom_idx, c, mesh_val * weight)


@wp.kernel
def _batch_bspline_spread_channels_kernel(
    positions: wp.array(dtype=Any),
    values: wp.array2d(dtype=Any),  # (N, C)
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    num_channels: wp.int32,
    mesh: wp.array4d(dtype=Any),  # (B*C, nx, ny, nz) - flattened batch*channel
):
    """Spread multi-channel values from atoms to batched mesh using B-splines.

    Batched version of _bspline_spread_channels_kernel. Due to Warp's 4D array
    limit, the batch and channel dimensions are flattened into a single dimension.

    Formula: mesh[sys*C + c, g] += values[atom, c] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair and iterates over all channels.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    values : wp.array2d, shape (N_total, C), dtype=wp.float32 or wp.float64
        Multi-channel values to spread (e.g., multipole moments).
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    num_channels : wp.int32
        Number of channels (C).
    mesh : wp.array4d, shape (B*C, nx, ny, nz), dtype=wp.float32 or wp.float64
        OUTPUT: Flattened 4D mesh to accumulate values. Must be zero-initialized.

    Notes
    -----
    - Mesh storage: (B*C, nx, ny, nz) with flat_idx = sys_idx * C + channel_idx.
    - Uses atomic adds for thread-safe accumulation to shared grid points.
    - Each system uses its own cell matrix for fractional coordinate conversion.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic adds for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    val = values[0, 0]  # Get type reference
    if weight > type(val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        # Spread each channel using flattened batch*channel indexing
        for c in range(num_channels):
            flat_idx = sys_idx * num_channels + c
            val = values[atom_idx, c]
            wp.atomic_add(mesh, flat_idx, gx, gy, gz, val * weight)


@wp.kernel
def _batch_bspline_gather_channels_kernel(
    positions: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    cell_inv_t: wp.array(dtype=Any),  # (B, 3, 3)
    order: wp.int32,
    num_channels: wp.int32,
    mesh: wp.array4d(dtype=Any),  # (B*C, nx, ny, nz) - flattened batch*channel
    output: wp.array2d(dtype=Any),  # (N, C)
):
    """Gather multi-channel values from batched mesh to atoms using B-splines.

    Batched version of _bspline_gather_channels_kernel. Due to Warp's 4D array
    limit, the batch and channel dimensions are flattened into a single dimension.

    Formula: output[atom, c] = Σ_g mesh[sys*C + c, g] * w(atom, g)

    Launch Grid
    -----------
    dim = [num_atoms_total, order^3]

    Each thread handles one (atom, grid_point) pair and iterates over all channels.

    Parameters
    ----------
    positions : wp.array, shape (N_total,), dtype=wp.vec3f or wp.vec3d
        Atomic coordinates for all systems concatenated.
    batch_idx : wp.array, shape (N_total,), dtype=wp.int32
        System index for each atom (0 to B-1).
    cell_inv_t : wp.array, shape (B, 3, 3), dtype=wp.mat33f or wp.mat33d
        Per-system transpose of inverse cell matrix.
    order : wp.int32
        B-spline order (1-4). Order 4 (cubic) recommended for PME.
    num_channels : wp.int32
        Number of channels (C).
    mesh : wp.array4d, shape (B*C, nx, ny, nz), dtype=wp.float32 or wp.float64
        Flattened 4D mesh (batch*channels × spatial) containing values.
    output : wp.array2d, shape (N_total, C), dtype=wp.float32 or wp.float64
        OUTPUT: Interpolated multi-channel values per atom. Must be zero-initialized.

    Notes
    -----
    - Mesh storage: (B*C, nx, ny, nz) with flat_idx = sys_idx * C + channel_idx.
    - Uses atomic adds since multiple threads contribute to each atom's output.
    - Each system uses its own cell matrix for fractional coordinate conversion.
    - Grid indices are wrapped using periodic boundary conditions.
    - Threads with 1e-8 weight or less skip the atomic adds for efficiency.
    """
    atom_idx, point_idx = wp.tid()

    sys_idx = batch_idx[atom_idx]
    mesh_dims = wp.vec3i(mesh.shape[1], mesh.shape[2], mesh.shape[3])
    position = positions[atom_idx]

    base_grid, theta = compute_fractional_coords(
        position, cell_inv_t[sys_idx], mesh_dims
    )
    offset = bspline_grid_offset(point_idx, order, theta)
    weight = bspline_weight_3d(theta, offset, order)

    mesh_val = mesh[0, 0, 0, 0]  # Get type reference
    if weight > type(mesh_val)(1e-8):
        gx = wrap_grid_index(base_grid[0] + offset[0], mesh_dims[0])
        gy = wrap_grid_index(base_grid[1] + offset[1], mesh_dims[1])
        gz = wrap_grid_index(base_grid[2] + offset[2], mesh_dims[2])

        # Gather each channel using flattened batch*channel indexing
        for c in range(num_channels):
            flat_idx = sys_idx * num_channels + c
            mesh_val = mesh[flat_idx, gx, gy, gz]
            wp.atomic_add(output, atom_idx, c, mesh_val * weight)


###########################################################################################
########################### Kernel Overloads for Dtype Flexibility #########################
###########################################################################################

# Type lists for creating overloads
_T = [wp.float32, wp.float64]
_V = [wp.vec3f, wp.vec3d]
_M = [wp.mat33f, wp.mat33d]

# Single-system kernel overloads
_bspline_spread_kernel_overload = {}
_bspline_gather_kernel_overload = {}
_bspline_gather_vec3_kernel_overload = {}
_bspline_gather_gradient_kernel_overload = {}

# Batch kernel overloads
_batch_bspline_spread_kernel_overload = {}
_batch_bspline_gather_kernel_overload = {}
_batch_bspline_gather_vec3_kernel_overload = {}
_batch_bspline_gather_gradient_kernel_overload = {}

# Multi-channel kernel overloads
_bspline_spread_channels_kernel_overload = {}
_bspline_gather_channels_kernel_overload = {}
_batch_bspline_spread_channels_kernel_overload = {}
_batch_bspline_gather_channels_kernel_overload = {}

for t, v, m in zip(_T, _V, _M):
    # Single-system kernels
    _bspline_spread_kernel_overload[t] = wp.overload(
        _bspline_spread_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # values
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array3d(dtype=t),  # mesh
        ],
    )
    _bspline_gather_kernel_overload[t] = wp.overload(
        _bspline_gather_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array3d(dtype=t),  # mesh
            wp.array(dtype=t),  # output
        ],
    )
    _bspline_gather_vec3_kernel_overload[t] = wp.overload(
        _bspline_gather_vec3_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # charges
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array3d(dtype=v),  # mesh
            wp.array(dtype=v),  # output
        ],
    )
    _bspline_gather_gradient_kernel_overload[t] = wp.overload(
        _bspline_gather_gradient_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # charges
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array3d(dtype=t),  # mesh
            wp.array(dtype=v),  # forces
        ],
    )

    # Batch kernels
    _batch_bspline_spread_kernel_overload[t] = wp.overload(
        _batch_bspline_spread_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # values
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=t, ndim=4),  # mesh
        ],
    )
    _batch_bspline_gather_kernel_overload[t] = wp.overload(
        _batch_bspline_gather_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=t, ndim=4),  # mesh
            wp.array(dtype=t),  # output
        ],
    )
    _batch_bspline_gather_vec3_kernel_overload[t] = wp.overload(
        _batch_bspline_gather_vec3_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # charges
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=v, ndim=4),  # mesh
            wp.array(dtype=v),  # output
        ],
    )
    _batch_bspline_gather_gradient_kernel_overload[t] = wp.overload(
        _batch_bspline_gather_gradient_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=t),  # charges
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=t, ndim=4),  # mesh
            wp.array(dtype=v),  # forces
        ],
    )

    # Multi-channel kernels
    _bspline_spread_channels_kernel_overload[t] = wp.overload(
        _bspline_spread_channels_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array2d(dtype=t),  # values
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=t, ndim=4),  # mesh
        ],
    )
    _bspline_gather_channels_kernel_overload[t] = wp.overload(
        _bspline_gather_channels_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.array(dtype=t, ndim=4),  # mesh
            wp.array2d(dtype=t),  # output
        ],
    )
    _batch_bspline_spread_channels_kernel_overload[t] = wp.overload(
        _batch_bspline_spread_channels_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array2d(dtype=t),  # values
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.int32,  # num_channels
            wp.array4d(dtype=t),  # mesh
        ],
    )
    _batch_bspline_gather_channels_kernel_overload[t] = wp.overload(
        _batch_bspline_gather_channels_kernel,
        [
            wp.array(dtype=v),  # positions
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=m),  # cell_inv_t
            wp.int32,  # order
            wp.int32,  # num_channels
            wp.array4d(dtype=t),  # mesh
            wp.array2d(dtype=t),  # output
        ],
    )


###########################################################################################
########################### Internal Custom Ops: _spline_* (Single-System) #################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_spline_spread",
    outputs=[
        OutputSpec(
            "mesh",
            wp.array(dtype=Any, ndim=3),
            lambda pos, values, cell, mesh_nx, mesh_ny, mesh_nz, spline_order, *_: (
                mesh_nx,
                mesh_ny,
                mesh_nz,
            ),
        ),
    ],
    grad_arrays=[
        "mesh",
        "positions",
        "values",
        "cell_inv_t",
    ],
)
def _spline_spread(
    positions: torch.Tensor,
    values: torch.Tensor,
    cell: torch.Tensor,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Single-system spline spread with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, values, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv_ex(cell)[0]
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_values = warp_from_torch(
        values.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )

    mesh = torch.zeros(
        (mesh_nx, mesh_ny, mesh_nz), device=positions.device, dtype=input_dtype
    )
    wp_mesh = warp_from_torch(mesh, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_spread_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[wp_positions, wp_values, wp_cell_inv_t, wp.int32(spline_order)],
            outputs=[wp_mesh],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            mesh,
            tape=tape,
            mesh=wp_mesh,
            positions=wp_positions,
            values=wp_values,
            cell_inv_t=wp_cell_inv_t,
        )
    return mesh


@warp_custom_op(
    name="alchemiops::_spline_gather",
    outputs=[
        OutputSpec(
            "values",
            wp.array(dtype=Any),
            lambda pos, *_: (pos.shape[0],),
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "mesh",
        "cell_inv_t",
    ],
)
def _spline_gather(
    positions: torch.Tensor,
    mesh: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Single-system spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )

    values = torch.zeros(num_atoms, device=positions.device, dtype=input_dtype)
    wp_values = warp_from_torch(values, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_gather_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[wp_positions, wp_cell_inv_t, wp.int32(spline_order), wp_mesh],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


@warp_custom_op(
    name="alchemiops::_spline_gather_vec3",
    outputs=[
        OutputSpec(
            "values", wp.array(dtype=Any, ndim=2), lambda pos, *_: (pos.shape[0], 3)
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "charges",
        "mesh",
        "cell_inv_t",
    ],
)
def _spline_gather_vec3(
    positions: torch.Tensor,
    charges: torch.Tensor,
    mesh: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Single-system vec3 spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_charges = warp_from_torch(
        charges.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_vec_dtype, requires_grad=needs_grad_flag
    )

    values = torch.zeros((num_atoms, 3), device=positions.device, dtype=input_dtype)
    wp_values = warp_from_torch(values, wp_vec_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_gather_vec3_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_charges,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp_mesh,
            ],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            charges=wp_charges,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


@warp_custom_op(
    name="alchemiops::_spline_gather_gradient",
    outputs=[
        OutputSpec(
            "forces", wp.array(dtype=Any, ndim=2), lambda pos, *_: (pos.shape[0], 3)
        ),
    ],
    grad_arrays=[
        "forces",
        "positions",
        "charges",
        "mesh",
        "cell_inv_t",
    ],
)
def _spline_gather_gradient(
    positions: torch.Tensor,
    charges: torch.Tensor,
    mesh: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Single-system spline gather gradient with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, charges, mesh, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_charges = warp_from_torch(
        charges.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )

    forces = torch.zeros((num_atoms, 3), device=positions.device, dtype=input_dtype)
    wp_forces = warp_from_torch(forces, wp_vec_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_gather_gradient_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_charges,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp_mesh,
            ],
            outputs=[wp_forces],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            forces,
            tape=tape,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return forces


###########################################################################################
########################### Internal Custom Ops: _batch_spline_* (Batch) ###################
###########################################################################################


@warp_custom_op(
    name="alchemiops::_batch_spline_spread",
    outputs=[
        OutputSpec(
            "mesh",
            wp.array(dtype=Any, ndim=4),
            lambda pos,
            values,
            batch_idx,
            cell,
            num_systems,
            mesh_nx,
            mesh_ny,
            mesh_nz,
            spline_order,
            *_: (num_systems, mesh_nx, mesh_ny, mesh_nz),
        ),
    ],
    grad_arrays=[
        "mesh",
        "positions",
        "values",
        "cell_inv_t",
    ],
)
def _batch_spline_spread(
    positions: torch.Tensor,
    values: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    num_systems: int,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Batch spline spread with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, values, cell)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_values = warp_from_torch(
        values.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )

    mesh = torch.zeros(
        (num_systems, mesh_nx, mesh_ny, mesh_nz),
        device=positions.device,
        dtype=input_dtype,
    )
    wp_mesh = warp_from_torch(mesh, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_spread_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_values,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
            ],
            outputs=[wp_mesh],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            mesh,
            tape=tape,
            mesh=wp_mesh,
            positions=wp_positions,
            values=wp_values,
            cell_inv_t=wp_cell_inv_t,
        )
    return mesh


@warp_custom_op(
    name="alchemiops::_batch_spline_gather",
    outputs=[
        OutputSpec(
            "values",
            wp.array(dtype=Any),
            lambda pos, *_: (pos.shape[0],),
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "mesh",
        "cell_inv_t",
    ],
)
def _batch_spline_gather(
    positions: torch.Tensor,
    mesh: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Batch spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )

    values = torch.zeros(num_atoms, device=positions.device, dtype=input_dtype)
    wp_values = warp_from_torch(values, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_gather_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp_mesh,
            ],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


@warp_custom_op(
    name="alchemiops::_batch_spline_gather_vec3",
    outputs=[
        OutputSpec(
            "values", wp.array(dtype=Any, ndim=2), lambda pos, *_: (pos.shape[0], 3)
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "charges",
        "mesh",
        "cell_inv_t",
    ],
)
def _batch_spline_gather_vec3(
    positions: torch.Tensor,
    charges: torch.Tensor,
    mesh: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Batch vec3 spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_charges = warp_from_torch(
        charges.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_vec_dtype, requires_grad=needs_grad_flag
    )

    values = torch.zeros((num_atoms, 3), device=positions.device, dtype=input_dtype)
    wp_values = warp_from_torch(values, wp_vec_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_gather_vec3_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_charges,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp_mesh,
            ],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            charges=wp_charges,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


@warp_custom_op(
    name="alchemiops::_batch_spline_gather_gradient",
    outputs=[
        OutputSpec(
            "forces", wp.array(dtype=Any, ndim=2), lambda pos, *_: (pos.shape[0], 3)
        ),
    ],
    grad_arrays=[
        "forces",
        "positions",
        "charges",
        "mesh",
        "cell_inv_t",
    ],
)
def _batch_spline_gather_gradient(
    positions: torch.Tensor,
    charges: torch.Tensor,
    mesh: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
    cell_inv_t: torch.Tensor | None = None,
) -> torch.Tensor:
    """Internal: Batch spline gather gradient with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, charges, mesh, cell)

    if cell_inv_t is None:
        cell_inv = torch.linalg.inv(cell)
        cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_charges = warp_from_torch(
        charges.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )

    forces = torch.zeros((num_atoms, 3), device=positions.device, dtype=input_dtype)
    wp_forces = warp_from_torch(forces, wp_vec_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_gather_gradient_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_charges,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp_mesh,
            ],
            outputs=[wp_forces],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            forces,
            tape=tape,
            forces=wp_forces,
            positions=wp_positions,
            charges=wp_charges,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return forces


###########################################################################################
########################### Internal Custom Ops: Multi-Channel (Single-System) #############
###########################################################################################


@warp_custom_op(
    name="alchemiops::_spline_spread_channels",
    outputs=[
        OutputSpec(
            "mesh",
            wp.array(dtype=Any, ndim=4),
            lambda pos,
            values,
            cell,
            num_channels,
            mesh_nx,
            mesh_ny,
            mesh_nz,
            spline_order,
            *_: (num_channels, mesh_nx, mesh_ny, mesh_nz),
        ),
    ],
    grad_arrays=[
        "mesh",
        "positions",
        "values",
        "cell_inv_t",
    ],
)
def _spline_spread_channels(
    positions: torch.Tensor,
    values: torch.Tensor,
    cell: torch.Tensor,
    num_channels: int,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
) -> torch.Tensor:
    """Internal: Single-system multi-channel spline spread with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, values, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    cell_inv = torch.linalg.inv_ex(cell)[0]
    cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_values = warp_from_torch(
        values.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )

    mesh = torch.zeros(
        (num_channels, mesh_nx, mesh_ny, mesh_nz),
        device=positions.device,
        dtype=input_dtype,
    )
    wp_mesh = warp_from_torch(mesh, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_spread_channels_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[wp_positions, wp_values, wp_cell_inv_t, wp.int32(spline_order)],
            outputs=[wp_mesh],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            mesh,
            tape=tape,
            mesh=wp_mesh,
            positions=wp_positions,
            values=wp_values,
            cell_inv_t=wp_cell_inv_t,
        )
    return mesh


@warp_custom_op(
    name="alchemiops::_spline_gather_channels",
    outputs=[
        OutputSpec(
            "values",
            wp.array(dtype=Any, ndim=2),
            lambda pos, mesh, *_: (pos.shape[0], mesh.shape[1]),
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "mesh",
        "cell_inv_t",
    ],
)
def _spline_gather_channels(
    positions: torch.Tensor,
    mesh: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
) -> torch.Tensor:
    """Internal: Single-system multi-channel spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_channels = mesh.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    if cell.dim() == 2:
        cell = cell.unsqueeze(0)

    cell_inv = torch.linalg.inv(cell)
    cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )
    wp_mesh = warp_from_torch(
        mesh.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )

    values = torch.zeros(
        (num_atoms, num_channels), device=positions.device, dtype=input_dtype
    )
    wp_values = warp_from_torch(values, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _bspline_gather_channels_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[wp_positions, wp_cell_inv_t, wp.int32(spline_order), wp_mesh],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


###########################################################################################
########################### Internal Custom Ops: Multi-Channel (Batch) #####################
###########################################################################################


def _batch_spline_spread_channels_output_shape(
    position,
    values,
    batch_idx,
    cell,
    num_systems,
    num_channels,
    mesh_nx,
    mesh_ny,
    mesh_nz,
    spline_order,
):
    return (num_systems, num_channels, mesh_nx, mesh_ny, mesh_nz)


@warp_custom_op(
    name="alchemiops::_batch_spline_spread_channels",
    outputs=[
        OutputSpec(
            "mesh",
            wp.array(dtype=Any, ndim=4),
            _batch_spline_spread_channels_output_shape,
        ),
    ],
    grad_arrays=[
        "mesh",
        "positions",
        "values",
        "cell_inv_t",
    ],
)
def _batch_spline_spread_channels(
    positions: torch.Tensor,
    values: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    num_systems: int,
    num_channels: int,
    mesh_nx: int,
    mesh_ny: int,
    mesh_nz: int,
    spline_order: int,
) -> torch.Tensor:
    """Internal: Batch multi-channel spline spread with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, values, cell)

    cell_inv = torch.linalg.inv(cell)
    cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_values = warp_from_torch(
        values.to(input_dtype), wp_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )

    # Create mesh with flattened (B*C, nx, ny, nz) format for Warp 4D limit
    mesh_flat = torch.zeros(
        (num_systems * num_channels, mesh_nx, mesh_ny, mesh_nz),
        device=positions.device,
        dtype=input_dtype,
    )
    wp_mesh = warp_from_torch(mesh_flat, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_spread_channels_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_values,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp.int32(num_channels),
            ],
            outputs=[wp_mesh],
            device=device,
        )

    # Reshape back to (B, C, nx, ny, nz) for output
    mesh = mesh_flat.view(num_systems, num_channels, mesh_nx, mesh_ny, mesh_nz)

    if needs_grad_flag:
        attach_for_backward(
            mesh,
            tape=tape,
            mesh=wp_mesh,
            positions=wp_positions,
            values=wp_values,
            cell_inv_t=wp_cell_inv_t,
        )
    return mesh


@warp_custom_op(
    name="alchemiops::_batch_spline_gather_channels",
    outputs=[
        OutputSpec(
            "values",
            wp.array(dtype=Any, ndim=2),
            lambda pos, mesh, *_: (pos.shape[0], mesh.shape[1]),
        ),
    ],
    grad_arrays=[
        "values",
        "positions",
        "mesh",
        "cell_inv_t",
    ],
)
def _batch_spline_gather_channels(
    positions: torch.Tensor,
    mesh: torch.Tensor,
    batch_idx: torch.Tensor,
    cell: torch.Tensor,
    spline_order: int,
) -> torch.Tensor:
    """Internal: Batch multi-channel spline gather with dtype flexibility."""
    device = wp.device_from_torch(positions.device)
    input_dtype = positions.dtype
    wp_dtype = get_wp_dtype(input_dtype)
    wp_vec_dtype = get_wp_vec_dtype(input_dtype)
    wp_mat_dtype = get_wp_mat_dtype(input_dtype)

    num_atoms = positions.shape[0]
    num_systems = mesh.shape[0]  # (B, C, nx, ny, nz)
    num_channels = mesh.shape[1]
    mesh_nx, mesh_ny, mesh_nz = mesh.shape[2], mesh.shape[3], mesh.shape[4]
    num_points = spline_order**3
    needs_grad_flag = needs_grad(positions, mesh, cell)

    cell_inv = torch.linalg.inv(cell)
    cell_inv_t = cell_inv.transpose(-1, -2).contiguous()

    wp_positions = warp_from_torch(
        positions, wp_vec_dtype, requires_grad=needs_grad_flag
    )
    wp_batch_idx = warp_from_torch(batch_idx, wp.int32)
    wp_cell_inv_t = warp_from_torch(
        cell_inv_t, wp_mat_dtype, requires_grad=needs_grad_flag
    )

    # Flatten mesh from (B, C, nx, ny, nz) to (B*C, nx, ny, nz) for Warp 4D limit
    mesh_flat = (
        mesh.to(input_dtype)
        .view(num_systems * num_channels, mesh_nx, mesh_ny, mesh_nz)
        .contiguous()
    )
    wp_mesh = warp_from_torch(mesh_flat, wp_dtype, requires_grad=needs_grad_flag)

    values = torch.zeros(
        (num_atoms, num_channels), device=positions.device, dtype=input_dtype
    )
    wp_values = warp_from_torch(values, wp_dtype, requires_grad=needs_grad_flag)

    kernel = _batch_bspline_gather_channels_kernel_overload[wp_dtype]

    with WarpAutogradContextManager(needs_grad_flag) as tape:
        wp.launch(
            kernel,
            dim=(num_atoms, num_points),
            inputs=[
                wp_positions,
                wp_batch_idx,
                wp_cell_inv_t,
                wp.int32(spline_order),
                wp.int32(num_channels),
                wp_mesh,
            ],
            outputs=[wp_values],
            device=device,
        )

    if needs_grad_flag:
        attach_for_backward(
            values,
            tape=tape,
            values=wp_values,
            positions=wp_positions,
            cell_inv_t=wp_cell_inv_t,
            mesh=wp_mesh,
        )
    return values


###########################################################################################
########################### Unified Public API #############################################
###########################################################################################


[docs] def spline_spread( positions: torch.Tensor, values: torch.Tensor, cell: torch.Tensor, mesh_dims: tuple[int, int, int], spline_order: int = 4, batch_idx: torch.Tensor | None = None, cell_inv_t: torch.Tensor | None = None, ) -> torch.Tensor: """Spread values from atoms to mesh grid using B-spline interpolation. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. values : torch.Tensor, shape (N,) Values to spread (e.g., charges). cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. For batched, shape should be (B, 3, 3). mesh_dims : tuple[int, int, int] Mesh dimensions (nx, ny, nz). spline_order : int, default=4 B-spline order (1-4, where 4=cubic). batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. cell_inv_t : torch.Tensor | None, default=None Precomputed transpose of cell inverse. If provided, skips inverse computation. Shape (1, 3, 3) for single-system or (B, 3, 3) for batch. Returns ------- mesh : torch.Tensor For single-system: shape (nx, ny, nz) For batch: shape (B, nx, ny, nz) """ mesh_nx, mesh_ny, mesh_nz = mesh_dims if batch_idx is None: return _spline_spread( positions, values, cell, mesh_nx, mesh_ny, mesh_nz, spline_order, cell_inv_t ) else: num_systems = cell.shape[0] if cell.dim() == 2: cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() return _batch_spline_spread( positions, values, batch_idx, cell, num_systems, mesh_nx, mesh_ny, mesh_nz, spline_order, cell_inv_t, )
[docs] def spline_gather( positions: torch.Tensor, mesh: torch.Tensor, cell: torch.Tensor, spline_order: int = 4, batch_idx: torch.Tensor | None = None, cell_inv_t: torch.Tensor | None = None, ) -> torch.Tensor: """Gather values from mesh to atoms using B-spline interpolation. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. mesh : torch.Tensor For single-system: shape (nx, ny, nz) For batch: shape (B, nx, ny, nz) cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. spline_order : int, default=4 B-spline order. batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. cell_inv_t : torch.Tensor | None, default=None Precomputed transpose of cell inverse. If provided, skips inverse computation. Shape (1, 3, 3) for single-system or (B, 3, 3) for batch. Returns ------- values : torch.Tensor, shape (N,) Interpolated values at atomic positions. """ if batch_idx is None: return _spline_gather(positions, mesh, cell, spline_order, cell_inv_t) else: # Ensure cell is 3D for batch operations if cell.dim() == 2: num_systems = int(batch_idx.max().item()) + 1 cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() return _batch_spline_gather( positions, mesh, batch_idx, cell, spline_order, cell_inv_t )
[docs] def spline_gather_vec3( positions: torch.Tensor, charges: torch.Tensor, mesh: torch.Tensor, cell: torch.Tensor, spline_order: int = 4, batch_idx: torch.Tensor | None = None, cell_inv_t: torch.Tensor | None = None, ) -> torch.Tensor: """Gather 3D vector values from mesh to atoms using B-spline interpolation. This is useful for interpolating vector fields like electric fields. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. mesh : torch.Tensor For single-system: shape (nx, ny, nz, 3) For batch: shape (B, nx, ny, nz, 3) cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. spline_order : int, default=4 B-spline order. batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. cell_inv_t : torch.Tensor | None, default=None Precomputed transpose of cell inverse. If provided, skips inverse computation. Shape (1, 3, 3) for single-system or (B, 3, 3) for batch. Returns ------- vectors : torch.Tensor, shape (N, 3) Interpolated 3D vectors at atomic positions. """ if batch_idx is None: return _spline_gather_vec3( positions, charges, mesh, cell, spline_order, cell_inv_t ) else: # Ensure cell is 3D for batch operations if cell.dim() == 2: num_systems = int(batch_idx.max().item()) + 1 cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() return _batch_spline_gather_vec3( positions, charges, mesh, batch_idx, cell, spline_order, cell_inv_t )
[docs] def spline_gather_gradient( positions: torch.Tensor, charges: torch.Tensor, mesh: torch.Tensor, cell: torch.Tensor, spline_order: int = 4, batch_idx: torch.Tensor | None = None, cell_inv_t: torch.Tensor | None = None, ) -> torch.Tensor: """Gather gradient from mesh to atoms using B-spline derivatives. Computes forces: .. math:: F_i = -q_i \\sum_g \\phi(g) \\nabla w(r_i, g) Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. charges : torch.Tensor, shape (N,) Atomic charges. mesh : torch.Tensor For single-system: shape (nx, ny, nz) For batch: shape (B, nx, ny, nz) cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. spline_order : int, default=4 B-spline order. batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. cell_inv_t : torch.Tensor | None, default=None Precomputed transpose of cell inverse. If provided, skips inverse computation. Shape (1, 3, 3) for single-system or (B, 3, 3) for batch. Returns ------- forces : torch.Tensor, shape (N, 3) Forces on atoms. """ if batch_idx is None: return _spline_gather_gradient( positions, charges, mesh, cell, spline_order, cell_inv_t ) else: # Ensure cell is 3D for batch operations if cell.dim() == 2: num_systems = int(batch_idx.max().item()) + 1 cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() return _batch_spline_gather_gradient( positions, charges, mesh, batch_idx, cell, spline_order, cell_inv_t )
def spline_spread_channels( positions: torch.Tensor, values: torch.Tensor, cell: torch.Tensor, mesh_dims: tuple[int, int, int], spline_order: int = 4, batch_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Spread multi-channel values from atoms to mesh grid using B-spline interpolation. This is useful for spreading multipole coefficients (e.g., 9 channels for L_max=2: 1 monopole + 3 dipoles + 5 quadrupoles). Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. values : torch.Tensor, shape (N, C) Multi-channel values to spread. C is the number of channels. cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. For batched, shape should be (B, 3, 3). mesh_dims : tuple[int, int, int] Mesh dimensions (nx, ny, nz). spline_order : int, default=4 B-spline order (1-4, where 4=cubic). batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. Returns ------- mesh : torch.Tensor For single-system: shape (C, nx, ny, nz) For batch: shape (B, C, nx, ny, nz) Example ------- >>> # Spread 9-channel multipole coefficients >>> multipoles = torch.randn(100, 9, dtype=torch.float64, device="cuda") >>> mesh = spline_spread_channels(positions, multipoles, cell, (16, 16, 16)) >>> print(mesh.shape) # (9, 16, 16, 16) """ mesh_nx, mesh_ny, mesh_nz = mesh_dims num_channels = values.shape[1] if batch_idx is None: return _spline_spread_channels( positions, values, cell, num_channels, mesh_nx, mesh_ny, mesh_nz, spline_order, ) else: if cell.dim() == 2: num_systems = int(batch_idx.max().item()) + 1 cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() else: num_systems = cell.shape[0] return _batch_spline_spread_channels( positions, values, batch_idx, cell, num_systems, num_channels, mesh_nx, mesh_ny, mesh_nz, spline_order, ) def spline_gather_channels( positions: torch.Tensor, mesh: torch.Tensor, cell: torch.Tensor, spline_order: int = 4, batch_idx: torch.Tensor | None = None, ) -> torch.Tensor: """Gather multi-channel values from mesh to atoms using B-spline interpolation. This is the inverse of spline_spread_channels. Parameters ---------- positions : torch.Tensor, shape (N, 3) Atomic positions. mesh : torch.Tensor For single-system: shape (C, nx, ny, nz) For batch: shape (B, C, nx, ny, nz) cell : torch.Tensor, shape (3, 3), (1, 3, 3), or (B, 3, 3) Unit cell matrix. spline_order : int, default=4 B-spline order. batch_idx : torch.Tensor | None, shape (N,), dtype=int32, default=None System index for each atom. If None, uses single-system kernel. Returns ------- values : torch.Tensor, shape (N, C) Interpolated multi-channel values at atomic positions. Example ------- >>> # Gather 9-channel potential from mesh >>> potential_mesh = torch.randn(9, 16, 16, 16, dtype=torch.float64, device="cuda") >>> potentials = spline_gather_channels(positions, potential_mesh, cell) >>> print(potentials.shape) # (100, 9) """ if batch_idx is None: return _spline_gather_channels(positions, mesh, cell, spline_order) else: # Ensure cell is 3D for batch operations if cell.dim() == 2: num_systems = int(batch_idx.max().item()) + 1 cell = cell.unsqueeze(0).expand(num_systems, -1, -1).contiguous() return _batch_spline_gather_channels( positions, mesh, batch_idx, cell, spline_order ) ########################################################################################### ########################### Deconvolution Functions ####################################### ########################################################################################### def _bspline_modulus(k: torch.Tensor, n: int, order: int) -> torch.Tensor: """Compute the modulus of B-spline Fourier transform. The B-spline function :math:`M_n(u)` has Fourier transform: .. math:: \\hat{M}_n(k) = \\left[\\frac{\\sin(\\pi k/n)}{\\pi k/n}\\right]^n For PME, we need the modulus of this for the cardinal B-spline interpolation. Parameters ---------- k : torch.Tensor Frequency indices (integers). n : int Grid dimension. order : int B-spline order. Returns ------- torch.Tensor :math:`|b(k)|^2` where :math:`b(k)` is the B-spline Fourier coefficient. """ # Compute the exponential B-spline factors # Following Essmann et al. (1995) Eq. 4.7 pi = torch.tensor(math.pi, dtype=torch.float64, device=k.device) # For order n B-splines, the Fourier transform involves # the exponential factors exp(2*pi*i m k / n) for m = 0, ..., order-1 # summed and then raised to order power # Handle k=0 case specially (limit is 1) result = torch.ones_like(k, dtype=torch.float64) # For non-zero k, compute the product nonzero_mask = k != 0 # w = 2*pi * k / n w = 2.0 * pi * k.float() / n # The B-spline Fourier coefficient is: # b(k) = sum_{j=0}^{order-1} M_order(j+1) * exp(2*pi*i j k / n) # where M_order is the B-spline basis function # Compute M_order values at integer points 1, 2, ..., order m_values = _compute_bspline_coefficients(order, k.device) # Sum: b(k) = Σ_j M_order(j+1) * exp(i w j) b_real = torch.zeros_like(k, dtype=torch.float64) b_imag = torch.zeros_like(k, dtype=torch.float64) for j in range(order): phase = w * j b_real = b_real + m_values[j] * torch.cos(phase) b_imag = b_imag + m_values[j] * torch.sin(phase) # |b(k)|^2 b_sq = b_real**2 + b_imag**2 # Handle k=0 case result = torch.where(nonzero_mask, b_sq, result) return result def _compute_bspline_coefficients(order: int, device) -> torch.Tensor: """Compute B-spline basis function values at integer points. For a B-spline of order n, we need M_n(1), M_n(2), ..., M_n(n). These are used in the Fourier transform computation. Parameters ---------- order : int B-spline order. device PyTorch device. Returns ------- torch.Tensor B-spline values [M_n(1), M_n(2), ..., M_n(n)]. """ if order == 1: return torch.tensor([1.0], dtype=torch.float64, device=device) elif order == 2: return torch.tensor([0.5, 0.5], dtype=torch.float64, device=device) elif order == 3: return torch.tensor([1 / 6, 4 / 6, 1 / 6], dtype=torch.float64, device=device) elif order == 4: return torch.tensor( [1 / 24, 11 / 24, 11 / 24, 1 / 24], dtype=torch.float64, device=device ) elif order == 5: return torch.tensor( [1 / 120, 26 / 120, 66 / 120, 26 / 120, 1 / 120], dtype=torch.float64, device=device, ) elif order == 6: return torch.tensor( [1 / 720, 57 / 720, 302 / 720, 302 / 720, 57 / 720, 1 / 720], dtype=torch.float64, device=device, ) else: # Use recursive definition for higher orders # M_n(u) = u/(n-1) * M_{n-1}(u) + (n-u)/(n-1) * M_{n-1}(u-1) coeffs = _compute_bspline_coefficients(order - 1, device) new_coeffs = torch.zeros(order, dtype=torch.float64, device=device) for j in range(order): u = float(j + 1) if j < order - 1: new_coeffs[j] += u / (order - 1) * coeffs[j] if j > 0: new_coeffs[j] += (order - u) / (order - 1) * coeffs[j - 1] return new_coeffs def compute_bspline_deconvolution( mesh_dims: tuple[int, int, int], spline_order: int = 4, device=None, ) -> torch.Tensor: """Compute B-spline deconvolution factors for Fourier space correction. In FFT-based methods (like PME), the B-spline interpolation introduces smoothing in the charge distribution. This function computes the deconvolution factors to correct for this smoothing in Fourier space. The correction is: mesh_corrected_k = mesh_k * deconv Parameters ---------- mesh_dims : tuple[int, int, int] Mesh dimensions (nx, ny, nz). spline_order : int, default=4 B-spline order. device : torch.device, optional Device for the output tensor. Default: CPU. Returns ------- deconv : torch.Tensor, shape (nx, ny, nz) Deconvolution factors. Multiply with FFT of mesh to correct. Example ------- >>> deconv = compute_bspline_deconvolution((16, 16, 16), spline_order=4) >>> mesh_fft = torch.fft.fftn(charge_mesh) >>> mesh_corrected_fft = mesh_fft * deconv >>> charge_mesh_corrected = torch.fft.ifftn(mesh_corrected_fft).real Notes ----- The deconvolution factor for a given k-vector is: .. math:: D(k_x, k_y, k_z) = \\frac{1}{|b(k_x)|^2 \\cdot |b(k_y)|^2 \\cdot |b(k_z)|^2} where :math:`b(k)` is the Fourier transform of the 1D B-spline. For efficiency, this uses the separable property of the 3D B-spline. """ if device is None: device = torch.device("cpu") nx, ny, nz = mesh_dims # Create frequency indices for each dimension # For FFT, frequencies are arranged as [0, 1, ..., n//2, -(n//2-1), ..., -1] kx = torch.fft.fftfreq(nx, device=device) * nx # Integer frequencies ky = torch.fft.fftfreq(ny, device=device) * ny kz = torch.fft.fftfreq(nz, device=device) * nz # Compute |b(k)|^2 for each dimension bx_sq = _bspline_modulus(kx, nx, spline_order) by_sq = _bspline_modulus(ky, ny, spline_order) bz_sq = _bspline_modulus(kz, nz, spline_order) # The 3D deconvolution is the product of 1D factors # deconv = 1 / (bx^2 * by^2 * bz^2) # Use outer product for efficiency bx_sq = bx_sq.view(nx, 1, 1) by_sq = by_sq.view(1, ny, 1) bz_sq = bz_sq.view(1, 1, nz) b_sq_3d = bx_sq * by_sq * bz_sq # Avoid division by zero (should not happen for reasonable orders) b_sq_3d = torch.clamp(b_sq_3d, min=1e-15) deconv = 1.0 / b_sq_3d return deconv def compute_bspline_deconvolution_1d( n: int, spline_order: int = 4, device=None, ) -> torch.Tensor: """Compute 1D B-spline deconvolution factors. Useful for separable operations or debugging. Parameters ---------- n : int Grid dimension. spline_order : int, default=4 B-spline order. device : torch.device, optional Device for the output tensor. Returns ------- deconv_1d : torch.Tensor, shape (n,) 1D deconvolution factors. """ if device is None: device = torch.device("cpu") k = torch.fft.fftfreq(n, device=device) * n b_sq = _bspline_modulus(k, n, spline_order) b_sq = torch.clamp(b_sq, min=1e-15) return 1.0 / b_sq ########################################################################################### ########################### Convenience Exports ########################################### ########################################################################################### __all__ = [ # Unified PyTorch API (scalar) "spline_spread", "spline_gather", "spline_gather_vec3", "spline_gather_gradient", # Unified PyTorch API (multi-channel) "spline_spread_channels", "spline_gather_channels", # Deconvolution "compute_bspline_deconvolution", "compute_bspline_deconvolution_1d", # Warp functions (for custom kernels) "bspline_weight", "bspline_derivative", "bspline_weight_3d", "bspline_weight_gradient_3d", "compute_fractional_coords", "bspline_grid_offset", "wrap_grid_index", # Warp kernels (single-system, scalar) "_bspline_spread_kernel", "_bspline_gather_kernel", "_bspline_gather_vec3_kernel", "_bspline_gather_gradient_kernel", # Warp kernels (batch, scalar) "_batch_bspline_spread_kernel", "_batch_bspline_gather_kernel", "_batch_bspline_gather_vec3_kernel", "_batch_bspline_gather_gradient_kernel", # Warp kernels (single-system, multi-channel) "_bspline_spread_channels_kernel", "_bspline_gather_channels_kernel", # Warp kernels (batch, multi-channel) "_batch_bspline_spread_channels_kernel", "_batch_bspline_gather_channels_kernel", ]