Source code for nvalchemiops.torch.interactions.electrostatics.dsf

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Damped Shifted Force (DSF) Electrostatics - PyTorch Bindings
=============================================================

This module provides PyTorch bindings for DSF electrostatic calculations.
It wraps the framework-agnostic Warp launchers from
``nvalchemiops.interactions.electrostatics.dsf``.

Public API
----------
- ``dsf_coulomb()``: Compute DSF electrostatic energy, forces, and virial

Features:
- Both undamped (alpha=0, shifted-force Coulomb) and damped DSF
- Both neighbor list (CSR) and neighbor matrix formats
- Batched calculations
- Charge gradient support for MLIP training (via straight-through trick)
- Optional forces and virial computation
- float32 and float64 precision support

Integration Pattern
-------------------
This module uses ``torch.library.custom_op`` with ``mutates_args`` and
``register_fake`` instead of the ``@warp_custom_op`` / ``WarpAutogradContextManager``
pattern used by the Coulomb bindings. This is intentional:

- DSF does not require double backward (Hessian / gradients of forces w.r.t.
  positions). Forces and charge gradients are computed analytically in the
  forward Warp kernel, so there is no need for a Warp backward tape.
- Charge gradients are propagated through PyTorch autograd via a
  "straight-through trick": a zero-valued correction term whose gradient
  equals the kernel-computed dE/dq is added to the energy tensor.
- ``register_fake`` enables ``torch.compile`` compatibility.

Examples
--------
>>> # Basic DSF energy and forces
>>> energy, forces = dsf_coulomb(
...     positions, charges, cutoff=10.0, alpha=0.2,
...     neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr,
... )

>>> # With charge gradients for MLIP training
>>> charges.requires_grad_(True)
>>> energy, forces = dsf_coulomb(
...     positions, charges, cutoff=10.0, alpha=0.2,
...     neighbor_list=neighbor_list, neighbor_ptr=neighbor_ptr,
... )
>>> energy.sum().backward()
>>> charge_grads = charges.grad  # dE/dq_i
"""

from __future__ import annotations

import torch
import warp as wp

from nvalchemiops.interactions.electrostatics.dsf import (
    dsf_csr as wp_dsf_csr,
)
from nvalchemiops.interactions.electrostatics.dsf import (
    dsf_matrix as wp_dsf_matrix,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype

__all__ = [
    "dsf_coulomb",
]


# ==============================================================================
# Internal Custom Ops
# ==============================================================================


@torch.library.custom_op(
    "nvalchemiops::dsf_csr_op",
    mutates_args=("energy", "forces", "virial", "charge_grad"),
)
def _dsf_csr_op(
    positions: torch.Tensor,
    charges: torch.Tensor,
    idx_j: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    cutoff: float,
    alpha: float,
    energy: torch.Tensor,
    forces: torch.Tensor,
    virial: torch.Tensor,
    charge_grad: torch.Tensor,
    compute_forces: bool = True,
    compute_virial: bool = False,
    compute_charge_grad: bool = False,
    cell: torch.Tensor | None = None,
    unit_shifts: torch.Tensor | None = None,
    batch_idx: torch.Tensor | None = None,
    device: str | None = None,
) -> None:
    """Internal custom op: DSF with CSR neighbor list (optional PBC)."""
    num_atoms = positions.size(0)
    if num_atoms == 0:
        return

    if device is None:
        device = str(positions.device)

    energy.zero_()
    if compute_forces:
        forces.zero_()
    if compute_virial:
        virial.zero_()
    if compute_charge_grad:
        charge_grad.zero_()

    input_dtype = positions.dtype
    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    positions_wp = wp.from_torch(positions.detach(), dtype=wp_vec, return_ctype=True)
    charges_wp = wp.from_torch(charges.detach(), dtype=wp_scalar, return_ctype=True)
    idx_j_wp = wp.from_torch(idx_j, dtype=wp.int32, return_ctype=True)
    neighbor_ptr_wp = wp.from_torch(neighbor_ptr, dtype=wp.int32, return_ctype=True)

    energy_wp = wp.from_torch(energy, dtype=wp.float64, return_ctype=True)
    forces_wp = wp.from_torch(forces, dtype=wp_vec, return_ctype=True)
    virial_wp = wp.from_torch(virial, dtype=wp_mat, return_ctype=True)
    charge_grad_wp = wp.from_torch(charge_grad, dtype=wp_scalar, return_ctype=True)

    cell_wp = None
    unit_shifts_wp = None
    if cell is not None:
        cell_wp = wp.from_torch(cell.detach(), dtype=wp_mat, return_ctype=True)
    if unit_shifts is not None:
        unit_shifts_wp = wp.from_torch(unit_shifts, dtype=wp.vec3i, return_ctype=True)

    batch_idx_wp = None
    if batch_idx is not None:
        batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)

    wp_dsf_csr(
        positions=positions_wp,
        charges=charges_wp,
        idx_j=idx_j_wp,
        neighbor_ptr=neighbor_ptr_wp,
        cutoff=cutoff,
        alpha=alpha,
        energy=energy_wp,
        forces=forces_wp,
        virial=virial_wp,
        charge_grad=charge_grad_wp,
        cell=cell_wp,
        unit_shifts=unit_shifts_wp,
        device=device,
        batch_idx=batch_idx_wp,
        compute_forces=compute_forces,
        compute_virial=compute_virial,
        compute_charge_grad=compute_charge_grad,
        wp_scalar_type=wp_scalar,
    )


@_dsf_csr_op.register_fake
def _dsf_csr_op_fake(
    positions: torch.Tensor,
    charges: torch.Tensor,
    idx_j: torch.Tensor,
    neighbor_ptr: torch.Tensor,
    cutoff: float,
    alpha: float,
    energy: torch.Tensor,
    forces: torch.Tensor,
    virial: torch.Tensor,
    charge_grad: torch.Tensor,
    compute_forces: bool = True,
    compute_virial: bool = False,
    compute_charge_grad: bool = False,
    cell: torch.Tensor | None = None,
    unit_shifts: torch.Tensor | None = None,
    batch_idx: torch.Tensor | None = None,
    device: str | None = None,
) -> None:
    pass


@torch.library.custom_op(
    "nvalchemiops::dsf_matrix_op",
    mutates_args=("energy", "forces", "virial", "charge_grad"),
)
def _dsf_matrix_op(
    positions: torch.Tensor,
    charges: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    cutoff: float,
    alpha: float,
    fill_value: int,
    energy: torch.Tensor,
    forces: torch.Tensor,
    virial: torch.Tensor,
    charge_grad: torch.Tensor,
    compute_forces: bool = True,
    compute_virial: bool = False,
    compute_charge_grad: bool = False,
    cell: torch.Tensor | None = None,
    neighbor_matrix_shifts: torch.Tensor | None = None,
    batch_idx: torch.Tensor | None = None,
    device: str | None = None,
) -> None:
    """Internal custom op: DSF with neighbor matrix (optional PBC)."""
    num_atoms = positions.size(0)
    if num_atoms == 0:
        return

    if device is None:
        device = str(positions.device)

    energy.zero_()
    if compute_forces:
        forces.zero_()
    if compute_virial:
        virial.zero_()
    if compute_charge_grad:
        charge_grad.zero_()

    input_dtype = positions.dtype
    wp_scalar = get_wp_dtype(input_dtype)
    wp_vec = get_wp_vec_dtype(input_dtype)
    wp_mat = get_wp_mat_dtype(input_dtype)

    positions_wp = wp.from_torch(positions.detach(), dtype=wp_vec, return_ctype=True)
    charges_wp = wp.from_torch(charges.detach(), dtype=wp_scalar, return_ctype=True)
    neighbor_matrix_wp = wp.from_torch(
        neighbor_matrix, dtype=wp.int32, return_ctype=True
    )

    energy_wp = wp.from_torch(energy, dtype=wp.float64, return_ctype=True)
    forces_wp = wp.from_torch(forces, dtype=wp_vec, return_ctype=True)
    virial_wp = wp.from_torch(virial, dtype=wp_mat, return_ctype=True)
    charge_grad_wp = wp.from_torch(charge_grad, dtype=wp_scalar, return_ctype=True)

    cell_wp = None
    neighbor_matrix_shifts_wp = None
    if cell is not None:
        cell_wp = wp.from_torch(cell.detach(), dtype=wp_mat, return_ctype=True)
    if neighbor_matrix_shifts is not None:
        neighbor_matrix_shifts_wp = wp.from_torch(
            neighbor_matrix_shifts, dtype=wp.vec3i, return_ctype=True
        )

    batch_idx_wp = None
    if batch_idx is not None:
        batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)

    wp_dsf_matrix(
        positions=positions_wp,
        charges=charges_wp,
        neighbor_matrix=neighbor_matrix_wp,
        cutoff=cutoff,
        alpha=alpha,
        fill_value=fill_value,
        energy=energy_wp,
        forces=forces_wp,
        virial=virial_wp,
        charge_grad=charge_grad_wp,
        cell=cell_wp,
        neighbor_matrix_shifts=neighbor_matrix_shifts_wp,
        device=device,
        batch_idx=batch_idx_wp,
        compute_forces=compute_forces,
        compute_virial=compute_virial,
        compute_charge_grad=compute_charge_grad,
        wp_scalar_type=wp_scalar,
    )


@_dsf_matrix_op.register_fake
def _dsf_matrix_op_fake(
    positions: torch.Tensor,
    charges: torch.Tensor,
    neighbor_matrix: torch.Tensor,
    cutoff: float,
    alpha: float,
    fill_value: int,
    energy: torch.Tensor,
    forces: torch.Tensor,
    virial: torch.Tensor,
    charge_grad: torch.Tensor,
    compute_forces: bool = True,
    compute_virial: bool = False,
    compute_charge_grad: bool = False,
    cell: torch.Tensor | None = None,
    neighbor_matrix_shifts: torch.Tensor | None = None,
    batch_idx: torch.Tensor | None = None,
    device: str | None = None,
) -> None:
    pass


# ==============================================================================
# Public API
# ==============================================================================


[docs] def dsf_coulomb( positions: torch.Tensor, charges: torch.Tensor, cutoff: float, alpha: float = 0.2, cell: torch.Tensor | None = None, batch_idx: torch.Tensor | None = None, # Neighbor list (CSR) format neighbor_list: torch.Tensor | None = None, neighbor_ptr: torch.Tensor | None = None, unit_shifts: torch.Tensor | None = None, # Neighbor matrix format neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, fill_value: int | None = None, # Control flags compute_forces: bool = True, compute_virial: bool = False, num_systems: int | None = None, device: str | None = None, ) -> ( tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] ): """Compute DSF electrostatic energy, forces, and virial. The Damped Shifted Force (DSF) method is a pairwise O(N) electrostatic summation technique that ensures both potential energy and forces smoothly vanish at a defined cutoff radius. Supports float32 and float64 input precision. Energy is always returned in float64. Forces, virial, and charge gradients match the input precision. Parameters ---------- positions : torch.Tensor, shape (num_atoms, 3) Atomic coordinates (float32 or float64). charges : torch.Tensor, shape (num_atoms,) Atomic charges (must match positions dtype). If requires_grad=True, charge gradients (dE/dq) will be propagated through autograd. cutoff : float Cutoff radius beyond which interactions are zero. alpha : float, default 0.2 Damping parameter. Set to 0.0 for shifted-force bare Coulomb. cell : torch.Tensor, shape (num_systems, 3, 3), optional Unit cell matrices for periodic boundary conditions. batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional System index for each atom. If None, all atoms in one system. neighbor_list : torch.Tensor, shape (2, num_pairs), dtype=int32, optional Neighbor list in COO format. Row 1 contains destination atoms. neighbor_ptr : torch.Tensor, shape (num_atoms+1,), dtype=int32, optional CSR row pointers (required with neighbor_list). unit_shifts : torch.Tensor, shape (num_pairs, 3), dtype=int32, optional Integer unit cell shifts for PBC (required with neighbor_list + cell). neighbor_matrix : torch.Tensor, shape (num_atoms, max_neighbors), dtype=int32, optional Dense neighbor matrix format. neighbor_matrix_shifts : torch.Tensor, shape (num_atoms, max_neighbors, 3), dtype=int32, optional Integer unit cell shifts for matrix format PBC. fill_value : int, optional Padding indicator for neighbor_matrix. Defaults to num_atoms. compute_forces : bool, default True Whether to compute forces. compute_virial : bool, default False Whether to compute virial tensor (requires PBC and compute_forces). num_systems : int, optional Number of systems. Inferred from batch_idx or cell if not given. device : str, optional Warp device string. Inferred from positions if not given. Returns ------- energy : torch.Tensor, shape (num_systems,), dtype=float64 Per-system electrostatic energy (always float64). If charges.requires_grad, this tensor is connected to the autograd graph for charge gradients. forces : torch.Tensor, shape (num_atoms, 3), dtype matches input Per-atom forces. Only returned if compute_forces=True. virial : torch.Tensor, shape (num_systems, 3, 3), dtype matches input Per-system virial tensor. Only returned if compute_virial=True. Notes ----- - Assumes a full neighbor list (each pair appears in both directions). - For MLIP training with geometry-dependent charges, set ``charges.requires_grad_(True)`` before calling. After ``energy.sum().backward()``, ``charges.grad`` will contain dE/dq. - Charge gradients (dE/dq) are computed when ``charges.requires_grad=True``, regardless of ``compute_forces``. - The returned ``energy`` tensor is **not** differentiable w.r.t. ``positions`` or ``cell`` through PyTorch autograd. Forces are computed analytically by the Warp kernel, not via autograd. Examples -------- >>> # Basic energy + forces >>> energy, forces = dsf_coulomb(positions, charges, cutoff=10.0, alpha=0.2, ... neighbor_list=nl, neighbor_ptr=ptr) >>> # MLIP workflow with charge gradients >>> charges = model(positions) # Predict charges from geometry >>> charges.requires_grad_(True) >>> energy, forces = dsf_coulomb(positions, charges, cutoff=10.0, alpha=0.2, ... neighbor_list=nl, neighbor_ptr=ptr) >>> loss = (energy - ref_energy).pow(2).sum() >>> loss.backward() # charges.grad now contains dE/dq * dloss/dE """ # Validate inputs if compute_virial and not compute_forces: raise ValueError("compute_virial=True requires compute_forces=True") if compute_virial and cell is None: raise ValueError( "compute_virial=True requires periodic boundary conditions (cell)" ) # Validate neighbor format: exactly one of list or matrix must be provided use_list = neighbor_list is not None use_matrix = neighbor_matrix is not None if not use_list and not use_matrix: raise ValueError( "Must provide either neighbor_list (with neighbor_ptr) or neighbor_matrix" ) if use_list and use_matrix: raise ValueError( "Cannot provide both neighbor list and neighbor matrix formats" ) if use_list and neighbor_ptr is None: raise ValueError("neighbor_ptr is required when using neighbor_list format") # Validate PBC shift tensors when cell is provided if cell is not None: if use_list and unit_shifts is None: raise ValueError( "unit_shifts is required when using neighbor_list format with " "periodic boundary conditions (cell)" ) if use_matrix and neighbor_matrix_shifts is None: raise ValueError( "neighbor_matrix_shifts is required when using neighbor_matrix format " "with periodic boundary conditions (cell)" ) if charges.dtype != positions.dtype: raise ValueError( f"charges dtype ({charges.dtype}) must match positions dtype ({positions.dtype})" ) input_dtype = positions.dtype # Charge gradients are computed whenever charges require grad, # regardless of whether forces are requested. compute_charge_grad = charges.requires_grad # Get shapes num_atoms = positions.size(0) if num_atoms == 0: if num_systems is None: num_systems = 1 dev = positions.device empty_energy = torch.zeros(num_systems, dtype=torch.float64, device=dev) if not compute_forces: return (empty_energy,) empty_forces = torch.zeros((0, 3), dtype=input_dtype, device=dev) if not compute_virial: return empty_energy, empty_forces empty_virial = torch.zeros((num_systems, 3, 3), dtype=input_dtype, device=dev) return empty_energy, empty_forces, empty_virial # Determine number of systems if num_systems is None: if batch_idx is None: num_systems = 1 elif cell is not None: num_systems = cell.size(0) else: num_systems = int(batch_idx.max().item()) + 1 # Ensure cell matches input dtype if cell is not None: cell = cell.to(dtype=input_dtype) # Allocate output tensors dev = positions.device energy = torch.zeros(num_systems, dtype=torch.float64, device=dev) if compute_forces: forces_out = torch.zeros((num_atoms, 3), dtype=input_dtype, device=dev) else: forces_out = torch.empty((0, 3), dtype=input_dtype, device=dev) if compute_charge_grad: charge_grad_out = torch.zeros(num_atoms, dtype=input_dtype, device=dev) else: charge_grad_out = torch.empty(0, dtype=input_dtype, device=dev) if compute_virial: virial_out = torch.zeros((num_systems, 3, 3), dtype=input_dtype, device=dev) else: virial_out = torch.empty((0, 3, 3), dtype=input_dtype, device=dev) # Dispatch to appropriate custom op (2-way: by neighbor format) if neighbor_matrix is not None: if fill_value is None: fill_value = num_atoms _dsf_matrix_op( positions=positions, charges=charges, neighbor_matrix=neighbor_matrix, cutoff=cutoff, alpha=alpha, fill_value=fill_value, energy=energy, forces=forces_out, virial=virial_out, charge_grad=charge_grad_out, compute_forces=compute_forces, compute_virial=compute_virial, compute_charge_grad=compute_charge_grad, cell=cell, neighbor_matrix_shifts=neighbor_matrix_shifts, batch_idx=batch_idx, device=device, ) else: idx_j = neighbor_list[1].contiguous() _dsf_csr_op( positions=positions, charges=charges, idx_j=idx_j, neighbor_ptr=neighbor_ptr, cutoff=cutoff, alpha=alpha, energy=energy, forces=forces_out, virial=virial_out, charge_grad=charge_grad_out, compute_forces=compute_forces, compute_virial=compute_virial, compute_charge_grad=compute_charge_grad, cell=cell, unit_shifts=unit_shifts, batch_idx=batch_idx, device=device, ) # Charge gradient support via straight-through trick # This makes energy differentiable w.r.t. charges without Warp tape. # The correction term has value 0 but gradient dE/dq w.r.t. charges. if compute_charge_grad: # Cast to float64 for numerical stability in the correction computation. # charge_grad_out is in input precision (possibly float32). cg_f64 = charge_grad_out.to(dtype=torch.float64) charges_f64 = charges.to(dtype=torch.float64) charges_detached_f64 = charges_f64.detach() # correction[i] = dE_dq[i] * (q[i] - q_detached[i]) = dE_dq[i] * 0 = 0 # but d(correction)/d(q[i]) = dE_dq[i] correction = cg_f64.detach() * (charges_f64 - charges_detached_f64) if batch_idx is not None: system_correction = torch.zeros_like(energy) system_correction.scatter_add_(0, batch_idx.long(), correction) else: # Single system: sum all corrections system_correction = correction.sum().unsqueeze(0) energy = energy + system_correction # Build return tuple if not compute_forces: return (energy,) if not compute_virial: return energy, forces_out return energy, forces_out, virial_out