Source code for nvalchemiops.jax.interactions.electrostatics.coulomb

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX Coulomb electrostatics implementation.

Wraps the framework-agnostic Warp kernels from
``nvalchemiops.interactions.electrostatics.coulomb`` with JAX bindings.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
from warp.jax_experimental import jax_kernel

from nvalchemiops.interactions.electrostatics.coulomb import (
    _batch_coulomb_energy_forces_kernel,
    _batch_coulomb_energy_forces_matrix_kernel,
    _batch_coulomb_energy_kernel,
    _batch_coulomb_energy_matrix_kernel,
    _coulomb_energy_forces_kernel,
    _coulomb_energy_forces_matrix_kernel,
    _coulomb_energy_kernel,
    _coulomb_energy_matrix_kernel,
)

__all__ = [
    "coulomb_energy",
    "coulomb_forces",
    "coulomb_energy_forces",
]

# ==============================================================================
# JAX Kernel Wrappers
# ==============================================================================

# --- Neighbor List (CSR) Format ---

_jax_coulomb_energy_list = jax_kernel(
    _coulomb_energy_kernel,
    num_outputs=1,
    in_out_argnames=["energies"],
    enable_backward=False,
)

_jax_coulomb_energy_forces_list = jax_kernel(
    _coulomb_energy_forces_kernel,
    num_outputs=2,
    in_out_argnames=["energies", "forces"],
    enable_backward=False,
)

_jax_batch_coulomb_energy_list = jax_kernel(
    _batch_coulomb_energy_kernel,
    num_outputs=1,
    in_out_argnames=["energies"],
    enable_backward=False,
)

_jax_batch_coulomb_energy_forces_list = jax_kernel(
    _batch_coulomb_energy_forces_kernel,
    num_outputs=2,
    in_out_argnames=["energies", "forces"],
    enable_backward=False,
)

# --- Neighbor Matrix Format ---

_jax_coulomb_energy_matrix = jax_kernel(
    _coulomb_energy_matrix_kernel,
    num_outputs=1,
    in_out_argnames=["atomic_energies"],
    enable_backward=False,
)

_jax_coulomb_energy_forces_matrix = jax_kernel(
    _coulomb_energy_forces_matrix_kernel,
    num_outputs=2,
    in_out_argnames=["atomic_energies", "atomic_forces"],
    enable_backward=False,
)

_jax_batch_coulomb_energy_matrix = jax_kernel(
    _batch_coulomb_energy_matrix_kernel,
    num_outputs=1,
    in_out_argnames=["atomic_energies"],
    enable_backward=False,
)

_jax_batch_coulomb_energy_forces_matrix = jax_kernel(
    _batch_coulomb_energy_forces_matrix_kernel,
    num_outputs=2,
    in_out_argnames=["atomic_energies", "atomic_forces"],
    enable_backward=False,
)


# ==============================================================================
# Public API
# ==============================================================================


[docs] def coulomb_energy( positions: jax.Array, charges: jax.Array, cell: jax.Array, cutoff: float, alpha: float = 0.0, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, fill_value: int | None = None, batch_idx: jax.Array | None = None, ) -> jax.Array: """Compute Coulomb electrostatic energies. Computes pairwise electrostatic energies using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations. Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic charges. cell : jax.Array, shape (1, 3, 3) or (B, 3, 3) Unit cell matrix. Shape (B, 3, 3) for batched calculations. cutoff : float Cutoff distance for interactions. alpha : float, default=0.0 Ewald splitting parameter. Use 0.0 for undamped Coulomb. neighbor_list : jax.Array | None, shape (2, num_pairs) Neighbor pairs in COO format. Row 0 = source, Row 1 = target. neighbor_ptr : jax.Array | None, shape (N+1,) CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module. neighbor_shifts : jax.Array | None, shape (num_pairs, 3) Integer unit cell shifts for neighbor list format. neighbor_matrix : jax.Array | None, shape (N, max_neighbors) Neighbor indices in matrix format. neighbor_matrix_shifts : jax.Array | None, shape (N, max_neighbors, 3) Integer unit cell shifts for matrix format. fill_value : int | None Fill value for neighbor matrix padding. batch_idx : jax.Array | None, shape (N,) Batch indices for each atom. Returns ------- energies : jax.Array, shape (N,) Per-atom energies. Sum to get total energy. Examples -------- >>> # Direct Coulomb (undamped) >>> energies = coulomb_energy( ... positions, charges, cell, cutoff=10.0, alpha=0.0, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... ) >>> total_energy = energies.sum() >>> # Ewald/PME real-space (damped) >>> energies = coulomb_energy( ... positions, charges, cell, cutoff=10.0, alpha=0.3, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... ) """ # Validate inputs use_list = neighbor_list is not None and neighbor_shifts is not None use_matrix = neighbor_matrix is not None and neighbor_matrix_shifts is not None if not use_list and not use_matrix: raise ValueError( "Must provide either neighbor_list/neighbor_shifts or " "neighbor_matrix/neighbor_matrix_shifts" ) if use_list and use_matrix: raise ValueError( "Cannot provide both neighbor list and neighbor matrix formats" ) # Store original dtype for output original_dtype = positions.dtype # Convert to float64 for numerical stability positions_f64 = positions.astype(jnp.float64) charges_f64 = charges.astype(jnp.float64) cell_f64 = cell.astype(jnp.float64) # Ensure cell is (B, 3, 3) if cell_f64.ndim == 2: cell_f64 = cell_f64[jnp.newaxis, :, :] num_atoms = positions_f64.shape[0] is_batched = batch_idx is not None # Allocate output energies = jnp.zeros(num_atoms, dtype=jnp.float64) if use_list: if neighbor_ptr is None: raise ValueError("neighbor_ptr is required when using neighbor_list format") # Extract idx_j (target indices) from neighbor_list idx_j = neighbor_list[1].astype(jnp.int32) neighbor_ptr_i32 = neighbor_ptr.astype(jnp.int32) neighbor_shifts_i32 = neighbor_shifts.astype(jnp.int32) if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) (energies,) = _jax_batch_coulomb_energy_list( positions_f64, charges_f64, cell_f64, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, batch_idx_i32, float(cutoff), float(alpha), energies, launch_dims=(num_atoms,), ) else: (energies,) = _jax_coulomb_energy_list( positions_f64, charges_f64, cell_f64, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, float(cutoff), float(alpha), energies, launch_dims=(num_atoms,), ) else: neighbor_matrix_i32 = neighbor_matrix.astype(jnp.int32) neighbor_matrix_shifts_i32 = neighbor_matrix_shifts.astype(jnp.int32) if fill_value is None: fill_value = num_atoms if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) (energies,) = _jax_batch_coulomb_energy_matrix( positions_f64, charges_f64, cell_f64, neighbor_matrix_i32, neighbor_matrix_shifts_i32, batch_idx_i32, float(cutoff), float(alpha), int(fill_value), energies, launch_dims=(num_atoms,), ) else: (energies,) = _jax_coulomb_energy_matrix( positions_f64, charges_f64, cell_f64, neighbor_matrix_i32, neighbor_matrix_shifts_i32, float(cutoff), float(alpha), int(fill_value), energies, launch_dims=(num_atoms,), ) return energies.astype(original_dtype)
[docs] def coulomb_forces( positions: jax.Array, charges: jax.Array, cell: jax.Array, cutoff: float, alpha: float = 0.0, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, fill_value: int | None = None, batch_idx: jax.Array | None = None, ) -> jax.Array: """Compute Coulomb electrostatic forces. Convenience wrapper that returns only forces (no energies). Parameters ---------- See coulomb_energy for parameter descriptions. Returns ------- forces : jax.Array, shape (N, 3) Forces on each atom. See Also -------- coulomb_energy_forces : Compute both energies and forces """ _, forces = coulomb_energy_forces( positions=positions, charges=charges, cell=cell, cutoff=cutoff, alpha=alpha, neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, neighbor_shifts=neighbor_shifts, neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, fill_value=fill_value, batch_idx=batch_idx, ) return forces
[docs] def coulomb_energy_forces( positions: jax.Array, charges: jax.Array, cell: jax.Array, cutoff: float, alpha: float = 0.0, neighbor_list: jax.Array | None = None, neighbor_ptr: jax.Array | None = None, neighbor_shifts: jax.Array | None = None, neighbor_matrix: jax.Array | None = None, neighbor_matrix_shifts: jax.Array | None = None, fill_value: int | None = None, batch_idx: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Compute Coulomb electrostatic energies and forces. Computes pairwise electrostatic energies and forces using the Coulomb law, with optional erfc damping for Ewald/PME real-space calculations. Parameters ---------- positions : jax.Array, shape (N, 3) Atomic coordinates. charges : jax.Array, shape (N,) Atomic charges. cell : jax.Array, shape (1, 3, 3) or (B, 3, 3) Unit cell matrix. Shape (B, 3, 3) for batched calculations. cutoff : float Cutoff distance for interactions. alpha : float, default=0.0 Ewald splitting parameter. Use 0.0 for undamped Coulomb. neighbor_list : jax.Array | None, shape (2, num_pairs) Neighbor pairs in COO format. neighbor_ptr : jax.Array | None, shape (N+1,) CSR row pointers for neighbor list. Required with neighbor_list. Provided by neighborlist module. neighbor_shifts : jax.Array | None, shape (num_pairs, 3) Integer unit cell shifts for neighbor list format. neighbor_matrix : jax.Array | None, shape (N, max_neighbors) Neighbor indices in matrix format. neighbor_matrix_shifts : jax.Array | None, shape (N, max_neighbors, 3) Integer unit cell shifts for matrix format. fill_value : int | None Fill value for neighbor matrix padding. batch_idx : jax.Array | None, shape (N,) Batch indices for each atom. Returns ------- energies : jax.Array, shape (N,) Per-atom energies. forces : jax.Array, shape (N, 3) Forces on each atom. Examples -------- >>> # Direct Coulomb >>> energies, forces = coulomb_energy_forces( ... positions, charges, cell, cutoff=10.0, alpha=0.0, ... neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr, ... neighbor_shifts=neighbor_shifts ... ) >>> # Ewald/PME real-space >>> energies, forces = coulomb_energy_forces( ... positions, charges, cell, cutoff=10.0, alpha=0.3, ... neighbor_matrix=neighbor_matrix, neighbor_matrix_shifts=neighbor_matrix_shifts, ... fill_value=num_atoms ... ) """ # Validate inputs use_list = neighbor_list is not None and neighbor_shifts is not None use_matrix = neighbor_matrix is not None and neighbor_matrix_shifts is not None if not use_list and not use_matrix: raise ValueError( "Must provide either neighbor_list/neighbor_shifts or " "neighbor_matrix/neighbor_matrix_shifts" ) if use_list and use_matrix: raise ValueError( "Cannot provide both neighbor list and neighbor matrix formats" ) # Store original dtype for output original_dtype = positions.dtype # Convert to float64 for numerical stability positions_f64 = positions.astype(jnp.float64) charges_f64 = charges.astype(jnp.float64) cell_f64 = cell.astype(jnp.float64) # Ensure cell is (B, 3, 3) if cell_f64.ndim == 2: cell_f64 = cell_f64[jnp.newaxis, :, :] num_atoms = positions_f64.shape[0] is_batched = batch_idx is not None # Allocate outputs energies = jnp.zeros(num_atoms, dtype=jnp.float64) forces = jnp.zeros((num_atoms, 3), dtype=jnp.float64) if use_list: if neighbor_ptr is None: raise ValueError("neighbor_ptr is required when using neighbor_list format") # Extract idx_j (target indices) from neighbor_list idx_j = neighbor_list[1].astype(jnp.int32) neighbor_ptr_i32 = neighbor_ptr.astype(jnp.int32) neighbor_shifts_i32 = neighbor_shifts.astype(jnp.int32) if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) energies, forces = _jax_batch_coulomb_energy_forces_list( positions_f64, charges_f64, cell_f64, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, batch_idx_i32, float(cutoff), float(alpha), energies, forces, launch_dims=(num_atoms,), ) else: energies, forces = _jax_coulomb_energy_forces_list( positions_f64, charges_f64, cell_f64, idx_j, neighbor_ptr_i32, neighbor_shifts_i32, float(cutoff), float(alpha), energies, forces, launch_dims=(num_atoms,), ) else: neighbor_matrix_i32 = neighbor_matrix.astype(jnp.int32) neighbor_matrix_shifts_i32 = neighbor_matrix_shifts.astype(jnp.int32) if fill_value is None: fill_value = num_atoms if is_batched: batch_idx_i32 = batch_idx.astype(jnp.int32) energies, forces = _jax_batch_coulomb_energy_forces_matrix( positions_f64, charges_f64, cell_f64, neighbor_matrix_i32, neighbor_matrix_shifts_i32, batch_idx_i32, float(cutoff), float(alpha), int(fill_value), energies, forces, launch_dims=(num_atoms,), ) else: energies, forces = _jax_coulomb_energy_forces_matrix( positions_f64, charges_f64, cell_f64, neighbor_matrix_i32, neighbor_matrix_shifts_i32, float(cutoff), float(alpha), int(fill_value), energies, forces, launch_dims=(num_atoms,), ) return energies.astype(original_dtype), forces.astype(original_dtype)