# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DFT-D3(BJ) Dispersion Correction - Warp Kernel Implementation
This module implements the DFT-D3 dispersion correction with Becke-Johnson (BJ)
damping as Warp GPU/CPU kernels. The implementation provides efficient computation
of dispersion energies and forces using a multi-pass algorithm with support for
periodic boundary conditions, batched systems, and smooth cutoff functions.
For detailed theory, usage examples, and parameter setup, see the
:doc:`DFT-D3 User Guide </userguide/components/dispersion>`.
Multi-Pass Kernel Architecture
-------------------------------
The implementation uses four kernel passes to efficiently handle the chain rule
dependency in force calculations:
1. **Pass 0 (_compute_cartesian_shifts)**: [PBC only] Convert unit cell shifts
to Cartesian coordinates
2. **Pass 1 (_cn_kernel)**: Compute coordination numbers using geometric counting
function
3. **Pass 2 (_direct_forces_and_dE_dCN_kernel)**: Compute C6 interpolation,
dispersion energy, direct forces, and accumulate :math:`\\partial E/\\partial \\text{CN}`
4. **Pass 3 (_cn_forces_contrib_kernel)**: Add CN-dependent force contribution
using precomputed :math:`\\partial E/\\partial \\text{CN}` values
Warp Launchers (Framework-Agnostic)
------------------------------------
This module provides four framework-agnostic warp launcher functions that accept
warp arrays directly, with distinct signatures based on neighbor format and PBC support.
These are called by framework-specific wrappers (PyTorch, JAX) after converting
framework tensors to warp arrays:
**Neighbor Matrix Format:**
- ``dftd3_matrix`` - Non-periodic systems (no PBC parameters)
- ``dftd3_matrix_pbc`` - Periodic systems (requires cell and neighbor_matrix_shifts)
**Neighbor List (CSR) Format:**
- ``dftd3`` - Non-periodic systems (no PBC parameters)
- ``dftd3_pbc`` - Periodic systems (requires cell and unit_shifts)
.. code-block:: python
from nvalchemiops.interactions.dispersion._dftd3 import (
dftd3_matrix,
dftd3_matrix_pbc,
dftd3,
dftd3_pbc,
)
# Neighbor matrix format - non-periodic
dftd3_matrix(
positions=positions_wp, # warp array
numbers=numbers_wp,
neighbor_matrix=neighbor_matrix_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=0.3981, a2=4.4211, s8=1.9889,
coord_num=coord_num_wp, # pre-allocated output
forces=forces_wp, # pre-allocated output
energy=energy_wp, # pre-allocated output
virial=virial_wp, # pre-allocated output (not computed for non-PBC)
vec_dtype=wp.vec3f,
)
# Neighbor matrix format - periodic (PBC)
dftd3_matrix_pbc(
positions=positions_wp,
numbers=numbers_wp,
neighbor_matrix=neighbor_matrix_wp,
cell=cell_wp, # REQUIRED for PBC
neighbor_matrix_shifts=shifts_wp, # REQUIRED for PBC
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=0.3981, a2=4.4211, s8=1.9889,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
vec_dtype=wp.vec3f,
compute_virial=True, # Optional: enable virial computation
)
# Neighbor list format (CSR) - non-periodic
dftd3(
positions=positions_wp,
numbers=numbers_wp,
idx_j=idx_j_wp,
neighbor_ptr=neighbor_ptr_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=0.3981, a2=4.4211, s8=1.9889,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp, # pre-allocated output (not computed for non-PBC)
vec_dtype=wp.vec3f,
)
# Neighbor list format (CSR) - periodic (PBC)
dftd3_pbc(
positions=positions_wp,
numbers=numbers_wp,
idx_j=idx_j_wp,
neighbor_ptr=neighbor_ptr_wp,
cell=cell_wp, # REQUIRED for PBC
unit_shifts=unit_shifts_wp, # REQUIRED for PBC
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=0.3981, a2=4.4211, s8=1.9889,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
vec_dtype=wp.vec3f,
compute_virial=True, # Optional: enable virial computation
)
PyTorch Interface
-----------------
For PyTorch integration, use the high-level wrapper in the torch namespace:
.. code-block:: python
from nvalchemiops.torch.interactions.dispersion import dftd3, D3Parameters
# Using neighbor matrix format
energy, forces, coord_num = dftd3(
positions, numbers,
neighbor_matrix=neighbor_matrix,
a1=0.3981, a2=4.4211, s8=1.9889,
d3_params=d3_params, # D3Parameters instance or dict
cell=cell, # Optional for PBC
neighbor_matrix_shifts=neighbor_matrix_shifts, # Optional for PBC
)
# Using neighbor list format (sparse COO)
energy, forces, coord_num = dftd3(
positions, numbers,
neighbor_list=neighbor_list, # shape (2, num_pairs)
a1=0.3981, a2=4.4211, s8=1.9889,
d3_params=d3_params,
cell=cell, # Optional for PBC
unit_shifts=unit_shifts, # Optional for PBC, shape (num_pairs, 3)
)
Data Structure Requirements
---------------------------
**Neighbor Formats**
The implementation supports two neighbor representation formats:
1. **Neighbor Matrix Format** (dense): `[num_atoms, max_neighbors]` where
`neighbor_matrix[i, k]` is the k-th neighbor of atom i. Padding entries use
values >= `fill_value` (typically `num_atoms`).
2. **Neighbor List Format** (sparse COO): `[2, num_pairs]` where row 0 contains
source atom indices and row 1 contains target atom indices. No padding needed.
Both formats can be generated by :func:`nvalchemiops.neighborlist.neighbor_list` using
the `return_neighbor_list` parameter.
**Parameter Arrays**
- `covalent_radii`: `[max_Z+1]` float32
- `r4r2`: `[max_Z+1]` float32
- `c6_reference`: `[max_Z+1, max_Z+1, 5, 5]` float32
- `coord_num_ref`: `[max_Z+1, max_Z+1, 5, 5]` float32
Index 0 reserved for padding; valid atomic numbers 1 to max_Z.
**Periodic Boundary Conditions**
- `cell`: `[num_systems, 3, 3]` lattice vectors (row format)
- For neighbor matrix: `neighbor_matrix_shifts`: `[num_atoms, max_neighbors, 3]` int32 unit cell shifts
- For neighbor list: `unit_shifts`: `[num_pairs, 3]` int32 unit cell shifts
Units
-----
Kernels are **unit-agnostic** but require consistency. Standard Grimme group
parameters use **atomic units (Bohr, Hartree)**, which is recommended:
- Positions, covalent radii, `a2`, cutoffs: Bohr
- Energy output: Hartree
- Forces output: Hartree/Bohr
- Parameter `k1`: 1/Bohr
Technical Notes
---------------
- Supports float32 and float64 positions and cell. Outputs are always float32
- **Two-body only**: Axilrod-Teller-Muto (C9) three-body terms not included
See Also
--------
:class:`D3Parameters` : Dataclass for parameter validation and management
:func:`dftd3` : Main PyTorch interface function
:doc:`/userguide/components/dispersion` : Complete user guide with theory and examples
"""
from __future__ import annotations
from typing import Any
import warp as wp
__all__ = [
# Warp launchers (framework-agnostic public API)
"dftd3_matrix",
"dftd3_matrix_pbc",
"dftd3",
"dftd3_pbc",
# Kernel overload dictionaries (for framework bindings)
"_compute_cartesian_shifts_matrix_overload",
"_cn_kernel_matrix_overload",
"_direct_forces_and_dE_dCN_kernel_matrix_overload",
"_cn_forces_contrib_kernel_matrix_overload",
"_compute_cartesian_shifts_overload",
"_cn_kernel_overload",
"_direct_forces_and_dE_dCN_kernel_overload",
"_cn_forces_contrib_kernel_overload",
]
# ==============================================================================
# Helper Functions
# ==============================================================================
@wp.func
def _s5_switch(
r: wp.float32,
r_on: wp.float32,
r_off: wp.float32,
inv_w: wp.float32,
) -> tuple[wp.float32, wp.float32]:
"""
C² smooth switching function for cutoff smoothing.
This function provides a smooth transition from 1 to 0 over the interval
[r_on, r_off]. The switching polynomial S5(t) has continuous first and
second derivatives at the boundaries.
Parameters
----------
r : float32
Distance between atoms
r_on : float32
Distance where switching begins
r_off : float32
Distance where switching completes
inv_w : float32
Precomputed 1/(r_off - r_on) for efficiency
Returns
-------
Sw : float32
Switching function value Sw(r) ∈ [0, 1]
dSw_dr : float32
Derivative of switching function with respect to r
Notes
-----
The switching function is defined as:
.. math::
S_w(r) = \\begin{cases}
1 & \\text{if } r \\leq r_{\\text{on}} \\\\
1 - S_5(t) & \\text{if } r_{\\text{on}} < r < r_{\\text{off}} \\\\
0 & \\text{if } r \\geq r_{\\text{off}}
\\end{cases}
where :math:`t = (r - r_{\\text{on}})/(r_{\\text{off}} - r_{\\text{on}}) \\in (0,1)` and
.. math::
S_5(t) = 10t^3 - 15t^4 + 6t^5
The derivative is:
.. math::
\\frac{dS_5}{dt} = 30t^2 - 60t^3 + 30t^4
\\frac{dS_w}{dr} = -\\frac{dS_5}{dt} \\cdot \\frac{1}{r_{\\text{off}} - r_{\\text{on}}}
This ensures :math:`C^2` continuity (continuous function, first, and second derivatives)
at both :math:`r_{\\text{on}}` and :math:`r_{\\text{off}}` boundaries.
See Also
--------
:func:`_direct_forces_and_dE_dCN_kernel_matrix` : Uses this switching function
for cutoff smoothing (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel` : Uses this switching function
for cutoff smoothing (neighbor list)
"""
if r_off <= r_on:
# disabled or degenerate: no switching
return 1.0, 0.0
if r <= r_on:
return 1.0, 0.0
if r >= r_off:
return 0.0, 0.0
t = (r - r_on) * inv_w # t in (0,1)
t2 = t * t
t3 = t2 * t
t4 = t3 * t
t5 = t4 * t
switch = 1.0 - (10.0 * t3 - 15.0 * t4 + 6.0 * t5)
dSdt = -30.0 * t2 + 60.0 * t3 - 30.0 * t4 # NOSONAR (S125) "math formula"
dSw_dr = dSdt * inv_w # NOSONAR (S125) "math formula"
return switch, dSw_dr
@wp.func
def _c6ab_interpolate(
cn_i: wp.float32,
cn_j: wp.float32,
c6ab_mat: wp.array2d(dtype=wp.float32),
cnref_i_mat: wp.array2d(dtype=wp.float32),
cnref_j_mat: wp.array2d(dtype=wp.float32),
k3: wp.float32,
) -> tuple[wp.float32, wp.float32, wp.float32]:
"""
Interpolate C6 coefficient and CN derivatives using Gaussian weighting.
This function performs Gaussian interpolation over a 5x5 reference grid
to compute the environment-dependent C6 coefficient for an atom pair,
along with derivatives with respect to coordination numbers.
Parameters
----------
cn_i : float32
Coordination number of atom i
cn_j : float32
Coordination number of atom j
c6ab_mat : wp.array2d(dtype=float32)
C6 reference values [5, 5] for this element pair
cnref_i_mat : wp.array2d(dtype=float32)
CN reference grid [5, 5] for atom i
cnref_j_mat : wp.array2d(dtype=float32)
CN reference grid [5, 5] for atom j
k3 : float32
Gaussian width parameter (typically -4.0)
Returns
-------
c6_ij : float32
Interpolated C6 coefficient
dC6_dCNi : float32
Derivative :math:`\\partial C_6/\\partial \text{CN}_i`
dC6_dCNj : float32
Derivative :math:`\\partial C_6/\\partial \text{CN}_j`
Notes
-----
The Gaussian weights are:
.. math::
L_{pq} = \\exp\\left(-k_3 \\left[(\\text{CN}_i - \\text{CN}_{\\text{ref},i}[p,q])^2 +
(\\text{CN}_j - \\text{CN}_{\\text{ref},j}[p,q])^2\\right]\\right)
The interpolated C6 and derivatives are:
.. math::
C_6 = \\frac{\\sum_{pq} C_6^{\\text{ref}}[p,q] L_{pq}}{\\sum_{pq} L_{pq}}
\\frac{\\partial C_6}{\\partial \\text{CN}_i} = \\frac{2k_3}{w} (z_{d_i} - C_6 w_{d_i})
where accumulators :math:`w`, :math:`z`, :math:`w_{d_i}`, :math:`z_{d_i}` are
computed in a single pass over the 5x5 grid.
See Also
--------
:func:`_direct_forces_and_dE_dCN_kernel_matrix` : Calls this function for C6
coefficient interpolation (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel` : Calls this function for C6
coefficient interpolation (neighbor list)
"""
# log-sum-exp trick to avoid numerical instability
max_exp = wp.float(-1e20)
for p in range(5):
for q in range(5):
c6_val = c6ab_mat[p, q]
if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel"
continue
cnref_i = cnref_i_mat[p, q]
cnref_j = cnref_j_mat[q, p]
di = cn_i - cnref_i
dj = cn_j - cnref_j
exp_arg = k3 * (di * di + dj * dj)
if exp_arg > max_exp:
max_exp = exp_arg
w = float(0.0)
z = float(0.0)
w_di = float(0.0)
w_dj = float(0.0)
z_di = float(0.0)
z_dj = float(0.0)
for p in range(5):
for q in range(5):
c6_val = c6ab_mat[p, q]
if c6_val == 0.0: # NOSONAR (S1244) "gpu kernel"
continue
cnref_i = cnref_i_mat[p, q]
cnref_j = cnref_j_mat[q, p] # Note transpose indexing
di = cn_i - cnref_i
dj = cn_j - cnref_j
# Compute exponent argument and skip negligible contributions
exp_arg = k3 * (di * di + dj * dj) - max_exp
if exp_arg < -12.0:
continue
L = wp.exp(exp_arg)
w += L
z += c6_val * L
w_di += L * di
w_dj += L * dj
z_di += c6_val * L * di
z_dj += c6_val * L * dj
eps_w = 1e-12
if w > eps_w:
w_inv = 1.0 / w
c6_ij = z * w_inv
s_i = z_di - c6_ij * w_di
s_j = z_dj - c6_ij * w_dj
k3_w_w_inv = (2.0 * k3) * w_inv
dC6_dCNi = k3_w_w_inv * s_i # NOSONAR (S125) "math formula"
dC6_dCNj = k3_w_w_inv * s_j # NOSONAR (S125) "math formula"
return c6_ij, dC6_dCNi, dC6_dCNj
else:
return 0.0, 0.0, 0.0
@wp.func
def _compute_distance_vector_pbc(
pos_i: Any,
pos_j: Any,
cartesian_shift: Any,
periodic: bool,
compute_vectors: bool,
) -> tuple[wp.float32, wp.float32, wp.vec3f, wp.vec3f]:
"""
Compute distance with optional PBC and vector outputs.
Parameters
----------
pos_i, pos_j : vec3
Atomic positions
cartesian_shift : vec3
PBC shift (ignored if periodic=False)
periodic : bool
Apply PBC shift
compute_vectors : bool
If True, compute r_hat; if False, r_hat returns zero vector
Returns
-------
r : float32
Distance
r_inv : float32
Inverse distance (0 if r < 1e-12)
r_hat : vec3f
Unit vector (zero vec if compute_vectors=False or r < 1e-12)
r_ij : vec3f
Distance vector (always returned)
"""
if periodic:
r_ij_native = (pos_j - pos_i) + cartesian_shift
else:
r_ij_native = pos_j - pos_i
r_ij = wp.vec3f(
wp.float32(r_ij_native[0]),
wp.float32(r_ij_native[1]),
wp.float32(r_ij_native[2]),
)
r = wp.length(r_ij)
if r < 1e-12:
return r, wp.float32(0.0), wp.vec3f(0.0, 0.0, 0.0), r_ij
r_inv = 1.0 / r
if compute_vectors:
r_hat = r_ij * r_inv
return r, r_inv, r_hat, r_ij
else:
return r, r_inv, wp.vec3f(0.0, 0.0, 0.0), r_ij
@wp.func
def _cn_counting(
r_inv: wp.float32,
rcov_i: wp.float32,
rcov_j: wp.float32,
k1: wp.float32,
compute_derivative: bool,
) -> tuple[wp.float32, wp.float32]:
"""
Compute CN counting function with optional derivative.
Parameters
----------
r_inv : float32
Inverse distance
rcov_i, rcov_j : float32
Covalent radii
k1 : float32
Steepness parameter
compute_derivative : bool
If True, compute dCN_dr; if False, returns None
Returns
-------
f_cn : float32
Counting function value
dCN_dr : float32
Derivative (zero if compute_derivative=False)
"""
rcov_ij = rcov_i + rcov_j
rcov_r_inv = rcov_ij * r_inv
f_cn = 1.0 / (1.0 + wp.exp(-k1 * (rcov_r_inv - 1.0)))
if compute_derivative:
dCN_dr = -f_cn * (1.0 - f_cn) * k1 * rcov_r_inv * r_inv # NOSONAR (S125)
return f_cn, dCN_dr
else:
return f_cn, wp.float32(0.0)
@wp.func
def _bj_damping(
r: wp.float32,
r4r2_i: wp.float32,
r4r2_j: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
) -> tuple[wp.float32, wp.float32, wp.float32, wp.float32, wp.float32, wp.float32]:
"""
Compute Becke-Johnson damping.
Returns
-------
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv : float32
"""
r4r2_ij = 3.0 * r4r2_i * r4r2_j
r0 = a1 * wp.sqrt(r4r2_ij) + a2
r2 = r * r
r4 = r2 * r2
r6 = r4 * r2
r8 = r4 * r4
r0_2 = r0 * r0
r0_4 = r0_2 * r0_2
r0_6 = r0_4 * r0_2
r0_8 = r0_4 * r0_4
den6 = r6 + r0_6
den8 = r8 + r0_8
den6_inv = 1.0 / den6
den8_inv = 1.0 / den8
damp_6 = s6 * den6_inv
damp_8 = s8 * r4r2_ij * den8_inv
damp_sum = damp_6 + damp_8
return damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv
@wp.func
def _dispersion_energy_force(
c6_ij: wp.float32,
r: wp.float32,
r_hat: wp.vec3f,
damp_sum: wp.float32,
r4r2_ij: wp.float32,
r6: wp.float32,
r4: wp.float32,
den6_inv: wp.float32,
den8_inv: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
) -> tuple[wp.float32, wp.vec3f]:
"""
Compute dispersion energy and direct force with S5 switching.
Returns
-------
e_ij_sw : float32
Smoothed energy
F_direct : vec3f
Direct force vector
"""
e_ij = -c6_ij * damp_sum
r5 = r4 * r
r7 = r6 * r
dD6_dr = -6.0 * s6 * r5 * den6_inv * den6_inv # NOSONAR (S125) "math formula"
dD8_dr = -8.0 * s8 * r4r2_ij * r7 * den8_inv * den8_inv # NOSONAR (S125)
dE_dr_direct = -c6_ij * (dD6_dr + dD8_dr) # NOSONAR (S125) "math formula"
sw, dsw_dr = _s5_switch(r, s5_smoothing_on, s5_smoothing_off, inv_w)
e_ij_sw = e_ij * sw
dE_dr_direct_sw = sw * dE_dr_direct + e_ij * dsw_dr # NOSONAR (S125)
F_direct = dE_dr_direct_sw * r_hat # NOSONAR (S125) "math formula"
return e_ij_sw, F_direct
@wp.func
def _unit_shift_to_cartesian(
unit_shift: wp.vec3i,
cell_mat: Any,
) -> Any:
"""Convert integer unit cell shift to Cartesian coordinates."""
unit_shift_float = type(cell_mat[0])(
type(cell_mat[0, 0])(unit_shift[0]),
type(cell_mat[0, 0])(unit_shift[1]),
type(cell_mat[0, 0])(unit_shift[2]),
)
return unit_shift_float * cell_mat
# ==============================================================================
# Kernels
# ==============================================================================
@wp.kernel(enable_backward=False)
def _compute_cartesian_shifts_matrix(
cell: wp.array(dtype=Any),
unit_shifts: wp.array2d(dtype=wp.vec3i),
neighbor_matrix: wp.array2d(dtype=wp.int32),
batch_idx: wp.array(dtype=wp.int32),
fill_value: wp.int32,
cartesian_shifts: wp.array2d(dtype=Any),
):
"""
Convert unit cell shifts to Cartesian coordinates for periodic boundaries.
For each neighbor in the neighbor matrix, this kernel computes the Cartesian
shift vector that should be applied to atom j's position to obtain its
periodic image closest to atom i.
Parameters
----------
cell : wp.array3d(dtype=float32)
Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :]
is the i-th lattice vector for system s (row vectors). Units should match
position coordinates.
unit_shifts : wp.array2d(dtype=vec3i)
Integer unit cell shifts [num_atoms, max_neighbors] as vec3i
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
batch_idx : wp.array(dtype=int32)
System index [num_atoms] for each atom
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
cartesian_shifts : wp.array2d(dtype=vec3f)
Output: Cartesian shift vectors [num_atoms, max_neighbors] as vec3 in same
units as cell vectors
Notes
-----
The Cartesian shift is computed as:
.. math::
\\mathbf{s} = n_a \\mathbf{a} + n_b \\mathbf{b} + n_c \\mathbf{c}
where :math:`\\mathbf{a}, \\mathbf{b}, \\mathbf{c}` are lattice vectors
and :math:`n_a, n_b, n_c` are integer shifts. The system ID is obtained
from atom i's batch index.
Launch with dim=(num_atoms, max_neighbors) (one thread per atom-neighbor pair).
See Also
--------
:func:`_cn_kernel_matrix` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor matrix)
:func:`_cn_kernel` : Pass 1 - Uses computed Cartesian shifts for PBC (neighbor list)
:func:`_direct_forces_and_dE_dCN_kernel_matrix` : Pass 2 - Uses computed
Cartesian shifts for PBC (neighbor matrix)
:func:`_direct_forces_and_dE_dCN_kernel` : Pass 2 - Uses computed
Cartesian shifts for PBC (neighbor list)
:func:`_cn_forces_contrib_kernel_matrix` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor matrix)
:func:`_cn_forces_contrib_kernel` : Pass 3 - Uses computed Cartesian shifts for PBC (neighbor list)
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i, neighbor_idx = wp.tid()
max_neighbors = neighbor_matrix.shape[1]
if neighbor_idx >= max_neighbors:
return
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
return
system_id = batch_idx[atom_i]
cell_mat = cell[system_id]
unit_shift = unit_shifts[atom_i, neighbor_idx]
cartesian_shifts[atom_i, neighbor_idx] = _unit_shift_to_cartesian(
unit_shift, cell_mat
)
@wp.kernel(enable_backward=False)
def _cn_kernel_matrix(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
k1: wp.float32,
fill_value: wp.int32,
periodic: bool,
coord_num: wp.array(dtype=wp.float32),
):
"""
Compute coordination numbers using geometric counting function.
This kernel computes the coordination number (CN) for each atom based on
a smooth counting function that depends on interatomic distances and
covalent radii. Supports periodic boundary conditions via Cartesian shifts.
Algorithm
---------
For each atom i, iterate over its neighbors (neighbor matrix format) and accumulate:
.. math::
\\text{CN}_i = \\sum_{j \\in \\text{neighbors}(i)} f(r_{ij})
f(r) = \\frac{1}{1 + \\exp\\left[k_1\\left(\\frac{r_{\\text{cov}}}{r} - 1\\right)\\right]}
where :math:`r_{\\text{cov}} = r_{\\text{cov}}[Z_i] + r_{\\text{cov}}[Z_j]`.
The counting function smoothly transitions from 1 (bonded) to 0 (non-bonded).
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC (ignored if periodic=False), in same units as positions
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] indexed by atomic number, in same units as positions
k1 : float32
Steepness parameter for counting function (typically 16.0 1/Bohr)
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts; if False, non-periodic
coord_num : wp.array(dtype=float32)
Output: coordination numbers [num_atoms] (dimensionless)
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates CN in a local register
- Padding atoms indicated by numbers[i] == 0 are skipped
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts_matrix` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel_matrix` : Pass 2 - Uses coordination numbers
computed here
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate coordination number in local register
cn_acc = wp.float32(0.0)
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
# skip padding
if numbers[atom_j] == 0:
continue
# Compute distance with optional PBC shift
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
False,
)
if r < 1e-12:
continue
# Compute coordination number contribution
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False
)
cn_acc += f_cn
# Write final coordination number once
coord_num[atom_i] = cn_acc
@wp.kernel(enable_backward=False)
def _direct_forces_and_dE_dCN_kernel_matrix( # NOSONAR (S1542) "math formula"
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
coord_num: wp.array(dtype=wp.float32),
r4r2: wp.array(dtype=wp.float32),
c6_reference: wp.array4d(dtype=wp.float32),
coord_num_ref: wp.array4d(dtype=wp.float32),
k3: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
fill_value: wp.int32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
forces: wp.array(dtype=wp.vec3f),
energy: wp.array(dtype=wp.float32),
virial: wp.array(dtype=Any),
):
"""
Pass 2: Compute direct forces, energy, and accumulate dE/dCN per atom.
Computes dispersion energy and forces at constant CN, and accumulates
dE/dCN contributions for each atom for use in Pass 3.
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions
coord_num : wp.array(dtype=float32)
Coordination numbers [num_atoms] from Pass 1 (dimensionless)
r4r2 : wp.array(dtype=float32)
<r⁴>/<r²> expectation values [max_Z+1] (dimensionless)
c6_reference : wp.array4d(dtype=float32)
C6 reference [max_Z+1, max_Z+1, 5, 5] in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32)
CN reference grid [max_Z+1, max_Z+1, 5, 5] (dimensionless)
k3 : float32
CN interpolation width (typically -4.0, dimensionless)
a1, a2 : float32
Becke-Johnson damping parameters (a1 dimensionless, a2 in distance units)
s6, s8 : float32
Scaling factors for C6 and C8 terms (dimensionless)
s5_smoothing_on, s5_smoothing_off : float32
S5 switching radii in same units as positions
inv_w : float32
Precomputed 1/(s5_off - s5_on) in inverse distance units
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts
batch_idx : wp.array(dtype=int32)
System index [num_atoms]
dE_dCN : wp.array(dtype=float32)
Output: accumulated dE/dCN [num_atoms] in energy units
forces : wp.array(dtype=vec3f)
Output: direct forces [num_atoms] in energy/distance units
energy : wp.array(dtype=float32)
Output: system energies [num_systems] in energy units
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates results in local registers
- Direct forces are F = :math:`-(\\partial E/\\partial r)|_\text{CN}`, without chain rule term
- dE_dCN[i] = :math:`\\sum_j \\partial E_{ij}/\\partial \text{CN}_i` accumulated over all pairs containing atom i
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts_matrix` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_cn_kernel_matrix` : Pass 1 - Computes coordination numbers used here
:func:`_cn_forces_contrib_kernel_matrix` : Pass 3 - Uses dE/dCN values accumulated here
:func:`_c6ab_interpolate` : Called to interpolate C6 coefficients
:func:`_s5_switch` : Called for cutoff smoothing
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
pos_i = positions[atom_i]
cn_i = coord_num[atom_i]
z_i = numbers[atom_i]
r4r2_i = r4r2[z_i]
# Accumulate in local registers (using float64 for better precision)
F_acc = wp.vec3d() # NOSONAR (S117) "math formula"
dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula"
energy_acc = wp.float64(0.0)
# Initialize virial accumulator
if compute_virial:
virial_acc = wp.mat33d()
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Geometry
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
True,
)
if r < 1e-12:
continue
cn_j = coord_num[atom_j]
z_j = numbers[atom_j]
# C6 interpolation
c6ab_mat = c6_reference[z_i, z_j]
cnref_i_mat = coord_num_ref[z_i, z_j]
cnref_j_mat = coord_num_ref[z_j, z_i]
c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula"
cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3
)
if c6_ij < 1e-12:
continue
# BJ damping
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping(
r, r4r2_i, r4r2[z_j], a1, a2, s6, s8
)
# Energy and direct force
e_ij_sw, F_direct = _dispersion_energy_force(
c6_ij,
r,
r_hat,
damp_sum,
r4r2_ij,
r6,
r4,
den6_inv,
den8_inv,
s6,
s8,
s5_smoothing_on,
s5_smoothing_off,
inv_w,
)
# Accumulate in registers
F_acc += wp.vec3d(F_direct) # NOSONAR (S117) "math formula"
energy_acc += wp.float64(e_ij_sw)
dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula"
# Accumulate virial if requested
if compute_virial:
virial_acc += wp.mat33d(wp.outer(F_direct, r_ij))
# Write final results once (atomic only for shared batch array)
# Convert from float64 accumulation to float32 output
forces[atom_i] = wp.vec3f(F_acc)
dE_dCN[atom_i] = dE_dCN_acc
wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc))
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc))
@wp.kernel(enable_backward=False)
def _cn_forces_contrib_kernel_matrix(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
neighbor_matrix: wp.array2d(dtype=wp.int32),
cartesian_shifts: wp.array2d(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
k1: wp.float32,
fill_value: wp.int32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
forces: wp.array(dtype=wp.vec3f),
virial: wp.array(dtype=Any),
):
"""
Pass 3: Add CN-dependent force contribution.
Adds the CN-dependent term to forces computed in Pass 2. Computes
distances and CN derivatives without repeating C6 interpolation and
damping calculations.
Parameters
----------
positions : wp.array(dtype=vec3f)
Atomic coordinates [num_atoms] in consistent distance units
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
neighbor_matrix : wp.array2d(dtype=int32)
Neighbor indices [num_atoms, max_neighbors]. See module docstring
for more details.
cartesian_shifts : wp.array2d(dtype=vec3f)
Cartesian shifts [num_atoms, max_neighbors] as vec3 for PBC, in same units as positions
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] in same units as positions
dE_dCN : wp.array(dtype=float32)
Precomputed dE/dCN [num_atoms] from Pass 2 in energy units
k1 : float32
CN counting steepness in inverse distance units (typically 16.0 1/Bohr)
fill_value : int32
Value indicating padding in neighbor_matrix (typically num_atoms)
periodic : bool
If True, apply PBC using cartesian_shifts
forces : wp.array(dtype=vec3f)
Input/Output: add chain term to direct forces [num_atoms] in energy/distance units
Notes
-----
- Launch with dim=num_atoms (one thread per atom)
- Each thread iterates over all neighbors and accumulates results in local registers
- Skips C6 interpolation and damping calculations
- Uses precomputed dE_dCN[i] = :math:`\\sum_k \\partial E_{ik}/\\partial \text{CN}_i` from all pairs
- Neighbor entries with j >= fill_value are padding and are skipped
See Also
--------
:func:`_compute_cartesian_shifts_matrix` : Pass 0 - Computes Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel_matrix` : Pass 2 - Computes dE/dCN values used here
:func:`dftd3` : High-level wrapper that orchestrates all passes
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
max_neighbors = neighbor_matrix.shape[1]
dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula"
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate force in local register (using float64 for better precision)
F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula"
# Initialize virial accumulator
if compute_virial:
virial_chain_acc = wp.mat33d()
for neighbor_idx in range(max_neighbors):
atom_j = neighbor_matrix[atom_i, neighbor_idx]
if atom_j >= fill_value:
continue
if numbers[atom_j] == 0:
continue
# Distance
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i,
positions[atom_j],
cartesian_shifts[atom_i, neighbor_idx],
periodic,
True,
)
if r < 1e-12:
continue
# CN derivative
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True
)
# CN-dependent force contribution
dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula"
dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula"
F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula"
F_chain_acc += wp.vec3d(F_chain)
# Accumulate virial if requested
if compute_virial:
virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij))
# Add accumulated force to existing forces (direct read-modify-write)
# Convert from float64 accumulation to float32 output
forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc)
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc))
# ==============================================================================
# Neighbor List Kernels
# ==============================================================================
@wp.kernel(enable_backward=False)
def _compute_cartesian_shifts(
cell: wp.array(dtype=Any),
unit_shifts: wp.array(dtype=wp.vec3i),
neighbor_ptr: wp.array(dtype=wp.int32),
batch_idx: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
):
"""
Convert unit cell shifts to Cartesian coordinates for CSR neighbor lists.
For each edge in the CSR neighbor list, this kernel computes the Cartesian
shift vector that should be applied to the destination atom's position.
Parameters
----------
cell : wp.array(dtype=mat33)
Unit cell lattice vectors [num_systems, 3, 3]. Convention: cell[s, i, :]
is the i-th lattice vector for system s (row vectors). Units should match
position coordinates.
unit_shifts : wp.array(dtype=vec3i)
Integer unit cell shifts [num_edges] as vec3i
neighbor_ptr : wp.array(dtype=int32)
CSR row pointers [num_atoms+1]
batch_idx : wp.array(dtype=int32)
System index [num_atoms] for each atom
cartesian_shifts : wp.array(dtype=vec3)
Output: Cartesian shift vectors [num_edges] as vec3 in same units as cell vectors
Notes
-----
Launch with dim=num_atoms (one thread per atom). Each thread processes all edges
for that atom using the CSR pointers.
See Also
--------
:func:`_cn_kernel` : Uses computed Cartesian shifts for PBC
:func:`_direct_forces_and_dE_dCN_kernel` : Uses computed Cartesian shifts for PBC
:func:`_cn_forces_contrib_kernel` : Uses computed Cartesian shifts for PBC
"""
atom_i = wp.tid()
# Get number of atoms from batch_idx size
if atom_i >= batch_idx.shape[0]:
return
system_id = batch_idx[atom_i]
cell_mat = cell[system_id]
# Get range of edges for this atom
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
# Convert all unit shifts for this atom's neighbors to Cartesian
for edge_idx in range(j_range_start, j_range_end):
unit_shift = unit_shifts[edge_idx]
cartesian_shifts[edge_idx] = _unit_shift_to_cartesian(unit_shift, cell_mat)
@wp.kernel(enable_backward=False)
def _cn_kernel(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
k1: wp.float32,
periodic: bool,
coord_num: wp.array(dtype=wp.float32),
):
"""
Compute coordination numbers using CSR neighbor list format.
Parameters
----------
positions : wp.array(dtype=vec3)
Atomic coordinates [num_atoms]
numbers : wp.array(dtype=int32)
Atomic numbers [num_atoms]
idx_j : wp.array(dtype=int32)
Destination atom indices [num_edges] in CSR format
neighbor_ptr : wp.array(dtype=int32)
CSR row pointers [num_atoms+1]
cartesian_shifts : wp.array(dtype=vec3)
Cartesian shifts [num_edges] as vec3 for PBC
covalent_radii : wp.array(dtype=float32)
Covalent radii [max_Z+1] indexed by atomic number
k1 : float32
Steepness parameter for counting function
periodic : bool
If True, apply PBC using cartesian_shifts
coord_num : wp.array(dtype=float32)
Output: coordination numbers [num_atoms]
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate coordination number in local register
cn_acc = wp.float32(0.0)
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Compute distance with optional PBC shift
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, False
)
if r < 1e-12:
continue
# Compute coordination number contribution
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, False
)
cn_acc += f_cn
# Write final coordination number once
coord_num[atom_i] = cn_acc
@wp.kernel(enable_backward=False)
def _direct_forces_and_dE_dCN_kernel( # NOSONAR (S1542) "math formula"
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
coord_num: wp.array(dtype=wp.float32),
r4r2: wp.array(dtype=wp.float32),
c6_reference: wp.array4d(dtype=wp.float32),
coord_num_ref: wp.array4d(dtype=wp.float32),
k3: wp.float32,
a1: wp.float32,
a2: wp.float32,
s6: wp.float32,
s8: wp.float32,
s5_smoothing_on: wp.float32,
s5_smoothing_off: wp.float32,
inv_w: wp.float32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
forces: wp.array(dtype=wp.vec3f),
energy: wp.array(dtype=wp.float32),
virial: wp.array(dtype=Any),
):
"""
Pass 2: Compute direct forces, energy, and accumulate dE/dCN using
CSR neighbor list.
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
pos_i = positions[atom_i]
cn_i = coord_num[atom_i]
z_i = numbers[atom_i]
r4r2_i = r4r2[z_i]
# Accumulate in local registers (using float64 for better precision)
F_acc = wp.vec3d() # NOSONAR (S117) "math formula"
dE_dCN_acc = wp.float32(0.0) # NOSONAR (S117) "math formula"
energy_acc = wp.float64(0.0)
# Initialize virial accumulator
if compute_virial:
virial_acc = wp.mat33d()
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
# skip padding atoms
if numbers[atom_j] == 0:
continue
# Geometry
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True
)
if r < 1e-12:
continue
cn_j = coord_num[atom_j]
z_j = numbers[atom_j]
# C6 interpolation
c6ab_mat = c6_reference[z_i, z_j]
cnref_i_mat = coord_num_ref[z_i, z_j]
cnref_j_mat = coord_num_ref[z_j, z_i]
c6_ij, dC6_dCNi, dC6_dCNj = _c6ab_interpolate( # NOSONAR (S125) "math formula"
cn_i, cn_j, c6ab_mat, cnref_i_mat, cnref_j_mat, k3
)
if c6_ij < 1e-12:
continue
# BJ damping
damp_sum, r4r2_ij, r6, r4, den6_inv, den8_inv = _bj_damping(
r, r4r2_i, r4r2[z_j], a1, a2, s6, s8
)
# Energy and direct force
e_ij_sw, F_direct = _dispersion_energy_force(
c6_ij,
r,
r_hat,
damp_sum,
r4r2_ij,
r6,
r4,
den6_inv,
den8_inv,
s6,
s8,
s5_smoothing_on,
s5_smoothing_off,
inv_w,
)
# Accumulate in registers
F_acc += wp.vec3d(F_direct)
energy_acc += wp.float64(e_ij_sw)
dE_dCN_acc += -damp_sum * dC6_dCNi # NOSONAR (S117) "math formula"
# Accumulate virial if requested
if compute_virial:
virial_acc += wp.mat33d(wp.outer(F_direct, r_ij))
# Write final results once (atomic only for shared batch array)
# Convert from float64 accumulation to float32 output
forces[atom_i] = wp.vec3f(F_acc)
dE_dCN[atom_i] = wp.float32(dE_dCN_acc)
wp.atomic_add(energy, batch_idx[atom_i], 0.5 * wp.float32(energy_acc))
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_acc))
@wp.kernel(enable_backward=False)
def _cn_forces_contrib_kernel(
positions: wp.array(dtype=Any),
numbers: wp.array(dtype=wp.int32),
idx_j: wp.array(dtype=wp.int32),
neighbor_ptr: wp.array(dtype=wp.int32),
cartesian_shifts: wp.array(dtype=Any),
covalent_radii: wp.array(dtype=wp.float32),
dE_dCN: wp.array(dtype=wp.float32), # NOSONAR (S125) "math formula"
k1: wp.float32,
periodic: bool,
batch_idx: wp.array(dtype=wp.int32),
compute_virial: bool,
forces: wp.array(dtype=wp.vec3f),
virial: wp.array(dtype=Any),
):
"""
Pass 3: Add CN-dependent force contribution using CSR neighbor list.
Notes
-----
Launch with dim=num_atoms (one thread per atom).
"""
atom_i = wp.tid()
if atom_i >= numbers.shape[0]:
return
# skip padding atoms
if numbers[atom_i] == 0:
return
dE_dCN_i = dE_dCN[atom_i] # NOSONAR (S125) "math formula"
pos_i = positions[atom_i]
rcov_i = covalent_radii[numbers[atom_i]]
# Accumulate force in local register (using float64 for better precision)
F_chain_acc = wp.vec3d() # NOSONAR (S117) "math formula"
# Initialize virial accumulator
if compute_virial:
virial_chain_acc = wp.mat33d()
# Iterate over neighbors using CSR pointers
j_range_start = neighbor_ptr[atom_i]
j_range_end = neighbor_ptr[atom_i + 1]
for edge_idx in range(j_range_start, j_range_end):
atom_j = idx_j[edge_idx]
if numbers[atom_j] == 0:
continue
# Distance
r, r_inv, r_hat, r_ij = _compute_distance_vector_pbc(
pos_i, positions[atom_j], cartesian_shifts[edge_idx], periodic, True
)
if r < 1e-12:
continue
# CN derivative
f_cn, dCN_dr = _cn_counting(
r_inv, rcov_i, covalent_radii[numbers[atom_j]], k1, True
)
# CN-dependent force contribution
dE_dCN_j = dE_dCN[atom_j] # NOSONAR (S125) "math formula"
dE_dr_chain = (dE_dCN_i + dE_dCN_j) * dCN_dr # NOSONAR (S125) "math formula"
F_chain = dE_dr_chain * r_hat # NOSONAR (S125) "math formula"
F_chain_acc += wp.vec3d(F_chain)
# Accumulate virial if requested
if compute_virial:
virial_chain_acc += wp.mat33d(wp.outer(F_chain, r_ij))
# Add accumulated force to existing forces (direct read-modify-write)
# Convert from float64 accumulation to float32 output
forces[atom_i] = forces[atom_i] + wp.vec3f(F_chain_acc)
# Add virial contribution with -0.5 scaling for correct sign and double counting
if compute_virial:
wp.atomic_add(virial, batch_idx[atom_i], -0.5 * wp.mat33f(virial_chain_acc))
# ==============================================================================
# Kernel Overload Registration
# ==============================================================================
# Type constants for overload generation
T = [wp.float32, wp.float64]
V = [wp.vec3f, wp.vec3d]
M = [wp.mat33f, wp.mat33d]
# Overload dictionaries keyed by scalar type
# Neighbor matrix format (dense)
_compute_cartesian_shifts_matrix_overload = {}
_cn_kernel_matrix_overload = {}
_direct_forces_and_dE_dCN_kernel_matrix_overload = {}
_cn_forces_contrib_kernel_matrix_overload = {}
# Neighbor list kernel overload dictionaries (CSR format) - default naming convention
_compute_cartesian_shifts_overload = {}
_cn_kernel_overload = {}
_direct_forces_and_dE_dCN_kernel_overload = {}
_cn_forces_contrib_kernel_overload = {}
# Register overloads for all kernel variants
for t, v, m in zip(T, V, M):
# Neighbor matrix format (dense)
_compute_cartesian_shifts_matrix_overload[t] = wp.overload(
_compute_cartesian_shifts_matrix,
[
wp.array(dtype=m),
wp.array2d(dtype=wp.vec3i),
wp.array2d(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.int32,
wp.array2d(dtype=v),
],
)
_cn_kernel_matrix_overload[t] = wp.overload(
_cn_kernel_matrix,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.float32),
],
)
_direct_forces_and_dE_dCN_kernel_matrix_overload[t] = wp.overload(
_direct_forces_and_dE_dCN_kernel_matrix,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.float32),
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.mat33f),
],
)
_cn_forces_contrib_kernel_matrix_overload[t] = wp.overload(
_cn_forces_contrib_kernel_matrix,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array2d(dtype=wp.int32),
wp.array2d(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.float32,
wp.int32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.mat33f),
],
)
# Neighbor list kernel overloads (CSR format) - default naming convention
_compute_cartesian_shifts_overload[t] = wp.overload(
_compute_cartesian_shifts,
[
wp.array(dtype=m),
wp.array(dtype=wp.vec3i),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
],
)
_cn_kernel_overload[t] = wp.overload(
_cn_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.float32,
wp.bool,
wp.array(dtype=wp.float32),
],
)
_direct_forces_and_dE_dCN_kernel_overload[t] = wp.overload(
_direct_forces_and_dE_dCN_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.array4d(dtype=wp.float32),
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.float32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.float32),
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.mat33f),
],
)
_cn_forces_contrib_kernel_overload[t] = wp.overload(
_cn_forces_contrib_kernel,
[
wp.array(dtype=v),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
wp.array(dtype=v),
wp.array(dtype=wp.float32),
wp.array(dtype=wp.float32),
wp.float32,
wp.bool,
wp.array(dtype=wp.int32),
wp.bool,
wp.array(dtype=wp.vec3f),
wp.array(dtype=wp.mat33f),
],
)
# ==============================================================================
# Warp Launchers (Framework-Agnostic)
# ==============================================================================
[docs]
def dftd3_matrix(
positions: wp.array,
numbers: wp.array,
neighbor_matrix: wp.array,
covalent_radii: wp.array,
r4r2: wp.array,
c6_reference: wp.array,
coord_num_ref: wp.array,
a1: float,
a2: float,
s8: float,
coord_num: wp.array,
forces: wp.array,
energy: wp.array,
virial: wp.array,
batch_idx: wp.array,
cartesian_shifts: wp.array,
dE_dCN: wp.array,
wp_dtype: type,
device: str,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 0.0,
s5_smoothing_off: float = 0.0,
fill_value: int | None = None,
) -> None:
"""
Launch DFT-D3(BJ) dispersion calculation using neighbor matrix format (non-periodic).
This is a framework-agnostic warp launcher for non-periodic (non-PBC) systems
that accepts warp arrays directly and orchestrates the multi-pass kernel execution
for DFT-D3(BJ) dispersion energy, forces, and coordination number calculations.
Framework-specific wrappers (PyTorch, JAX) handle tensor-to-warp conversion and
call this function.
For periodic systems, use :func:`dftd3_matrix_pbc` instead.
Multi-Pass Algorithm
---------------------
1. **Pass 1**: Compute coordination numbers using geometric counting function
2. **Pass 2**: Compute direct forces, energy, and accumulate dE/dCN
3. **Pass 3**: Add CN-dependent force contribution using chain rule
Parameters
----------
positions : wp.array(dtype=vec3f or vec3d), shape [num_atoms]
Atomic coordinates in consistent distance units (typically Bohr). Supports
both float32 (vec3f) and float64 (vec3d) precision.
numbers : wp.array(dtype=int32), shape [num_atoms]
Atomic numbers
neighbor_matrix : wp.array2d(dtype=int32), shape [num_atoms, max_neighbors]
Neighbor indices. Padding entries have values >= fill_value.
covalent_radii : wp.array(dtype=float32), shape [max_Z+1]
Covalent radii indexed by atomic number, in same units as positions
r4r2 : wp.array(dtype=float32), shape [max_Z+1]
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
C6 reference values in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
coord_num : wp.array(dtype=float32), shape [num_atoms]
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated and zeroed.
forces : wp.array(dtype=vec3f), shape [num_atoms]
OUTPUT: Atomic forces in energy/distance units. Must be pre-allocated and zeroed.
energy : wp.array(dtype=float32), shape [num_systems]
OUTPUT: Dispersion energy in energy units. Must be pre-allocated and zeroed.
virial : wp.array(dtype=mat33f), shape [num_systems]
OUTPUT: Virial tensor (not computed for non-periodic systems). Must be
pre-allocated but will not be modified.
batch_idx : wp.array(dtype=int32), shape [num_atoms]
Batch indices mapping each atom to its system index.
cartesian_shifts : wp.array(dtype=vec3f or vec3d), shape [num_atoms, max_neighbors]
SCRATCH: Pre-allocated buffer for Cartesian shift vectors.
Values are not used for non-periodic systems, but the array must
still be provided with shape matching neighbor_matrix.
Must be pre-allocated by caller.
dE_dCN : wp.array(dtype=float32), shape [num_atoms]
SCRATCH: Pre-allocated buffer for chain rule dE/dCN intermediate.
Must be pre-allocated and zeroed by caller.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64) matching positions dtype.
device : str
Warp device string (e.g., 'cuda:0', 'cpu').
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr). Default: 16.0
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless).
Default: -4.0
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 0.0
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 0.0
fill_value : int or None, optional
Value indicating padding in neighbor_matrix. If None, inferred from num_atoms.
Returns
-------
None
All outputs are written to pre-allocated arrays (coord_num, forces, energy).
Notes
-----
- All output arrays must be pre-allocated and zeroed by the caller
- Supports float32 and float64 positions; outputs always float32
- Padding atoms indicated by numbers[i] == 0 are skipped
- **Two-body only**: Three-body Axilrod-Teller-Muto terms not included
- Unit consistency required: standard D3 parameters use atomic units
(Bohr for distances, Hartree for energy)
- Virial is NOT computed for non-periodic systems (use dftd3_matrix_pbc for PBC)
See Also
--------
dftd3_matrix_pbc : Neighbor matrix format with PBC support
dftd3 : Neighbor list (CSR) format, non-periodic
dftd3_pbc : Neighbor list (CSR) format with PBC support
"""
# Get number of atoms from positions array
num_atoms = positions.shape[0]
# Set fill_value if not provided
if fill_value is None:
fill_value = num_atoms
# Handle empty case
if num_atoms == 0:
return
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
periodic = False
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
covalent_radii,
wp.float32(k1),
wp.int32(fill_value),
periodic,
],
outputs=[coord_num],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
# compute_virial=False for non-periodic systems
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
coord_num,
r4r2,
c6_reference,
coord_num_ref,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
wp.int32(fill_value),
periodic,
batch_idx,
False, # compute_virial=False for non-periodic
],
outputs=[dE_dCN, forces, energy, virial],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
covalent_radii,
dE_dCN,
wp.float32(k1),
wp.int32(fill_value),
periodic,
batch_idx,
False, # compute_virial=False for non-periodic
],
outputs=[forces, virial],
device=device,
)
[docs]
def dftd3_matrix_pbc(
positions: wp.array,
numbers: wp.array,
neighbor_matrix: wp.array,
cell: wp.array,
neighbor_matrix_shifts: wp.array,
covalent_radii: wp.array,
r4r2: wp.array,
c6_reference: wp.array,
coord_num_ref: wp.array,
a1: float,
a2: float,
s8: float,
coord_num: wp.array,
forces: wp.array,
energy: wp.array,
virial: wp.array,
batch_idx: wp.array,
cartesian_shifts: wp.array,
dE_dCN: wp.array,
wp_dtype: type,
device: str,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 0.0,
s5_smoothing_off: float = 0.0,
fill_value: int | None = None,
compute_virial: bool = False,
) -> None:
"""
Launch DFT-D3(BJ) dispersion calculation using neighbor matrix format with PBC.
This is a framework-agnostic warp launcher for periodic boundary condition (PBC)
systems that accepts warp arrays directly and orchestrates the multi-pass kernel
execution for DFT-D3(BJ) dispersion energy, forces, virial, and coordination
number calculations. Framework-specific wrappers (PyTorch, JAX) handle
tensor-to-warp conversion and call this function.
For non-periodic systems, use :func:`dftd3_matrix` instead.
Multi-Pass Algorithm
---------------------
1. **Pass 0**: Convert unit cell shifts to Cartesian coordinates
2. **Pass 1**: Compute coordination numbers using geometric counting function
3. **Pass 2**: Compute direct forces, energy, and accumulate dE/dCN
4. **Pass 3**: Add CN-dependent force contribution using chain rule
Parameters
----------
positions : wp.array(dtype=vec3f or vec3d), shape [num_atoms]
Atomic coordinates in consistent distance units (typically Bohr). Supports
both float32 (vec3f) and float64 (vec3d) precision.
numbers : wp.array(dtype=int32), shape [num_atoms]
Atomic numbers
neighbor_matrix : wp.array2d(dtype=int32), shape [num_atoms, max_neighbors]
Neighbor indices. Padding entries have values >= fill_value.
cell : wp.array(dtype=mat33f or mat33d), shape [num_systems]
Unit cell lattice vectors for PBC, in same dtype/units as positions.
Convention: cell[s, i, :] is the i-th lattice vector for system s (row vectors).
neighbor_matrix_shifts : wp.array2d(dtype=vec3i), shape [num_atoms, max_neighbors]
Integer unit cell shifts for PBC. shift[i, k] is the shift for the k-th
neighbor of atom i.
covalent_radii : wp.array(dtype=float32), shape [max_Z+1]
Covalent radii indexed by atomic number, in same units as positions
r4r2 : wp.array(dtype=float32), shape [max_Z+1]
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
C6 reference values in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
coord_num : wp.array(dtype=float32), shape [num_atoms]
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated and zeroed.
forces : wp.array(dtype=vec3f), shape [num_atoms]
OUTPUT: Atomic forces in energy/distance units. Must be pre-allocated and zeroed.
energy : wp.array(dtype=float32), shape [num_systems]
OUTPUT: Dispersion energy in energy units. Must be pre-allocated and zeroed.
virial : wp.array(dtype=mat33f), shape [num_systems]
OUTPUT: Virial tensor in energy units. Must be pre-allocated and zeroed.
Only computed if compute_virial=True.
batch_idx : wp.array(dtype=int32), shape [num_atoms]
Batch indices mapping each atom to its system index.
cartesian_shifts : wp.array(dtype=vec3f or vec3d), shape [num_atoms, max_neighbors]
SCRATCH: Pre-allocated buffer for Cartesian shift vectors.
Populated by Pass 0 from unit cell shifts. Must be pre-allocated
with shape matching neighbor_matrix.
dE_dCN : wp.array(dtype=float32), shape [num_atoms]
SCRATCH: Pre-allocated buffer for chain rule dE/dCN intermediate.
Must be pre-allocated and zeroed by caller.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64) matching positions dtype.
device : str
Warp device string (e.g., 'cuda:0', 'cpu').
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr). Default: 16.0
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless).
Default: -4.0
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 0.0
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 0.0
fill_value : int or None, optional
Value indicating padding in neighbor_matrix. If None, inferred from num_atoms.
compute_virial : bool, optional
If True, compute virial tensor. Default: False
Returns
-------
None
All outputs are written to pre-allocated arrays (coord_num, forces, energy, virial).
Notes
-----
- All output arrays must be pre-allocated and zeroed by the caller
- Supports float32 and float64 positions/cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0 are skipped
- **Two-body only**: Three-body Axilrod-Teller-Muto terms not included
- Unit consistency required: standard D3 parameters use atomic units
(Bohr for distances, Hartree for energy)
- Virial tensor is computed when compute_virial=True
See Also
--------
dftd3_matrix : Neighbor matrix format, non-periodic
dftd3 : Neighbor list (CSR) format, non-periodic
dftd3_pbc : Neighbor list (CSR) format with PBC support
"""
# Get number of atoms from positions array
num_atoms = positions.shape[0]
max_neighbors = neighbor_matrix.shape[1] if num_atoms > 0 else 0
# Set fill_value if not provided
if fill_value is None:
fill_value = num_atoms
# Handle empty case
if num_atoms == 0:
return
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
# Pass 0: Compute cartesian shifts from unit cell shifts
periodic = True
wp.launch(
kernel=_compute_cartesian_shifts_matrix_overload[wp_dtype],
dim=(num_atoms, max_neighbors),
inputs=[
cell,
neighbor_matrix_shifts,
neighbor_matrix,
batch_idx,
wp.int32(fill_value),
],
outputs=[cartesian_shifts],
device=device,
)
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
covalent_radii,
wp.float32(k1),
wp.int32(fill_value),
periodic,
],
outputs=[coord_num],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
coord_num,
r4r2,
c6_reference,
coord_num_ref,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
wp.int32(fill_value),
periodic,
batch_idx,
compute_virial,
],
outputs=[dE_dCN, forces, energy, virial],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_matrix_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
neighbor_matrix,
cartesian_shifts,
covalent_radii,
dE_dCN,
wp.float32(k1),
wp.int32(fill_value),
periodic,
batch_idx,
compute_virial,
],
outputs=[forces, virial],
device=device,
)
[docs]
def dftd3(
positions: wp.array,
numbers: wp.array,
idx_j: wp.array,
neighbor_ptr: wp.array,
covalent_radii: wp.array,
r4r2: wp.array,
c6_reference: wp.array,
coord_num_ref: wp.array,
a1: float,
a2: float,
s8: float,
coord_num: wp.array,
forces: wp.array,
energy: wp.array,
virial: wp.array,
batch_idx: wp.array,
cartesian_shifts: wp.array,
dE_dCN: wp.array,
wp_dtype: type,
device: str,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 0.0,
s5_smoothing_off: float = 0.0,
) -> None:
"""
Launch DFT-D3(BJ) dispersion calculation using neighbor list (CSR) format (non-periodic).
This is a framework-agnostic warp launcher for non-periodic (non-PBC) systems
that accepts warp arrays directly and orchestrates the multi-pass kernel execution
for DFT-D3(BJ) dispersion energy, forces, and coordination number calculations
using CSR (Compressed Sparse Row) neighbor list format. Framework-specific
wrappers (PyTorch, JAX) handle tensor-to-warp conversion and call this function.
For periodic systems, use :func:`dftd3_pbc` instead.
Multi-Pass Algorithm
---------------------
1. **Pass 1**: Compute coordination numbers using geometric counting function
2. **Pass 2**: Compute direct forces, energy, and accumulate dE/dCN
3. **Pass 3**: Add CN-dependent force contribution using chain rule
Parameters
----------
positions : wp.array(dtype=vec3f or vec3d), shape [num_atoms]
Atomic coordinates in consistent distance units (typically Bohr). Supports
both float32 (vec3f) and float64 (vec3d) precision.
numbers : wp.array(dtype=int32), shape [num_atoms]
Atomic numbers
idx_j : wp.array(dtype=int32), shape [num_edges]
Destination atom indices in CSR format
neighbor_ptr : wp.array(dtype=int32), shape [num_atoms+1]
CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1]
gives the range of neighbors for atom i
covalent_radii : wp.array(dtype=float32), shape [max_Z+1]
Covalent radii indexed by atomic number, in same units as positions
r4r2 : wp.array(dtype=float32), shape [max_Z+1]
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
C6 reference values in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
coord_num : wp.array(dtype=float32), shape [num_atoms]
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated and zeroed.
forces : wp.array(dtype=vec3f), shape [num_atoms]
OUTPUT: Atomic forces in energy/distance units. Must be pre-allocated and zeroed.
energy : wp.array(dtype=float32), shape [num_systems]
OUTPUT: Dispersion energy in energy units. Must be pre-allocated and zeroed.
virial : wp.array(dtype=mat33f), shape [num_systems]
OUTPUT: Virial tensor (not computed for non-periodic systems). Must be
pre-allocated but will not be modified.
batch_idx : wp.array(dtype=int32), shape [num_atoms]
Batch indices mapping each atom to its system index.
cartesian_shifts : wp.array(dtype=vec3f or vec3d), shape [num_edges]
SCRATCH: Pre-allocated buffer for Cartesian shift vectors.
Values are not used for non-periodic systems, but the array must
still be provided with length matching idx_j.
Must be pre-allocated by caller.
dE_dCN : wp.array(dtype=float32), shape [num_atoms]
SCRATCH: Pre-allocated buffer for chain rule dE/dCN intermediate.
Must be pre-allocated and zeroed by caller.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64) matching positions dtype.
device : str
Warp device string (e.g., 'cuda:0', 'cpu').
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr). Default: 16.0
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless).
Default: -4.0
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 0.0
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 0.0
Returns
-------
None
All outputs are written to pre-allocated arrays (coord_num, forces, energy).
Virial is not computed for non-periodic systems.
Notes
-----
- All output arrays must be pre-allocated and zeroed by the caller
- Supports float32 and float64 positions; outputs always float32
- Padding atoms indicated by numbers[i] == 0 are skipped
- **Two-body only**: Three-body Axilrod-Teller-Muto terms not included
- Unit consistency required: standard D3 parameters use atomic units
(Bohr for distances, Hartree for energy)
- CSR format is more memory-efficient for sparse neighbor lists
- Virial is NOT computed for non-periodic systems (use dftd3_pbc for PBC)
See Also
--------
dftd3_pbc : Neighbor list (CSR) format with PBC support
dftd3_matrix : Neighbor matrix format, non-periodic
dftd3_matrix_pbc : Neighbor matrix format with PBC support
"""
# Get number of atoms and edges
num_atoms = positions.shape[0]
num_edges = idx_j.shape[0]
# Handle empty case
if num_atoms == 0 or num_edges == 0:
return
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
periodic = False
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
covalent_radii,
wp.float32(k1),
periodic,
],
outputs=[coord_num],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
# compute_virial=False for non-periodic systems
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
coord_num,
r4r2,
c6_reference,
coord_num_ref,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
periodic,
batch_idx,
False, # compute_virial=False for non-periodic
],
outputs=[dE_dCN, forces, energy, virial],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
covalent_radii,
dE_dCN,
wp.float32(k1),
periodic,
batch_idx,
False, # compute_virial=False for non-periodic
],
outputs=[forces, virial],
device=device,
)
[docs]
def dftd3_pbc(
positions: wp.array,
numbers: wp.array,
idx_j: wp.array,
neighbor_ptr: wp.array,
cell: wp.array,
unit_shifts: wp.array,
covalent_radii: wp.array,
r4r2: wp.array,
c6_reference: wp.array,
coord_num_ref: wp.array,
a1: float,
a2: float,
s8: float,
coord_num: wp.array,
forces: wp.array,
energy: wp.array,
virial: wp.array,
batch_idx: wp.array,
cartesian_shifts: wp.array,
dE_dCN: wp.array,
wp_dtype: type,
device: str,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 0.0,
s5_smoothing_off: float = 0.0,
compute_virial: bool = False,
) -> None:
"""
Launch DFT-D3(BJ) dispersion calculation using neighbor list (CSR) format with PBC.
This is a framework-agnostic warp launcher for periodic boundary condition (PBC)
systems that accepts warp arrays directly and orchestrates the multi-pass kernel
execution for DFT-D3(BJ) dispersion energy, forces, virial, and coordination
number calculations using CSR (Compressed Sparse Row) neighbor list format.
Framework-specific wrappers (PyTorch, JAX) handle tensor-to-warp conversion
and call this function.
For non-periodic systems, use :func:`dftd3` instead.
Multi-Pass Algorithm
---------------------
1. **Pass 0**: Convert unit cell shifts to Cartesian coordinates
2. **Pass 1**: Compute coordination numbers using geometric counting function
3. **Pass 2**: Compute direct forces, energy, and accumulate dE/dCN
4. **Pass 3**: Add CN-dependent force contribution using chain rule
Parameters
----------
positions : wp.array(dtype=vec3f or vec3d), shape [num_atoms]
Atomic coordinates in consistent distance units (typically Bohr). Supports
both float32 (vec3f) and float64 (vec3d) precision.
numbers : wp.array(dtype=int32), shape [num_atoms]
Atomic numbers
idx_j : wp.array(dtype=int32), shape [num_edges]
Destination atom indices in CSR format
neighbor_ptr : wp.array(dtype=int32), shape [num_atoms+1]
CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1]
gives the range of neighbors for atom i
cell : wp.array(dtype=mat33f or mat33d), shape [num_systems]
Unit cell lattice vectors for PBC, in same dtype/units as positions.
Convention: cell[s, i, :] is the i-th lattice vector for system s (row vectors).
unit_shifts : wp.array(dtype=vec3i), shape [num_edges]
Integer unit cell shifts for PBC. shift[e] is the shift for edge e.
covalent_radii : wp.array(dtype=float32), shape [max_Z+1]
Covalent radii indexed by atomic number, in same units as positions
r4r2 : wp.array(dtype=float32), shape [max_Z+1]
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
C6 reference values in energy x distance^6 units
coord_num_ref : wp.array4d(dtype=float32), shape [max_Z+1, max_Z+1, 5, 5]
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
coord_num : wp.array(dtype=float32), shape [num_atoms]
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated and zeroed.
forces : wp.array(dtype=vec3f), shape [num_atoms]
OUTPUT: Atomic forces in energy/distance units. Must be pre-allocated and zeroed.
energy : wp.array(dtype=float32), shape [num_systems]
OUTPUT: Dispersion energy in energy units. Must be pre-allocated and zeroed.
virial : wp.array(dtype=mat33f), shape [num_systems]
OUTPUT: Virial tensor in energy units. Must be pre-allocated and zeroed.
Only computed if compute_virial=True.
batch_idx : wp.array(dtype=int32), shape [num_atoms]
Batch indices mapping each atom to its system index.
cartesian_shifts : wp.array(dtype=vec3f or vec3d), shape [num_edges]
SCRATCH: Pre-allocated buffer for Cartesian shift vectors.
Populated by Pass 0 from unit cell shifts. Must be pre-allocated
with length matching idx_j.
dE_dCN : wp.array(dtype=float32), shape [num_atoms]
SCRATCH: Pre-allocated buffer for chain rule dE/dCN intermediate.
Must be pre-allocated and zeroed by caller.
wp_dtype : type
Warp scalar dtype (wp.float32 or wp.float64) matching positions dtype.
device : str
Warp device string (e.g., 'cuda:0', 'cpu').
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr). Default: 16.0
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless).
Default: -4.0
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 0.0
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 0.0
compute_virial : bool, optional
If True, compute virial tensor. Default: False
Returns
-------
None
All outputs are written to pre-allocated arrays (coord_num, forces, energy, virial).
Notes
-----
- All output arrays must be pre-allocated and zeroed by the caller
- Supports float32 and float64 positions/cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0 are skipped
- **Two-body only**: Three-body Axilrod-Teller-Muto terms not included
- Unit consistency required: standard D3 parameters use atomic units
(Bohr for distances, Hartree for energy)
- Virial tensor is computed when compute_virial=True
See Also
--------
dftd3 : Neighbor list (CSR) format, non-periodic
dftd3_matrix : Neighbor matrix format, non-periodic
dftd3_matrix_pbc : Neighbor matrix format with PBC support
"""
# Get number of atoms and edges
num_atoms = positions.shape[0]
num_edges = idx_j.shape[0]
# Handle empty case
if num_atoms == 0 or num_edges == 0:
return
# Precompute inv_w for S5 switching
if s5_smoothing_off > s5_smoothing_on:
inv_w = 1.0 / (s5_smoothing_off - s5_smoothing_on)
else:
inv_w = 0.0
# Pass 0: Compute cartesian shifts from unit cell shifts
periodic = True
wp.launch(
kernel=_compute_cartesian_shifts_overload[wp_dtype],
dim=num_atoms,
inputs=[
cell,
unit_shifts,
neighbor_ptr,
batch_idx,
],
outputs=[cartesian_shifts],
device=device,
)
# Pass 1: Compute coordination numbers
wp.launch(
kernel=_cn_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
covalent_radii,
wp.float32(k1),
periodic,
],
outputs=[coord_num],
device=device,
)
# Pass 2: Compute direct forces, energy, and accumulate dE/dCN
wp.launch(
kernel=_direct_forces_and_dE_dCN_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
coord_num,
r4r2,
c6_reference,
coord_num_ref,
wp.float32(k3),
wp.float32(a1),
wp.float32(a2),
wp.float32(s6),
wp.float32(s8),
wp.float32(s5_smoothing_on),
wp.float32(s5_smoothing_off),
wp.float32(inv_w),
periodic,
batch_idx,
compute_virial,
],
outputs=[dE_dCN, forces, energy, virial],
device=device,
)
# Pass 3: Add CN-dependent force contribution
wp.launch(
kernel=_cn_forces_contrib_kernel_overload[wp_dtype],
dim=num_atoms,
inputs=[
positions,
numbers,
idx_j,
neighbor_ptr,
cartesian_shifts,
covalent_radii,
dE_dCN,
wp.float32(k1),
periodic,
batch_idx,
compute_virial,
],
outputs=[forces, virial],
device=device,
)