Source code for nvalchemiops.interactions.dispersion.dftd3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
DFT-D3(BJ) Dispersion Correction - Warp Kernel Implementation

This module implements the DFT-D3 dispersion correction with Becke-Johnson (BJ)
damping as Warp GPU/CPU kernels. The implementation provides efficient computation
of dispersion energies and forces using a multi-pass algorithm with support for
periodic boundary conditions, batched systems, and smooth cutoff functions.

For detailed theory, usage examples, and parameter setup, see the
:doc:`DFT-D3 User Guide </userguide/components/dispersion>`.

Multi-Pass Kernel Architecture
-------------------------------
The implementation uses four kernel passes to efficiently handle the chain rule
dependency in force calculations:

1. **Pass 0 (_compute_cartesian_shifts)**: [PBC only] Convert unit cell shifts
   to Cartesian coordinates

2. **Pass 1 (_cn_kernel)**: Compute coordination numbers using geometric counting
   function

3. **Pass 2 (_direct_forces_and_dE_dCN_kernel)**: Compute C6 interpolation,
   dispersion energy, direct forces, and accumulate :math:`\\partial E/\\partial \\text{CN}`

4. **Pass 3 (_cn_forces_contrib_kernel)**: Add CN-dependent force contribution
   using precomputed :math:`\\partial E/\\partial \\text{CN}` values

PyTorch Interface
-----------------
Use :func:`dftd3` as the main entry point for PyTorch integration:

.. code-block:: python

    from nvalchemiops.interactions.dispersion.dftd3 import (
        dftd3,
        D3Parameters,
    )

    # Using neighbor matrix format
    energy, forces, coord_num = dftd3(
        positions, numbers,
        neighbor_matrix=neighbor_matrix,
        a1=0.3981, a2=4.4211, s8=1.9889,
        d3_params=d3_params,  # D3Parameters instance or dict
        cell=cell,  # Optional for PBC
        neighbor_matrix_shifts=neighbor_matrix_shifts,  # Optional for PBC
    )

    # Using neighbor list format (sparse COO)
    energy, forces, coord_num = dftd3(
        positions, numbers,
        neighbor_list=neighbor_list,  # shape (2, num_pairs)
        a1=0.3981, a2=4.4211, s8=1.9889,
        d3_params=d3_params,
        cell=cell,  # Optional for PBC
        unit_shifts=unit_shifts,  # Optional for PBC, shape (num_pairs, 3)
    )

Data Structure Requirements
---------------------------
**Neighbor Formats**

The implementation supports two neighbor representation formats:

1. **Neighbor Matrix Format** (dense): `[num_atoms, max_neighbors]` where
   `neighbor_matrix[i, k]` is the k-th neighbor of atom i. Padding entries use
   values >= `fill_value` (typically `num_atoms`).

2. **Neighbor List Format** (sparse COO): `[2, num_pairs]` where row 0 contains
   source atom indices and row 1 contains target atom indices. No padding needed.

Both formats can be generated by :func:`nvalchemiops.neighborlist.neighbor_list` using
the `return_neighbor_list` parameter.

**Parameter Arrays**

- `covalent_radii`: `[max_Z+1]` float32
- `r4r2`: `[max_Z+1]` float32
- `c6_reference`: `[max_Z+1, max_Z+1, 5, 5]` float32
- `coord_num_ref`: `[max_Z+1, max_Z+1, 5, 5]` float32

Index 0 reserved for padding; valid atomic numbers 1 to max_Z.

**Periodic Boundary Conditions**

- `cell`: `[num_systems, 3, 3]` lattice vectors (row format)
- For neighbor matrix: `neighbor_matrix_shifts`: `[num_atoms, max_neighbors, 3]` int32 unit cell shifts
- For neighbor list: `unit_shifts`: `[num_pairs, 3]` int32 unit cell shifts

Units
-----
Kernels are **unit-agnostic** but require consistency. Standard Grimme group
parameters use **atomic units (Bohr, Hartree)**, which is recommended:

- Positions, covalent radii, `a2`, cutoffs: Bohr
- Energy output: Hartree
- Forces output: Hartree/Bohr
- Parameter `k1`: 1/Bohr

Technical Notes
---------------
- Supports float32 and float64 positions and cell. Outputs are always float32
- **Two-body only**: Axilrod-Teller-Muto (C9) three-body terms not included

See Also
--------
:class:`D3Parameters` : Dataclass for parameter validation and management
:func:`dftd3` : Main PyTorch interface function
:doc:`/userguide/components/dispersion` : Complete user guide with theory and examples
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch
import warp as wp

__all__ = [
    "D3Parameters",
    "dftd3",
]

# ==============================================================================
# Data Structures
# ==============================================================================


[docs] @dataclass class D3Parameters: """ DFT-D3 reference parameters for dispersion correction calculations. This dataclass encapsulates all element-specific parameters required for DFT-D3 dispersion corrections. The main purpose for this structure is to provide validation, ensuring the correct shapes, dtypes, and keys are present and complete. These parameters are used by :func:`dftd3`. Parameters ---------- rcov : torch.Tensor Covalent radii [max_Z+1] as float32 or float64. Units should be consistent with position coordinates. Index 0 is reserved for padding; valid atomic numbers are 1 to max_Z. r4r2 : torch.Tensor <r⁴>/<r²> expectation values [max_Z+1] as float32 or float64. Dimensionless ratio used for computing C8 coefficients from C6 values. c6ab : torch.Tensor C6 reference coefficients [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32 or float64. Units are energy x distance^6. Indexed by atomic numbers and coordination number reference indices. cn_ref : torch.Tensor Coordination number reference grid [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32 or float64. Dimensionless CN values for Gaussian interpolation. interp_mesh : int, optional Size of the coordination number interpolation mesh. Default: 5 (standard DFT-D3 uses a 5x5 grid) Raises ------ ValueError If parameter shapes are inconsistent or invalid TypeError If parameters are not torch.Tensor or have invalid dtypes Notes ----- - Parameters should use consistent units matching your coordinate system. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree x Bohr^6 for C6 coefficients). - Index 0 in all arrays is reserved for padding atoms (atomic number 0) - Valid atomic numbers range from 1 to max_z - The standard DFT-D3 implementation supports elements 1-94 (H to Pu) - Parameters can be float32 or float64; they will be converted to float32 during computation for efficiency Examples -------- Create parameters from individual tensors: >>> params = D3Parameters( ... rcov=torch.rand(95), # 94 elements + padding ... r4r2=torch.rand(95), ... c6ab=torch.rand(95, 95, 5, 5), ... cn_ref=torch.rand(95, 95, 5, 5), ... ) Create from a dictionary (e.g., loaded from file): >>> state_dict = torch.load("dftd3_parameters.pt") >>> params = D3Parameters( ... rcov=state_dict["rcov"], ... r4r2=state_dict["r4r2"], ... c6ab=state_dict["c6ab"], ... cn_ref=state_dict["cn_ref"], ... ) """ rcov: torch.Tensor r4r2: torch.Tensor c6ab: torch.Tensor cn_ref: torch.Tensor interp_mesh: int = 5 def __post_init__(self) -> None: """Validate parameter shapes, dtypes, and physical constraints.""" # Type validation for name, tensor in [ ("rcov", self.rcov), ("r4r2", self.r4r2), ("c6ab", self.c6ab), ("cn_ref", self.cn_ref), ]: if not isinstance(tensor, torch.Tensor): raise TypeError( f"Parameter '{name}' must be a torch.Tensor, got {type(tensor)}" ) if tensor.dtype not in (torch.float32, torch.float64): raise TypeError( f"Parameter '{name}' must be float32 or float64, got {tensor.dtype}" ) # Shape validation if self.rcov.ndim != 1: raise ValueError( f"rcov must be 1D tensor [max_Z+1], got shape {self.rcov.shape}" ) max_z = self.rcov.size(0) - 1 if max_z < 1: raise ValueError( f"rcov must have at least 2 elements (padding + 1 element), got {self.rcov.size(0)}" ) if self.r4r2.shape != (max_z + 1,): raise ValueError( f"r4r2 must have shape [{max_z + 1}] to match rcov, got {self.r4r2.shape}" ) expected_c6_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh) if self.c6ab.shape != expected_c6_shape: raise ValueError( f"c6ab must have shape {expected_c6_shape}, got {self.c6ab.shape}" ) expected_cn_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh) if self.cn_ref.shape != expected_cn_shape: raise ValueError( f"cn_ref must have shape {expected_cn_shape}, got {self.cn_ref.shape}" ) # Device consistency validation devices = [ self.rcov.device, self.r4r2.device, self.c6ab.device, self.cn_ref.device, ] if len({str(d) for d in devices}) > 1: raise ValueError( f"All parameters must be on the same device. " f"Got devices: rcov={self.rcov.device}, r4r2={self.r4r2.device}, " f"c6ab={self.c6ab.device}, cn_ref={self.cn_ref.device}" ) @property def max_z(self) -> int: """Maximum atomic number supported by these parameters.""" return self.rcov.size(0) - 1 @property def device(self) -> torch.device: """Device where parameters are stored.""" return self.rcov.device
[docs] def to( self, device: str | torch.device | None = None, dtype: torch.dtype | None = None, ) -> D3Parameters: """ Move all parameters to the specified device and/or convert to specified dtype. Parameters ---------- device : str or torch.device or None, optional Target device (e.g., 'cpu', 'cuda', 'cuda:0'). If None, keeps current device. dtype : torch.dtype or None, optional Target dtype (e.g., torch.float32, torch.float64). If None, keeps current dtype. Returns ------- D3Parameters New instance with parameters on the target device and/or dtype Examples -------- Move to GPU: >>> params_gpu = params.to(device='cuda') Convert to float32: >>> params_f32 = params.to(dtype=torch.float32) Move to GPU and convert to float32: >>> params_gpu_f32 = params.to(device='cuda', dtype=torch.float32) """ return D3Parameters( rcov=self.rcov.to(device=device, dtype=dtype), r4r2=self.r4r2.to(device=device, dtype=dtype), c6ab=self.c6ab.to(device=device, dtype=dtype), cn_ref=self.cn_ref.to(device=device, dtype=dtype), interp_mesh=self.interp_mesh, )
# ============================================================================== # Helper Functions # ============================================================================== @wp.func def _s5_switch( r: wp.float32, r_on: wp.float32, r_off: wp.float32, inv_w: wp.float32, ) -> tuple[wp.float32, wp.float32]: """ C² smooth switching function for cutoff smoothing. This function provides a smooth transition from 1 to 0 over the interval [r_on, r_off]. The switching polynomial S5(t) has continuous first and second derivatives at the boundaries. Parameters ---------- r : float32 Distance between atoms r_on : float32 Distance where switching begins r_off : float32 Distance where switching completes inv_w : float32 Precomputed 1/(r_off - r_on) for efficiency Returns ------- Sw : float32 Switching function value Sw(r) ∈ [0, 1] dSw_dr : float32 Derivative of switching function with respect to r Notes ----- The switching function is defined as: .. math:: S_w(r) = \\begin{cases} 1 & \\text{if } r \\leq r_{\\text{on}} \\\\ 1 - S_5(t) & \\text{if } r_{\\text{on}} < r < r_{\\text{off}} \\\\ 0 & \\text{if } r \\geq r_{\\text{off}} \\end{cases} where :math:`t = (r - r_{\\text{on}})/(r_{\\text{off}} - r_{\\text{on}}) \\in (0,1)` and .. math:: S_5(t) = 10t^3 - 15t^4 + 6t^5 The derivative is: .. math:: \\frac{dS_5}{dt} = 30t^2 - 60t^3 + 30t^4 \\frac{dS_w}{dr} = -\\frac{dS_5}{dt} \\cdot \\frac{1}{r_{\\text{off}} - r_{\\text{on}}} This ensures :math:`C^2` continuity (continuous function, first, and second derivatives) at both :math:`r_{\\text{on}}` and :math:`r_{\\text{off}}` boundaries. See Also -------- :func:`_direct_forces_and_dE_dCN_kernel_nm` : Uses this switching function for cutoff smoothing (neighbor matrix) :func:`_direct_forces_and_dE_dCN_kernel_nl` : Uses this switching function for cutoff smoothing (neighbor list) """ if r_off <= r_on: # disabled or degenerate: no switching return 1.0, 0.0 if r <= r_on: return 1.0, 0.0 if r >= r_off: return 0.0, 0.0 t = (r - r_on) * inv_w # t in (0,1) t2 = t * t t3 = t2 * t t4 = t3 * t t5 = t4 * t switch = 1.0 - (10.0 * t3 - 15.0 * t4 + 6.0 * t5) dSdt = -30.0 * t2 + 60.0 * t3 - 30.0 * t4 # NOSONAR (S125) "math formula" dSw_dr = dSdt * inv_w # NOSONAR (S125) "math formula" return switch, dSw_dr @wp.func def _c6ab_interpolate( cn_i: wp.float32, cn_j: wp.float32, c6ab_mat: wp.array2d(dtype=wp.float32), cnref_i_mat: wp.array2d(dtype=wp.float32), cnref_j_mat: wp.array2d(dtype=wp.float32), k3: wp.float32, ) -> tuple[wp.float32, wp.float32, wp.float32]: """ Interpolate C6 coefficient and CN derivatives using Gaussian weighting. This function performs Gaussian interpolation over a 5x5 reference grid to compute the environment-dependent C6 coefficient for an atom pair, along with derivatives with respect to coordination numbers. Parameters ---------- cn_i : float32 Coordination number of atom i cn_j : float32 Coordination number of atom j c6ab_mat : wp.array2d(dtype=float32) C6 reference values [5, 5] for this element pair cnref_i_mat : wp.array2d(dtype=float32) CN reference grid [5, 5] for atom i cnref_j_mat : wp.array2d(dtype=float32) CN reference grid [5, 5] for atom j k3 : float32 Gaussian width parameter (typically -4.0) Returns ------- c6_ij : float32 Interpolated C6 coefficient dC6_dCNi : float32 Derivative :math:`\\partial C_6/\\partial \text{CN}_i` dC6_dCNj : float32 Derivative :math:`\\partial C_6/\\partial \text{CN}_j` Notes ----- The Gaussian weights are: .. math:: L_{pq} = \\exp\\left(-k_3 \\left[(\\text{CN}_i - \\text{CN}_{\\text{ref},i}[p,q])^2 + (\\text{CN}_j - \\text{CN}_{\\text{ref},j}[p,q])^2\\right]\\right) The interpolated C6 and derivatives are: .. math:: C_6 = \\frac{\\sum_{pq} C_6^{\\text{ref}}[p,q] L_{pq}}{\\sum_{pq} L_{pq}} \\frac{\\partial C_6}{\\partial \\text{CN}_i} = \\frac{2k_3}{w} (z_{d_i} - C_6 w_{d_i}) where accumulators :math:`w`, :math:`z`, :math:`w_{d_i}`, :math:`z_{d_i}` are computed in a single pass over the 5x5 grid. See Also -------- :func:`_direct_forces_and_dE_dCN_kernel_nm` : Calls this function for C6 coefficient interpolation (neighbor matrix) :func:`_direct_forces_and_dE_dCN_kernel_nl` : Calls this function for C6 coefficient interpolation (neighbor list) """ # log-sum-exp trick to avoid numerical instability max_exp = wp.float(-1e20) for p in range(5): for q in range(5): c6_val = c6ab_mat[p, q] if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel" continue cnref_i = cnref_i_mat[p, q] cnref_j = cnref_j_mat[q, p] di = cn_i - cnref_i dj = cn_j - cnref_j exp_arg = k3 * (di * di + dj * dj) if exp_arg > max_exp: max_exp = exp_arg w = float(0.0) z = float(0.0) w_di = float(0.0) w_dj = float(0.0) z_di = float(0.0) z_dj = float(0.0) for p in range(5): for q in range(5): c6_val = c6ab_mat[p, q] if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel" continue cnref_i = cnref_i_mat[p, q] cnref_j = cnref_j_mat[q, p] # Note transpose indexing di = cn_i - cnref_i dj = cn_j - cnref_j # Compute exponent argument and skip negligible contributions exp_arg = k3 * (di * di + dj * dj) - max_exp if exp_arg < -12.0: continue L = wp.exp(exp_arg) w += L z += c6_val * L w_di += L * di w_dj += L * dj z_di += c6_val * L * di z_dj += c6_val * L * dj eps_w = 1e-12 if w > eps_w: w_inv = 1.0 / w c6_ij = z * w_inv s_i = z_di - c6_ij * w_di s_j = z_dj - c6_ij * w_dj k3_w_w_inv = (2.0 * k3) * w_inv dC6_dCNi = k3_w_w_inv * s_i # NOSONAR (S125) "math formula" dC6_dCNj = k3_w_w_inv * s_j # NOSONAR (S125) "math formula" return c6_ij, dC6_dCNi, dC6_dCNj else: return 0.0, 0.0, 0.0 @wp.func def _compute_distance_vector_pbc( pos_i: Any, pos_j: Any, cartesian_shift: Any, periodic: bool, compute_vectors: bool, ) -> tuple[wp.float32, wp.float32, wp.vec3f, wp.vec3f]: """ Compute distance with optional PBC and vector outputs. Parameters ---------- pos_i, pos_j : vec3 Atomic positions cartesian_shift : vec3 PBC shift (ignored if periodic=False) periodic : bool Apply PBC shift compute_vectors : bool If True, compute r_hat; if False, r_hat returns zero vector Returns ------- r : float32 Distance r_inv : float32 Inverse distance (0 if r < 1e-12) r_hat : vec3f Unit vector (zero vec if compute_vectors=False or r < 1e-12) r_ij : vec3f Distance vector (always returned) """ if periodic: r_ij_native = (pos_j - pos_i) + cartesian_shift else: r_ij_native = pos_j - pos_i r_ij = wp.vec3f( wp.float32(r_ij_native[0]), wp.float32(r_ij_native[1]), wp.float32(r_ij_native[2]), ) r = wp.length(r_ij) if r < 1e-12: return r, wp.float32(0.0), wp.vec3f(0.0, 0.0, 0.0), r_ij r_inv = 1.0 / r if compute_vectors: r_hat = r_ij * r_inv return r, r_inv, r_hat, r_ij else: return r, r_inv, wp.vec3f(0.0, 0.0, 0.0), r_ij @wp.func def _cn_counting( r_inv: wp.float32, rcov_i: wp.float32, rcov_j: wp.float32, k1: wp.float32, compute_derivative: bool, ) -> tuple[wp.float32, wp.float32]: """ Compute CN counting function with optional derivative. Parameters ---------- r_inv : float32 Inverse distance rcov_i, rcov_j : float32 Covalent radii k1 : float32 Steepness parameter compute_derivative : bool If True, compute dCN_dr; if False, returns None Returns ------- f_cn : float32 Counting function value dCN_dr : float32 Derivative (zero if compute_derivative=False) """ rcov_ij = rcov_i + rcov_j rcov_r_inv = rcov_ij * r_inv f_cn = 1.0 / (1.0 + wp.exp(-k1 * (rcov_r_inv - 1.0))) if compute_derivative: dCN_dr = -f_cn * (1.0 - f_cn) * k1 * rcov_r_inv * r_inv # NOSONAR (S125) return f_cn, dCN_dr else: return f_cn, wp.float32(0.0) @wp.func def _bj_damping( r: wp.float32, r4r2_i: wp.float32, r4r2_j: wp.float32, a1: wp.float32, a2: wp.float32, s6: wp.float32, s8: wp.float32, ) -> tuple[wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32]: """ Compute Becke-Johnson damping. Returns ------- damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv : float32 """ r4r2_ij = 3.0 * r4r2_i * r4r2_j r0 = a1 * wp.sqrt(r4r2_ij) + a2 r2 = r * r r4 = r2 * r2 r6 = r4 * r2 r8 = r4 * r4 r0_2 = r0 * r0 r0_4 = r0_2 * r0_2 r0_6 = r0_4 * r0_2 r0_8 = r0_4 * r0_4 den6 = r6 + r0_6 den8 = r8 + r0_8 den6_inv = 1.0 / den6 den8_inv = 1.0 / den8 damp_6 = s6 * den6_inv damp_8 = s8 * r4r2_ij * den8_inv damp_sum = damp_6 + damp_8 return damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv @wp.func def _dispersion_energy_force( c6_ij: wp.float32, r: wp.float32, r_hat: wp.vec3f, damp_sum: wp.float32, r4r2_ij: wp.float32, r6: wp.float32, r4: wp.float32, den6_inv: wp.float32, den8_inv: wp.float32, s6: wp.float32, s8: wp.float32, s5_smoothing_on: wp.float32, s5_smoothing_off: wp.float32, inv_w: wp.float32, ) -> tuple[wp.float32, wp.vec3f]: """ Compute dispersion energy and direct force with S5 switching. Returns ------- e_ij_sw : float32 Smoothed energy F_direct : vec3f Direct force vector """ e_ij = -c6_ij * damp_sum r5 = r4 * r r7 = r6 * r dD6_dr = -6.0 * s6 * r5 * den6_inv * den6_inv # NOSONAR (S125) "math formula" dD8_dr = -8.0 * s8 * r4r2_ij * r7 * den8_inv * den8_inv # NOSONAR (S125) dE_dr_direct = -c6_ij * (dD6_dr + dD8_dr) # NOSONAR (S125) "math formula" sw, dsw_dr = _s5_switch(r, s5_smoothing_on, s5_smoothing_off, inv_w) e_ij_sw = e_ij * sw dE_dr_direct_sw = sw * dE_dr_direct + e_ij * dsw_dr # NOSONAR (S125) F_direct = dE_dr_direct_sw * r_hat # NOSONAR (S125) "math formula" return e_ij_sw, F_direct @wp.func def _unit_shift_to_cartesian( unit_shift: wp.vec3i, cell_mat: Any, ) -> Any: """Convert integer unit cell shift to Cartesian coordinates.""" unit_shift_float = wp.vec( cell_mat[0].dtype(unit_shift[0]), cell_mat[0].dtype(unit_shift[1]), cell_mat[0].dtype(unit_shift[2]), ) return unit_shift_float * cell_mat # ============================================================================== # Kernels # ============================================================================== @wp.kernel(enable_backward=False) def _compute_cartesian_shifts( cell: wp.array(dtype=Any), unit_shifts: wp.array2d(dtype=wp.vec3i), neighbor_matrix: wp.array2d(dtype=wp.int32), batch_idx: wp.array(dtype=wp.int32), fill_value: wp.int32, cartesian_shifts: wp.array2d(dtype=Any), ): """ Convert unit cell shifts to Cartesian coordinates for periodic boundaries. For each neighbor in the neighbor matrix, this kernel computes the Cartesian shift vector that should be applied to atom j's position to obtain its periodic image closest to atom i. Parameters ---------- cell : wp.array3d(dtype=float32) Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :] is the i-th lattice vector for system s (row vectors). Units should match position coordinates. unit_shifts : wp.array2d(dtype=vec3i) Integer unit cell shifts [num_atoms, max_neighbors] as vec3i neighbor_matrix : wp.array2d(dtype=int32) Neighbor indices [num_atoms, max_neighbors]. See module docstring for more details. batch_idx : wp.array(dtype=int32) System index [num_atoms] for each atom fill_value : int32 Value indicating padding in neighbor_matrix (typically num_atoms) cartesian_shifts : wp.array2d(dtype=vec3f) Output: Cartesian shift vectors [num_atoms, max_neighbors] as vec3 in same units as cell vectors Notes ----- The Cartesian shift is computed as: .. math:: \\mathbf{s} = n_a \\mathbf{a} + n_b \\mathbf{b} + n_c \\mathbf{c} where :math:`\\mathbf{a}, \\mathbf{b}, \\mathbf{c}` are lattice vectors and :math:`n_a, n_b, n_c` are integer shifts. The system ID is obtained from atom i's batch index. Launch with dim=(num_atoms, max_neighbors) (one thread per atom-neighbor pair). See Also -------- :func:`_cn_kernel_nm` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor matrix) :func:`_cn_kernel_nl` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor list) :func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Uses computed Cartesian shifts for PBC (neighbor matrix) :func:`_direct_forces_and_dE_dCN_kernel_nl` : Pass 2 - Uses computed Cartesian shifts for PBC (neighbor list) :func:`_cn_forces_contrib_kernel_nm` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor matrix) :func:`_cn_forces_contrib_kernel_nl` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor list) :func:`dftd3` : High-level wrapper that orchestrates all passes """ atom_i, neighbor_idx = wp.tid() max_neighbors = neighbor_matrix.shape[1] if neighbor_idx >= max_neighbors: return atom_j = neighbor_matrix[atom_i, neighbor_idx] if atom_j >= fill_value: return system_id = batch_idx[atom_i] cell_mat = cell[system_id] unit_shift = unit_shifts[atom_i, neighbor_idx] cartesian_shifts[atom_i, neighbor_idx] = _unit_shift_to_cartesian( unit_shift, cell_mat ) @wp.kernel(enable_backward=False) def _cn_kernel_nm( positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), neighbor_matrix: wp.array2d(dtype=wp.int32), cartesian_shifts: wp.array2d(dtype=Any), covalent_radii: wp.array(dtype=wp.float32), k1: wp.float32, fill_value: wp.int32, periodic: bool, coord_num: wp.array(dtype=wp.float32), ): """ Compute coordination numbers using geometric counting function. This kernel computes the coordination number (CN) for each atom based on a smooth counting function that depends on interatomic distances and covalent radii. Supports periodic boundary conditions via Cartesian shifts. Algorithm --------- For each atom i, iterate over its neighbors (neighbor matrix format) and accumulate: .. math:: \\text{CN}_i = \\sum_{j \\in \\text{neighbors}(i)} f(r_{ij}) f(r) = \\frac{1}{1 + \\exp\\left[k_1\\left(\\frac{r_{\\text{cov}}}{r} - 1\\right)\\right]} where :math:`r_{\\text{cov}} = r_{\\text{cov}}[Z_i] + r_{\\text{cov}}[Z_j]`. The counting function smoothly transitions from 1 (bonded) to 0 (non-bonded). Parameters ---------- positions : wp.array(dtype=vec3f) Atomic coordinates [num_atoms] numbers : wp.array(dtype=int32) Atomic numbers [num_atoms] neighbor_matrix : wp.array2d(dtype=int32) Neighbor indices [num_atoms, max_neighbors]. See module docstring for more details. cartesian_shifts : wp.array2d(dtype=vec3f) Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC (ignored if periodic=False), in same units as positions covalent_radii : wp.array(dtype=float32) Covalent radii [max_Z+1] indexed by atomic number, in same units as positions k1 : float32 Steepness parameter for counting function (typically 16.0 1/Bohr) fill_value : int32 Value indicating padding in neighbor_matrix (typically num_atoms) periodic : bool If True, apply PBC using cartesian_shifts; if False, non-periodic coord_num : wp.array(dtype=float32) Output: coordination numbers [num_atoms] (dimensionless) Notes ----- - Launch with dim=num_atoms (one thread per atom) - Each thread iterates over all neighbors and accumulates CN in a local register - Padding atoms indicated by numbers[i] == 0 are skipped - Neighbor entries with j >= fill_value are padding and are skipped See Also -------- :func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC :func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Uses coordination numbers computed here :func:`dftd3` : High-level wrapper that orchestrates all passes """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return max_neighbors = neighbor_matrix.shape[1] pos_i = positions[atom_i] rcov_i = covalent_radii[numbers[atom_i]] # Accumulate coordination number in local register cn_acc = wp.float32(0.0) for neighbor_idx in range(max_neighbors): atom_j = neighbor_matrix[atom_i, neighbor_idx] if atom_j >= fill_value: continue # skip padding if numbers[atom_j] == 0: continue # Compute distance with optional PBC shift r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[atom_i, neighbor_idx], periodic, False, ) if r < 1e-12: continue # Compute coordination number contribution f_cn, dCN_dr = _cn_counting( r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False ) cn_acc += f_cn # Write final coordination number once coord_num[atom_i] = cn_acc @wp.kernel(enable_backward=False) def _direct_forces_and_dE_dCN_kernel_nm( # NOSONAR (S1542) "math formula" positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), neighbor_matrix: wp.array2d(dtype=wp.int32), cartesian_shifts: wp.array2d(dtype=Any), coord_num: wp.array(dtype=wp.float32), r4r2: wp.array(dtype=wp.float32), c6_reference: wp.array4d(dtype=wp.float32), coord_num_ref: wp.array4d(dtype=wp.float32), k3: wp.float32, a1: wp.float32, a2: wp.float32, s6: wp.float32, s8: wp.float32, s5_smoothing_on: wp.float32, s5_smoothing_off: wp.float32, inv_w: wp.float32, fill_value: wp.int32, periodic: bool, batch_idx: wp.array(dtype=wp.int32), compute_virial: bool, dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula" forces: wp.array(dtype=wp.vec3f), energy: wp.array(dtype=wp.float32), virial: wp.array(dtype=Any), ): """ Pass 2: Compute direct forces, energy, and accumulate dE/dCN per atom. Computes dispersion energy and forces at constant CN, and accumulates dE/dCN contributions for each atom for use in Pass 3. Parameters ---------- positions : wp.array(dtype=vec3f) Atomic coordinates [num_atoms] numbers : wp.array(dtype=int32) Atomic numbers [num_atoms] neighbor_matrix : wp.array2d(dtype=int32) Neighbor indices [num_atoms, max_neighbors]. See module docstring for more details. cartesian_shifts : wp.array2d(dtype=vec3f) Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions coord_num : wp.array(dtype=float32) Coordination numbers [num_atoms] from Pass 1 (dimensionless) r4r2 : wp.array(dtype=float32) <r⁴>/<r²> expectation values [max_Z+1] (dimensionless) c6_reference : wp.array4d(dtype=float32) C6 reference [max_Z+1, max_Z+1, 5, 5] in energy x distance^6 units coord_num_ref : wp.array4d(dtype=float32) CN reference grid [max_Z+1, max_Z+1, 5, 5] (dimensionless) k3 : float32 CN interpolation width (typically -4.0, dimensionless) a1, a2 : float32 Becke-Johnson damping parameters (a1 dimensionless, a2 in distance units) s6, s8 : float32 Scaling factors for C6 and C8 terms (dimensionless) s5_smoothing_on, s5_smoothing_off : float32 S5 switching radii in same units as positions inv_w : float32 Precomputed 1/(s5_off - s5_on) in inverse distance units fill_value : int32 Value indicating padding in neighbor_matrix (typically num_atoms) periodic : bool If True, apply PBC using cartesian_shifts batch_idx : wp.array(dtype=int32) System index [num_atoms] dE_dCN : wp.array(dtype=float32) Output: accumulated dE/dCN [num_atoms] in energy units forces : wp.array(dtype=vec3f) Output: direct forces [num_atoms] in energy/distance units energy : wp.array(dtype=float32) Output: system energies [num_systems] in energy units Notes ----- - Launch with dim=num_atoms (one thread per atom) - Each thread iterates over all neighbors and accumulates results in local registers - Direct forces are F = :math:`-(\\partial E/\\partial r)|_\text{CN}`, without chain rule term - dE_dCN[i] = :math:`\\sum_j \\partial E_{ij}/\\partial \text{CN}_i` accumulated over all pairs containing atom i - Neighbor entries with j >= fill_value are padding and are skipped See Also -------- :func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC :func:`_cn_kernel_nm` : Pass 1 - Computes coordination numbers used here :func:`_cn_forces_contrib_kernel_nm` : Pass 3 - Uses dE/dCN values accumulated here :func:`_c6ab_interpolate` : Called to interpolate C6 coefficients :func:`_s5_switch` : Called for cutoff smoothing :func:`dftd3` : High-level wrapper that orchestrates all passes """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return max_neighbors = neighbor_matrix.shape[1] pos_i = positions[atom_i] cn_i = coord_num[atom_i] z_i = numbers[atom_i] r4r2_i = r4r2[z_i] # Accumulate in local registers (using float64 for better precision) F_acc = wp.vec3d() # NOSONAR (S117) "math formula" dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula" energy_acc = wp.float64(0.0) # Initialize virial accumulator if compute_virial: virial_acc = wp.mat33d() for neighbor_idx in range(max_neighbors): atom_j = neighbor_matrix[atom_i, neighbor_idx] if atom_j >= fill_value: continue # skip padding atoms if numbers[atom_j] == 0: continue # Geometry r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[atom_i, neighbor_idx], periodic, True, ) if r < 1e-12: continue cn_j = coord_num[atom_j] z_j = numbers[atom_j] # C6 interpolation c6ab_mat = c6_reference[z_i, z_j] cnref_i_mat = coord_num_ref[z_i, z_j] cnref_j_mat = coord_num_ref[z_j, z_i] c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula" cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3 ) if c6_ij < 1e-12: continue # BJ damping damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping( r, r4r2_i, r4r2[z_j], a1, a2, s6, s8 ) # Energy and direct force e_ij_sw, F_direct = _dispersion_energy_force( c6_ij, r, r_hat, damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv, s6, s8, s5_smoothing_on, s5_smoothing_off, inv_w, ) # Accumulate in registers F_acc += wp.vec3d(F_direct) # NOSONAR (S117) "math formula" energy_acc += wp.float64(e_ij_sw) dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula" # Accumulate virial if requested if compute_virial: virial_acc += wp.mat33d(wp.outer(F_direct, r_ij)) # Write final results once (atomic only for shared batch array) # Convert from float64 accumulation to float32 output forces[atom_i] = wp.vec3f(F_acc) dE_dCN[atom_i] = dE_dCN_acc wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc)) # Add virial contribution with -0.5 scaling for correct sign and double counting if compute_virial: wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc)) @wp.kernel(enable_backward=False) def _cn_forces_contrib_kernel_nm( positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), neighbor_matrix: wp.array2d(dtype=wp.int32), cartesian_shifts: wp.array2d(dtype=Any), covalent_radii: wp.array(dtype=wp.float32), dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula" k1: wp.float32, fill_value: wp.int32, periodic: bool, batch_idx: wp.array(dtype=wp.int32), compute_virial: bool, forces: wp.array(dtype=wp.vec3f), virial: wp.array(dtype=Any), ): """ Pass 3: Add CN-dependent force contribution. Adds the CN-dependent term to forces computed in Pass 2. Computes distances and CN derivatives without repeating C6 interpolation and damping calculations. Parameters ---------- positions : wp.array(dtype=vec3f) Atomic coordinates [num_atoms] in consistent distance units numbers : wp.array(dtype=int32) Atomic numbers [num_atoms] neighbor_matrix : wp.array2d(dtype=int32) Neighbor indices [num_atoms, max_neighbors]. See module docstring for more details. cartesian_shifts : wp.array2d(dtype=vec3f) Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions covalent_radii : wp.array(dtype=float32) Covalent radii [max_Z+1] in same units as positions dE_dCN : wp.array(dtype=float32) Precomputed dE/dCN [num_atoms] from Pass 2 in energy units k1 : float32 CN counting steepness in inverse distance units (typically 16.0 1/Bohr) fill_value : int32 Value indicating padding in neighbor_matrix (typically num_atoms) periodic : bool If True, apply PBC using cartesian_shifts forces : wp.array(dtype=vec3f) Input/Output: add chain term to direct forces [num_atoms] in energy/distance units Notes ----- - Launch with dim=num_atoms (one thread per atom) - Each thread iterates over all neighbors and accumulates results in local registers - Skips C6 interpolation and damping calculations - Uses precomputed dE_dCN[i] = :math:`\\sum_k \\partial E_{ik}/\\partial \text{CN}_i` from all pairs - Neighbor entries with j >= fill_value are padding and are skipped See Also -------- :func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC :func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Computes dE/dCN values used here :func:`dftd3` : High-level wrapper that orchestrates all passes """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return max_neighbors = neighbor_matrix.shape[1] dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula" pos_i = positions[atom_i] rcov_i = covalent_radii[numbers[atom_i]] # Accumulate force in local register (using float64 for better precision) F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula" # Initialize virial accumulator if compute_virial: virial_chain_acc = wp.mat33d() for neighbor_idx in range(max_neighbors): atom_j = neighbor_matrix[atom_i, neighbor_idx] if atom_j >= fill_value: continue if numbers[atom_j] == 0: continue # Distance r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[atom_i, neighbor_idx], periodic, True, ) if r < 1e-12: continue # CN derivative f_cn, dCN_dr = _cn_counting( r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True ) # CN-dependent force contribution dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula" dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula" F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula" F_chain_acc += wp.vec3d(F_chain) # Accumulate virial if requested if compute_virial: virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij)) # Add accumulated force to existing forces (direct read-modify-write) # Convert from float64 accumulation to float32 output forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc) # Add virial contribution with -0.5 scaling for correct sign and double counting if compute_virial: wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc)) # ============================================================================== # Neighbor List Kernels # ============================================================================== @wp.kernel(enable_backward=False) def _compute_cartesian_shifts_nl( cell: wp.array(dtype=Any), unit_shifts: wp.array(dtype=wp.vec3i), neighbor_ptr: wp.array(dtype=wp.int32), batch_idx: wp.array(dtype=wp.int32), cartesian_shifts: wp.array(dtype=Any), ): """ Convert unit cell shifts to Cartesian coordinates for CSR neighbor lists. For each edge in the CSR neighbor list, this kernel computes the Cartesian shift vector that should be applied to the destination atom's position. Parameters ---------- cell : wp.array(dtype=mat33) Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :] is the i-th lattice vector for system s (row vectors). Units should match position coordinates. unit_shifts : wp.array(dtype=vec3i) Integer unit cell shifts [num_edges] as vec3i neighbor_ptr : wp.array(dtype=int32) CSR row pointers [num_atoms+1] batch_idx : wp.array(dtype=int32) System index [num_atoms] for each atom cartesian_shifts : wp.array(dtype=vec3) Output: Cartesian shift vectors [num_edges] as vec3 in same units as cell vectors Notes ----- Launch with dim=num_atoms (one thread per atom). Each thread processes all edges for that atom using the CSR pointers. See Also -------- :func:`_cn_kernel_nl` : Uses computed Cartesian shifts for PBC :func:`_direct_forces_and_dE_dCN_kernel_nl` : Uses computed Cartesian shifts for PBC :func:`_cn_forces_contrib_kernel_nl` : Uses computed Cartesian shifts for PBC """ atom_i = wp.tid() # Get number of atoms from batch_idx size if atom_i >= batch_idx.shape[0]: return system_id = batch_idx[atom_i] cell_mat = cell[system_id] # Get range of edges for this atom j_range_start = neighbor_ptr[atom_i] j_range_end = neighbor_ptr[atom_i + 1] # Convert all unit shifts for this atom's neighbors to Cartesian for edge_idx in range(j_range_start, j_range_end): unit_shift = unit_shifts[edge_idx] cartesian_shifts[edge_idx] = _unit_shift_to_cartesian(unit_shift, cell_mat) @wp.kernel(enable_backward=False) def _cn_kernel_nl( positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), idx_j: wp.array(dtype=wp.int32), neighbor_ptr: wp.array(dtype=wp.int32), cartesian_shifts: wp.array(dtype=Any), covalent_radii: wp.array(dtype=wp.float32), k1: wp.float32, periodic: bool, coord_num: wp.array(dtype=wp.float32), ): """ Compute coordination numbers using CSR neighbor list format. Parameters ---------- positions : wp.array(dtype=vec3) Atomic coordinates [num_atoms] numbers : wp.array(dtype=int32) Atomic numbers [num_atoms] idx_j : wp.array(dtype=int32) Destination atom indices [num_edges] in CSR format neighbor_ptr : wp.array(dtype=int32) CSR row pointers [num_atoms+1] cartesian_shifts : wp.array(dtype=vec3) Cartesian shifts [num_edges] as vec3 for PBC covalent_radii : wp.array(dtype=float32) Covalent radii [max_Z+1] indexed by atomic number k1 : float32 Steepness parameter for counting function periodic : bool If True, apply PBC using cartesian_shifts coord_num : wp.array(dtype=float32) Output: coordination numbers [num_atoms] Notes ----- Launch with dim=num_atoms (one thread per atom). """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return pos_i = positions[atom_i] rcov_i = covalent_radii[numbers[atom_i]] # Accumulate coordination number in local register cn_acc = wp.float32(0.0) # Iterate over neighbors using CSR pointers j_range_start = neighbor_ptr[atom_i] j_range_end = neighbor_ptr[atom_i + 1] for edge_idx in range(j_range_start, j_range_end): atom_j = idx_j[edge_idx] # skip padding atoms if numbers[atom_j] == 0: continue # Compute distance with optional PBC shift r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, False ) if r < 1e-12: continue # Compute coordination number contribution f_cn, dCN_dr = _cn_counting( r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False ) cn_acc += f_cn # Write final coordination number once coord_num[atom_i] = cn_acc @wp.kernel(enable_backward=False) def _direct_forces_and_dE_dCN_kernel_nl( # NOSONAR (S1542) "math formula" positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), idx_j: wp.array(dtype=wp.int32), neighbor_ptr: wp.array(dtype=wp.int32), cartesian_shifts: wp.array(dtype=Any), coord_num: wp.array(dtype=wp.float32), r4r2: wp.array(dtype=wp.float32), c6_reference: wp.array4d(dtype=wp.float32), coord_num_ref: wp.array4d(dtype=wp.float32), k3: wp.float32, a1: wp.float32, a2: wp.float32, s6: wp.float32, s8: wp.float32, s5_smoothing_on: wp.float32, s5_smoothing_off: wp.float32, inv_w: wp.float32, periodic: bool, batch_idx: wp.array(dtype=wp.int32), compute_virial: bool, dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula" forces: wp.array(dtype=wp.vec3f), energy: wp.array(dtype=wp.float32), virial: wp.array(dtype=Any), ): """ Pass 2: Compute direct forces, energy, and accumulate dE/dCN using CSR neighbor list. Notes ----- Launch with dim=num_atoms (one thread per atom). """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return pos_i = positions[atom_i] cn_i = coord_num[atom_i] z_i = numbers[atom_i] r4r2_i = r4r2[z_i] # Accumulate in local registers (using float64 for better precision) F_acc = wp.vec3d() # NOSONAR (S117) "math formula" dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula" energy_acc = wp.float64(0.0) # Initialize virial accumulator if compute_virial: virial_acc = wp.mat33d() # Iterate over neighbors using CSR pointers j_range_start = neighbor_ptr[atom_i] j_range_end = neighbor_ptr[atom_i + 1] for edge_idx in range(j_range_start, j_range_end): atom_j = idx_j[edge_idx] # skip padding atoms if numbers[atom_j] == 0: continue # Geometry r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True ) if r < 1e-12: continue cn_j = coord_num[atom_j] z_j = numbers[atom_j] # C6 interpolation c6ab_mat = c6_reference[z_i, z_j] cnref_i_mat = coord_num_ref[z_i, z_j] cnref_j_mat = coord_num_ref[z_j, z_i] c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula" cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3 ) if c6_ij < 1e-12: continue # BJ damping damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping( r, r4r2_i, r4r2[z_j], a1, a2, s6, s8 ) # Energy and direct force e_ij_sw, F_direct = _dispersion_energy_force( c6_ij, r, r_hat, damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv, s6, s8, s5_smoothing_on, s5_smoothing_off, inv_w, ) # Accumulate in registers F_acc += wp.vec3d(F_direct) energy_acc += wp.float64(e_ij_sw) dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula" # Accumulate virial if requested if compute_virial: virial_acc += wp.mat33d(wp.outer(F_direct, r_ij)) # Write final results once (atomic only for shared batch array) # Convert from float64 accumulation to float32 output forces[atom_i] = wp.vec3f(F_acc) dE_dCN[atom_i] = wp.float32(dE_dCN_acc) wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc)) # Add virial contribution with -0.5 scaling for correct sign and double counting if compute_virial: wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc)) @wp.kernel(enable_backward=False) def _cn_forces_contrib_kernel_nl( positions: wp.array(dtype=Any), numbers: wp.array(dtype=wp.int32), idx_j: wp.array(dtype=wp.int32), neighbor_ptr: wp.array(dtype=wp.int32), cartesian_shifts: wp.array(dtype=Any), covalent_radii: wp.array(dtype=wp.float32), dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula" k1: wp.float32, periodic: bool, batch_idx: wp.array(dtype=wp.int32), compute_virial: bool, forces: wp.array(dtype=wp.vec3f), virial: wp.array(dtype=Any), ): """ Pass 3: Add CN-dependent force contribution using CSR neighbor list. Notes ----- Launch with dim=num_atoms (one thread per atom). """ atom_i = wp.tid() if atom_i >= numbers.shape[0]: return # skip padding atoms if numbers[atom_i] == 0: return dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula" pos_i = positions[atom_i] rcov_i = covalent_radii[numbers[atom_i]] # Accumulate force in local register (using float64 for better precision) F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula" # Initialize virial accumulator if compute_virial: virial_chain_acc = wp.mat33d() # Iterate over neighbors using CSR pointers j_range_start = neighbor_ptr[atom_i] j_range_end = neighbor_ptr[atom_i + 1] for edge_idx in range(j_range_start, j_range_end): atom_j = idx_j[edge_idx] if numbers[atom_j] == 0: continue # Distance r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc( pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True ) if r < 1e-12: continue # CN derivative f_cn, dCN_dr = _cn_counting( r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True ) # CN-dependent force contribution dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula" dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula" F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula" F_chain_acc += wp.vec3d(F_chain) # Accumulate virial if requested if compute_virial: virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij)) # Add accumulated force to existing forces (direct read-modify-write) # Convert from float64 accumulation to float32 output forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc) # Add virial contribution with -0.5 scaling for correct sign and double counting if compute_virial: wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc)) # ============================================================================== # Kernel Overload Registration # ============================================================================== # Type constants for overload generation T = [wp.float32, wp.float64] V = [wp.vec3f, wp.vec3d] M = [wp.mat33f, wp.mat33d] # Overload dictionaries keyed by scalar type _compute_cartesian_shifts_overload = {} _cn_kernel_nm_overload = {} _direct_forces_and_dE_dCN_kernel_nm_overload = {} _cn_forces_contrib_kernel_nm_overload = {} # Neighbor list kernel overload dictionaries (CSR format) _compute_cartesian_shifts_nl_overload = {} _cn_kernel_nl_overload = {} _direct_forces_and_dE_dCN_kernel_nl_overload = {} _cn_forces_contrib_kernel_nl_overload = {} # Register overloads for all kernel variants for t, v, m in zip(T, V, M): _compute_cartesian_shifts_overload[t] = wp.overload( _compute_cartesian_shifts, [ wp.array(dtype=m), wp.array2d(dtype=wp.vec3i), wp.array2d(dtype=wp.int32), wp.array(dtype=wp.int32), wp.int32, wp.array2d(dtype=v), ], ) _cn_kernel_nm_overload[t] = wp.overload( _cn_kernel_nm, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array2d(dtype=wp.int32), wp.array2d(dtype=v), wp.array(dtype=wp.float32), wp.float32, wp.int32, wp.bool, wp.array(dtype=wp.float32), ], ) _direct_forces_and_dE_dCN_kernel_nm_overload[t] = wp.overload( _direct_forces_and_dE_dCN_kernel_nm, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array2d(dtype=wp.int32), wp.array2d(dtype=v), wp.array(dtype=wp.float32), wp.array(dtype=wp.float32), wp.array4d(dtype=wp.float32), wp.array4d(dtype=wp.float32), wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.int32, wp.bool, wp.array(dtype=wp.int32), wp.bool, wp.array(dtype=wp.float32), wp.array(dtype=wp.vec3f), wp.array(dtype=wp.float32), wp.array(dtype=wp.mat33f), ], ) _cn_forces_contrib_kernel_nm_overload[t] = wp.overload( _cn_forces_contrib_kernel_nm, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array2d(dtype=wp.int32), wp.array2d(dtype=v), wp.array(dtype=wp.float32), wp.array(dtype=wp.float32), wp.float32, wp.int32, wp.bool, wp.array(dtype=wp.int32), wp.bool, wp.array(dtype=wp.vec3f), wp.array(dtype=wp.mat33f), ], ) # Neighbor list kernel overloads (CSR format) _compute_cartesian_shifts_nl_overload[t] = wp.overload( _compute_cartesian_shifts_nl, [ wp.array(dtype=m), wp.array(dtype=wp.vec3i), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=v), ], ) _cn_kernel_nl_overload[t] = wp.overload( _cn_kernel_nl, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=v), wp.array(dtype=wp.float32), wp.float32, wp.bool, wp.array(dtype=wp.float32), ], ) _direct_forces_and_dE_dCN_kernel_nl_overload[t] = wp.overload( _direct_forces_and_dE_dCN_kernel_nl, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=v), wp.array(dtype=wp.float32), wp.array(dtype=wp.float32), wp.array4d(dtype=wp.float32), wp.array4d(dtype=wp.float32), wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.bool, wp.array(dtype=wp.int32), wp.bool, wp.array(dtype=wp.float32), wp.array(dtype=wp.vec3f), wp.array(dtype=wp.float32), wp.array(dtype=wp.mat33f), ], ) _cn_forces_contrib_kernel_nl_overload[t] = wp.overload( _cn_forces_contrib_kernel_nl, [ wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=wp.int32), wp.array(dtype=v), wp.array(dtype=wp.float32), wp.array(dtype=wp.float32), wp.float32, wp.bool, wp.array(dtype=wp.int32), wp.bool, wp.array(dtype=wp.vec3f), wp.array(dtype=wp.mat33f), ], ) # ============================================================================== # PyTorch Wrapper # ============================================================================== @torch.library.custom_op( "nvalchemiops::dftd3_nm", mutates_args=("energy", "forces", "coord_num", "virial"), ) def _dftd3_nm_op( positions: torch.Tensor, numbers: torch.Tensor, neighbor_matrix: torch.Tensor, covalent_radii: torch.Tensor, r4r2: torch.Tensor, c6_reference: torch.Tensor, coord_num_ref: torch.Tensor, a1: float, a2: float, s8: float, energy: torch.Tensor, forces: torch.Tensor, coord_num: torch.Tensor, virial: torch.Tensor, k1: float = 16.0, k3: float = -4.0, s6: float = 1.0, s5_smoothing_on: float = 1e10, s5_smoothing_off: float = 1e10, fill_value: int | None = None, batch_idx: torch.Tensor | None = None, cell: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, compute_virial: bool = False, device: str | None = None, ) -> None: """Internal custom op for DFT-D3(BJ) dispersion energy and forces computation. This is a low-level custom operator that performs DFT-D3(BJ) dispersion calculations using Warp kernels. Output tensors must be pre-allocated by the caller and are modified in-place. For most use cases, prefer the higher-level :func:`dftd3` wrapper function instead of calling this method directly. This function is torch.compile compatible. Parameters ---------- positions : torch.Tensor, shape (num_atoms, 3) Atomic coordinates as float32 or float64, in consistent distance units (conventionally Bohr) numbers : torch.Tensor, shape (num_atoms), dtype=int32 Atomic numbers neighbor_matrix : torch.Tensor, shape (num_atoms, max_neighbors), dtype=int32 Neighbor indices. See module docstring for format details. Padding entries have values >= fill_value. covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32 Covalent radii indexed by atomic number, in same units as positions r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32 <r⁴>/<r²> expectation values for C8 computation (dimensionless) c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32 C6 reference values in energy x distance^6 units coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32 CN reference grid (dimensionless) a1 : float Becke-Johnson damping parameter 1 (functional-dependent, dimensionless) a2 : float Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions s8 : float C8 term scaling factor (functional-dependent, dimensionless) energy : torch.Tensor, shape (num_systems,), dtype=float32 OUTPUT: Total dispersion energy. Must be pre-allocated. Units are energy (Hartree when using standard D3 parameters). forces : torch.Tensor, shape (num_atoms, 3), dtype=float32 OUTPUT: Atomic forces. Must be pre-allocated. Units are energy/distance (Hartree/Bohr when using standard D3 parameters). coord_num : torch.Tensor, shape (num_atoms,), dtype=float32 OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated. virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 OUTPUT: Virial tensor. Must be pre-allocated. Units are energy (Hartree when using standard D3 parameters). k1 : float, optional CN counting function steepness parameter, in inverse distance units (typically 16.0 1/Bohr for atomic units) k3 : float, optional CN interpolation Gaussian width parameter (typically -4.0, dimensionless) s6 : float, optional C6 term scaling factor (typically 1.0, dimensionless) s5_smoothing_on : float, optional Distance where S5 switching begins, in same units as positions. Default: 1e10 s5_smoothing_off : float, optional Distance where S5 switching completes, in same units as positions. Default: 1e10 fill_value : int | None, optional Value indicating padding in neighbor_matrix. If None, defaults to num_atoms. batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional Batch indices. If None, all atoms are in a single system (batch 0). cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64, optional Unit cell lattice vectors for PBC, in same dtype and units as positions. neighbor_matrix_shifts : torch.Tensor, shape (num_atoms, max_neighbors, 3), dtype=int32, optional Integer unit cell shifts for PBC. device : str, optional Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from positions. Returns ------- None Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True) Notes ----- - All input tensors should use consistent units. Standard D3 parameters use atomic units (Bohr for distances, Hartree for energy). - Float32 or float64 precision for positions and cell; outputs always float32 - Padding atoms indicated by numbers[i] == 0 - **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body Axilrod-Teller-Muto (ATM/C9) terms are not included - Bulk stress tensor can be obtained by dividing virial by system volume. See Also -------- :func:`dftd3` : Higher-level wrapper that handles allocation :class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters """ # Ensure all parameters are on correct device/dtype covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32) r4r2 = r4r2.to(device=positions.device, dtype=torch.float32) c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32) coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32) # Get shapes num_atoms = positions.size(0) max_neighbors = neighbor_matrix.size(1) if num_atoms > 0 else 0 # Set fill_value if not provided if fill_value is None: fill_value = num_atoms # Handle empty case if num_atoms == 0: return # Infer device from positions if not provided if device is None: device = str(positions.device) # Zero output tensors energy.zero_() forces.zero_() coord_num.zero_() virial.zero_() # Detect dtype and set appropriate Warp vector types if positions.dtype == torch.float64: vec_dtype = wp.vec3d mat_dtype = wp.mat33d wp_dtype = wp.float64 else: vec_dtype = wp.vec3f mat_dtype = wp.mat33f wp_dtype = wp.float32 # Create batch indices if not provided (single system) if batch_idx is None: batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device) # Convert PyTorch tensors to Warp arrays (detach positions) positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True) numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True) neighbor_matrix_wp = wp.from_torch( neighbor_matrix, dtype=wp.int32, return_ctype=True ) batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True) # Convert parameter tensors to Warp arrays (ensure float32) covalent_radii_wp = wp.from_torch( covalent_radii.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) r4r2_wp = wp.from_torch( r4r2.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) c6_reference_wp = wp.from_torch( c6_reference.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) coord_num_ref_wp = wp.from_torch( coord_num_ref.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) # Precompute inv_w for S5 switching if s5_smoothing_off > s5_smoothing_on: inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on) else: inv_w = 0.0 # Pass 0: Handle PBC - determine if periodic and compute cartesian shifts if cell is not None and neighbor_matrix_shifts is not None: # Case 1: Cell and neighbor_matrix_shifts provided - compute cartesian shifts periodic = True # Detach and convert cell cell_wp = wp.from_torch( cell.detach().to(dtype=positions.dtype, device=positions.device), dtype=mat_dtype, return_ctype=True, ) # Convert unit shifts to vec3i format unit_shifts_wp = wp.from_torch( neighbor_matrix_shifts.to(dtype=torch.int32, device=positions.device), dtype=wp.vec3i, return_ctype=True, ) # Create output array for cartesian shifts [num_atoms, max_neighbors, 3] cartesian_shifts = torch.empty( (num_atoms, max_neighbors, 3), dtype=positions.dtype, device=positions.device, ) cartesian_shifts_wp = wp.from_torch( cartesian_shifts, dtype=vec_dtype, return_ctype=True ) wp.launch( kernel=_compute_cartesian_shifts_overload[wp_dtype], dim=(num_atoms, max_neighbors), inputs=[ cell_wp, unit_shifts_wp, neighbor_matrix_wp, batch_idx_wp, wp.int32(fill_value), ], outputs=[cartesian_shifts_wp], device=device, ) else: # Case 2: No PBC - create zero shifts array (not used but need correct shape for kernel) periodic = False cartesian_shifts = torch.zeros( (num_atoms, max_neighbors, 3), dtype=positions.dtype, device=positions.device, ) cartesian_shifts_wp = wp.from_torch( cartesian_shifts, dtype=vec_dtype, return_ctype=True ) # Convert pre-allocated output arrays to Warp coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True) forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True) energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True) virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True) # Allocate dE_dCN array for chain rule computation (internal temporary) dE_dCN = torch.zeros( # NOSONAR (S125) "math formula" num_atoms, dtype=torch.float32, device=positions.device ) dE_dCN_wp = wp.from_torch( # NOSONAR (S125) "math formula" dE_dCN, dtype=wp.float32, return_ctype=True ) # Pass 1: Compute coordination numbers wp.launch( kernel=_cn_kernel_nm_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, neighbor_matrix_wp, cartesian_shifts_wp, covalent_radii_wp, wp.float32(k1), wp.int32(fill_value), periodic, ], outputs=[coord_num_wp], device=device, ) # Pass 2: Compute direct forces, energy, and accumulate dE/dCN wp.launch( kernel=_direct_forces_and_dE_dCN_kernel_nm_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, neighbor_matrix_wp, cartesian_shifts_wp, coord_num_wp, r4r2_wp, c6_reference_wp, coord_num_ref_wp, wp.float32(k3), wp.float32(a1), wp.float32(a2), wp.float32(s6), wp.float32(s8), wp.float32(s5_smoothing_on), wp.float32(s5_smoothing_off), wp.float32(inv_w), wp.int32(fill_value), periodic, batch_idx_wp, wp.bool(compute_virial), ], outputs=[dE_dCN_wp, forces_wp, energy_wp, virial_wp], device=device, ) # Pass 3: Add CN-dependent force contribution wp.launch( kernel=_cn_forces_contrib_kernel_nm_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, neighbor_matrix_wp, cartesian_shifts_wp, covalent_radii_wp, dE_dCN_wp, wp.float32(k1), wp.int32(fill_value), periodic, batch_idx_wp, wp.bool(compute_virial), ], outputs=[forces_wp, virial_wp], device=device, ) @torch.library.custom_op( "nvalchemiops::dftd3_nl", mutates_args=("energy", "forces", "coord_num", "virial"), ) def _dftd3_nl_op( positions: torch.Tensor, numbers: torch.Tensor, idx_j: torch.Tensor, neighbor_ptr: torch.Tensor, covalent_radii: torch.Tensor, r4r2: torch.Tensor, c6_reference: torch.Tensor, coord_num_ref: torch.Tensor, a1: float, a2: float, s8: float, energy: torch.Tensor, forces: torch.Tensor, coord_num: torch.Tensor, virial: torch.Tensor, k1: float = 16.0, k3: float = -4.0, s6: float = 1.0, s5_smoothing_on: float = 1e10, s5_smoothing_off: float = 1e10, batch_idx: torch.Tensor | None = None, cell: torch.Tensor | None = None, unit_shifts: torch.Tensor | None = None, compute_virial: bool = False, device: str | None = None, ) -> None: """Internal custom op for DFT-D3(BJ) using CSR neighbor list format. This is a low-level custom operator that performs DFT-D3(BJ) dispersion calculations using CSR (Compressed Sparse Row) neighbor list format with idx_j (destination indices) and neighbor_ptr (row pointers). Output tensors must be pre-allocated by the caller and are modified in-place. For most use cases, prefer the higher-level :func:`dftd3` wrapper function instead of calling this method directly. This function is torch.compile compatible. Parameters ---------- positions : torch.Tensor, shape (num_atoms, 3) Atomic coordinates as float32 or float64 numbers : torch.Tensor, shape (num_atoms), dtype=int32 Atomic numbers idx_j : torch.Tensor, shape (num_edges,), dtype=int32 Destination atom indices (flattened neighbor list in CSR format) neighbor_ptr : torch.Tensor, shape (num_atoms+1,), dtype=int32 CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1] gives neighbors of atom i covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32 Covalent radii indexed by atomic number r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32 <r⁴>/<r²> expectation values c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32 C6 reference values coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32 CN reference grid a1 : float Becke-Johnson damping parameter 1 a2 : float Becke-Johnson damping parameter 2 s8 : float C8 term scaling factor energy : torch.Tensor, shape (num_systems,), dtype=float32 OUTPUT: Total dispersion energy forces : torch.Tensor, shape (num_atoms, 3), dtype=float32 OUTPUT: Atomic forces coord_num : torch.Tensor, shape (num_atoms,), dtype=float32 OUTPUT: Coordination numbers virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 OUTPUT: Virial tensor. Must be pre-allocated. Units are energy (Hartree when using standard D3 parameters). k1 : float, optional CN counting function steepness parameter k3 : float, optional CN interpolation Gaussian width parameter s6 : float, optional C6 term scaling factor s5_smoothing_on : float, optional Distance where S5 switching begins s5_smoothing_off : float, optional Distance where S5 switching completes batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional Batch indices cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64, optional Unit cell lattice vectors for PBC, in same dtype and units as positions. unit_shifts : torch.Tensor, shape (num_edges, 3), dtype=int32, optional Integer unit cell shifts for PBC device : str, optional Warp device string Returns ------- None Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True) Notes ----- - All input tensors should use consistent units. Standard D3 parameters use atomic units (Bohr for distances, Hartree for energy). - Float32 or float64 precision for positions and cell; outputs always float32 - Padding atoms indicated by numbers[i] == 0 - **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body Axilrod-Teller-Muto (ATM/C9) terms are not included - Bulk stress tensor can be obtained by dividing virial by system volume. See Also -------- :func:`dftd3` : Higher-level wrapper that handles allocation :class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters """ # Ensure all parameters are on correct device/dtype covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32) r4r2 = r4r2.to(device=positions.device, dtype=torch.float32) c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32) coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32) # Get shapes num_atoms = positions.size(0) num_edges = idx_j.size(0) # Handle empty case if num_atoms == 0 or num_edges == 0: return # Infer device from positions if not provided if device is None: device = str(positions.device) # Zero output tensors energy.zero_() forces.zero_() coord_num.zero_() virial.zero_() # Detect dtype and set appropriate Warp vector types if positions.dtype == torch.float64: vec_dtype = wp.vec3d wp_dtype = wp.float64 else: vec_dtype = wp.vec3f wp_dtype = wp.float32 # Create batch indices if not provided (single system) if batch_idx is None: batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device) # Convert PyTorch tensors to Warp arrays positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True) numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True) idx_j_wp = wp.from_torch(idx_j, dtype=wp.int32, return_ctype=True) neighbor_ptr_wp = wp.from_torch(neighbor_ptr, dtype=wp.int32, return_ctype=True) batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True) # Convert parameter tensors to Warp arrays covalent_radii_wp = wp.from_torch( covalent_radii.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) r4r2_wp = wp.from_torch( r4r2.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) c6_reference_wp = wp.from_torch( c6_reference.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) coord_num_ref_wp = wp.from_torch( coord_num_ref.to(dtype=torch.float32, device=positions.device), dtype=wp.float32, return_ctype=True, ) # Precompute inv_w for S5 switching if s5_smoothing_off > s5_smoothing_on: inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on) else: inv_w = 0.0 # Pass 0: Handle PBC - compute cartesian shifts if needed if unit_shifts is not None and cell is not None: # PBC case - convert unit shifts to Cartesian using Warp kernel periodic = True # Convert cell to Warp cell_wp = wp.from_torch( cell.detach().to(dtype=positions.dtype, device=positions.device), dtype=wp.mat33d if positions.dtype == torch.float64 else wp.mat33f, return_ctype=True, ) # Convert unit shifts to vec3i format unit_shifts_wp = wp.from_torch( unit_shifts.to(dtype=torch.int32, device=positions.device), dtype=wp.vec3i, return_ctype=True, ) # Convert neighbor_ptr to Warp neighbor_ptr_wp_shifts = wp.from_torch( neighbor_ptr, dtype=wp.int32, return_ctype=True ) # Create batch_idx if not provided (single system) if batch_idx is None: batch_idx_shifts = torch.zeros( num_atoms, dtype=torch.int32, device=positions.device ) else: batch_idx_shifts = batch_idx batch_idx_shifts_wp = wp.from_torch( batch_idx_shifts, dtype=wp.int32, return_ctype=True ) # Allocate output for Cartesian shifts cartesian_shifts = torch.empty( (num_edges, 3), dtype=positions.dtype, device=positions.device, ) cartesian_shifts_wp = wp.from_torch( cartesian_shifts, dtype=vec_dtype, return_ctype=True ) # Launch kernel to compute Cartesian shifts wp.launch( kernel=_compute_cartesian_shifts_nl_overload[wp_dtype], dim=num_atoms, inputs=[ cell_wp, unit_shifts_wp, neighbor_ptr_wp_shifts, batch_idx_shifts_wp, ], outputs=[cartesian_shifts_wp], device=device, ) else: # Non-periodic case - create zero shifts periodic = False cartesian_shifts = torch.zeros( (num_edges, 3), dtype=positions.dtype, device=positions.device, ) # Convert cartesian shifts to Warp cartesian_shifts_wp = wp.from_torch( cartesian_shifts, dtype=vec_dtype, return_ctype=True ) # Convert pre-allocated output arrays to Warp coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True) forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True) energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True) virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True) # Allocate dE_dCN array for chain rule computation dE_dCN = torch.zeros( # NOSONAR (S125) "math formula" num_atoms, dtype=torch.float32, device=positions.device ) dE_dCN_wp = wp.from_torch( # NOSONAR (S125) "math formula" dE_dCN, dtype=wp.float32, return_ctype=True ) # Pass 1: Compute coordination numbers wp.launch( kernel=_cn_kernel_nl_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, idx_j_wp, neighbor_ptr_wp, cartesian_shifts_wp, covalent_radii_wp, wp.float32(k1), periodic, ], outputs=[coord_num_wp], device=device, ) # Pass 2: Compute direct forces, energy, and accumulate dE/dCN wp.launch( kernel=_direct_forces_and_dE_dCN_kernel_nl_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, idx_j_wp, neighbor_ptr_wp, cartesian_shifts_wp, coord_num_wp, r4r2_wp, c6_reference_wp, coord_num_ref_wp, wp.float32(k3), wp.float32(a1), wp.float32(a2), wp.float32(s6), wp.float32(s8), wp.float32(s5_smoothing_on), wp.float32(s5_smoothing_off), wp.float32(inv_w), periodic, batch_idx_wp, wp.bool(compute_virial), ], outputs=[dE_dCN_wp, forces_wp, energy_wp, virial_wp], device=device, ) # Pass 3: Add CN-dependent force contribution wp.launch( kernel=_cn_forces_contrib_kernel_nl_overload[wp_dtype], dim=num_atoms, inputs=[ positions_wp, numbers_wp, idx_j_wp, neighbor_ptr_wp, cartesian_shifts_wp, covalent_radii_wp, dE_dCN_wp, wp.float32(k1), periodic, batch_idx_wp, wp.bool(compute_virial), ], outputs=[forces_wp, virial_wp], device=device, )
[docs] def dftd3( positions: torch.Tensor, numbers: torch.Tensor, a1: float, a2: float, s8: float, k1: float = 16.0, k3: float = -4.0, s6: float = 1.0, s5_smoothing_on: float = 1e10, s5_smoothing_off: float = 1e10, fill_value: int | None = None, d3_params: D3Parameters | dict[str, torch.Tensor] | None = None, covalent_radii: torch.Tensor | None = None, r4r2: torch.Tensor | None = None, c6_reference: torch.Tensor | None = None, coord_num_ref: torch.Tensor | None = None, batch_idx: torch.Tensor | None = None, cell: torch.Tensor | None = None, neighbor_matrix: torch.Tensor | None = None, neighbor_matrix_shifts: torch.Tensor | None = None, neighbor_list: torch.Tensor | None = None, neighbor_ptr: torch.Tensor | None = None, unit_shifts: torch.Tensor | None = None, compute_virial: bool = False, num_systems: int | None = None, device: str | None = None, ) -> ( tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ): """ Compute DFT-D3(BJ) dispersion energy and forces using Warp with optional periodic boundary condition support and smoothing function. **DFT-D3 parameters must be explicitly provided** using one of three methods: 1. **D3Parameters dataclass**: Supply a :class:`D3Parameters` instance (recommended). Individual parameters can override dataclass values if both are provided. 2. **Explicit parameters**: Supply all four parameters individually: ``covalent_radii``, ``r4r2``, ``c6_reference``, and ``coord_num_ref``. 3. **Dictionary**: Provide a ``d3_params`` dictionary with keys: ``"rcov"``, ``"r4r2"``, ``"c6ab"``, and ``"cn_ref"``. Individual parameters can override dictionary values if both are provided. See ``examples/interactions/utils.py`` for parameter generation utilities. This wrapper can be launched by either supplying a neighbor matrix or a neighbor list, both of which can be generated by the :func:`nvalchemiops.neighborlist.neighbor_list` function where the latter can be returned by setting the `return_neighbor_list` parameter to True. Parameters ---------- positions : torch.Tensor Atomic coordinates [num_atoms, 3] as float32 or float64, in consistent distance units (conventionally Bohr when using standard D3 parameters) numbers : torch.Tensor Atomic numbers [num_atoms] as int32 a1 : float Becke-Johnson damping parameter 1 (functional-dependent, dimensionless) a2 : float Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions s8 : float C8 term scaling factor (functional-dependent, dimensionless) k1 : float, optional CN counting function steepness parameter, in inverse distance units (typically 16.0 1/Bohr for atomic units) k3 : float, optional CN interpolation Gaussian width parameter (typically -4.0, dimensionless) s6 : float, optional C6 term scaling factor (typically 1.0, dimensionless) s5_smoothing_on : float, optional Distance where S5 switching begins, in same units as positions. Set greater or equal to s5_smoothing_off to disable smoothing. Default: 1e10 s5_smoothing_off : float, optional Distance where S5 switching completes, in same units as positions. Default: 1e10 (effectively no cutoff) fill_value : int | None, optional Value indicating padding in neighbor_matrix. If None, defaults to num_atoms. Entries with neighbor_matrix[i, k] >= fill_value are treated as padding. Default: None d3_params : D3Parameters | dict[str, torch.Tensor] | None, optional DFT-D3 parameters provided as either: - :class:`D3Parameters` dataclass instance (recommended) - Dictionary with keys: "rcov", "r4r2", "c6ab", "cn_ref" Individual parameters below can override values from d3_params. covalent_radii : torch.Tensor | None, optional Covalent radii [max_Z+1] as float32, indexed by atomic number, in same units as positions. If provided, overrides the value in d3_params. r4r2 : torch.Tensor | None, optional <r4>/<r2> expectation values [max_Z+1] as float32 for C8 computation (dimensionless). If provided, overrides the value in d3_params. c6_reference : torch.Tensor | None, optional C6 reference values [max_Z+1, max_Z+1, 5, 5] as float32 in energy × distance^6 units. If provided, overrides the value in d3_params. coord_num_ref : torch.Tensor | None, optional CN reference grid [max_Z+1, max_Z+1, 5, 5] as float32 (dimensionless). If provided, overrides the value in d3_params. batch_idx : torch.Tensor or None, optional Batch indices [num_atoms] as int32. If None, all atoms are assumed to be in a single system (batch 0). For batched calculations, atoms with the same batch index belong to the same system. Default: None cell : torch.Tensor or None, optional, as float32 or float64 Unit cell lattice vectors [num_systems, 3, 3] for PBC, in same dtype and units as positions. Convention: cell[s, i, :] is i-th lattice vector for system s. If None, non-periodic calculation. Default: None neighbor_matrix : torch.Tensor | None, optional Neighbor indices [num_atoms, max_neighbors] as int32. See module docstring for details on the format. Padding entries have values >= fill_value. Mutually exclusive with neighbor_list. Default: None neighbor_matrix_shifts : torch.Tensor or None, optional Integer unit cell shifts [num_atoms, max_neighbors, 3] as int32 for PBC with neighbor_matrix format. If None, non-periodic calculation. If provided along with cell, Cartesian shifts are computed. Mutually exclusive with unit_shifts. Default: None neighbor_list : torch.Tensor or None, optional Neighbor pairs [2, num_pairs] as int32 in COO format, where row 0 contains source atom indices and row 1 contains target atom indices. Alternative to neighbor_matrix for sparse neighbor representations. Mutually exclusive with neighbor_matrix. Must be used together with `neighbor_ptr` (both are returned by the neighbor list API when `return_neighbor_list=True`). Default: None neighbor_ptr : torch.Tensor or None, optional CSR row pointers [num_atoms+1] as int32. Required when using `neighbor_list`. Indicates that `neighbor_list[1, :]` contains destination atoms in CSR format where `neighbor_ptr[i]:neighbor_ptr[i+1]` gives the range of neighbors for atom i. Returned by the neighbor list API when `return_neighbor_list=True`. Default: None unit_shifts : torch.Tensor or None, optional Integer unit cell shifts [num_pairs, 3] as int32 for PBC with neighbor_list format. If None, non-periodic calculation. If provided along with cell, Cartesian shifts are computed. Mutually exclusive with neighbor_matrix_shifts. Default: None compute_virial : bool, optional If True, allocate and compute virial tensor. Ignored if virial parameter is provided. Default: False num_systems : int, optional Number of systems in batch. In none provided, inferred from cell or from batch_idx (introcudes CUDA synchronization overhead). Default: None device : str or None, optional Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from positions tensor. Default: None Returns ------- energy : torch.Tensor Total dispersion energy [num_systems] as float32. Units are energy (Hartree when using standard D3 parameters). forces : torch.Tensor Atomic forces [num_atoms, 3] as float32. Units are energy/distance (Hartree/Bohr when using standard D3 parameters). coord_num : torch.Tensor Coordination numbers [num_atoms] as float32 (dimensionless) virial : torch.Tensor, optional Virial tensor [num_systems, 3, 3] as float32. Units are energy (Hartree when using standard D3 parameters). Only returned if compute_virial=True. Notes ----- - **Unit consistency**: All inputs must use consistent units. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree for energy), so using atomic units throughout is recommended and conventional. - Float32 or float64 precision for positions and cell; outputs always float32 - **Neighbor formats**: Supports both neighbor_matrix (dense) and neighbor_list (sparse COO) formats. Choose neighbor_list for sparse systems or when memory efficiency is important. - Padding atoms indicated by numbers[i] == 0 - Requires symmetric neighbor representation (each pair appears twice) - **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body Axilrod-Teller-Muto (ATM/C9) terms are not included - Virial computation requires periodic boundary conditions. - Bulk stress tensor can be obtained by dividing virial by system volume. **Neighbor Format Selection**: - Use neighbor_matrix for dense systems or when max_neighbors is small - Use neighbor_list for sparse systems, large cutoffs, or memory-constrained scenarios - Both formats produce identical results and support PBC **PBC Handling**: - Matrix format: Provide cell and neighbor_matrix_shifts - List format: Provide cell and unit_shifts - Non-periodic: Omit both cell and shift parameters See Also -------- :class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters :func:`_dftd3_nm_op` : Internal custom operator for neighbor matrix format :func:`_dftd3_nl_op` : Internal custom operator for neighbor list format :func:`_compute_cartesian_shifts` : Pass 0 - Converts unit cell shifts to Cartesian :func:`_cn_kernel_nm` : Pass 1 - Computes coordination numbers (neighbor matrix) :func:`_cn_kernel_nl` : Pass 1 - Computes coordination numbers (neighbor list) :func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - neighbor matrix format :func:`_direct_forces_and_dE_dCN_kernel_nl` : Pass 2 - neighbor list format :func:`_cn_forces_contrib_kernel_nm` : Pass 3 - neighbor matrix format :func:`_cn_forces_contrib_kernel_nl` : Pass 3 - neighbor list format """ # Validate neighbor format inputs matrix_provided = neighbor_matrix is not None list_provided = neighbor_list is not None if matrix_provided and list_provided: raise ValueError( "Cannot provide both neighbor_matrix and neighbor_list. " "Please provide only one neighbor representation format." ) if not matrix_provided and not list_provided: raise ValueError("Must provide either neighbor_matrix or neighbor_list.") # Validate PBC shift inputs match neighbor format if matrix_provided and unit_shifts is not None: raise ValueError( "unit_shifts is for neighbor_list format. " "Use neighbor_matrix_shifts for neighbor_matrix format." ) if list_provided and neighbor_matrix_shifts is not None: raise ValueError( "neighbor_matrix_shifts is for neighbor_matrix format. " "Use unit_shifts for neighbor_list format." ) # Validate neighbor_ptr is provided when using neighbor_list format if list_provided and neighbor_ptr is None: raise ValueError( "neighbor_ptr must be provided when using neighbor_list format. " "Obtain it from the neighbor list API by setting return_neighbor_list=True." ) # Validate functional parameters if a1 is None or a2 is None or s8 is None: raise ValueError( "Functional parameters a1, a2, and s8 must be provided. " "These are functional-dependent parameters required for DFT-D3(BJ) calculations." ) # Validate virial computation requires PBC if compute_virial: if cell is None: raise ValueError( "Virial computation requires periodic boundary conditions. " "Please provide unit cell parameters (cell) and shifts " "(neighbor_matrix_shifts or unit_shifts) when compute_virial=True " "or when passing a virial tensor." ) if matrix_provided and neighbor_matrix_shifts is None: raise ValueError( "Virial computation requires periodic boundary conditions. " "Please provide neighbor_matrix_shifts along with cell when using " "neighbor_matrix format and compute_virial=True or passing a virial tensor." ) if list_provided and unit_shifts is None: raise ValueError( "Virial computation requires periodic boundary conditions. " "Please provide unit_shifts along with cell when using " "neighbor_list format and compute_virial=True or passing a virial tensor." ) # Determine how parameters are being supplied # Case 1: All individual parameters provided explicitly if all( param is not None for param in [covalent_radii, r4r2, c6_reference, coord_num_ref] ): # Use explicit parameters directly (already assigned) pass # Case 2: d3_params provided (D3Parameters or dictionary, with optional overrides) elif d3_params is not None: # Convert D3Parameters to dictionary for consistent access if isinstance(d3_params, D3Parameters): d3_params = d3_params.__dict__ # these are written to throw KeyError if the keys are not present if covalent_radii is None: covalent_radii = d3_params["rcov"] if r4r2 is None: r4r2 = d3_params["r4r2"] if c6_reference is None: c6_reference = d3_params["c6ab"] if coord_num_ref is None: coord_num_ref = d3_params["cn_ref"] # Case 3: No parameters provided - raise error else: raise RuntimeError( "DFT-D3 parameters must be explicitly provided. " "Either supply all individual parameters (covalent_radii, r4r2, " "c6_reference, coord_num_ref), provide a D3Parameters instance, " "or provide a d3_params dictionary. See the function docstring for details." ) # Get shapes num_atoms = positions.size(0) # Handle empty case if num_atoms == 0: if batch_idx is None or ( isinstance(batch_idx, torch.Tensor) and batch_idx.numel() == 0 ): num_systems = 1 else: num_systems = int(batch_idx.max().item()) + 1 empty_energy = torch.zeros( num_systems, dtype=torch.float32, device=positions.device ) empty_forces = torch.zeros((0, 3), dtype=torch.float32, device=positions.device) empty_cn = torch.zeros((0,), dtype=torch.float32, device=positions.device) # Handle virial for empty case if compute_virial is True if compute_virial: empty_virial = torch.zeros( (0, 3, 3), dtype=torch.float32, device=positions.device ) return empty_energy, empty_forces, empty_cn, empty_virial else: return empty_energy, empty_forces, empty_cn # Determine number of systems for energy allocation if num_systems is None: if batch_idx is None: num_systems = 1 elif cell is not None: num_systems = cell.size(0) else: num_systems = int(batch_idx.max().item()) + 1 # Allocate output tensors energy = torch.zeros(num_systems, dtype=torch.float32, device=positions.device) forces = torch.zeros((num_atoms, 3), dtype=torch.float32, device=positions.device) coord_num = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device) if compute_virial: virial = torch.zeros( (num_systems, 3, 3), dtype=torch.float32, device=positions.device ) else: virial = torch.zeros((0, 3, 3), dtype=torch.float32, device=positions.device) # Dispatch to appropriate implementation based on neighbor format if neighbor_matrix is not None: # Matrix format - call custom op _dftd3_nm_op( positions=positions, numbers=numbers, neighbor_matrix=neighbor_matrix, covalent_radii=covalent_radii, r4r2=r4r2, c6_reference=c6_reference, coord_num_ref=coord_num_ref, a1=a1, a2=a2, s8=s8, energy=energy, forces=forces, coord_num=coord_num, k1=k1, k3=k3, s6=s6, s5_smoothing_on=s5_smoothing_on, s5_smoothing_off=s5_smoothing_off, fill_value=fill_value, batch_idx=batch_idx, cell=cell, neighbor_matrix_shifts=neighbor_matrix_shifts, virial=virial, compute_virial=compute_virial, device=device, ) else: # List format - use CSR format from neighbor list API # neighbor_list: [2, num_pairs] in COO format where row 1 is idx_j (destination atoms) # neighbor_ptr: [num_atoms+1] CSR row pointers (required, from neighbor list API) # Extract idx_j from neighbor_list (row 1 contains destination atoms) idx_j_csr = neighbor_list[1] _dftd3_nl_op( positions=positions, numbers=numbers, idx_j=idx_j_csr, neighbor_ptr=neighbor_ptr, covalent_radii=covalent_radii, r4r2=r4r2, c6_reference=c6_reference, coord_num_ref=coord_num_ref, a1=a1, a2=a2, s8=s8, energy=energy, forces=forces, coord_num=coord_num, k1=k1, k3=k3, s6=s6, s5_smoothing_on=s5_smoothing_on, s5_smoothing_off=s5_smoothing_off, batch_idx=batch_idx, cell=cell, unit_shifts=unit_shifts, virial=virial, compute_virial=compute_virial, device=device, ) if compute_virial: return energy, forces, coord_num, virial else: return energy, forces, coord_num