.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/dispersion/03_jax_dftd3_molecule.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_dispersion_03_jax_dftd3_molecule.py: JAX DFT-D3 Dispersion Correction for a Molecule ================================================ This example demonstrates how to compute the DFT-D3 dispersion energy and forces for a single molecular system using the JAX API with GPU-accelerated Warp kernels. The DFT-D3 method provides London dispersion corrections to standard DFT calculations, which is essential for accurately modeling non-covalent interactions. This implementation uses environment-dependent C6 coefficients and includes Becke-Johnson damping (D3-BJ). In this example you will learn: - How to load DFT-D3 parameters and convert them for the JAX API - Loading molecular coordinates from an XYZ file into JAX arrays - Computing neighbor lists for non-periodic systems using the JAX API - Calculating dispersion energies and forces with the JAX DFT-D3 function - ``jax.jit`` compilation of the full neighbor list + DFT-D3 pipeline .. important:: This script is intended as an API demonstration. Do not use this script for performance benchmarking; refer to the `benchmarks` folder instead. .. GENERATED FROM PYTHON SOURCE LINES 43-48 Setup and Parameter Loading ---------------------------- First, we need to import the necessary modules and load the DFT-D3 parameters. The parameters contain element-specific C6 coefficients and radii that are used in the dispersion energy calculation. .. GENERATED FROM PYTHON SOURCE LINES 48-111 .. code-block:: Python from __future__ import annotations import os import sys from pathlib import Path try: import jax import jax.numpy as jnp except ImportError: print( "This example requires JAX. Install with: pip install 'nvalchemi-toolkit-ops[jax]'" ) sys.exit(0) import numpy as np import torch try: from nvalchemiops.jax.interactions.dispersion import D3Parameters, dftd3 from nvalchemiops.jax.neighbors import neighbor_list from nvalchemiops.jax.neighbors.naive import naive_neighbor_list except Exception as exc: print( f"JAX/Warp backend unavailable ({exc}). This example requires a CUDA-backed runtime." ) sys.exit(0) # Unit conversion constants (CODATA 2022) BOHR_TO_ANGSTROM = 0.529177210544 HARTREE_TO_EV = 27.211386245981 ANGSTROM_TO_BOHR = 1.0 / BOHR_TO_ANGSTROM # Check for cached parameters, download if needed param_file = ( Path(os.path.expanduser("~")) / ".cache" / "nvalchemiops" / "dftd3_parameters.pt" ) if not param_file.exists(): print("Downloading DFT-D3 parameters...") try: _import_dir = str(Path(__file__).parent) except NameError: _import_dir = str(Path.cwd()) sys.path.insert(0, _import_dir) from utils import extract_dftd3_parameters, save_dftd3_parameters params_torch = extract_dftd3_parameters() save_dftd3_parameters(params_torch) else: params_torch = torch.load(param_file, weights_only=True) print("Loaded cached DFT-D3 parameters") # Convert PyTorch tensors to JAX arrays d3_params = D3Parameters( rcov=jnp.array(params_torch["rcov"].numpy(), dtype=jnp.float32), r4r2=jnp.array(params_torch["r4r2"].numpy(), dtype=jnp.float32), c6ab=jnp.array(params_torch["c6ab"].numpy(), dtype=jnp.float32), cn_ref=jnp.array(params_torch["cn_ref"].numpy(), dtype=jnp.float32), ) print(f"Loaded D3 parameters for elements 1-{d3_params.max_z}") .. rst-class:: sphx-glr-script-out .. code-block:: none Loaded cached DFT-D3 parameters Loaded D3 parameters for elements 1-94 .. GENERATED FROM PYTHON SOURCE LINES 112-117 Load Molecular Structure ------------------------ We'll load a molecular dimer from an XYZ file. This is a simple text format where the first line contains the number of atoms, the second line is a comment, and subsequent lines contain: element symbol, x, y, z coordinates. .. GENERATED FROM PYTHON SOURCE LINES 117-154 .. code-block:: Python try: _script_dir = Path(__file__).parent except NameError: _script_dir = Path.cwd() xyz_file = _script_dir / "dimer.xyz" with open(xyz_file) as f: lines = f.readlines() num_atoms = int(lines[0]) coords_angstrom = np.zeros((num_atoms, 3), dtype=np.float32) atomic_numbers_np = np.zeros(num_atoms, dtype=np.int32) for i, line in enumerate(lines[2:]): parts = line.split() symbol = parts[0] # Map element symbols to atomic numbers atomic_number = 6 if symbol == "C" else 1 # Carbon or Hydrogen atomic_numbers_np[i] = atomic_number # Store coordinates (in Angstrom) coords_angstrom[i, 0] = float(parts[1]) coords_angstrom[i, 1] = float(parts[2]) coords_angstrom[i, 2] = float(parts[3]) # Convert to JAX arrays coords_angstrom_jax = jnp.array(coords_angstrom) numbers = jnp.array(atomic_numbers_np, dtype=jnp.int32) # Convert coordinates to Bohr for DFT-D3 calculation positions_bohr = coords_angstrom_jax * ANGSTROM_TO_BOHR print(f"Loaded molecule with {num_atoms} atoms") print(f"Coordinates shape: {positions_bohr.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Loaded molecule with 36 atoms Coordinates shape: (36, 3) .. GENERATED FROM PYTHON SOURCE LINES 155-162 Compute Neighbor List --------------------- The DFT-D3 calculation requires knowing which atoms are within interaction range of each other. We use the GPU-accelerated neighbor list from nvalchemiops. For a non-periodic (molecular) system, we create a large cubic cell and set periodic boundary conditions (PBC) to False. .. GENERATED FROM PYTHON SOURCE LINES 162-181 .. code-block:: Python # For a non-periodic (molecular) system, we simply compute pairwise distances # without periodic boundary conditions. # Cutoff of 20 Angstrom in Bohr cutoff_bohr = 20.0 * ANGSTROM_TO_BOHR # Compute neighbor list using naive method (better for small non-periodic systems) # The cell_list method requires cell/pbc even for non-periodic systems neighbor_matrix, num_neighbors_per_atom = neighbor_list( positions_bohr, cutoff=cutoff_bohr, method="naive", max_neighbors=64, ) print(f"Neighbor matrix shape: {neighbor_matrix.shape}") print(f"Average neighbors per atom: {float(jnp.mean(num_neighbors_per_atom)):.1f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Neighbor matrix shape: (36, 64) Average neighbors per atom: 35.0 .. GENERATED FROM PYTHON SOURCE LINES 182-195 Calculate Dispersion Energy and Forces --------------------------------------- Now we can compute the DFT-D3 dispersion correction. The function returns: - energy: total dispersion energy [num_systems] in Hartree - forces: atomic forces [num_atoms, 3] in Hartree/Bohr - coord_num: coordination numbers [num_atoms] (dimensionless) We use PBE0 functional parameters: - s6 = 1.0 (C6 term coefficient, standard for D3-BJ) - s8 = 1.2177 (C8 term coefficient, PBE0-specific) - a1 = 0.4145 (BJ damping parameter, PBE0-specific) - a2 = 4.8593 (BJ damping radius, PBE0-specific) .. GENERATED FROM PYTHON SOURCE LINES 195-208 .. code-block:: Python energy, forces, coord_num = dftd3( positions=positions_bohr, numbers=numbers, a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0, d3_params=d3_params, neighbor_matrix=neighbor_matrix, fill_value=num_atoms, ) .. GENERATED FROM PYTHON SOURCE LINES 209-214 Results ------- Convert outputs to conventional units for display: - Energy: Hartree -> eV - Forces: Hartree/Bohr -> eV/Angstrom .. GENERATED FROM PYTHON SOURCE LINES 214-227 .. code-block:: Python # Convert energy to eV energy_ev = float(energy[0]) * HARTREE_TO_EV # Convert forces to eV/Angstrom forces_ev_angstrom = forces * (HARTREE_TO_EV / BOHR_TO_ANGSTROM) max_force = float(jnp.max(jnp.linalg.norm(forces_ev_angstrom, axis=1))) print(f"\nDispersion Energy: {energy_ev:.6f} eV") print(f"Energy per atom: {energy_ev / num_atoms:.6f} eV") print(f"Maximum force magnitude: {max_force:.6f} eV/Angstrom") print(f"\nCoordination numbers: {np.array(coord_num)}") .. rst-class:: sphx-glr-script-out .. code-block:: none Dispersion Energy: -1.340935 eV Energy per atom: -0.037248 eV Maximum force magnitude: 0.047128 eV/Angstrom Coordination numbers: [3.2349308 3.2337396 3.2349308 3.2771623 3.1036298 3.2771626 3.0788736 3.1036294 3.0788739 3.0804553 1.0007788 1.0004165 1.0007788 1.0047569 1.004581 1.0047569 1.0045811 1.0048931 3.2349308 3.2337396 3.2349308 3.2771623 3.1036296 3.2771626 3.0788736 3.1036294 3.0788739 3.0804553 1.0007789 1.0004165 1.0007789 1.004757 1.004581 1.004757 1.004581 1.0048932] .. GENERATED FROM PYTHON SOURCE LINES 228-242 JIT Compilation --------------- Demonstrate combining the neighbor list build and DFT-D3 calculation into a single ``jax.jit``-compiled function. This fuses the entire pipeline into one optimized computation. For JIT compatibility: - ``max_neighbors`` must be specified (static array shapes) - Functional parameters (``a1``, ``a2``, ``s8``, etc.) must be **static literals** inside the jitted function (required by Warp FFI kernels) - ``D3Parameters`` should be constructed inside the jitted function from traced arrays - ``fill_value`` and ``num_systems`` should be static .. GENERATED FROM PYTHON SOURCE LINES 242-284 .. code-block:: Python print("\n" + "=" * 70) print("JIT COMPILATION") print("=" * 70) # Define a jitted function that builds neighbors and computes DFT-D3 @jax.jit def compute_d3_energy_forces( positions: jax.Array, numbers: jax.Array, rcov: jax.Array, r4r2: jax.Array, c6ab: jax.Array, cn_ref: jax.Array, cutoff: float = cutoff_bohr, max_neighbors: int = 64, fill_value: int = num_atoms, ) -> tuple[jax.Array, jax.Array, jax.Array]: """JIT-compiled neighbor list + DFT-D3 pipeline.""" # Build neighbor matrix inside jit (max_neighbors must be static) nbmat, _ = naive_neighbor_list(positions, cutoff, max_neighbors=max_neighbors) # Construct D3Parameters inside jit from traced arrays params = D3Parameters(rcov=rcov, r4r2=r4r2, c6ab=c6ab, cn_ref=cn_ref) # Compute DFT-D3 with PBE0 parameters as static literals energy, forces, coord_num = dftd3( positions=positions, numbers=numbers, a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0, d3_params=params, neighbor_matrix=nbmat, fill_value=fill_value, ) return energy, forces, coord_num .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== JIT COMPILATION ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 285-286 Run the jitted function: .. GENERATED FROM PYTHON SOURCE LINES 286-322 .. code-block:: Python print("\nCompiling and running jitted DFT-D3 pipeline...") jit_energy, jit_forces, jit_cn = compute_d3_energy_forces( positions_bohr, numbers, d3_params.rcov, d3_params.r4r2, d3_params.c6ab, d3_params.cn_ref, ) jit_energy_ev = float(jit_energy[0]) * HARTREE_TO_EV jit_forces_ev = jit_forces * (HARTREE_TO_EV / BOHR_TO_ANGSTROM) jit_max_force = float(jnp.max(jnp.linalg.norm(jit_forces_ev, axis=1))) print(f" JIT dispersion energy: {jit_energy_ev:.6f} eV") print(f" JIT max force: {jit_max_force:.6f} eV/Angstrom") # Compare with non-jitted result energy_diff = abs(jit_energy_ev - energy_ev) print(f" Energy difference vs non-jitted: {energy_diff:.2e} eV") # Second call should be fast (already compiled) print("\nRunning jitted function again (should reuse compiled code)...") jit_energy_2, jit_forces_2, _ = compute_d3_energy_forces( positions_bohr, numbers, d3_params.rcov, d3_params.r4r2, d3_params.c6ab, d3_params.cn_ref, ) print( f" JIT dispersion energy (2nd call): {float(jit_energy_2[0]) * HARTREE_TO_EV:.6f} eV" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Compiling and running jitted DFT-D3 pipeline... JIT dispersion energy: -1.340935 eV JIT max force: 0.047128 eV/Angstrom Energy difference vs non-jitted: 0.00e+00 eV Running jitted function again (should reuse compiled code)... JIT dispersion energy (2nd call): -1.340935 eV .. GENERATED FROM PYTHON SOURCE LINES 323-342 Summary ------- This example demonstrated: 1. **Parameter loading** from cached DFT-D3 reference data (Grimme group) 2. **Molecular structure** loading from XYZ files into JAX arrays 3. **Neighbor list** construction for non-periodic systems 4. **DFT-D3 energy and forces** with PBE0 functional parameters 5. **Unit conversions** between atomic (Bohr/Hartree) and conventional (Angstrom/eV) units 6. **JIT compilation** of the full neighbor list + DFT-D3 pipeline Key JAX-specific patterns: - Load PyTorch parameters and convert to JAX arrays via ``jnp.array`` - Construct ``D3Parameters`` from JAX arrays - For ``jax.jit``: use static literals for functional parameters (``a1``, ``a2``, ``s8``), specify ``max_neighbors`` for static shapes, and construct ``D3Parameters`` inside the jitted function .. GENERATED FROM PYTHON SOURCE LINES 342-353 .. code-block:: Python print("\n" + "=" * 70) print("SUMMARY") print("=" * 70) print("\nKey takeaways:") print(" - DFT-D3 works in atomic units (Bohr, Hartree) internally") print(" - Convert Angstrom -> Bohr for positions, Hartree -> eV for energy") print(" - D3Parameters holds element-specific reference data") print(" - PBE0 parameters: a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0") print(" - Use jax.jit to fuse neighbor list + DFT-D3 into one compiled function") print("\nJAX DFT-D3 example completed successfully!") .. rst-class:: sphx-glr-script-out .. code-block:: none ====================================================================== SUMMARY ====================================================================== Key takeaways: - DFT-D3 works in atomic units (Bohr, Hartree) internally - Convert Angstrom -> Bohr for positions, Hartree -> eV for energy - D3Parameters holds element-specific reference data - PBE0 parameters: a1=0.4145, a2=4.8593, s8=1.2177, s6=1.0 - Use jax.jit to fuse neighbor list + DFT-D3 into one compiled function JAX DFT-D3 example completed successfully! .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.164 seconds) .. _sphx_glr_download_examples_dispersion_03_jax_dftd3_molecule.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 03_jax_dftd3_molecule.ipynb <03_jax_dftd3_molecule.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 03_jax_dftd3_molecule.py <03_jax_dftd3_molecule.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 03_jax_dftd3_molecule.zip <03_jax_dftd3_molecule.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_