# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DFT-D3(BJ) Dispersion Correction - Warp Kernel Implementation
This module implements the DFT-D3 dispersion correction with Becke-Johnson (BJ)
damping as Warp GPU/CPU kernels. The implementation provides efficient computation
of dispersion energies and forces using a multi-pass algorithm with support for
periodic boundary conditions, batched systems, and smooth cutoff functions.
For detailed theory, usage examples, and parameter setup, see the
:doc:`DFT-D3 User Guide </userguide/components/dispersion>`.
Multi-Pass Kernel Architecture
-------------------------------
The implementation uses four kernel passes to efficiently handle the chain rule
dependency in force calculations:
1. **Pass 0 (_compute_cartesian_shifts)**: [PBC only] Convert unit cell shifts
to Cartesian coordinates
2. **Pass 1 (_cn_kernel)**: Compute coordination numbers using geometric counting
function
3. **Pass 2 (_direct_forces_and_dE_dCN_kernel)**: Compute C6 interpolation,
dispersion energy, direct forces, and accumulate :math:`\\partial E/\\partial \\text{CN}`
4. **Pass 3 (_cn_forces_contrib_kernel)**: Add CN-dependent force contribution
using precomputed :math:`\\partial E/\\partial \\text{CN}` values
PyTorch Interface
-----------------
Use :func:`dftd3` as the main entry point for PyTorch integration:
.. code-block:: python
from nvalchemiops.interactions.dispersion.dftd3 import (
dftd3,
D3Parameters,
)
# Using neighbor matrix format
energy, forces, coord_num = dftd3(
positions, numbers,
neighbor_matrix=neighbor_matrix,
a1=0.3981, a2=4.4211, s8=1.9889,
d3_params=d3_params, # D3Parameters instance or dict
cell=cell, # Optional for PBC
neighbor_matrix_shifts=neighbor_matrix_shifts, # Optional for PBC
)
# Using neighbor list format (sparse COO)
energy, forces, coord_num = dftd3(
positions, numbers,
neighbor_list=neighbor_list, # shape (2, num_pairs)
a1=0.3981, a2=4.4211, s8=1.9889,
d3_params=d3_params,
cell=cell, # Optional for PBC
unit_shifts=unit_shifts, # Optional for PBC, shape (num_pairs, 3)
)
Data Structure Requirements
---------------------------
**Neighbor Formats**
The implementation supports two neighbor representation formats:
1. **Neighbor Matrix Format** (dense): `[num_atoms, max_neighbors]` where
`neighbor_matrix[i, k]` is the k-th neighbor of atom i. Padding entries use
values >= `fill_value` (typically `num_atoms`).
2. **Neighbor List Format** (sparse COO): `[2, num_pairs]` where row 0 contains
source atom indices and row 1 contains target atom indices. No padding needed.
Both formats can be generated by :func:`nvalchemiops.neighborlist.neighbor_list` using
the `return_neighbor_list` parameter.
**Parameter Arrays**
- `covalent_radii`: `[max_Z+1]` float32
- `r4r2`: `[max_Z+1]` float32
- `c6_reference`: `[max_Z+1, max_Z+1, 5, 5]` float32
- `coord_num_ref`: `[max_Z+1, max_Z+1, 5, 5]` float32
Index 0 reserved for padding; valid atomic numbers 1 to max_Z.
**Periodic Boundary Conditions**
- `cell`: `[num_systems, 3, 3]` lattice vectors (row format)
- For neighbor matrix: `neighbor_matrix_shifts`: `[num_atoms, max_neighbors, 3]` int32 unit cell shifts
- For neighbor list: `unit_shifts`: `[num_pairs, 3]` int32 unit cell shifts
Units
-----
Kernels are **unit-agnostic** but require consistency. Standard Grimme group
parameters use **atomic units (Bohr, Hartree)**, which is recommended:
- Positions, covalent radii, `a2`, cutoffs: Bohr
- Energy output: Hartree
- Forces output: Hartree/Bohr
- Parameter `k1`: 1/Bohr
Technical Notes
---------------
- Supports float32 and float64 positions and cell. Outputs are always float32
- **Two-body only**: Axilrod-Teller-Muto (C9) three-body terms not included
See Also
--------
:class:`D3Parameters` : Dataclass for parameter validation and management
:func:`dftd3` : Main PyTorch interface function
:doc:`/userguide/components/dispersion` : Complete user guide with theory and examples
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
import warp as wp
__all__ = [
"D3Parameters",
"dftd3",
]
# ==============================================================================
# Data Structures
# ==============================================================================
[docs]
@dataclass
class D3Parameters:
"""
DFT-D3 reference parameters for dispersion correction calculations.
This dataclass encapsulates all element-specific parameters required for
DFT-D3 dispersion corrections. The main purpose for this structure is to
provide validation, ensuring the correct shapes, dtypes, and keys are
present and complete. These parameters are used by :func:`dftd3`.
Parameters
----------
rcov : torch.Tensor
Covalent radii [max_Z+1] as float32 or float64. Units should be consistent
with position coordinates. Index 0 is reserved for
padding; valid atomic numbers are 1 to max_Z.
r4r2 : torch.Tensor
<r⁴>/<r²> expectation values [max_Z+1] as float32 or float64.
Dimensionless ratio used for computing C8 coefficients from C6 values.
c6ab : torch.Tensor
C6 reference coefficients [max_Z+1, max_Z+1, interp_mesh, interp_mesh]
as float32 or float64. Units are energy x distance^6. Indexed by atomic numbers and coordination number reference indices.
cn_ref : torch.Tensor
Coordination number reference grid [max_Z+1, max_Z+1, interp_mesh, interp_mesh]
as float32 or float64. Dimensionless CN values for Gaussian interpolation.
interp_mesh : int, optional
Size of the coordination number interpolation mesh. Default: 5
(standard DFT-D3 uses a 5x5 grid)
Raises
------
ValueError
If parameter shapes are inconsistent or invalid
TypeError
If parameters are not torch.Tensor or have invalid dtypes
Notes
-----
- Parameters should use consistent units matching your coordinate system.
Standard D3 parameters from the Grimme group use atomic units (Bohr for
distances, Hartree x Bohr^6 for C6 coefficients).
- Index 0 in all arrays is reserved for padding atoms (atomic number 0)
- Valid atomic numbers range from 1 to max_z
- The standard DFT-D3 implementation supports elements 1-94 (H to Pu)
- Parameters can be float32 or float64; they will be converted to float32
during computation for efficiency
Examples
--------
Create parameters from individual tensors:
>>> params = D3Parameters(
... rcov=torch.rand(95), # 94 elements + padding
... r4r2=torch.rand(95),
... c6ab=torch.rand(95, 95, 5, 5),
... cn_ref=torch.rand(95, 95, 5, 5),
... )
Create from a dictionary (e.g., loaded from file):
>>> state_dict = torch.load("dftd3_parameters.pt")
>>> params = D3Parameters(
... rcov=state_dict["rcov"],
... r4r2=state_dict["r4r2"],
... c6ab=state_dict["c6ab"],
... cn_ref=state_dict["cn_ref"],
... )
"""
rcov: torch.Tensor
r4r2: torch.Tensor
c6ab: torch.Tensor
cn_ref: torch.Tensor
interp_mesh: int = 5
def __post_init__(self) -> None:
"""Validate parameter shapes, dtypes, and physical constraints."""
# Type validation
for name, tensor in [
("rcov", self.rcov),
("r4r2", self.r4r2),
("c6ab", self.c6ab),
("cn_ref", self.cn_ref),
]:
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Parameter '{name}' must be a torch.Tensor, got {type(tensor)}"
)
if tensor.dtype not in (torch.float32, torch.float64):
raise TypeError(
f"Parameter '{name}' must be float32 or float64, got {tensor.dtype}"
)
# Shape validation
if self.rcov.ndim != 1:
raise ValueError(
f"rcov must be 1D tensor [max_Z+1], got shape {self.rcov.shape}"
)
max_z = self.rcov.size(0) - 1
if max_z < 1:
raise ValueError(
f"rcov must have at least 2 elements (padding + 1 element), got {self.rcov.size(0)}"
)
if self.r4r2.shape != (max_z + 1,):
raise ValueError(
f"r4r2 must have shape [{max_z + 1}] to match rcov, got {self.r4r2.shape}"
)
expected_c6_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh)
if self.c6ab.shape != expected_c6_shape:
raise ValueError(
f"c6ab must have shape {expected_c6_shape}, got {self.c6ab.shape}"
)
expected_cn_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh)
if self.cn_ref.shape != expected_cn_shape:
raise ValueError(
f"cn_ref must have shape {expected_cn_shape}, got {self.cn_ref.shape}"
)
# Device consistency validation
devices = [
self.rcov.device,
self.r4r2.device,
self.c6ab.device,
self.cn_ref.device,
]
if len({str(d) for d in devices}) > 1:
raise ValueError(
f"All parameters must be on the same device. "
f"Got devices: rcov={self.rcov.device}, r4r2={self.r4r2.device}, "
f"c6ab={self.c6ab.device}, cn_ref={self.cn_ref.device}"
)
@property
def max_z(self) -> int:
"""Maximum atomic number supported by these parameters."""
return self.rcov.size(0) - 1
@property
def device(self) -> torch.device:
"""Device where parameters are stored."""
return self.rcov.device
[docs]
def to(
self,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
) -> D3Parameters:
"""
Move all parameters to the specified device and/or convert to specified dtype.
Parameters
----------
device : str or torch.device or None, optional
Target device (e.g., 'cpu', 'cuda', 'cuda:0'). If None, keeps current device.
dtype : torch.dtype or None, optional
Target dtype (e.g., torch.float32, torch.float64). If None, keeps current dtype.
Returns
-------
D3Parameters
New instance with parameters on the target device and/or dtype
Examples
--------
Move to GPU:
>>> params_gpu = params.to(device='cuda')
Convert to float32:
>>> params_f32 = params.to(dtype=torch.float32)
Move to GPU and convert to float32:
>>> params_gpu_f32 = params.to(device='cuda', dtype=torch.float32)
"""
return D3Parameters(
rcov=self.rcov.to(device=device, dtype=dtype),
r4r2=self.r4r2.to(device=device, dtype=dtype),
c6ab=self.c6ab.to(device=device, dtype=dtype),
cn_ref=self.cn_ref.to(device=device, dtype=dtype),
interp_mesh=self.interp_mesh,
)
# ==============================================================================
# Helper Functions
# ==============================================================================
@wp.func
def _s5_switch(
r: wp.float32,
r_on: wp.float32,
r_off: wp.float32,
inv_w: wp.float32,
) -> tuple[wp.float32, wp.float32]:
"""
C² smooth switching function for cutoff smoothing.
This function provides a smooth transition from 1 to 0 over the interval
[r_on, r_off]. The switching polynomial S5(t) has continuous first and
second derivatives at the boundaries.
Parameters
----------
r : float32
Distance between atoms
r_on : float32
Distance where switching begins
r_off : float32
Distance where switching completes
inv_w : float32
Precomputed 1/(r_off - r_on) for efficiency
Returns
-------
Sw : float32
Switching function value Sw(r) ∈ [0, 1]
dSw_dr : float32
Derivative of switching function with respect to r
Notes
-----
The switching function is defined as:
.. math::
S_w(r) = \\begin{cases}
1 & \\text{if } r \\leq r_{\\text{on}} \\\\
1 - S_5(t) & \\text{if } r_{\\text{on}} < r < r_{\\text{off}} \\\\
0 & \\text{if } r \\geq r_{\\text{off}}
\\end{cases}
where :math:`t = (r - r_{\\text{on}})/(r_{\\text{off}} - r_{\\text{on}}) \\in (0,1)` and
.. math::
S_5(t) = 10t^3 - 15t^4 + 6t^5
The derivative is:
.. math::
\\frac{dS_5}{dt} = 30t^2 - 60t^3 + 30t^4
\\frac{dS_w}{dr} = -\\frac{dS_5}{dt} \\cdot \\frac{1}{r_{\\text{off}} - r_{\\text{on}}}
This ensures :math:`C^2` continuity (continuous function, first, and second derivatives)
at both :math:`r_{\\text{on}}` and :math:`r_{\\text{off}}` boundaries.
See Also
--------
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Uses this switching function
for cutoff smoothing (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel_nl` : Uses this switching function
for cutoff smoothing (neighbor list)
"""
if r_off <= r_on:
# disabled or degenerate: no switching
return 1.0, 0.0
if r <= r_on:
return 1.0, 0.0
if r >= r_off:
return 0.0, 0.0
t = (r - r_on) * inv_w # t in (0,1)
t2 = t * t
t3 = t2 * t
t4 = t3 * t
t5 = t4 * t
switch = 1.0 - (10.0 * t3 - 15.0 * t4 + 6.0 * t5)
dSdt = -30.0 * t2 + 60.0 * t3 - 30.0 * t4 # NOSONAR (S125) "math formula"
dSw_dr = dSdt * inv_w # NOSONAR (S125) "math formula"
return switch, dSw_dr
@wp.func
def _c6ab_interpolate(
cn_i: wp.float32,
cn_j: wp.float32,
c6ab_mat: wp.array2d(dtype=wp.float32),
cnref_i_mat: wp.array2d(dtype=wp.float32),
cnref_j_mat: wp.array2d(dtype=wp.float32),
k3: wp.float32,
) -> tuple[wp.float32, wp.float32, wp.float32]:
"""
Interpolate C6 coefficient and CN derivatives using Gaussian weighting.
This function performs Gaussian interpolation over a 5x5 reference grid
to compute the environment-dependent C6 coefficient for an atom pair,
along with derivatives with respect to coordination numbers.
Parameters
----------
cn_i : float32
Coordination number of atom i
cn_j : float32
Coordination number of atom j
c6ab_mat : wp.array2d(dtype=float32)
C6 reference values [5, 5] for this element pair
cnref_i_mat : wp.array2d(dtype=float32)
CN reference grid [5, 5] for atom i
cnref_j_mat : wp.array2d(dtype=float32)
CN reference grid [5, 5] for atom j
k3 : float32
Gaussian width parameter (typically -4.0)
Returns
-------
c6_ij : float32
Interpolated C6 coefficient
dC6_dCNi : float32
Derivative :math:`\\partial C_6/\\partial \text{CN}_i`
dC6_dCNj : float32
Derivative :math:`\\partial C_6/\\partial \text{CN}_j`
Notes
-----
The Gaussian weights are:
.. math::
L_{pq} = \\exp\\left(-k_3 \\left[(\\text{CN}_i - \\text{CN}_{\\text{ref},i}[p,q])^2 +
(\\text{CN}_j - \\text{CN}_{\\text{ref},j}[p,q])^2\\right]\\right)
The interpolated C6 and derivatives are:
.. math::
C_6 = \\frac{\\sum_{pq} C_6^{\\text{ref}}[p,q] L_{pq}}{\\sum_{pq} L_{pq}}
\\frac{\\partial C_6}{\\partial \\text{CN}_i} = \\frac{2k_3}{w} (z_{d_i} - C_6 w_{d_i})
where accumulators :math:`w`, :math:`z`, :math:`w_{d_i}`, :math:`z_{d_i}` are
computed in a single pass over the 5x5 grid.
See Also
--------
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Calls this function for C6
coefficient interpolation (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel_nl` : Calls this function for C6
coefficient interpolation (neighbor list)
"""
# log-sum-exp trick to avoid numerical instability
max_exp = wp.float(-1e20)
for p in range(5):
for q in range(5):
c6_val = c6ab_mat[p, q]
if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel"
continue
cnref_i = cnref_i_mat[p, q]
cnref_j = cnref_j_mat[q, p]
di = cn_i - cnref_i
dj = cn_j - cnref_j
exp_arg = k3 * (di * di + dj * dj)
if exp_arg > max_exp:
max_exp = exp_arg
w = float(0.0)
z = float(0.0)
w_di = float(0.0)
w_dj = float(0.0)
z_di = float(0.0)
z_dj = float(0.0)
for p in range(5):
for q in range(5):
c6_val = c6ab_mat[p, q]
if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel"
continue
cnref_i = cnref_i_mat[p, q]
cnref_j = cnref_j_mat[q, p] # Note transpose indexing
di = cn_i - cnref_i
dj = cn_j - cnref_j
# Compute exponent argument and skip negligible contributions
exp_arg = k3 * (di * di + dj * dj) - max_exp
if exp_arg < -12.0:
continue
L = wp.exp(exp_arg)
w += L
z += c6_val * L
w_di += L * di
w_dj += L * dj
z_di += c6_val * L * di
z_dj += c6_val * L * dj
eps_w = 1e-12
if w > eps_w:
w_inv = 1.0 / w
c6_ij = z * w_inv
s_i = z_di - c6_ij * w_di
s_j = z_dj - c6_ij * w_dj
k3_w_w_inv = (2.0 * k3) * w_inv
dC6_dCNi = k3_w_w_inv * s_i # NOSONAR (S125) "math formula"
dC6_dCNj = k3_w_w_inv * s_j # NOSONAR (S125) "math formula"
return c6_ij, dC6_dCNi, dC6_dCNj
else:
return 0.0, 0.0, 0.0
@wp.func
def _compute_distance_vector_pbc(
pos_i: Any,
pos_j: Any,
cartesian_shift: Any,
periodic: bool,
compute_vectors: bool,
) -> tuple[wp.float32, wp.float32, wp.vec3f, wp.vec3f]:
"""
Compute distance with optional PBC and vector outputs.
Parameters
----------
pos_i, pos_j : vec3
Atomic positions
cartesian_shift : vec3
PBC shift (ignored if periodic=False)
periodic : bool
Apply PBC shift
compute_vectors : bool
If True, compute r_hat; if False, r_hat returns zero vector
Returns
-------
r : float32
Distance
r_inv : float32
Inverse distance (0 if r < 1e-12)
r_hat : vec3f
Unit vector (zero vec if compute_vectors=False or r < 1e-12)
r_ij : vec3f
Distance vector (always returned)
"""
if periodic:
r_ij_native = (pos_j - pos_i) + cartesian_shift
else:
r_ij_native = pos_j - pos_i
r_ij = wp.vec3f(
wp.float32(r_ij_native[0]),
wp.float32(r_ij_native[1]),
wp.float32(r_ij_native[2]),
)
r = wp.length(r_ij)
if r < 1e-12:
return r, wp.float32(0.0), wp.vec3f(0.0, 0.0, 0.0), r_ij
r_inv = 1.0 / r
if compute_vectors:
r_hat = r_ij * r_inv
return r, r_inv, r_hat, r_ij
else:
return r, r_inv, wp.vec3f(0.0, 0.0, 0.0), r_ij
@wp.func
def _cn_counting(
r_inv: wp.float32,
rcov_i: wp.float32,
rcov_j: wp.float32,
k1: wp.float32,
compute_derivative: bool,
) -> tuple[wp.float32, wp.float32]:
"""
Compute CN counting function with optional derivative.
Parameters
----------
r_inv : float32
Inverse distance
rcov_i, rcov_j : float32
Covalent radii
k1 : float32
Steepness parameter
compute_derivative : bool
If True, compute dCN_dr; if False, returns None
Returns
-------
f_cn : float32
Counting function value
dCN_dr : float32
Derivative (zero if compute_derivative=False)
"""
rcov_ij = rcov_i + rcov_j
rcov_r_inv = rcov_ij * r_inv
f_cn = 1.0 / (1.0 + wp.exp(-k1 * (rcov_r_inv - 1.0)))
if compute_derivative:
dCN_dr = -f_cn * (1.0 - f_cn) * k1 * rcov_r_inv * r_inv # NOSONAR (S125)
return f_cn, dCN_dr
else:
return f_cn, wp.float32(0.0)
@wp.func
def _bj_damping(
r: wp.float32,
r4r2_i: wp.float32,
r4r2_j: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
) -> tuple[wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32]:
"""
Compute Becke-Johnson damping.
Returns
-------
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv : float32
"""
r4r2_ij = 3.0 * r4r2_i * r4r2_j
r0 = a1 * wp.sqrt(r4r2_ij) + a2
r2 = r * r
r4 = r2 * r2
r6 = r4 * r2
r8 = r4 * r4
r0_2 = r0 * r0
r0_4 = r0_2 * r0_2
r0_6 = r0_4 * r0_2
r0_8 = r0_4 * r0_4
den6 = r6 + r0_6
den8 = r8 + r0_8
den6_inv = 1.0 / den6
den8_inv = 1.0 / den8
damp_6 = s6 * den6_inv
damp_8 = s8 * r4r2_ij * den8_inv
damp_sum = damp_6 + damp_8
return damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv
@wp.func
def _dispersion_energy_force(
c6_ij: wp.float32,
r: wp.float32,
r_hat: wp.vec3f,
damp_sum: wp.float32,
r4r2_ij: wp.float32,
r6: wp.float32,
r4: wp.float32,
den6_inv: wp.float32,
den8_inv: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
) -> tuple[wp.float32, wp.vec3f]:
"""
Compute dispersion energy and direct force with S5 switching.
Returns
-------
e_ij_sw : float32
Smoothed energy
F_direct : vec3f
Direct force vector
"""
e_ij = -c6_ij * damp_sum
r5 = r4 * r
r7 = r6 * r
dD6_dr = -6.0 * s6 * r5 * den6_inv * den6_inv # NOSONAR (S125) "math formula"
dD8_dr = -8.0 * s8 * r4r2_ij * r7 * den8_inv * den8_inv # NOSONAR (S125)
dE_dr_direct = -c6_ij * (dD6_dr + dD8_dr) # NOSONAR (S125) "math formula"
sw, dsw_dr = _s5_switch(r, s5_smoothing_on, s5_smoothing_off, inv_w)
e_ij_sw = e_ij * sw
dE_dr_direct_sw = sw * dE_dr_direct + e_ij * dsw_dr # NOSONAR (S125)
F_direct = dE_dr_direct_sw * r_hat # NOSONAR (S125) "math formula"
return e_ij_sw, F_direct
@wp.func
def _unit_shift_to_cartesian(
unit_shift: wp.vec3i,
cell_mat: Any,
) -> Any:
"""Convert integer unit cell shift to Cartesian coordinates."""
unit_shift_float = wp.vec(
cell_mat[0].dtype(unit_shift[0]),
cell_mat[0].dtype(unit_shift[1]),
cell_mat[0].dtype(unit_shift[2]),
)
return unit_shift_float * cell_mat
# ==============================================================================
# Kernels
# ==============================================================================
@wp.kernel(enable_backward=False)
def _compute_cartesian_shifts(
cell: wp.array(dtype=Any),
unit_shifts: wp.array2d(dtype=wp.vec3i),
neighbor_matrix: wp.array2d(dtype=wp.int32),
batch_idx: wp.array(dtype=wp.int32),
fill_value: wp.int32,
cartesian_shifts: wp.array2d(dtype=Any),
):
"""
Convert unit cell shifts to Cartesian coordinates for periodic boundaries.
For each neighbor in the neighbor matrix, this kernel computes the Cartesian
shift vector that should be applied to atom j's position to obtain its
periodic image closest to atom i.
Parameters
----------
cell : wp.array3d(dtype=float32)
Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :]
is the i-th lattice vector for system s (row vectors). Units should match
position coordinates.
unit_shifts : wp.array2d(dtype=vec3i)
Integer unit cell shifts [num_atoms, max_neighbors] as vec3i
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
batch_idx : wp.array(dtype=int32)
System index [num_atoms] for each atom
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
cartesian_shifts : wp.array2d(dtype=vec3f)
Output: Cartesian shift vectors [num_atoms, max_neighbors] as vec3 in same
units as cell vectors
Notes
-----
The Cartesian shift is computed as:
.. math::
\\mathbf{s} = n_a \\mathbf{a} + n_b \\mathbf{b} + n_c \\mathbf{c}
where :math:`\\mathbf{a}, \\mathbf{b}, \\mathbf{c}` are lattice vectors
and :math:`n_a, n_b, n_c` are integer shifts. The system ID is obtained
from atom i's batch index.
Launch with dim=(num_atoms, max_neighbors) (one thread per atom-neighbor pair).
See Also
--------
:func:`_cn_kernel_nm` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor matrix)
:func:`_cn_kernel_nl` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor list)
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Uses computed
Cartesian shifts for PBC (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel_nl` : Pass 2 - Uses computed
Cartesian shifts for PBC (neighbor list)
:func:`_cn_forces_contrib_kernel_nm` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor matrix)
:func:`_cn_forces_contrib_kernel_nl` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor list)
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i, neighbor_idx = wp.tid()
max_neighbors = neighbor_matrix.shape[1]
if neighbor_idx >= max_neighbors:
return
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
return
system_id = batch_idx[atom_i]
cell_mat = cell[system_id]
unit_shift = unit_shifts[atom_i, neighbor_idx]
cartesian_shifts[atom_i, neighbor_idx] = _unit_shift_to_cartesian(
unit_shift, cell_mat
)
@wp.kernel(enable_backward=False)
def _cn_kernel_nm(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
k1: wp.float32,
fill_value: wp.int32,
periodic: bool,
coord_num: wp.array(dtype=wp.float32),
):
"""
Compute coordination numbers using geometric counting function.
This kernel computes the coordination number (CN) for each atom based on
a smooth counting function that depends on interatomic distances and
covalent radii. Supports periodic boundary conditions via Cartesian shifts.
Algorithm
---------
For each atom i, iterate over its neighbors (neighbor matrix format) and accumulate:
.. math::
\\text{CN}_i = \\sum_{j \\in \\text{neighbors}(i)} f(r_{ij})
f(r) = \\frac{1}{1 + \\exp\\left[k_1\\left(\\frac{r_{\\text{cov}}}{r} - 1\\right)\\right]}
where :math:`r_{\\text{cov}} = r_{\\text{cov}}[Z_i] + r_{\\text{cov}}[Z_j]`.
The counting function smoothly transitions from 1 (bonded) to 0 (non-bonded).
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC (ignored if periodic=False), in same units as positions
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] indexed by atomic number, in same units as positions
k1 : float32
Steepness parameter for counting function (typically 16.0 1/Bohr)
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts; if False, non-periodic
coord_num : wp.array(dtype=float32)
Output: coordination numbers [num_atoms] (dimensionless)
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates CN in a local register
- Padding atoms indicated by numbers[i] == 0 are skipped
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Uses coordination numbers
computed here
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate coordination number in local register
cn_acc = wp.float32(0.0)
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
# skip padding
if numbers[atom_j] == 0:
continue
# Compute distance with optional PBC shift
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
False,
)
if r < 1e-12:
continue
# Compute coordination number contribution
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False
)
cn_acc += f_cn
# Write final coordination number once
coord_num[atom_i] = cn_acc
@wp.kernel(enable_backward=False)
def _direct_forces_and_dE_dCN_kernel_nm( # NOSONAR (S1542) "math formula"
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
coord_num: wp.array(dtype=wp.float32),
r4r2: wp.array(dtype=wp.float32),
c6_reference: wp.array4d(dtype=wp.float32),
coord_num_ref: wp.array4d(dtype=wp.float32),
k3: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
fill_value: wp.int32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
forces: wp.array(dtype=wp.vec3f),
energy: wp.array(dtype=wp.float32),
virial: wp.array(dtype=Any),
):
"""
Pass 2: Compute direct forces, energy, and accumulate dE/dCN per atom.
Computes dispersion energy and forces at constant CN, and accumulates
dE/dCN contributions for each atom for use in Pass 3.
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions
coord_num : wp.array(dtype=float32)
Coordination numbers [num_atoms] from Pass 1 (dimensionless)
r4r2 : wp.array(dtype=float32)
<r⁴>/<r²> expectation values [max_Z+1] (dimensionless)
c6_reference : wp.array4d(dtype=float32)
C6 reference [max_Z+1, max_Z+1, 5, 5] in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32)
CN reference grid [max_Z+1, max_Z+1, 5, 5] (dimensionless)
k3 : float32
CN interpolation width (typically -4.0, dimensionless)
a1, a2 : float32
Becke-Johnson damping parameters (a1 dimensionless, a2 in distance units)
s6, s8 : float32
Scaling factors for C6 and C8 terms (dimensionless)
s5_smoothing_on, s5_smoothing_off : float32
S5 switching radii in same units as positions
inv_w : float32
Precomputed 1/(s5_off - s5_on) in inverse distance units
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts
batch_idx : wp.array(dtype=int32)
System index [num_atoms]
dE_dCN : wp.array(dtype=float32)
Output: accumulated dE/dCN [num_atoms] in energy units
forces : wp.array(dtype=vec3f)
Output: direct forces [num_atoms] in energy/distance units
energy : wp.array(dtype=float32)
Output: system energies [num_systems] in energy units
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates results in local registers
- Direct forces are F = :math:`-(\\partial E/\\partial r)|_\text{CN}`, without chain rule term
- dE_dCN[i] = :math:`\\sum_j \\partial E_{ij}/\\partial \text{CN}_i` accumulated over all pairs containing atom i
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_cn_kernel_nm` : Pass 1 - Computes coordination numbers used here
:func:`_cn_forces_contrib_kernel_nm` : Pass 3 - Uses dE/dCN values accumulated here
:func:`_c6ab_interpolate` : Called to interpolate C6 coefficients
:func:`_s5_switch` : Called for cutoff smoothing
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
pos_i = positions[atom_i]
cn_i = coord_num[atom_i]
z_i = numbers[atom_i]
r4r2_i = r4r2[z_i]
# Accumulate in local registers (using float64 for better precision)
F_acc = wp.vec3d() # NOSONAR (S117) "math formula"
dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula"
energy_acc = wp.float64(0.0)
# Initialize virial accumulator
if compute_virial:
virial_acc = wp.mat33d()
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Geometry
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
True,
)
if r < 1e-12:
continue
cn_j = coord_num[atom_j]
z_j = numbers[atom_j]
# C6 interpolation
c6ab_mat = c6_reference[z_i, z_j]
cnref_i_mat = coord_num_ref[z_i, z_j]
cnref_j_mat = coord_num_ref[z_j, z_i]
c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula"
cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3
)
if c6_ij < 1e-12:
continue
# BJ damping
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping(
r, r4r2_i, r4r2[z_j], a1, a2, s6, s8
)
# Energy and direct force
e_ij_sw, F_direct = _dispersion_energy_force(
c6_ij,
r,
r_hat,
damp_sum,
r4r2_ij,
r6,
r4,
den6_inv,
den8_inv,
s6,
s8,
s5_smoothing_on,
s5_smoothing_off,
inv_w,
)
# Accumulate in registers
F_acc += wp.vec3d(F_direct) # NOSONAR (S117) "math formula"
energy_acc += wp.float64(e_ij_sw)
dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula"
# Accumulate virial if requested
if compute_virial:
virial_acc += wp.mat33d(wp.outer(F_direct, r_ij))
# Write final results once (atomic only for shared batch array)
# Convert from float64 accumulation to float32 output
forces[atom_i] = wp.vec3f(F_acc)
dE_dCN[atom_i] = dE_dCN_acc
wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc))
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc))
@wp.kernel(enable_backward=False)
def _cn_forces_contrib_kernel_nm(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
k1: wp.float32,
fill_value: wp.int32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
forces: wp.array(dtype=wp.vec3f),
virial: wp.array(dtype=Any),
):
"""
Pass 3: Add CN-dependent force contribution.
Adds the CN-dependent term to forces computed in Pass 2. Computes
distances and CN derivatives without repeating C6 interpolation and
damping calculations.
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms] in consistent distance units
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] in same units as positions
dE_dCN : wp.array(dtype=float32)
Precomputed dE/dCN [num_atoms] from Pass 2 in energy units
k1 : float32
CN counting steepness in inverse distance units (typically 16.0 1/Bohr)
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts
forces : wp.array(dtype=vec3f)
Input/Output: add chain term to direct forces [num_atoms] in energy/distance units
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates results in local registers
- Skips C6 interpolation and damping calculations
- Uses precomputed dE_dCN[i] = :math:`\\sum_k \\partial E_{ik}/\\partial \text{CN}_i` from all pairs
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - Computes dE/dCN values used here
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula"
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate force in local register (using float64 for better precision)
F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula"
# Initialize virial accumulator
if compute_virial:
virial_chain_acc = wp.mat33d()
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
if numbers[atom_j] == 0:
continue
# Distance
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
True,
)
if r < 1e-12:
continue
# CN derivative
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True
)
# CN-dependent force contribution
dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula"
dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula"
F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula"
F_chain_acc += wp.vec3d(F_chain)
# Accumulate virial if requested
if compute_virial:
virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij))
# Add accumulated force to existing forces (direct read-modify-write)
# Convert from float64 accumulation to float32 output
forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc)
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc))
# ==============================================================================
# Neighbor List Kernels
# ==============================================================================
@wp.kernel(enable_backward=False)
def _compute_cartesian_shifts_nl(
cell: wp.array(dtype=Any),
unit_shifts: wp.array(dtype=wp.vec3i),
neighbor_ptr: wp.array(dtype=wp.int32),
batch_idx: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
):
"""
Convert unit cell shifts to Cartesian coordinates for CSR neighbor lists.
For each edge in the CSR neighbor list, this kernel computes the Cartesian
shift vector that should be applied to the destination atom's position.
Parameters
----------
cell : wp.array(dtype=mat33)
Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :]
is the i-th lattice vector for system s (row vectors). Units should match
position coordinates.
unit_shifts : wp.array(dtype=vec3i)
Integer unit cell shifts [num_edges] as vec3i
neighbor_ptr : wp.array(dtype=int32)
CSR row pointers [num_atoms+1]
batch_idx : wp.array(dtype=int32)
System index [num_atoms] for each atom
cartesian_shifts : wp.array(dtype=vec3)
Output: Cartesian shift vectors [num_edges] as vec3 in same units as cell vectors
Notes
-----
Launch with dim=num_atoms (one thread per atom). Each thread processes all edges
for that atom using the CSR pointers.
See Also
--------
:func:`_cn_kernel_nl` : Uses computed Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel_nl` : Uses computed Cartesian shifts for PBC
:func:`_cn_forces_contrib_kernel_nl` : Uses computed Cartesian shifts for PBC
"""
atom_i = wp.tid()
# Get number of atoms from batch_idx size
if atom_i >= batch_idx.shape[0]:
return
system_id = batch_idx[atom_i]
cell_mat = cell[system_id]
# Get range of edges for this atom
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
# Convert all unit shifts for this atom's neighbors to Cartesian
for edge_idx in range(j_range_start, j_range_end):
unit_shift = unit_shifts[edge_idx]
cartesian_shifts[edge_idx] = _unit_shift_to_cartesian(unit_shift, cell_mat)
@wp.kernel(enable_backward=False)
def _cn_kernel_nl(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
k1: wp.float32,
periodic: bool,
coord_num: wp.array(dtype=wp.float32),
):
"""
Compute coordination numbers using CSR neighbor list format.
Parameters
----------
positions : wp.array(dtype=vec3)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
idx_j : wp.array(dtype=int32)
Destination atom indices [num_edges] in CSR format
neighbor_ptr : wp.array(dtype=int32)
CSR row pointers [num_atoms+1]
cartesian_shifts : wp.array(dtype=vec3)
Cartesian shifts [num_edges] as vec3 for PBC
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] indexed by atomic number
k1 : float32
Steepness parameter for counting function
periodic : bool
If True, apply PBC using cartesian_shifts
coord_num : wp.array(dtype=float32)
Output: coordination numbers [num_atoms]
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate coordination number in local register
cn_acc = wp.float32(0.0)
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Compute distance with optional PBC shift
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, False
)
if r < 1e-12:
continue
# Compute coordination number contribution
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False
)
cn_acc += f_cn
# Write final coordination number once
coord_num[atom_i] = cn_acc
@wp.kernel(enable_backward=False)
def _direct_forces_and_dE_dCN_kernel_nl( # NOSONAR (S1542) "math formula"
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
coord_num: wp.array(dtype=wp.float32),
r4r2: wp.array(dtype=wp.float32),
c6_reference: wp.array4d(dtype=wp.float32),
coord_num_ref: wp.array4d(dtype=wp.float32),
k3: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
forces: wp.array(dtype=wp.vec3f),
energy: wp.array(dtype=wp.float32),
virial: wp.array(dtype=Any),
):
"""
Pass 2: Compute direct forces, energy, and accumulate dE/dCN using
CSR neighbor list.
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
pos_i = positions[atom_i]
cn_i = coord_num[atom_i]
z_i = numbers[atom_i]
r4r2_i = r4r2[z_i]
# Accumulate in local registers (using float64 for better precision)
F_acc = wp.vec3d() # NOSONAR (S117) "math formula"
dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula"
energy_acc = wp.float64(0.0)
# Initialize virial accumulator
if compute_virial:
virial_acc = wp.mat33d()
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Geometry
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True
)
if r < 1e-12:
continue
cn_j = coord_num[atom_j]
z_j = numbers[atom_j]
# C6 interpolation
c6ab_mat = c6_reference[z_i, z_j]
cnref_i_mat = coord_num_ref[z_i, z_j]
cnref_j_mat = coord_num_ref[z_j, z_i]
c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula"
cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3
)
if c6_ij < 1e-12:
continue
# BJ damping
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping(
r, r4r2_i, r4r2[z_j], a1, a2, s6, s8
)
# Energy and direct force
e_ij_sw, F_direct = _dispersion_energy_force(
c6_ij,
r,
r_hat,
damp_sum,
r4r2_ij,
r6,
r4,
den6_inv,
den8_inv,
s6,
s8,
s5_smoothing_on,
s5_smoothing_off,
inv_w,
)
# Accumulate in registers
F_acc += wp.vec3d(F_direct)
energy_acc += wp.float64(e_ij_sw)
dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula"
# Accumulate virial if requested
if compute_virial:
virial_acc += wp.mat33d(wp.outer(F_direct, r_ij))
# Write final results once (atomic only for shared batch array)
# Convert from float64 accumulation to float32 output
forces[atom_i] = wp.vec3f(F_acc)
dE_dCN[atom_i] = wp.float32(dE_dCN_acc)
wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc))
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc))
@wp.kernel(enable_backward=False)
def _cn_forces_contrib_kernel_nl(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
k1: wp.float32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
forces: wp.array(dtype=wp.vec3f),
virial: wp.array(dtype=Any),
):
"""
Pass 3: Add CN-dependent force contribution using CSR neighbor list.
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula"
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate force in local register (using float64 for better precision)
F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula"
# Initialize virial accumulator
if compute_virial:
virial_chain_acc = wp.mat33d()
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
if numbers[atom_j] == 0:
continue
# Distance
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True
)
if r < 1e-12:
continue
# CN derivative
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True
)
# CN-dependent force contribution
dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula"
dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula"
F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula"
F_chain_acc += wp.vec3d(F_chain)
# Accumulate virial if requested
if compute_virial:
virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij))
# Add accumulated force to existing forces (direct read-modify-write)
# Convert from float64 accumulation to float32 output
forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc)
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc))
# ==============================================================================
# Kernel Overload Registration
# ==============================================================================
# Type constants for overload generation
T = [wp.float32, wp.float64]
V = [wp.vec3f, wp.vec3d]
M = [wp.mat33f, wp.mat33d]
# Overload dictionaries keyed by scalar type
_compute_cartesian_shifts_overload = {}
_cn_kernel_nm_overload = {}
_direct_forces_and_dE_dCN_kernel_nm_overload = {}
_cn_forces_contrib_kernel_nm_overload = {}
# Neighbor list kernel overload dictionaries (CSR format)
_compute_cartesian_shifts_nl_overload = {}
_cn_kernel_nl_overload = {}
_direct_forces_and_dE_dCN_kernel_nl_overload = {}
_cn_forces_contrib_kernel_nl_overload = {}
# Register overloads for all kernel variants
for t, v, m in zip(T, V, M):
_compute_cartesian_shifts_overload[t] = wp.overload(
_compute_cartesian_shifts,
[
wp.array(dtype=m),
wp.array2d(dtype=wp.vec3i),
wp.array2d(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.int32,
wp.array2d(dtype=v),
],
)
_cn_kernel_nm_overload[t] = wp.overload(
_cn_kernel_nm,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.float32),
],
)
_direct_forces_and_dE_dCN_kernel_nm_overload[t] = wp.overload(
_direct_forces_and_dE_dCN_kernel_nm,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.float32),
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.mat33f),
],
)
_cn_forces_contrib_kernel_nm_overload[t] = wp.overload(
_cn_forces_contrib_kernel_nm,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.mat33f),
],
)
# Neighbor list kernel overloads (CSR format)
_compute_cartesian_shifts_nl_overload[t] = wp.overload(
_compute_cartesian_shifts_nl,
[
wp.array(dtype=m),
wp.array(dtype=wp.vec3i),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
],
)
_cn_kernel_nl_overload[t] = wp.overload(
_cn_kernel_nl,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.float32,
wp.bool,
wp.array(dtype=wp.float32),
],
)
_direct_forces_and_dE_dCN_kernel_nl_overload[t] = wp.overload(
_direct_forces_and_dE_dCN_kernel_nl,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.float32),
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.mat33f),
],
)
_cn_forces_contrib_kernel_nl_overload[t] = wp.overload(
_cn_forces_contrib_kernel_nl,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.float32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.mat33f),
],
)
# ==============================================================================
# PyTorch Wrapper
# ==============================================================================
@torch.library.custom_op(
"nvalchemiops::dftd3_nm",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_nm_op(
positions: torch.Tensor,
numbers: torch.Tensor,
neighbor_matrix: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
fill_value: int | None = None,
batch_idx: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
neighbor_matrix_shifts: torch.Tensor | None = None,
compute_virial: bool = False,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) dispersion energy and forces computation.
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using Warp kernels. Output tensors must be pre-allocated by
the caller and are modified in-place. For most use cases, prefer the
higher-level :func:`dftd3` wrapper function instead of calling
this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64, in consistent distance units
(conventionally Bohr)
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
neighbor_matrix : torch.Tensor, shape (num_atoms, max_neighbors), dtype=int32
Neighbor indices. See module docstring for format details.
Padding entries have values >= fill_value.
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number, in same units as positions
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values in energy x distance^6 units
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces. Must be pre-allocated. Units are energy/distance
(Hartree/Bohr when using standard D3 parameters).
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated.
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr for atomic units)
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless)
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless)
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 1e10
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 1e10
fill_value : int | None, optional
Value indicating padding in neighbor_matrix. If None, defaults to num_atoms.
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices. If None, all atoms are in a single system (batch 0).
cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64, optional
Unit cell lattice vectors for PBC, in same dtype and units as positions.
neighbor_matrix_shifts : torch.Tensor, shape (num_atoms, max_neighbors, 3), dtype=int32, optional
Integer unit cell shifts for PBC.
device : str, optional
Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from positions.
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions and cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Bulk stress tensor can be obtained by dividing virial by system volume.
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
max_neighbors = neighbor_matrix.size(1) if num_atoms > 0 else 0
# Set fill_value if not provided
if fill_value is None:
fill_value = num_atoms
# Handle empty case
if num_atoms == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp vector types
if positions.dtype == torch.float64:
vec_dtype = wp.vec3d
mat_dtype = wp.mat33d
wp_dtype = wp.float64
else:
vec_dtype = wp.vec3f
mat_dtype = wp.mat33f
wp_dtype = wp.float32
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays (detach positions)
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
neighbor_matrix_wp = wp.from_torch(
neighbor_matrix, dtype=wp.int32, return_ctype=True
)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays (ensure float32)
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
# Pass 0: Handle PBC - determine if periodic and compute cartesian shifts
if cell is not None and neighbor_matrix_shifts is not None:
# Case 1: Cell and neighbor_matrix_shifts provided - compute cartesian shifts
periodic = True
# Detach and convert cell
cell_wp = wp.from_torch(
cell.detach().to(dtype=positions.dtype, device=positions.device),
dtype=mat_dtype,
return_ctype=True,
)
# Convert unit shifts to vec3i format
unit_shifts_wp = wp.from_torch(
neighbor_matrix_shifts.to(dtype=torch.int32, device=positions.device),
dtype=wp.vec3i,
return_ctype=True,
)
# Create output array for cartesian shifts [num_atoms, max_neighbors, 3]
cartesian_shifts = torch.empty(
(num_atoms, max_neighbors, 3),
dtype=positions.dtype,
device=positions.device,
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
wp.launch(
kernel=_compute_cartesian_shifts_overload[wp_dtype],
dim=(num_atoms, max_neighbors),
inputs=[
cell_wp,
unit_shifts_wp,
neighbor_matrix_wp,
batch_idx_wp,
wp.int32(fill_value),
],
outputs=[cartesian_shifts_wp],
device=device,
)
else:
# Case 2: No PBC - create zero shifts array (not used but need correct shape for kernel)
periodic = False
cartesian_shifts = torch.zeros(
(num_atoms, max_neighbors, 3),
dtype=positions.dtype,
device=positions.device,
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate dE_dCN array for chain rule computation (internal temporary)
dE_dCN = torch.zeros( # NOSONAR (S125) "math formula"
num_atoms, dtype=torch.float32, device=positions.device
)
dE_dCN_wp = wp.from_torch( # NOSONAR (S125) "math formula"
dE_dCN, dtype=wp.float32, return_ctype=True
)
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_nm_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
neighbor_matrix_wp,
cartesian_shifts_wp,
covalent_radii_wp,
wp.float32(k1),
wp.int32(fill_value),
periodic,
],
outputs=[coord_num_wp],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_nm_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
neighbor_matrix_wp,
cartesian_shifts_wp,
coord_num_wp,
r4r2_wp,
c6_reference_wp,
coord_num_ref_wp,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
wp.int32(fill_value),
periodic,
batch_idx_wp,
wp.bool(compute_virial),
],
outputs=[dE_dCN_wp, forces_wp, energy_wp, virial_wp],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_nm_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
neighbor_matrix_wp,
cartesian_shifts_wp,
covalent_radii_wp,
dE_dCN_wp,
wp.float32(k1),
wp.int32(fill_value),
periodic,
batch_idx_wp,
wp.bool(compute_virial),
],
outputs=[forces_wp, virial_wp],
device=device,
)
@torch.library.custom_op(
"nvalchemiops::dftd3_nl",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_nl_op(
positions: torch.Tensor,
numbers: torch.Tensor,
idx_j: torch.Tensor,
neighbor_ptr: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
batch_idx: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
unit_shifts: torch.Tensor | None = None,
compute_virial: bool = False,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) using CSR neighbor list format.
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using CSR (Compressed Sparse Row) neighbor list format with
idx_j (destination indices) and neighbor_ptr (row pointers). Output tensors
must be pre-allocated by the caller and are modified in-place. For most use
cases, prefer the higher-level :func:`dftd3` wrapper
function instead of calling this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
idx_j : torch.Tensor, shape (num_edges,), dtype=int32
Destination atom indices (flattened neighbor list in CSR format)
neighbor_ptr : torch.Tensor, shape (num_atoms+1,), dtype=int32
CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1] gives neighbors of atom i
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid
a1 : float
Becke-Johnson damping parameter 1
a2 : float
Becke-Johnson damping parameter 2
s8 : float
C8 term scaling factor
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
k1 : float, optional
CN counting function steepness parameter
k3 : float, optional
CN interpolation Gaussian width parameter
s6 : float, optional
C6 term scaling factor
s5_smoothing_on : float, optional
Distance where S5 switching begins
s5_smoothing_off : float, optional
Distance where S5 switching completes
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices
cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64, optional
Unit cell lattice vectors for PBC, in same dtype and units as positions.
unit_shifts : torch.Tensor, shape (num_edges, 3), dtype=int32, optional
Integer unit cell shifts for PBC
device : str, optional
Warp device string
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions and cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Bulk stress tensor can be obtained by dividing virial by system volume.
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
num_edges = idx_j.size(0)
# Handle empty case
if num_atoms == 0 or num_edges == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp vector types
if positions.dtype == torch.float64:
vec_dtype = wp.vec3d
wp_dtype = wp.float64
else:
vec_dtype = wp.vec3f
wp_dtype = wp.float32
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
idx_j_wp = wp.from_torch(idx_j, dtype=wp.int32, return_ctype=True)
neighbor_ptr_wp = wp.from_torch(neighbor_ptr, dtype=wp.int32, return_ctype=True)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
# Pass 0: Handle PBC - compute cartesian shifts if needed
if unit_shifts is not None and cell is not None:
# PBC case - convert unit shifts to Cartesian using Warp kernel
periodic = True
# Convert cell to Warp
cell_wp = wp.from_torch(
cell.detach().to(dtype=positions.dtype, device=positions.device),
dtype=wp.mat33d if positions.dtype == torch.float64 else wp.mat33f,
return_ctype=True,
)
# Convert unit shifts to vec3i format
unit_shifts_wp = wp.from_torch(
unit_shifts.to(dtype=torch.int32, device=positions.device),
dtype=wp.vec3i,
return_ctype=True,
)
# Convert neighbor_ptr to Warp
neighbor_ptr_wp_shifts = wp.from_torch(
neighbor_ptr, dtype=wp.int32, return_ctype=True
)
# Create batch_idx if not provided (single system)
if batch_idx is None:
batch_idx_shifts = torch.zeros(
num_atoms, dtype=torch.int32, device=positions.device
)
else:
batch_idx_shifts = batch_idx
batch_idx_shifts_wp = wp.from_torch(
batch_idx_shifts, dtype=wp.int32, return_ctype=True
)
# Allocate output for Cartesian shifts
cartesian_shifts = torch.empty(
(num_edges, 3),
dtype=positions.dtype,
device=positions.device,
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
# Launch kernel to compute Cartesian shifts
wp.launch(
kernel=_compute_cartesian_shifts_nl_overload[wp_dtype],
dim=num_atoms,
inputs=[
cell_wp,
unit_shifts_wp,
neighbor_ptr_wp_shifts,
batch_idx_shifts_wp,
],
outputs=[cartesian_shifts_wp],
device=device,
)
else:
# Non-periodic case - create zero shifts
periodic = False
cartesian_shifts = torch.zeros(
(num_edges, 3),
dtype=positions.dtype,
device=positions.device,
)
# Convert cartesian shifts to Warp
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate dE_dCN array for chain rule computation
dE_dCN = torch.zeros( # NOSONAR (S125) "math formula"
num_atoms, dtype=torch.float32, device=positions.device
)
dE_dCN_wp = wp.from_torch( # NOSONAR (S125) "math formula"
dE_dCN, dtype=wp.float32, return_ctype=True
)
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_nl_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
idx_j_wp,
neighbor_ptr_wp,
cartesian_shifts_wp,
covalent_radii_wp,
wp.float32(k1),
periodic,
],
outputs=[coord_num_wp],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_nl_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
idx_j_wp,
neighbor_ptr_wp,
cartesian_shifts_wp,
coord_num_wp,
r4r2_wp,
c6_reference_wp,
coord_num_ref_wp,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
periodic,
batch_idx_wp,
wp.bool(compute_virial),
],
outputs=[dE_dCN_wp, forces_wp, energy_wp, virial_wp],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_nl_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions_wp,
numbers_wp,
idx_j_wp,
neighbor_ptr_wp,
cartesian_shifts_wp,
covalent_radii_wp,
dE_dCN_wp,
wp.float32(k1),
periodic,
batch_idx_wp,
wp.bool(compute_virial),
],
outputs=[forces_wp, virial_wp],
device=device,
)
[docs]
def dftd3(
positions: torch.Tensor,
numbers: torch.Tensor,
a1: float,
a2: float,
s8: float,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
fill_value: int | None = None,
d3_params: D3Parameters | dict[str, torch.Tensor] | None = None,
covalent_radii: torch.Tensor | None = None,
r4r2: torch.Tensor | None = None,
c6_reference: torch.Tensor | None = None,
coord_num_ref: torch.Tensor | None = None,
batch_idx: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
neighbor_matrix: torch.Tensor | None = None,
neighbor_matrix_shifts: torch.Tensor | None = None,
neighbor_list: torch.Tensor | None = None,
neighbor_ptr: torch.Tensor | None = None,
unit_shifts: torch.Tensor | None = None,
compute_virial: bool = False,
num_systems: int | None = None,
device: str | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
"""
Compute DFT-D3(BJ) dispersion energy and forces using Warp
with optional periodic boundary condition support and smoothing function.
**DFT-D3 parameters must be explicitly provided** using one of three methods:
1. **D3Parameters dataclass**: Supply a :class:`D3Parameters` instance (recommended).
Individual parameters can override dataclass values if both are provided.
2. **Explicit parameters**: Supply all four parameters individually:
``covalent_radii``, ``r4r2``, ``c6_reference``, and ``coord_num_ref``.
3. **Dictionary**: Provide a ``d3_params`` dictionary with keys:
``"rcov"``, ``"r4r2"``, ``"c6ab"``, and ``"cn_ref"``.
Individual parameters can override dictionary values if both are provided.
See ``examples/interactions/utils.py`` for parameter generation utilities.
This wrapper can be launched by either supplying a neighbor matrix or a
neighbor list, both of which can be generated by the :func:`nvalchemiops.neighborlist.neighbor_list` function where the latter can be returned by setting the `return_neighbor_list` parameter to True.
Parameters
----------
positions : torch.Tensor
Atomic coordinates [num_atoms, 3] as float32 or float64, in consistent distance
units (conventionally Bohr when using standard D3 parameters)
numbers : torch.Tensor
Atomic numbers [num_atoms] as int32
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr for atomic units)
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless)
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless)
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Set greater or
equal to s5_smoothing_off to disable smoothing. Default: 1e10
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions.
Default: 1e10 (effectively no cutoff)
fill_value : int | None, optional
Value indicating padding in neighbor_matrix. If None, defaults to num_atoms.
Entries with neighbor_matrix[i, k] >= fill_value are treated as padding. Default: None
d3_params : D3Parameters | dict[str, torch.Tensor] | None, optional
DFT-D3 parameters provided as either:
- :class:`D3Parameters` dataclass instance (recommended)
- Dictionary with keys: "rcov", "r4r2", "c6ab", "cn_ref"
Individual parameters below can override values from d3_params.
covalent_radii : torch.Tensor | None, optional
Covalent radii [max_Z+1] as float32, indexed by atomic number, in same units
as positions. If provided, overrides the value in d3_params.
r4r2 : torch.Tensor | None, optional
<r4>/<r2> expectation values [max_Z+1] as float32 for C8 computation (dimensionless).
If provided, overrides the value in d3_params.
c6_reference : torch.Tensor | None, optional
C6 reference values [max_Z+1, max_Z+1, 5, 5] as float32 in energy × distance^6 units.
If provided, overrides the value in d3_params.
coord_num_ref : torch.Tensor | None, optional
CN reference grid [max_Z+1, max_Z+1, 5, 5] as float32 (dimensionless).
If provided, overrides the value in d3_params.
batch_idx : torch.Tensor or None, optional
Batch indices [num_atoms] as int32. If None, all atoms are assumed
to be in a single system (batch 0). For batched calculations, atoms with
the same batch index belong to the same system. Default: None
cell : torch.Tensor or None, optional, as float32 or float64
Unit cell lattice vectors [num_systems, 3, 3] for PBC, in same dtype and units as positions.
Convention: cell[s, i, :] is i-th lattice vector for system s.
If None, non-periodic calculation. Default: None
neighbor_matrix : torch.Tensor | None, optional
Neighbor indices [num_atoms, max_neighbors] as int32. See module docstring for
details on the format. Padding entries have values >= fill_value.
Mutually exclusive with neighbor_list. Default: None
neighbor_matrix_shifts : torch.Tensor or None, optional
Integer unit cell shifts [num_atoms, max_neighbors, 3] as int32 for PBC with
neighbor_matrix format. If None, non-periodic calculation. If provided along
with cell, Cartesian shifts are computed. Mutually exclusive with unit_shifts.
Default: None
neighbor_list : torch.Tensor or None, optional
Neighbor pairs [2, num_pairs] as int32 in COO format, where row 0 contains
source atom indices and row 1 contains target atom indices. Alternative to
neighbor_matrix for sparse neighbor representations. Mutually exclusive with
neighbor_matrix. Must be used together with `neighbor_ptr` (both are returned
by the neighbor list API when `return_neighbor_list=True`).
Default: None
neighbor_ptr : torch.Tensor or None, optional
CSR row pointers [num_atoms+1] as int32. Required when using `neighbor_list`.
Indicates that `neighbor_list[1, :]` contains destination atoms in CSR
format where
`neighbor_ptr[i]:neighbor_ptr[i+1]` gives the range of neighbors for atom i.
Returned by the neighbor list API when `return_neighbor_list=True`.
Default: None
unit_shifts : torch.Tensor or None, optional
Integer unit cell shifts [num_pairs, 3] as int32 for PBC with neighbor_list
format. If None, non-periodic calculation. If provided along with cell,
Cartesian shifts are computed. Mutually exclusive with neighbor_matrix_shifts.
Default: None
compute_virial : bool, optional
If True, allocate and compute virial tensor. Ignored if virial
parameter is provided. Default: False
num_systems : int, optional
Number of systems in batch. In none provided, inferred from cell
or from batch_idx (introcudes CUDA synchronization overhead). Default: None
device : str or None, optional
Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from
positions tensor. Default: None
Returns
-------
energy : torch.Tensor
Total dispersion energy [num_systems] as float32. Units are energy
(Hartree when using standard D3 parameters).
forces : torch.Tensor
Atomic forces [num_atoms, 3] as float32. Units are energy/distance
(Hartree/Bohr when using standard D3 parameters).
coord_num : torch.Tensor
Coordination numbers [num_atoms] as float32 (dimensionless)
virial : torch.Tensor, optional
Virial tensor [num_systems, 3, 3] as float32.
Units are energy (Hartree when using standard D3 parameters). Only returned
if compute_virial=True.
Notes
-----
- **Unit consistency**: All inputs must use consistent units. Standard D3 parameters
from the Grimme group use atomic units (Bohr for distances, Hartree for energy),
so using atomic units throughout is recommended and conventional.
- Float32 or float64 precision for positions and cell; outputs always float32
- **Neighbor formats**: Supports both neighbor_matrix (dense) and neighbor_list (sparse COO)
formats. Choose neighbor_list for sparse systems or when memory efficiency is important.
- Padding atoms indicated by numbers[i] == 0
- Requires symmetric neighbor representation (each pair appears twice)
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Virial computation requires periodic boundary conditions.
- Bulk stress tensor can be obtained by dividing virial by system volume.
**Neighbor Format Selection**:
- Use neighbor_matrix for dense systems or when max_neighbors is small
- Use neighbor_list for sparse systems, large cutoffs, or memory-constrained scenarios
- Both formats produce identical results and support PBC
**PBC Handling**:
- Matrix format: Provide cell and neighbor_matrix_shifts
- List format: Provide cell and unit_shifts
- Non-periodic: Omit both cell and shift parameters
See Also
--------
:class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters
:func:`_dftd3_nm_op` : Internal custom operator for neighbor matrix format
:func:`_dftd3_nl_op` : Internal custom operator for neighbor list format
:func:`_compute_cartesian_shifts` : Pass 0 - Converts unit cell shifts to Cartesian
:func:`_cn_kernel_nm` : Pass 1 - Computes coordination numbers (neighbor matrix)
:func:`_cn_kernel_nl` : Pass 1 - Computes coordination numbers (neighbor list)
:func:`_direct_forces_and_dE_dCN_kernel_nm` : Pass 2 - neighbor matrix format
:func:`_direct_forces_and_dE_dCN_kernel_nl` : Pass 2 - neighbor list format
:func:`_cn_forces_contrib_kernel_nm` : Pass 3 - neighbor matrix format
:func:`_cn_forces_contrib_kernel_nl` : Pass 3 - neighbor list format
"""
# Validate neighbor format inputs
matrix_provided = neighbor_matrix is not None
list_provided = neighbor_list is not None
if matrix_provided and list_provided:
raise ValueError(
"Cannot provide both neighbor_matrix and neighbor_list. "
"Please provide only one neighbor representation format."
)
if not matrix_provided and not list_provided:
raise ValueError("Must provide either neighbor_matrix or neighbor_list.")
# Validate PBC shift inputs match neighbor format
if matrix_provided and unit_shifts is not None:
raise ValueError(
"unit_shifts is for neighbor_list format. "
"Use neighbor_matrix_shifts for neighbor_matrix format."
)
if list_provided and neighbor_matrix_shifts is not None:
raise ValueError(
"neighbor_matrix_shifts is for neighbor_matrix format. "
"Use unit_shifts for neighbor_list format."
)
# Validate neighbor_ptr is provided when using neighbor_list format
if list_provided and neighbor_ptr is None:
raise ValueError(
"neighbor_ptr must be provided when using neighbor_list format. "
"Obtain it from the neighbor list API by setting return_neighbor_list=True."
)
# Validate functional parameters
if a1 is None or a2 is None or s8 is None:
raise ValueError(
"Functional parameters a1, a2, and s8 must be provided. "
"These are functional-dependent parameters required for DFT-D3(BJ) calculations."
)
# Validate virial computation requires PBC
if compute_virial:
if cell is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide unit cell parameters (cell) and shifts "
"(neighbor_matrix_shifts or unit_shifts) when compute_virial=True "
"or when passing a virial tensor."
)
if matrix_provided and neighbor_matrix_shifts is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide neighbor_matrix_shifts along with cell when using "
"neighbor_matrix format and compute_virial=True or passing a virial tensor."
)
if list_provided and unit_shifts is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide unit_shifts along with cell when using "
"neighbor_list format and compute_virial=True or passing a virial tensor."
)
# Determine how parameters are being supplied
# Case 1: All individual parameters provided explicitly
if all(
param is not None
for param in [covalent_radii, r4r2, c6_reference, coord_num_ref]
):
# Use explicit parameters directly (already assigned)
pass
# Case 2: d3_params provided (D3Parameters or dictionary, with optional overrides)
elif d3_params is not None:
# Convert D3Parameters to dictionary for consistent access
if isinstance(d3_params, D3Parameters):
d3_params = d3_params.__dict__
# these are written to throw KeyError if the keys are not present
if covalent_radii is None:
covalent_radii = d3_params["rcov"]
if r4r2 is None:
r4r2 = d3_params["r4r2"]
if c6_reference is None:
c6_reference = d3_params["c6ab"]
if coord_num_ref is None:
coord_num_ref = d3_params["cn_ref"]
# Case 3: No parameters provided - raise error
else:
raise RuntimeError(
"DFT-D3 parameters must be explicitly provided. "
"Either supply all individual parameters (covalent_radii, r4r2, "
"c6_reference, coord_num_ref), provide a D3Parameters instance, "
"or provide a d3_params dictionary. See the function docstring for details."
)
# Get shapes
num_atoms = positions.size(0)
# Handle empty case
if num_atoms == 0:
if batch_idx is None or (
isinstance(batch_idx, torch.Tensor) and batch_idx.numel() == 0
):
num_systems = 1
else:
num_systems = int(batch_idx.max().item()) + 1
empty_energy = torch.zeros(
num_systems, dtype=torch.float32, device=positions.device
)
empty_forces = torch.zeros((0, 3), dtype=torch.float32, device=positions.device)
empty_cn = torch.zeros((0,), dtype=torch.float32, device=positions.device)
# Handle virial for empty case if compute_virial is True
if compute_virial:
empty_virial = torch.zeros(
(0, 3, 3), dtype=torch.float32, device=positions.device
)
return empty_energy, empty_forces, empty_cn, empty_virial
else:
return empty_energy, empty_forces, empty_cn
# Determine number of systems for energy allocation
if num_systems is None:
if batch_idx is None:
num_systems = 1
elif cell is not None:
num_systems = cell.size(0)
else:
num_systems = int(batch_idx.max().item()) + 1
# Allocate output tensors
energy = torch.zeros(num_systems, dtype=torch.float32, device=positions.device)
forces = torch.zeros((num_atoms, 3), dtype=torch.float32, device=positions.device)
coord_num = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
if compute_virial:
virial = torch.zeros(
(num_systems, 3, 3), dtype=torch.float32, device=positions.device
)
else:
virial = torch.zeros((0, 3, 3), dtype=torch.float32, device=positions.device)
# Dispatch to appropriate implementation based on neighbor format
if neighbor_matrix is not None:
# Matrix format - call custom op
_dftd3_nm_op(
positions=positions,
numbers=numbers,
neighbor_matrix=neighbor_matrix,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
fill_value=fill_value,
batch_idx=batch_idx,
cell=cell,
neighbor_matrix_shifts=neighbor_matrix_shifts,
virial=virial,
compute_virial=compute_virial,
device=device,
)
else:
# List format - use CSR format from neighbor list API
# neighbor_list: [2, num_pairs] in COO format where row 1 is idx_j (destination atoms)
# neighbor_ptr: [num_atoms+1] CSR row pointers (required, from neighbor list API)
# Extract idx_j from neighbor_list (row 1 contains destination atoms)
idx_j_csr = neighbor_list[1]
_dftd3_nl_op(
positions=positions,
numbers=numbers,
idx_j=idx_j_csr,
neighbor_ptr=neighbor_ptr,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
batch_idx=batch_idx,
cell=cell,
unit_shifts=unit_shifts,
virial=virial,
compute_virial=compute_virial,
device=device,
)
if compute_virial:
return energy, forces, coord_num, virial
else:
return energy, forces, coord_num