Note
Go to the end to download the full example code.
JAX DFT-D3 Dispersion Correction for a Molecule#
This example demonstrates how to compute the DFT-D3 dispersion energy and forces for a single molecular system using the JAX API with GPU-accelerated Warp kernels.
The DFT-D3 method provides London dispersion corrections to standard DFT calculations, which is essential for accurately modeling non-covalent interactions. This implementation uses environment-dependent C6 coefficients and includes Becke-Johnson damping (D3-BJ).
In this example you will learn:
How to load DFT-D3 parameters and convert them for the JAX API
Loading molecular coordinates from an XYZ file into JAX arrays
Computing neighbor lists for non-periodic systems using the JAX API
Calculating dispersion energies and forces with the JAX DFT-D3 function
jax.jitcompilation of the full neighbor list + DFT-D3 pipeline
Important
This script is intended as an API demonstration. Do not use this script for performance benchmarking; refer to the benchmarks folder instead.
Setup and Parameter Loading#
First, we need to import the necessary modules and load the DFT-D3 parameters. The parameters contain element-specific C6 coefficients and radii that are used in the dispersion energy calculation.
from __future__ import annotations
import os
import sys
from pathlib import Path
try:
import jax
import jax.numpy as jnp
except ImportError:
print(
"This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'"
)
sys.exit(0)
import numpy as np
import torch
try:
from nvalchemiops.jax.interactions.dispersion import D3Parameters, dftd3
from nvalchemiops.jax.neighbors import neighbor_list
from nvalchemiops.jax.neighbors.naive import naive_neighbor_list
except Exception as exc:
print(
f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime."
)
sys.exit(0)
# Unit conversion constants (CODATA 2022)
BOHR_TO_ANGSTROM = 0.529177210544
HARTREE_TO_EV = 27.211386245981
ANGSTROM_TO_BOHR = 1.0 / BOHR_TO_ANGSTROM
# Check for cached parameters, download if needed
param_file = (
Path(os.path.expanduser("~")) / ".cache" / "nvalchemiops" / "dftd3_parameters.pt"
)
if not param_file.exists():
print("Downloading DFT-D3 parameters...")
try:
_import_dir = str(Path(__file__).parent)
except NameError:
_import_dir = str(Path.cwd())
sys.path.insert(0, _import_dir)
from utils import extract_dftd3_parameters, save_dftd3_parameters
params_torch = extract_dftd3_parameters()
save_dftd3_parameters(params_torch)
else:
params_torch = torch.load(param_file, weights_only=True)
print("Loaded cached DFT-D3 parameters")
# Convert PyTorch tensors to JAX arrays
d3_params = D3Parameters(
rcov=jnp.array(params_torch["rcov"].numpy(), dtype=jnp.float32),
r4r2=jnp.array(params_torch["r4r2"].numpy(), dtype=jnp.float32),
c6ab=jnp.array(params_torch["c6ab"].numpy(), dtype=jnp.float32),
cn_ref=jnp.array(params_torch["cn_ref"].numpy(), dtype=jnp.float32),
)
print(f"Loaded D3 parameters for elements 1-{d3_params.max_z}")
Loaded cached DFT-D3 parameters
Loaded D3 parameters for elements 1-94
Load Molecular Structure#
We’ll load a molecular dimer from an XYZ file. This is a simple text format where the first line contains the number of atoms, the second line is a comment, and subsequent lines contain: element symbol, x, y, z coordinates.
try:
_script_dir = Path(__file__).parent
except NameError:
_script_dir = Path.cwd()
xyz_file = _script_dir / "dimer.xyz"
with open(xyz_file) as f:
lines = f.readlines()
num_atoms = int(lines[0])
coords_angstrom = np.zeros((num_atoms, 3), dtype=np.float32)
atomic_numbers_np = np.zeros(num_atoms, dtype=np.int32)
for i, line in enumerate(lines[2:]):
parts = line.split()
symbol = parts[0]
# Map element symbols to atomic numbers
atomic_number = 6 if symbol == "C" else 1 # Carbon or Hydrogen
atomic_numbers_np[i] = atomic_number
# Store coordinates (in Angstrom)
coords_angstrom[i, 0] = float(parts[1])
coords_angstrom[i, 1] = float(parts[2])
coords_angstrom[i, 2] = float(parts[3])
# Convert to JAX arrays
coords_angstrom_jax = jnp.array(coords_angstrom)
numbers = jnp.array(atomic_numbers_np, dtype=jnp.int32)
# Convert coordinates to Bohr for DFT-D3 calculation
positions_bohr = coords_angstrom_jax * ANGSTROM_TO_BOHR
print(f"Loaded molecule with {num_atoms} atoms")
print(f"Coordinates shape: {positions_bohr.shape}")
Loaded molecule with 36 atoms
Coordinates shape: (36, 3)
Compute Neighbor List#
The DFT-D3 calculation requires knowing which atoms are within interaction range of each other. We use the GPU-accelerated neighbor list from nvalchemiops.
For a non-periodic (molecular) system, we create a large cubic cell and set periodic boundary conditions (PBC) to False.
# For a non-periodic (molecular) system, we simply compute pairwise distances
# without periodic boundary conditions.
# Cutoff of 20 Angstrom in Bohr
cutoff_bohr = 20.0 * ANGSTROM_TO_BOHR
# Compute neighbor list using naive method (better for small non-periodic systems)
# The cell_list method requires cell/pbc even for non-periodic systems
neighbor_matrix, num_neighbors_per_atom = neighbor_list(
positions_bohr,
cutoff=cutoff_bohr,
method="naive",
max_neighbors=64,
)
print(f"Neighbor matrix shape: {neighbor_matrix.shape}")
print(f"Average neighbors per atom: {float(jnp.mean(num_neighbors_per_atom)):.1f}")
Neighbor matrix shape: (36, 64)
Average neighbors per atom: 35.0
Calculate Dispersion Energy and Forces#
Now we can compute the DFT-D3 dispersion correction. The function returns:
energy: total dispersion energy [num_systems] in Hartree
forces: atomic forces [num_atoms, 3] in Hartree/Bohr
coord_num: coordination numbers [num_atoms] (dimensionless)
We use PBE0 functional parameters: - s6 = 1.0 (C6 term coefficient, standard for D3-BJ) - s8 = 1.2177 (C8 term coefficient, PBE0-specific) - a1 = 0.4145 (BJ damping parameter, PBE0-specific) - a2 = 4.8593 (BJ damping radius, PBE0-specific)
energy, forces, coord_num = dftd3(
positions=positions_bohr,
numbers=numbers,
a1=0.4145,
a2=4.8593,
s8=1.2177,
s6=1.0,
d3_params=d3_params,
neighbor_matrix=neighbor_matrix,
fill_value=num_atoms,
)
Results#
Convert outputs to conventional units for display: - Energy: Hartree -> eV - Forces: Hartree/Bohr -> eV/Angstrom
# Convert energy to eV
energy_ev = float(energy[0]) * HARTREE_TO_EV
# Convert forces to eV/Angstrom
forces_ev_angstrom = forces * (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
max_force = float(jnp.max(jnp.linalg.norm(forces_ev_angstrom, axis=1)))
print(f"\nDispersion Energy: {energy_ev:.6f} eV")
print(f"Energy per atom: {energy_ev / num_atoms:.6f} eV")
print(f"Maximum force magnitude: {max_force:.6f} eV/Angstrom")
print(f"\nCoordination numbers: {np.array(coord_num)}")
Dispersion Energy: -1.340935 eV
Energy per atom: -0.037248 eV
Maximum force magnitude: 0.047128 eV/Angstrom
Coordination numbers: [3.2349308 3.2337396 3.2349308 3.2771623 3.1036298 3.2771626 3.0788736
3.1036294 3.0788739 3.0804553 1.0007788 1.0004165 1.0007788 1.0047569
1.004581 1.0047569 1.0045811 1.0048931 3.2349308 3.2337396 3.2349308
3.2771623 3.1036296 3.2771626 3.0788736 3.1036294 3.0788739 3.0804553
1.0007789 1.0004165 1.0007789 1.004757 1.004581 1.004757 1.004581
1.0048932]
JIT Compilation#
Demonstrate combining the neighbor list build and DFT-D3 calculation into a
single jax.jit-compiled function. This fuses the entire pipeline into one
optimized computation.
For JIT compatibility:
max_neighborsmust be specified (static array shapes)Functional parameters (
a1,a2,s8, etc.) must be static literals inside the jitted function (required by Warp FFI kernels)D3Parametersshould be constructed inside the jitted function from traced arraysfill_valueandnum_systemsshould be static
print("\n" + "=" * 70)
print("JIT COMPILATION")
print("=" * 70)
# Define a jitted function that builds neighbors and computes DFT-D3
@jax.jit
def compute_d3_energy_forces(
positions: jax.Array,
numbers: jax.Array,
rcov: jax.Array,
r4r2: jax.Array,
c6ab: jax.Array,
cn_ref: jax.Array,
cutoff: float = cutoff_bohr,
max_neighbors: int = 64,
fill_value: int = num_atoms,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""JIT-compiled neighbor list + DFT-D3 pipeline."""
# Build neighbor matrix inside jit (max_neighbors must be static)
nbmat, _ = naive_neighbor_list(positions, cutoff, max_neighbors=max_neighbors)
# Construct D3Parameters inside jit from traced arrays
params = D3Parameters(rcov=rcov, r4r2=r4r2, c6ab=c6ab, cn_ref=cn_ref)
# Compute DFT-D3 with PBE0 parameters as static literals
energy, forces, coord_num = dftd3(
positions=positions,
numbers=numbers,
a1=0.4145,
a2=4.8593,
s8=1.2177,
s6=1.0,
d3_params=params,
neighbor_matrix=nbmat,
fill_value=fill_value,
)
return energy, forces, coord_num
======================================================================
JIT COMPILATION
======================================================================
Run the jitted function:
print("\nCompiling and running jitted DFT-D3 pipeline...")
jit_energy, jit_forces, jit_cn = compute_d3_energy_forces(
positions_bohr,
numbers,
d3_params.rcov,
d3_params.r4r2,
d3_params.c6ab,
d3_params.cn_ref,
)
jit_energy_ev = float(jit_energy[0]) * HARTREE_TO_EV
jit_forces_ev = jit_forces * (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
jit_max_force = float(jnp.max(jnp.linalg.norm(jit_forces_ev, axis=1)))
print(f" JIT dispersion energy: {jit_energy_ev:.6f} eV")
print(f" JIT max force: {jit_max_force:.6f} eV/Angstrom")
# Compare with non-jitted result
energy_diff = abs(jit_energy_ev - energy_ev)
print(f" Energy difference vs non-jitted: {energy_diff:.2e} eV")
# Second call should be fast (already compiled)
print("\nRunning jitted function again (should reuse compiled code)...")
jit_energy_2, jit_forces_2, _ = compute_d3_energy_forces(
positions_bohr,
numbers,
d3_params.rcov,
d3_params.r4r2,
d3_params.c6ab,
d3_params.cn_ref,
)
print(
f" JIT dispersion energy (2nd call): {float(jit_energy_2[0]) * HARTREE_TO_EV:.6f} eV"
)
Compiling and running jitted DFT-D3 pipeline...
JIT dispersion energy: -1.340935 eV
JIT max force: 0.047128 eV/Angstrom
Energy difference vs non-jitted: 0.00e+00 eV
Running jitted function again (should reuse compiled code)...
JIT dispersion energy (2nd call): -1.340935 eV
Summary#
This example demonstrated:
Parameter loading from cached DFT-D3 reference data (Grimme group)
Molecular structure loading from XYZ files into JAX arrays
Neighbor list construction for non-periodic systems
DFT-D3 energy and forces with PBE0 functional parameters
Unit conversions between atomic (Bohr/Hartree) and conventional (Angstrom/eV) units
JIT compilation of the full neighbor list + DFT-D3 pipeline
Key JAX-specific patterns:
Load PyTorch parameters and convert to JAX arrays via
jnp.arrayConstruct
D3Parametersfrom JAX arraysFor
jax.jit: use static literals for functional parameters (a1,a2,s8), specifymax_neighborsfor static shapes, and constructD3Parametersinside the jitted function
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print("\nKey takeaways:")
print(" - DFT-D3 works in atomic units (Bohr, Hartree) internally")
print(" - Convert Angstrom -> Bohr for positions, Hartree -> eV for energy")
print(" - D3Parameters holds element-specific reference data")
print(" - PBE0 parameters: a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0")
print(" - Use jax.jit to fuse neighbor list + DFT-D3 into one compiled function")
print("\nJAX DFT-D3 example completed successfully!")
======================================================================
SUMMARY
======================================================================
Key takeaways:
- DFT-D3 works in atomic units (Bohr, Hartree) internally
- Convert Angstrom -> Bohr for positions, Hartree -> eV for energy
- D3Parameters holds element-specific reference data
- PBE0 parameters: a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0
- Use jax.jit to fuse neighbor list + DFT-D3 into one compiled function
JAX DFT-D3 example completed successfully!
Total running time of the script: (0 minutes 1.164 seconds)