# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import torch
import warp as wp
from nvalchemiops.interactions.dispersion._dftd3 import (
dftd3 as wp_dftd3,
)
from nvalchemiops.interactions.dispersion._dftd3 import (
dftd3_matrix as wp_dftd3_matrix,
)
from nvalchemiops.interactions.dispersion._dftd3 import (
dftd3_matrix_pbc as wp_dftd3_matrix_pbc,
)
from nvalchemiops.interactions.dispersion._dftd3 import (
dftd3_pbc as wp_dftd3_pbc,
)
from nvalchemiops.torch.types import get_wp_dtype, get_wp_mat_dtype, get_wp_vec_dtype
__all__ = [
"D3Parameters",
"dftd3",
]
[docs]
@dataclass
class D3Parameters:
"""
DFT-D3 reference parameters for dispersion correction calculations.
This dataclass encapsulates all element-specific parameters required for
DFT-D3 dispersion corrections. The main purpose for this structure is to
provide validation, ensuring the correct shapes, dtypes, and keys are
present and complete. These parameters are used by :func:`dftd3`.
Parameters
----------
rcov : torch.Tensor
Covalent radii [max_Z+1] as float32 or float64. Units should be consistent
with position coordinates. Index 0 is reserved for
padding; valid atomic numbers are 1 to max_Z.
r4r2 : torch.Tensor
<r⁴>/<r²> expectation values [max_Z+1] as float32 or float64.
Dimensionless ratio used for computing C8 coefficients from C6 values.
c6ab : torch.Tensor
C6 reference coefficients [max_Z+1, max_Z+1, interp_mesh, interp_mesh]
as float32 or float64. Units are energy x distance^6. Indexed by atomic numbers and coordination number reference indices.
cn_ref : torch.Tensor
Coordination number reference grid [max_Z+1, max_Z+1, interp_mesh, interp_mesh]
as float32 or float64. Dimensionless CN values for Gaussian interpolation.
interp_mesh : int, optional
Size of the coordination number interpolation mesh. Default: 5
(standard DFT-D3 uses a 5x5 grid)
Raises
------
ValueError
If parameter shapes are inconsistent or invalid
TypeError
If parameters are not torch.Tensor or have invalid dtypes
Notes
-----
- Parameters should use consistent units matching your coordinate system.
Standard D3 parameters from the Grimme group use atomic units (Bohr for
distances, Hartree x Bohr^6 for C6 coefficients).
- Index 0 in all arrays is reserved for padding atoms (atomic number 0)
- Valid atomic numbers range from 1 to max_z
- The standard DFT-D3 implementation supports elements 1-94 (H to Pu)
- Parameters can be float32 or float64; they will be converted to float32
during computation for efficiency
Examples
--------
Create parameters from individual tensors:
>>> params = D3Parameters(
... rcov=torch.rand(95), # 94 elements + padding
... r4r2=torch.rand(95),
... c6ab=torch.rand(95, 95, 5, 5),
... cn_ref=torch.rand(95, 95, 5, 5),
... )
Create from a dictionary (e.g., loaded from file):
>>> state_dict = torch.load("dftd3_parameters.pt")
>>> params = D3Parameters(
... rcov=state_dict["rcov"],
... r4r2=state_dict["r4r2"],
... c6ab=state_dict["c6ab"],
... cn_ref=state_dict["cn_ref"],
... )
"""
rcov: torch.Tensor
r4r2: torch.Tensor
c6ab: torch.Tensor
cn_ref: torch.Tensor
interp_mesh: int = 5
def __post_init__(self) -> None:
"""Validate parameter shapes, dtypes, and physical constraints."""
# Type validation
for name, tensor in [
("rcov", self.rcov),
("r4r2", self.r4r2),
("c6ab", self.c6ab),
("cn_ref", self.cn_ref),
]:
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Parameter '{name}' must be a torch.Tensor, got {type(tensor)}"
)
if tensor.dtype not in (torch.float32, torch.float64):
raise TypeError(
f"Parameter '{name}' must be float32 or float64, got {tensor.dtype}"
)
# Shape validation
if self.rcov.ndim != 1:
raise ValueError(
f"rcov must be 1D tensor [max_Z+1], got shape {self.rcov.shape}"
)
max_z = self.rcov.size(0) - 1
if max_z < 1:
raise ValueError(
f"rcov must have at least 2 elements (padding + 1 element), got {self.rcov.size(0)}"
)
if self.r4r2.shape != (max_z + 1,):
raise ValueError(
f"r4r2 must have shape [{max_z + 1}] to match rcov, got {self.r4r2.shape}"
)
expected_c6_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh)
if self.c6ab.shape != expected_c6_shape:
raise ValueError(
f"c6ab must have shape {expected_c6_shape}, got {self.c6ab.shape}"
)
expected_cn_shape = (max_z + 1, max_z + 1, self.interp_mesh, self.interp_mesh)
if self.cn_ref.shape != expected_cn_shape:
raise ValueError(
f"cn_ref must have shape {expected_cn_shape}, got {self.cn_ref.shape}"
)
# Device consistency validation
devices = [
self.rcov.device,
self.r4r2.device,
self.c6ab.device,
self.cn_ref.device,
]
if len({str(d) for d in devices}) > 1:
raise ValueError(
f"All parameters must be on the same device. "
f"Got devices: rcov={self.rcov.device}, r4r2={self.r4r2.device}, "
f"c6ab={self.c6ab.device}, cn_ref={self.cn_ref.device}"
)
@property
def max_z(self) -> int:
"""Maximum atomic number supported by these parameters."""
return self.rcov.size(0) - 1
@property
def device(self) -> torch.device:
"""Device where parameters are stored."""
return self.rcov.device
[docs]
def to(
self,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
) -> D3Parameters:
"""
Move all parameters to the specified device and/or convert to specified dtype.
Parameters
----------
device : str or torch.device or None, optional
Target device (e.g., 'cpu', 'cuda', 'cuda:0'). If None, keeps current device.
dtype : torch.dtype or None, optional
Target dtype (e.g., torch.float32, torch.float64). If None, keeps current dtype.
Returns
-------
D3Parameters
New instance with parameters on the target device and/or dtype
Examples
--------
Move to GPU:
>>> params_gpu = params.to(device='cuda')
Convert to float32:
>>> params_f32 = params.to(dtype=torch.float32)
Move to GPU and convert to float32:
>>> params_gpu_f32 = params.to(device='cuda', dtype=torch.float32)
"""
return D3Parameters(
rcov=self.rcov.to(device=device, dtype=dtype),
r4r2=self.r4r2.to(device=device, dtype=dtype),
c6ab=self.c6ab.to(device=device, dtype=dtype),
cn_ref=self.cn_ref.to(device=device, dtype=dtype),
interp_mesh=self.interp_mesh,
)
# ==============================================================================
# PyTorch Wrapper
# ==============================================================================
@torch.library.custom_op(
"nvalchemiops::dftd3_matrix",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_matrix_op(
positions: torch.Tensor,
numbers: torch.Tensor,
neighbor_matrix: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
fill_value: int | None = None,
batch_idx: torch.Tensor | None = None,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) dispersion energy and forces
computation (non-PBC, neighbor matrix format).
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using Warp kernels for non-periodic systems with neighbor matrix format.
Output tensors must be pre-allocated by the caller and are modified in-place.
For most use cases, prefer the higher-level :func:`dftd3` wrapper function
instead of calling this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64, in consistent distance units
(conventionally Bohr)
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
neighbor_matrix : torch.Tensor, shape (num_atoms, max_neighbors), dtype=int32
Neighbor indices. See module docstring for format details.
Padding entries have values >= fill_value.
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number, in same units as positions
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values in energy x distance^6 units
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces. Must be pre-allocated. Units are energy/distance
(Hartree/Bohr when using standard D3 parameters).
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated.
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor (remains zeros for non-PBC). Must be pre-allocated.
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr for atomic units)
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless)
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless)
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 1e10
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 1e10
fill_value : int | None, optional
Value indicating padding in neighbor_matrix. If None, defaults to num_atoms.
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices. If None, all atoms are in a single system (batch 0).
device : str, optional
Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from positions.
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (remains zeros)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- For PBC calculations, use :func:`_dftd3_matrix_pbc_op` instead
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:func:`_dftd3_matrix_pbc_op` : PBC variant with neighbor matrix format
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
# Set fill_value if not provided
if fill_value is None:
fill_value = num_atoms
# Handle empty case
if num_atoms == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp types
wp_dtype = get_wp_dtype(positions.dtype)
vec_dtype = get_wp_vec_dtype(positions.dtype)
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays (detach positions)
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
neighbor_matrix_wp = wp.from_torch(
neighbor_matrix, dtype=wp.int32, return_ctype=True
)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays (ensure float32)
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate scratch buffers
max_neighbors = neighbor_matrix.shape[1]
cartesian_shifts = torch.zeros(
num_atoms, max_neighbors, 3, dtype=positions.dtype, device=positions.device
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
dE_dCN = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
dE_dCN_wp = wp.from_torch(dE_dCN, dtype=wp.float32, return_ctype=True)
# Call non-PBC warp launcher
wp_dftd3_matrix(
positions=positions_wp,
numbers=numbers_wp,
neighbor_matrix=neighbor_matrix_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=a1,
a2=a2,
s8=s8,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
batch_idx=batch_idx_wp,
cartesian_shifts=cartesian_shifts_wp,
dE_dCN=dE_dCN_wp,
wp_dtype=wp_dtype,
device=device,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
fill_value=fill_value,
)
@torch.library.custom_op(
"nvalchemiops::dftd3_matrix_pbc",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_matrix_pbc_op(
positions: torch.Tensor,
numbers: torch.Tensor,
neighbor_matrix: torch.Tensor,
cell: torch.Tensor,
neighbor_matrix_shifts: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
fill_value: int | None = None,
batch_idx: torch.Tensor | None = None,
compute_virial: bool = False,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) dispersion energy and forces computation (PBC, neighbor matrix format).
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using Warp kernels for periodic systems with neighbor matrix format.
Output tensors must be pre-allocated by the caller and are modified in-place.
For most use cases, prefer the higher-level :func:`dftd3` wrapper function
instead of calling this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64, in consistent distance units
(conventionally Bohr)
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
neighbor_matrix : torch.Tensor, shape (num_atoms, max_neighbors), dtype=int32
Neighbor indices. See module docstring for format details.
Padding entries have values >= fill_value.
cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64
Unit cell lattice vectors for PBC, in same dtype and units as positions.
neighbor_matrix_shifts : torch.Tensor, shape (num_atoms, max_neighbors, 3), dtype=int32
Integer unit cell shifts for PBC.
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number, in same units as positions
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values for C8 computation (dimensionless)
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values in energy x distance^6 units
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid (dimensionless)
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces. Must be pre-allocated. Units are energy/distance
(Hartree/Bohr when using standard D3 parameters).
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers (dimensionless). Must be pre-allocated.
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr for atomic units)
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless)
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless)
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Default: 1e10
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions. Default: 1e10
fill_value : int | None, optional
Value indicating padding in neighbor_matrix. If None, defaults to num_atoms.
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices. If None, all atoms are in a single system (batch 0).
compute_virial : bool, optional
If True, compute virial tensor. Default: False
device : str, optional
Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from positions.
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions and cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Bulk stress tensor can be obtained by dividing virial by system volume.
- For non-PBC calculations, use :func:`_dftd3_matrix_op` instead
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:func:`_dftd3_matrix_op` : Non-PBC variant with neighbor matrix format
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
# Set fill_value if not provided
if fill_value is None:
fill_value = num_atoms
# Handle empty case
if num_atoms == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp types
wp_dtype = get_wp_dtype(positions.dtype)
vec_dtype = get_wp_vec_dtype(positions.dtype)
mat_dtype = get_wp_mat_dtype(positions.dtype)
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays (detach positions)
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
neighbor_matrix_wp = wp.from_torch(
neighbor_matrix, dtype=wp.int32, return_ctype=True
)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays (ensure float32)
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Convert cell and neighbor_matrix_shifts to warp for PBC
cell_wp = wp.from_torch(
cell.detach().to(dtype=positions.dtype, device=positions.device),
dtype=mat_dtype,
return_ctype=True,
)
neighbor_matrix_shifts_wp = wp.from_torch(
neighbor_matrix_shifts.to(dtype=torch.int32, device=positions.device),
dtype=wp.vec3i,
return_ctype=True,
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate scratch buffers
max_neighbors = neighbor_matrix.shape[1]
cartesian_shifts = torch.zeros(
num_atoms, max_neighbors, 3, dtype=positions.dtype, device=positions.device
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
dE_dCN = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
dE_dCN_wp = wp.from_torch(dE_dCN, dtype=wp.float32, return_ctype=True)
# Call PBC warp launcher
wp_dftd3_matrix_pbc(
positions=positions_wp,
numbers=numbers_wp,
neighbor_matrix=neighbor_matrix_wp,
cell=cell_wp,
neighbor_matrix_shifts=neighbor_matrix_shifts_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=a1,
a2=a2,
s8=s8,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
batch_idx=batch_idx_wp,
cartesian_shifts=cartesian_shifts_wp,
dE_dCN=dE_dCN_wp,
wp_dtype=wp_dtype,
device=device,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
fill_value=fill_value,
compute_virial=compute_virial,
)
@torch.library.custom_op(
"nvalchemiops::dftd3",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_op(
positions: torch.Tensor,
numbers: torch.Tensor,
idx_j: torch.Tensor,
neighbor_ptr: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
batch_idx: torch.Tensor | None = None,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) using CSR neighbor list format (non-PBC).
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using CSR (Compressed Sparse Row) neighbor list format with
idx_j (destination indices) and neighbor_ptr (row pointers) for non-periodic
systems. Output tensors must be pre-allocated by the caller and are modified
in-place. For most use cases, prefer the higher-level :func:`dftd3` wrapper
function instead of calling this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
idx_j : torch.Tensor, shape (num_edges,), dtype=int32
Destination atom indices (flattened neighbor list in CSR format)
neighbor_ptr : torch.Tensor, shape (num_atoms+1,), dtype=int32
CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1] gives neighbors of atom i
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid
a1 : float
Becke-Johnson damping parameter 1
a2 : float
Becke-Johnson damping parameter 2
s8 : float
C8 term scaling factor
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor (remains zeros for non-PBC). Must be pre-allocated.
k1 : float, optional
CN counting function steepness parameter
k3 : float, optional
CN interpolation Gaussian width parameter
s6 : float, optional
C6 term scaling factor
s5_smoothing_on : float, optional
Distance where S5 switching begins
s5_smoothing_off : float, optional
Distance where S5 switching completes
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices
device : str, optional
Warp device string
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (remains zeros)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- For PBC calculations, use :func:`_dftd3_pbc_op` instead
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:func:`_dftd3_pbc_op` : PBC variant with CSR neighbor list format
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
num_edges = idx_j.size(0)
# Handle empty case
if num_atoms == 0 or num_edges == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp types
wp_dtype = get_wp_dtype(positions.dtype)
vec_dtype = get_wp_vec_dtype(positions.dtype)
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
idx_j_wp = wp.from_torch(idx_j, dtype=wp.int32, return_ctype=True)
neighbor_ptr_wp = wp.from_torch(neighbor_ptr, dtype=wp.int32, return_ctype=True)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate scratch buffers
cartesian_shifts = torch.zeros(
num_edges, 3, dtype=positions.dtype, device=positions.device
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
dE_dCN = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
dE_dCN_wp = wp.from_torch(dE_dCN, dtype=wp.float32, return_ctype=True)
# Call non-PBC warp launcher
wp_dftd3(
positions=positions_wp,
numbers=numbers_wp,
idx_j=idx_j_wp,
neighbor_ptr=neighbor_ptr_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=a1,
a2=a2,
s8=s8,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
batch_idx=batch_idx_wp,
cartesian_shifts=cartesian_shifts_wp,
dE_dCN=dE_dCN_wp,
wp_dtype=wp_dtype,
device=device,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
)
@torch.library.custom_op(
"nvalchemiops::dftd3_pbc",
mutates_args=("energy", "forces", "coord_num", "virial"),
)
def _dftd3_pbc_op(
positions: torch.Tensor,
numbers: torch.Tensor,
idx_j: torch.Tensor,
neighbor_ptr: torch.Tensor,
cell: torch.Tensor,
unit_shifts: torch.Tensor,
covalent_radii: torch.Tensor,
r4r2: torch.Tensor,
c6_reference: torch.Tensor,
coord_num_ref: torch.Tensor,
a1: float,
a2: float,
s8: float,
energy: torch.Tensor,
forces: torch.Tensor,
coord_num: torch.Tensor,
virial: torch.Tensor,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
batch_idx: torch.Tensor | None = None,
compute_virial: bool = False,
device: str | None = None,
) -> None:
"""Internal custom op for DFT-D3(BJ) using CSR neighbor list format (PBC).
This is a low-level custom operator that performs DFT-D3(BJ) dispersion
calculations using CSR (Compressed Sparse Row) neighbor list format with
idx_j (destination indices) and neighbor_ptr (row pointers) for periodic
systems. Output tensors must be pre-allocated by the caller and are modified
in-place. For most use cases, prefer the higher-level :func:`dftd3` wrapper
function instead of calling this method directly.
This function is torch.compile compatible.
Parameters
----------
positions : torch.Tensor, shape (num_atoms, 3)
Atomic coordinates as float32 or float64
numbers : torch.Tensor, shape (num_atoms), dtype=int32
Atomic numbers
idx_j : torch.Tensor, shape (num_edges,), dtype=int32
Destination atom indices (flattened neighbor list in CSR format)
neighbor_ptr : torch.Tensor, shape (num_atoms+1,), dtype=int32
CSR row pointers where neighbor_ptr[i]:neighbor_ptr[i+1] gives neighbors of atom i
cell : torch.Tensor, shape (num_systems, 3, 3), dtype=float32 or float64
Unit cell lattice vectors for PBC, in same dtype and units as positions.
unit_shifts : torch.Tensor, shape (num_edges, 3), dtype=int32
Integer unit cell shifts for PBC
covalent_radii : torch.Tensor, shape (max_Z+1), dtype=float32
Covalent radii indexed by atomic number
r4r2 : torch.Tensor, shape (max_Z+1), dtype=float32
<r⁴>/<r²> expectation values
c6_reference : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
C6 reference values
coord_num_ref : torch.Tensor, shape (max_Z+1, max_Z+1, 5, 5), dtype=float32
CN reference grid
a1 : float
Becke-Johnson damping parameter 1
a2 : float
Becke-Johnson damping parameter 2
s8 : float
C8 term scaling factor
energy : torch.Tensor, shape (num_systems,), dtype=float32
OUTPUT: Total dispersion energy
forces : torch.Tensor, shape (num_atoms, 3), dtype=float32
OUTPUT: Atomic forces
coord_num : torch.Tensor, shape (num_atoms,), dtype=float32
OUTPUT: Coordination numbers
virial : torch.Tensor, shape (num_systems, 3, 3), dtype=float32
OUTPUT: Virial tensor. Must be pre-allocated. Units are energy
(Hartree when using standard D3 parameters).
k1 : float, optional
CN counting function steepness parameter
k3 : float, optional
CN interpolation Gaussian width parameter
s6 : float, optional
C6 term scaling factor
s5_smoothing_on : float, optional
Distance where S5 switching begins
s5_smoothing_off : float, optional
Distance where S5 switching completes
batch_idx : torch.Tensor, shape (num_atoms,), dtype=int32, optional
Batch indices
compute_virial : bool, optional
If True, compute virial tensor. Default: False
device : str, optional
Warp device string
Returns
-------
None
Modifies input tensors in-place: energy, forces, coord_num, virial (if compute_virial=True)
Notes
-----
- All input tensors should use consistent units. Standard D3 parameters use
atomic units (Bohr for distances, Hartree for energy).
- Float32 or float64 precision for positions and cell; outputs always float32
- Padding atoms indicated by numbers[i] == 0
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Bulk stress tensor can be obtained by dividing virial by system volume.
- For non-PBC calculations, use :func:`_dftd3_op` instead
See Also
--------
:func:`dftd3` : Higher-level wrapper that handles allocation
:func:`_dftd3_op` : Non-PBC variant with CSR neighbor list format
"""
# Ensure all parameters are on correct device/dtype
covalent_radii = covalent_radii.to(device=positions.device, dtype=torch.float32)
r4r2 = r4r2.to(device=positions.device, dtype=torch.float32)
c6_reference = c6_reference.to(device=positions.device, dtype=torch.float32)
coord_num_ref = coord_num_ref.to(device=positions.device, dtype=torch.float32)
# Get shapes
num_atoms = positions.size(0)
num_edges = idx_j.size(0)
# Handle empty case
if num_atoms == 0 or num_edges == 0:
return
# Infer device from positions if not provided
if device is None:
device = str(positions.device)
# Zero output tensors
energy.zero_()
forces.zero_()
coord_num.zero_()
virial.zero_()
# Detect dtype and set appropriate Warp types
wp_dtype = get_wp_dtype(positions.dtype)
vec_dtype = get_wp_vec_dtype(positions.dtype)
mat_dtype = get_wp_mat_dtype(positions.dtype)
# Create batch indices if not provided (single system)
if batch_idx is None:
batch_idx = torch.zeros(num_atoms, dtype=torch.int32, device=positions.device)
# Convert PyTorch tensors to Warp arrays
positions_wp = wp.from_torch(positions.detach(), dtype=vec_dtype, return_ctype=True)
numbers_wp = wp.from_torch(numbers, dtype=wp.int32, return_ctype=True)
idx_j_wp = wp.from_torch(idx_j, dtype=wp.int32, return_ctype=True)
neighbor_ptr_wp = wp.from_torch(neighbor_ptr, dtype=wp.int32, return_ctype=True)
batch_idx_wp = wp.from_torch(batch_idx, dtype=wp.int32, return_ctype=True)
# Convert parameter tensors to Warp arrays
covalent_radii_wp = wp.from_torch(
covalent_radii.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
r4r2_wp = wp.from_torch(
r4r2.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
c6_reference_wp = wp.from_torch(
c6_reference.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
coord_num_ref_wp = wp.from_torch(
coord_num_ref.to(dtype=torch.float32, device=positions.device),
dtype=wp.float32,
return_ctype=True,
)
# Convert cell and unit_shifts to warp for PBC
cell_wp = wp.from_torch(
cell.detach().to(dtype=positions.dtype, device=positions.device),
dtype=mat_dtype,
return_ctype=True,
)
unit_shifts_wp = wp.from_torch(
unit_shifts.to(dtype=torch.int32, device=positions.device),
dtype=wp.vec3i,
return_ctype=True,
)
# Convert pre-allocated output arrays to Warp
coord_num_wp = wp.from_torch(coord_num, dtype=wp.float32, return_ctype=True)
forces_wp = wp.from_torch(forces, dtype=wp.vec3f, return_ctype=True)
energy_wp = wp.from_torch(energy, dtype=wp.float32, return_ctype=True)
virial_wp = wp.from_torch(virial, dtype=wp.mat33f, return_ctype=True)
# Allocate scratch buffers
cartesian_shifts = torch.zeros(
num_edges, 3, dtype=positions.dtype, device=positions.device
)
cartesian_shifts_wp = wp.from_torch(
cartesian_shifts, dtype=vec_dtype, return_ctype=True
)
dE_dCN = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
dE_dCN_wp = wp.from_torch(dE_dCN, dtype=wp.float32, return_ctype=True)
# Call PBC warp launcher
wp_dftd3_pbc(
positions=positions_wp,
numbers=numbers_wp,
idx_j=idx_j_wp,
neighbor_ptr=neighbor_ptr_wp,
cell=cell_wp,
unit_shifts=unit_shifts_wp,
covalent_radii=covalent_radii_wp,
r4r2=r4r2_wp,
c6_reference=c6_reference_wp,
coord_num_ref=coord_num_ref_wp,
a1=a1,
a2=a2,
s8=s8,
coord_num=coord_num_wp,
forces=forces_wp,
energy=energy_wp,
virial=virial_wp,
batch_idx=batch_idx_wp,
cartesian_shifts=cartesian_shifts_wp,
dE_dCN=dE_dCN_wp,
wp_dtype=wp_dtype,
device=device,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
compute_virial=compute_virial,
)
[docs]
def dftd3(
positions: torch.Tensor,
numbers: torch.Tensor,
a1: float,
a2: float,
s8: float,
k1: float = 16.0,
k3: float = -4.0,
s6: float = 1.0,
s5_smoothing_on: float = 1e10,
s5_smoothing_off: float = 1e10,
fill_value: int | None = None,
d3_params: D3Parameters | dict[str, torch.Tensor] | None = None,
covalent_radii: torch.Tensor | None = None,
r4r2: torch.Tensor | None = None,
c6_reference: torch.Tensor | None = None,
coord_num_ref: torch.Tensor | None = None,
batch_idx: torch.Tensor | None = None,
cell: torch.Tensor | None = None,
neighbor_matrix: torch.Tensor | None = None,
neighbor_matrix_shifts: torch.Tensor | None = None,
neighbor_list: torch.Tensor | None = None,
neighbor_ptr: torch.Tensor | None = None,
unit_shifts: torch.Tensor | None = None,
compute_virial: bool = False,
num_systems: int | None = None,
device: str | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
):
"""
Compute DFT-D3(BJ) dispersion energy and forces using Warp
with optional periodic boundary condition support and smoothing function.
**DFT-D3 parameters must be explicitly provided** using one of three methods:
1. **D3Parameters dataclass**: Supply a :class:`D3Parameters` instance (recommended).
Individual parameters can override dataclass values if both are provided.
2. **Explicit parameters**: Supply all four parameters individually:
``covalent_radii``, ``r4r2``, ``c6_reference``, and ``coord_num_ref``.
3. **Dictionary**: Provide a ``d3_params`` dictionary with keys:
``"rcov"``, ``"r4r2"``, ``"c6ab"``, and ``"cn_ref"``.
Individual parameters can override dictionary values if both are provided.
See ``examples/dispersion/utils.py`` for parameter generation utilities.
This wrapper can be launched by either supplying a neighbor matrix or a
neighbor list, both of which can be generated by the :func:`nvalchemiops.neighborlist.neighbor_list` function where the latter can be returned by setting the `return_neighbor_list` parameter to True.
Parameters
----------
positions : torch.Tensor
Atomic coordinates [num_atoms, 3] as float32 or float64, in consistent distance
units (conventionally Bohr when using standard D3 parameters)
numbers : torch.Tensor
Atomic numbers [num_atoms] as int32
a1 : float
Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 : float
Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 : float
C8 term scaling factor (functional-dependent, dimensionless)
k1 : float, optional
CN counting function steepness parameter, in inverse distance units
(typically 16.0 1/Bohr for atomic units)
k3 : float, optional
CN interpolation Gaussian width parameter (typically -4.0, dimensionless)
s6 : float, optional
C6 term scaling factor (typically 1.0, dimensionless)
s5_smoothing_on : float, optional
Distance where S5 switching begins, in same units as positions. Set greater or
equal to s5_smoothing_off to disable smoothing. Default: 1e10
s5_smoothing_off : float, optional
Distance where S5 switching completes, in same units as positions.
Default: 1e10 (effectively no cutoff)
fill_value : int | None, optional
Value indicating padding in neighbor_matrix. If None, defaults to num_atoms.
Entries with neighbor_matrix[i, k] >= fill_value are treated as padding. Default: None
d3_params : D3Parameters | dict[str, torch.Tensor] | None, optional
DFT-D3 parameters provided as either:
- :class:`D3Parameters` dataclass instance (recommended)
- Dictionary with keys: "rcov", "r4r2", "c6ab", "cn_ref"
Individual parameters below can override values from d3_params.
covalent_radii : torch.Tensor | None, optional
Covalent radii [max_Z+1] as float32, indexed by atomic number, in same units
as positions. If provided, overrides the value in d3_params.
r4r2 : torch.Tensor | None, optional
<r4>/<r2> expectation values [max_Z+1] as float32 for C8 computation (dimensionless).
If provided, overrides the value in d3_params.
c6_reference : torch.Tensor | None, optional
C6 reference values [max_Z+1, max_Z+1, 5, 5] as float32 in energy × distance^6 units.
If provided, overrides the value in d3_params.
coord_num_ref : torch.Tensor | None, optional
CN reference grid [max_Z+1, max_Z+1, 5, 5] as float32 (dimensionless).
If provided, overrides the value in d3_params.
batch_idx : torch.Tensor or None, optional
Batch indices [num_atoms] as int32. If None, all atoms are assumed
to be in a single system (batch 0). For batched calculations, atoms with
the same batch index belong to the same system. Default: None
cell : torch.Tensor or None, optional, as float32 or float64
Unit cell lattice vectors [num_systems, 3, 3] for PBC, in same dtype and units as positions.
Convention: cell[s, i, :] is i-th lattice vector for system s.
If None, non-periodic calculation. Default: None
neighbor_matrix : torch.Tensor | None, optional
Neighbor indices [num_atoms, max_neighbors] as int32. See module docstring for
details on the format. Padding entries have values >= fill_value.
Mutually exclusive with neighbor_list. Default: None
neighbor_matrix_shifts : torch.Tensor or None, optional
Integer unit cell shifts [num_atoms, max_neighbors, 3] as int32 for PBC with
neighbor_matrix format. If None, non-periodic calculation. If provided along
with cell, Cartesian shifts are computed. Mutually exclusive with unit_shifts.
Default: None
neighbor_list : torch.Tensor or None, optional
Neighbor pairs [2, num_pairs] as int32 in COO format, where row 0 contains
source atom indices and row 1 contains target atom indices. Alternative to
neighbor_matrix for sparse neighbor representations. Mutually exclusive with
neighbor_matrix. Must be used together with `neighbor_ptr` (both are returned
by the neighbor list API when `return_neighbor_list=True`).
Default: None
neighbor_ptr : torch.Tensor or None, optional
CSR row pointers [num_atoms+1] as int32. Required when using `neighbor_list`.
Indicates that `neighbor_list[1, :]` contains destination atoms in CSR
format where
`neighbor_ptr[i]:neighbor_ptr[i+1]` gives the range of neighbors for atom i.
Returned by the neighbor list API when `return_neighbor_list=True`.
Default: None
unit_shifts : torch.Tensor or None, optional
Integer unit cell shifts [num_pairs, 3] as int32 for PBC with neighbor_list
format. If None, non-periodic calculation. If provided along with cell,
Cartesian shifts are computed. Mutually exclusive with neighbor_matrix_shifts.
Default: None
compute_virial : bool, optional
If True, allocate and compute virial tensor. Ignored if virial
parameter is provided. Default: False
num_systems : int, optional
Number of systems in batch. In none provided, inferred from cell
or from batch_idx (introcudes CUDA synchronization overhead). Default: None
device : str or None, optional
Warp device string (e.g., 'cuda:0', 'cpu'). If None, inferred from
positions tensor. Default: None
Returns
-------
energy : torch.Tensor
Total dispersion energy [num_systems] as float32. Units are energy
(Hartree when using standard D3 parameters).
forces : torch.Tensor
Atomic forces [num_atoms, 3] as float32. Units are energy/distance
(Hartree/Bohr when using standard D3 parameters).
coord_num : torch.Tensor
Coordination numbers [num_atoms] as float32 (dimensionless)
virial : torch.Tensor, optional
Virial tensor [num_systems, 3, 3] as float32.
Units are energy (Hartree when using standard D3 parameters). Only returned
if compute_virial=True.
Notes
-----
- **Unit consistency**: All inputs must use consistent units. Standard D3 parameters
from the Grimme group use atomic units (Bohr for distances, Hartree for energy),
so using atomic units throughout is recommended and conventional.
- Float32 or float64 precision for positions and cell; outputs always float32
- **Neighbor formats**: Supports both neighbor_matrix (dense) and neighbor_list (sparse COO)
formats. Choose neighbor_list for sparse systems or when memory efficiency is important.
- Padding atoms indicated by numbers[i] == 0
- Requires symmetric neighbor representation (each pair appears twice)
- **Two-body only**: Computes pairwise C6 and C8 dispersion terms; three-body
Axilrod-Teller-Muto (ATM/C9) terms are not included
- Virial computation requires periodic boundary conditions.
- Bulk stress tensor can be obtained by dividing virial by system volume.
**Neighbor Format Selection**:
- Use neighbor_matrix for dense systems or when max_neighbors is small
- Use neighbor_list for sparse systems, large cutoffs, or memory-constrained scenarios
- Both formats produce identical results and support PBC
**PBC Handling**:
- Matrix format: Provide cell and neighbor_matrix_shifts
- List format: Provide cell and unit_shifts
- Non-periodic: Omit both cell and shift parameters
See Also
--------
:class:`D3Parameters` : Dataclass for organizing DFT-D3 reference parameters
:func:`_dftd3_matrix_op` : Internal custom operator for neighbor matrix format (non-PBC)
:func:`_dftd3_matrix_pbc_op` : Internal custom operator for neighbor matrix format (PBC)
:func:`_dftd3_op` : Internal custom operator for neighbor list format (non-PBC)
:func:`_dftd3_pbc_op` : Internal custom operator for neighbor list format (PBC)
"""
# Validate neighbor format inputs
matrix_provided = neighbor_matrix is not None
list_provided = neighbor_list is not None
if matrix_provided and list_provided:
raise ValueError(
"Cannot provide both neighbor_matrix and neighbor_list. "
"Please provide only one neighbor representation format."
)
if not matrix_provided and not list_provided:
raise ValueError("Must provide either neighbor_matrix or neighbor_list.")
# Validate PBC shift inputs match neighbor format
if matrix_provided and unit_shifts is not None:
raise ValueError(
"unit_shifts is for neighbor_list format. "
"Use neighbor_matrix_shifts for neighbor_matrix format."
)
if list_provided and neighbor_matrix_shifts is not None:
raise ValueError(
"neighbor_matrix_shifts is for neighbor_matrix format. "
"Use unit_shifts for neighbor_list format."
)
# Validate neighbor_ptr is provided when using neighbor_list format
if list_provided and neighbor_ptr is None:
raise ValueError(
"neighbor_ptr must be provided when using neighbor_list format. "
"Obtain it from the neighbor list API by setting return_neighbor_list=True."
)
# Validate functional parameters
if a1 is None or a2 is None or s8 is None:
raise ValueError(
"Functional parameters a1, a2, and s8 must be provided. "
"These are functional-dependent parameters required for DFT-D3(BJ) calculations."
)
# Validate virial computation requires PBC
if compute_virial:
if cell is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide unit cell parameters (cell) and shifts "
"(neighbor_matrix_shifts or unit_shifts) when compute_virial=True "
"or when passing a virial tensor."
)
if matrix_provided and neighbor_matrix_shifts is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide neighbor_matrix_shifts along with cell when using "
"neighbor_matrix format and compute_virial=True or passing a virial tensor."
)
if list_provided and unit_shifts is None:
raise ValueError(
"Virial computation requires periodic boundary conditions. "
"Please provide unit_shifts along with cell when using "
"neighbor_list format and compute_virial=True or passing a virial tensor."
)
# Determine how parameters are being supplied
# Case 1: All individual parameters provided explicitly
if all(
param is not None
for param in [covalent_radii, r4r2, c6_reference, coord_num_ref]
):
# Use explicit parameters directly (already assigned)
pass
# Case 2: d3_params provided (D3Parameters or dictionary, with optional overrides)
elif d3_params is not None:
# Convert D3Parameters to dictionary for consistent access
if isinstance(d3_params, D3Parameters):
d3_params = d3_params.__dict__
# these are written to throw KeyError if the keys are not present
if covalent_radii is None:
covalent_radii = d3_params["rcov"]
if r4r2 is None:
r4r2 = d3_params["r4r2"]
if c6_reference is None:
c6_reference = d3_params["c6ab"]
if coord_num_ref is None:
coord_num_ref = d3_params["cn_ref"]
# Case 3: No parameters provided - raise error
else:
raise RuntimeError(
"DFT-D3 parameters must be explicitly provided. "
"Either supply all individual parameters (covalent_radii, r4r2, "
"c6_reference, coord_num_ref), provide a D3Parameters instance, "
"or provide a d3_params dictionary. See the function docstring for details."
)
# Get shapes
num_atoms = positions.size(0)
# Handle empty case
if num_atoms == 0:
if batch_idx is None or (
isinstance(batch_idx, torch.Tensor) and batch_idx.numel() == 0
):
num_systems = 1
else:
num_systems = int(batch_idx.max().item()) + 1
empty_energy = torch.zeros(
num_systems, dtype=torch.float32, device=positions.device
)
empty_forces = torch.zeros((0, 3), dtype=torch.float32, device=positions.device)
empty_cn = torch.zeros((0,), dtype=torch.float32, device=positions.device)
# Handle virial for empty case if compute_virial is True
if compute_virial:
empty_virial = torch.zeros(
(0, 3, 3), dtype=torch.float32, device=positions.device
)
return empty_energy, empty_forces, empty_cn, empty_virial
else:
return empty_energy, empty_forces, empty_cn
# Determine number of systems for energy allocation
if num_systems is None:
if batch_idx is None:
num_systems = 1
elif cell is not None:
num_systems = cell.size(0)
else:
num_systems = int(batch_idx.max().item()) + 1
# Allocate output tensors
energy = torch.zeros(num_systems, dtype=torch.float32, device=positions.device)
forces = torch.zeros((num_atoms, 3), dtype=torch.float32, device=positions.device)
coord_num = torch.zeros(num_atoms, dtype=torch.float32, device=positions.device)
if compute_virial:
virial = torch.zeros(
(num_systems, 3, 3), dtype=torch.float32, device=positions.device
)
else:
virial = torch.zeros((0, 3, 3), dtype=torch.float32, device=positions.device)
# Dispatch to appropriate implementation based on neighbor format and PBC
if neighbor_matrix is not None:
# Matrix format - dispatch based on PBC
if cell is not None and neighbor_matrix_shifts is not None:
# PBC variant
_dftd3_matrix_pbc_op(
positions=positions,
numbers=numbers,
neighbor_matrix=neighbor_matrix,
cell=cell,
neighbor_matrix_shifts=neighbor_matrix_shifts,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
virial=virial,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
fill_value=fill_value,
batch_idx=batch_idx,
compute_virial=compute_virial,
device=device,
)
else:
# Non-PBC variant
_dftd3_matrix_op(
positions=positions,
numbers=numbers,
neighbor_matrix=neighbor_matrix,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
virial=virial,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
fill_value=fill_value,
batch_idx=batch_idx,
device=device,
)
else:
# List format - use CSR format from neighbor list API
# neighbor_list: [2, num_pairs] in COO format where row 1 is idx_j (destination atoms)
# neighbor_ptr: [num_atoms+1] CSR row pointers (required, from neighbor list API)
# Extract idx_j from neighbor_list (row 1 contains destination atoms)
idx_j_csr = neighbor_list[1]
# Dispatch based on PBC
if cell is not None and unit_shifts is not None:
# PBC variant
_dftd3_pbc_op(
positions=positions,
numbers=numbers,
idx_j=idx_j_csr,
neighbor_ptr=neighbor_ptr,
cell=cell,
unit_shifts=unit_shifts,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
virial=virial,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
batch_idx=batch_idx,
compute_virial=compute_virial,
device=device,
)
else:
# Non-PBC variant
_dftd3_op(
positions=positions,
numbers=numbers,
idx_j=idx_j_csr,
neighbor_ptr=neighbor_ptr,
covalent_radii=covalent_radii,
r4r2=r4r2,
c6_reference=c6_reference,
coord_num_ref=coord_num_ref,
a1=a1,
a2=a2,
s8=s8,
energy=energy,
forces=forces,
coord_num=coord_num,
virial=virial,
k1=k1,
k3=k3,
s6=s6,
s5_smoothing_on=s5_smoothing_on,
s5_smoothing_off=s5_smoothing_off,
batch_idx=batch_idx,
device=device,
)
if compute_virial:
return energy, forces, coord_num, virial
else:
return energy, forces, coord_num