Note
Go to the end to download the full example code.
DFT-D3 Dispersion Correction for Batched Crystals#
This example demonstrates GPU-accelerated DFT-D3 calculations for batches of periodic crystal structures. Batch processing is essential for high-throughput computational materials screening and allows efficient parallelization across multiple structures simultaneously.
We’ll load a crystal structure from the Crystallography Open Database (COD), replicate it to create a batch, and compute dispersion energies and forces for all structures at once using GPU acceleration.
Key concepts covered:
Loading crystal structures with periodic boundary conditions
Converting between fractional and Cartesian coordinates
Batch processing multiple structures efficiently
Using batch-aware neighbor lists
Working with atomic units (Bohr, Hartree) internally
Note
This example assumes 01_dftd3_molecule.py has been run first to cache
the DFT-D3 parameters.
Important
This script is intended as an API demonstration. Do not use this script for performance benchmarking; refer to the benchmarks folder instead.
Setup and Configuration#
First, we’ll import the necessary modules and set up our batch parameters. The low-level D3 kernels use atomic units internally (Bohr for distances, Hartree for energies), so we’ll need conversion constants.
from __future__ import annotations
import json
import math
import os
from pathlib import Path
import torch
import warp as wp
from utils import (
BOHR_TO_ANGSTROM, # Conversion: 1 Bohr = 0.529177 Angstrom
HARTREE_TO_EV, # Conversion: 1 Hartree = 27.2114 eV
extract_dftd3_parameters,
load_d3_parameters,
save_dftd3_parameters,
)
from nvalchemiops.interactions.dispersion.dftd3 import (
dftd3,
)
from nvalchemiops.neighborlist import (
estimate_max_neighbors,
neighbor_list,
)
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32
BATCH_SIZE = 8 # Number of structures to process simultaneously
CUTOFF = 20.0 # Interaction cutoff in Angstrom
print(f"Running on device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE} structures")
print(f"Cutoff radius: {CUTOFF} Angstrom")
Running on device: cuda
Batch size: 8 structures
Cutoff radius: 20.0 Angstrom
Load DFT-D3 Parameters#
Load the cached DFT-D3 parameters that contain element-specific C6 coefficients and covalent radii. These were downloaded in example 01.
cache_dir = Path(os.path.expanduser("~")) / ".cache" / "nvalchemiops"
param_file = cache_dir / "dftd3_parameters.pt"
if not param_file.exists():
print("Generating DFT-D3 parameters...")
params = extract_dftd3_parameters()
save_dftd3_parameters(params)
else:
print("Using cached DFT-D3 parameters")
d3_params = load_d3_parameters(param_file)
# Set random seed for reproducibility
torch.manual_seed(42)
torch.set_default_device(DEVICE)
torch.set_default_dtype(DTYPE)
Using cached DFT-D3 parameters
Utility Function: Lattice Parameters to Matrix#
Crystals are often described by lattice parameters (lengths and angles). We need to convert these to a matrix representation for calculations.
def lattice_params_to_matrix(
a: float, b: float, c: float, alpha: float, beta: float, gamma: float
) -> torch.Tensor:
"""
Convert lattice parameters to a 3x3 lattice matrix.
Parameters
----------
a, b, c : float
Lattice lengths in Angstroms
alpha, beta, gamma : float
Lattice angles in degrees
Returns
-------
torch.Tensor, shape (3, 3)
Lattice matrix where each row is a lattice vector
"""
# Convert angles to radians
alpha_rad = math.radians(alpha)
beta_rad = math.radians(beta)
gamma_rad = math.radians(gamma)
# Compute lattice vectors using standard crystallographic convention
ax = a
ay = 0.0
az = 0.0
bx = b * math.cos(gamma_rad)
by = b * math.sin(gamma_rad)
bz = 0.0
cx = c * math.cos(beta_rad)
cy = (
c
* (math.cos(alpha_rad) - math.cos(beta_rad) * math.cos(gamma_rad))
/ math.sin(gamma_rad)
)
cz = math.sqrt(c**2 - cx**2 - cy**2)
lattice_matrix = torch.tensor([[ax, ay, az], [bx, by, bz], [cx, cy, cz]])
return lattice_matrix
Load Crystal Structure from COD#
We’ll load a crystal structure from a JSON file downloaded from the Crystallography Open Database (COD). The structure contains:
Lattice parameters (a, b, c, alpha, beta, gamma)
Atomic positions in fractional coordinates
Element symbols and atomic numbers
with open("./4300813.json") as read_file:
data = json.load(read_file)
num_atoms = len(data["atoms"])
print("\nLoaded crystal structure COD ID 4300813")
print(f"Number of atoms in unit cell: {num_atoms}")
print("Lattice parameters:")
print(f" a = {data['unit_cell']['a']:.3f} Angstrom")
print(f" b = {data['unit_cell']['b']:.3f} Angstrom")
print(f" c = {data['unit_cell']['c']:.3f} Angstrom")
print(f" alpha = {data['unit_cell']['alpha']:.1f}°")
print(f" beta = {data['unit_cell']['beta']:.1f}°")
print(f" gamma = {data['unit_cell']['gamma']:.1f}°")
# Convert lattice parameters to matrix
cell = lattice_params_to_matrix(
data["unit_cell"]["a"],
data["unit_cell"]["b"],
data["unit_cell"]["c"],
data["unit_cell"]["alpha"],
data["unit_cell"]["beta"],
data["unit_cell"]["gamma"],
).unsqueeze(0) # Add batch dimension
Loaded crystal structure COD ID 4300813
Number of atoms in unit cell: 56
Lattice parameters:
a = 18.762 Angstrom
b = 13.720 Angstrom
c = 32.702 Angstrom
alpha = 90.0°
beta = 98.0°
gamma = 90.0°
Extract Atomic Positions and Numbers#
Crystal structures are typically given in fractional coordinates (relative to lattice vectors). We’ll convert these to Cartesian coordinates for calculations.
fractional = torch.zeros((num_atoms, 3), device=DEVICE, dtype=DTYPE)
numbers = torch.zeros((num_atoms), device=DEVICE, dtype=torch.int32)
atomic_number_mapping = data["atomic_numbers"]
for i, atom in enumerate(data["atoms"]):
fractional[i, :] = torch.tensor(
[atom["x"], atom["y"], atom["z"]], device=DEVICE, dtype=DTYPE
)
numbers[i] = atomic_number_mapping[atom["element"]]
# Set periodic boundary conditions (all three directions)
pbc = torch.tensor([True, True, True], device=DEVICE, dtype=torch.bool).unsqueeze(0)
# Convert fractional to Cartesian coordinates: r_cart = fractional @ cell^T
cartesian = fractional @ cell[0].T
print("Converted fractional to Cartesian coordinates")
print(f"Positions shape: {cartesian.shape}")
Converted fractional to Cartesian coordinates
Positions shape: torch.Size([56, 3])
Create Batch of Structures#
To demonstrate batch processing, we’ll replicate the crystal structure multiple times. In practice, each structure in the batch could be:
Different crystal polymorphs
Structures at different temperatures
Molecular dynamics snapshots
High-throughput screening candidates
For this example, we’ll use the same structure replicated with small random perturbations.
# Replicate positions for batch
positions_batch = cartesian.repeat(BATCH_SIZE, 1)
# Add small random noise to distinguish structures (optional)
positions_batch = torch.normal(positions_batch, 0.00).to(DEVICE)
# Convert from Angstrom to Bohr (atomic units) for DFT-D3 kernels
positions_bohr = positions_batch / BOHR_TO_ANGSTROM
# Replicate atomic numbers
numbers_batch = numbers.repeat(BATCH_SIZE)
print(f"\nCreated batch of {BATCH_SIZE} structures")
print(f"Total atoms in batch: {len(numbers_batch)}")
print(f"Positions shape: {positions_batch.shape}")
Created batch of 8 structures
Total atoms in batch: 448
Positions shape: torch.Size([448, 3])
Set Up Batch Indexing#
For batch processing, we need to tell the code which atoms belong to which structure. We use two tensors for this:
ptr: Pointer array marking structure boundaries [0, N, 2N, …, BATCH_SIZE*N]batch_idx: Per-atom array indicating structure membership [0,0,…,1,1,…,2,2,…]
ptr = torch.arange(
0, (BATCH_SIZE + 1) * num_atoms, num_atoms, device=DEVICE, dtype=torch.int32
)
batch_idx = torch.repeat_interleave(
torch.arange(BATCH_SIZE, dtype=torch.int32, device=DEVICE), num_atoms
)
# Replicate cell and PBC for each structure in batch
# Convert cell to Bohr for consistency with positions
cell_batch = (cell / BOHR_TO_ANGSTROM).repeat(BATCH_SIZE, 1, 1)
pbc_batch = pbc.repeat(BATCH_SIZE, 1)
print("Batch indexing:")
print(f" ptr shape: {ptr.shape}")
print(f" batch_idx shape: {batch_idx.shape}")
Batch indexing:
ptr shape: torch.Size([9])
batch_idx shape: torch.Size([448])
Compute Batch Neighbor Lists#
The neighbor list efficiently identifies which atoms are within interaction
range. For batches, we use the batch-aware batch_cell_list method which
processes all structures simultaneously while respecting batch boundaries.
Key features:
O(N) scaling per structure using spatial cell lists
Handles periodic boundary conditions automatically
Returns shifts for atoms interacting across periodic boundaries
# Estimate the number of neighbors per atom using 'sensible' parameters;
# knowledge of the system will improve the estimate and overallocating
# neighbor matrices. The example structure is relatively sparse.
cutoff_bohr = CUTOFF / BOHR_TO_ANGSTROM
max_neighbors = estimate_max_neighbors(
cutoff_bohr,
atomic_density=0.05, # the example structure is quite sparse
safety_factor=0.5,
)
# Build neighbor list (in Bohr to match positions)
# Note that we can choose to either use the `max_neighbors` value, or
# if you know the exact number of neighbors per atom for a given cutoff,
# you can use that value instead for optimal performance
neighbor_matrix, num_neighbors_per_atom, neighbor_matrix_shifts = neighbor_list(
positions_bohr,
cutoff_bohr,
cell_batch,
pbc_batch,
batch_idx=batch_idx,
batch_ptr=ptr,
max_neighbors=max_neighbors, # can specify an integer
method="batch_cell_list", # Batch-aware O(N) algorithm
)
print("\nNeighbor list computed:")
print(f" Neighbor matrix shape: {neighbor_matrix.shape}")
print(f" Average neighbors per atom: {num_neighbors_per_atom.float().mean():.1f}")
print(f" Max neighbors: {num_neighbors_per_atom.max()}")
print(f" Estimated max neighbors: {max_neighbors}")
if num_neighbors_per_atom.max() > max_neighbors:
print(
f" WARNING: Actual max neighbors ({num_neighbors_per_atom.max()})"
f" exceeds estimated max neighbors ({max_neighbors})"
)
elif num_neighbors_per_atom.max() * 0.5 < max_neighbors:
print(
f" WARNING: Actual max neighbors ({num_neighbors_per_atom.max()})"
f" is less than 50% of estimated max neighbors ({max_neighbors})"
)
Neighbor list computed:
Neighbor matrix shape: torch.Size([448, 5664])
Average neighbors per atom: 249.3
Max neighbors: 279
Estimated max neighbors: 5664
WARNING: Actual max neighbors (279) is less than 50% of estimated max neighbors (5664)
Calculate DFT-D3 Dispersion#
Now we can compute the DFT-D3 dispersion energy and forces for all structures in the batch simultaneously. The computation uses GPU-accelerated Warp kernels.
We use PBE0 functional parameters:
s6 = 1.0: Always 1.0 for D3-BJ dampings8 = 1.2177: PBE0-specific C8 scalinga1 = 0.4145: PBE0-specific BJ damping parametera2 = 4.8593: PBE0-specific BJ damping radius
Note
All inputs to the low-level kernels must be in atomic units (Bohr for distances, Hartree for energies). We’ll convert back to eV/Angstrom for output.
energies, forces, coord_num = dftd3(
positions=positions_bohr,
numbers=numbers_batch,
neighbor_matrix=neighbor_matrix,
d3_params=d3_params,
a1=0.4145, # BJ damping (PBE0)
a2=4.8593, # BJ damping (PBE0)
s6=1.0, # C6 coefficient
s8=1.2177, # C8 coefficient (PBE0)
cell=cell_batch,
neighbor_matrix_shifts=neighbor_matrix_shifts,
batch_idx=batch_idx,
device=positions_bohr.device.type,
)
wp.synchronize() # Ensure GPU computation completes
print("\nDFT-D3 calculation completed")
print(f"Computed energies and forces for {BATCH_SIZE} structures")
DFT-D3 calculation completed
Computed energies and forces for 8 structures
Results and Analysis#
Convert results from atomic units (Hartree, Bohr) to conventional units (eV, Angstrom) for easier interpretation.
energies_eV = energies * HARTREE_TO_EV
forces_eV_Ang = forces * HARTREE_TO_EV / BOHR_TO_ANGSTROM
print("\n" + "=" * 60)
print("DISPERSION ENERGY RESULTS")
print("=" * 60)
print("\nPer-structure dispersion energies (eV):")
for i, energy in enumerate(energies_eV):
print(f" Structure {i}: {energy:.6f} eV")
print("\nEnergy statistics:")
print(f" Mean: {energies_eV.mean():.6f} eV")
print(f" Std: {energies_eV.std():.6f} eV")
print(f" Min: {energies_eV.min():.6f} eV")
print(f" Max: {energies_eV.max():.6f} eV")
print("\n" + "=" * 60)
print("FORCE RESULTS (first 5 atoms)")
print("=" * 60)
print("\nForces (eV/Angstrom):")
print(forces_eV_Ang[:5])
max_force_magnitude = forces_eV_Ang.norm(dim=1).max()
print(f"\nMaximum force magnitude: {max_force_magnitude:.6f} eV/Angstrom")
============================================================
DISPERSION ENERGY RESULTS
============================================================
Per-structure dispersion energies (eV):
Structure 0: -3.201655 eV
Structure 1: -3.201655 eV
Structure 2: -3.201655 eV
Structure 3: -3.201655 eV
Structure 4: -3.201655 eV
Structure 5: -3.201655 eV
Structure 6: -3.201655 eV
Structure 7: -3.201655 eV
Energy statistics:
Mean: -3.201655 eV
Std: 0.000000 eV
Min: -3.201655 eV
Max: -3.201655 eV
============================================================
FORCE RESULTS (first 5 atoms)
============================================================
Forces (eV/Angstrom):
tensor([[-0.0288, -0.0057, 0.0306],
[-0.0235, 0.0053, 0.0186],
[-0.0145, -0.0015, 0.0153],
[ 0.0547, -0.0245, 0.0090],
[ 0.0527, 0.0056, 0.0253]], device='cuda:0')
Maximum force magnitude: 0.071464 eV/Angstrom
Total running time of the script: (0 minutes 0.212 seconds)