nvalchemiops.jax.interactions.dispersion: Dispersion Corrections#
The dispersion module provides JAX bindings for the GPU-accelerated implementations of dispersion interactions.
JAX dispersion interactions API.
This module provides JAX bindings for dispersion corrections (DFT-D3).
Tip
For the underlying framework-agnostic Warp kernels, see nvalchemiops.interactions.dispersion: Dispersion Corrections.
High-Level Interface#
DFT-D3(BJ) Dispersion Corrections#
The DFT-D3 implementation supports two neighbor representation formats:
Neighbor matrix (dense):
[num_atoms, max_neighbors]with paddingNeighbor list (sparse CSR): Compressed sparse row format with
idx_jandneighbor_ptr
Both formats produce identical results and support all features including periodic
boundary conditions, batching, and smooth cutoff functions. The high-level wrapper
automatically dispatches to the appropriate kernels based on which format is provided.
The method should be jax.jit compatible.
- nvalchemiops.jax.interactions.dispersion.dftd3(positions, numbers, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, fill_value=None, d3_params=None, covalent_radii=None, r4r2=None, c6_reference=None, coord_num_ref=None, batch_idx=None, cell=None, neighbor_matrix=None, neighbor_matrix_shifts=None, neighbor_list=None, neighbor_ptr=None, unit_shifts=None, compute_virial=False, num_systems=None)[source]#
Compute DFT-D3(BJ) dispersion energy and forces using Warp with JAX arrays.
DFT-D3 parameters must be explicitly provided using one of three methods:
D3Parameters dataclass: Supply a
D3Parametersinstance (recommended). Individual parameters can override dataclass values if both are provided.Explicit parameters: Supply all four parameters individually:
covalent_radii,r4r2,c6_reference, andcoord_num_ref.Dictionary: Provide a
d3_paramsdictionary with keys:"rcov","r4r2","c6ab", and"cn_ref". Individual parameters can override dictionary values if both are provided.
See
examples/dispersion/utils.pyfor parameter generation utilities.- Parameters:
positions (jax.Array) – Atomic coordinates [num_atoms, 3] as float32 or float64, in consistent distance units (conventionally Bohr when using standard D3 parameters)
numbers (jax.Array) – Atomic numbers [num_atoms] as int32
a1 (float) – Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)
a2 (float) – Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions
s8 (float) – C8 term scaling factor (functional-dependent, dimensionless)
k1 (float, optional) – CN counting function steepness parameter, in inverse distance units (typically 16.0 1/Bohr for atomic units). Default: 16.0
k3 (float, optional) – CN interpolation Gaussian width parameter (typically -4.0, dimensionless). Default: -4.0
s6 (float, optional) – C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0
s5_smoothing_on (float, optional) – Distance where S5 switching begins, in same units as positions. Default: 1e10
s5_smoothing_off (float, optional) – Distance where S5 switching completes, in same units as positions. Default: 1e10 (effectively no cutoff)
fill_value (int | None, optional) – Value indicating padding in neighbor_matrix. If None, defaults to num_atoms. Default: None
d3_params (D3Parameters | dict[str, jax.Array] | None, optional) – DFT-D3 parameters provided as either: -
D3Parametersdataclass instance (recommended) - Dictionary with keys: “rcov”, “r4r2”, “c6ab”, “cn_ref” Individual parameters below can override values from d3_params.covalent_radii (jax.Array | None, optional) – Covalent radii [max_Z+1] as float32, indexed by atomic number, in same units as positions. If provided, overrides the value in d3_params.
r4r2 (jax.Array | None, optional) – <r4>/<r2> expectation values [max_Z+1] as float32 for C8 computation (dimensionless). If provided, overrides the value in d3_params.
c6_reference (jax.Array | None, optional) – C6 reference values [max_Z+1, max_Z+1, 5, 5] as float32 in energy × distance^6 units. If provided, overrides the value in d3_params.
coord_num_ref (jax.Array | None, optional) – CN reference grid [max_Z+1, max_Z+1, 5, 5] as float32 (dimensionless). If provided, overrides the value in d3_params.
batch_idx (jax.Array or None, optional) – Batch indices [num_atoms] as int32. If None, all atoms are assumed to be in a single system (batch 0). Default: None
cell (jax.Array or None, optional) – Unit cell lattice vectors [num_systems, 3, 3] for PBC, in same dtype and units as positions. Convention: cell[s, i, :] is i-th lattice vector for system s. If None, non-periodic calculation. Default: None
neighbor_matrix (jax.Array | None, optional) – Neighbor indices [num_atoms, max_neighbors] as int32. Each row i contains indices of atom i’s neighbors, padded with
fill_valuefor unused slots. Mutually exclusive withneighbor_list. Default: Noneneighbor_matrix_shifts (jax.Array or None, optional) – Integer unit cell shifts [num_atoms, max_neighbors, 3] as int32 for PBC with neighbor_matrix format. If None, non-periodic calculation. Mutually exclusive with unit_shifts. Default: None
neighbor_list (jax.Array or None, optional) – Neighbor pairs [2, num_pairs] as int32 in COO format, where row 0 contains source atom indices and row 1 contains target atom indices. Alternative to neighbor_matrix for sparse neighbor representations. Mutually exclusive with neighbor_matrix. Must be used together with neighbor_ptr. Default: None
neighbor_ptr (jax.Array or None, optional) – CSR row pointers [num_atoms+1] as int32. Required when using neighbor_list. Indicates that neighbor_list[1, :] contains destination atoms in CSR format. Default: None
unit_shifts (jax.Array or None, optional) – Integer unit cell shifts [num_pairs, 3] as int32 for PBC with neighbor_list format. If None, non-periodic calculation. Mutually exclusive with neighbor_matrix_shifts. Default: None
compute_virial (bool, optional) – If True, compute and return virial tensor. Default: False
num_systems (int, optional) – Number of systems in batch. If None, inferred from
cellor frombatch_idx(introduces device synchronization overhead). Default: None
- Returns:
energy (jax.Array) – Total dispersion energy [num_systems] as float32. Units are energy (Hartree when using standard D3 parameters).
forces (jax.Array) – Atomic forces [num_atoms, 3] as float32. Units are energy/distance (Hartree/Bohr when using standard D3 parameters).
coord_num (jax.Array) – Coordination numbers [num_atoms] as float32 (dimensionless)
virial (jax.Array, optional) – Virial tensor [num_systems, 3, 3] as float32. Only returned if compute_virial=True.
- Return type:
tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]
Notes
Unit consistency: All inputs must use consistent units. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree for energy).
Float32 or float64 precision for positions and cell; outputs always float32
Neighbor formats: Supports both neighbor_matrix (dense) and neighbor_list (sparse) formats. Choose neighbor_list for sparse systems or when memory efficiency is important.
Padding atoms indicated by numbers[i] == 0
Requires symmetric neighbor representation (each pair appears twice)
Two-body only: Computes pairwise C6 and C8 dispersion terms; three-body Axilrod-Teller-Muto (ATM/C9) terms are not included
Virial computation requires periodic boundary conditions.
- Raises:
ValueError – If neighbor format is invalid or PBC requirements are not met
RuntimeError – If DFT-D3 parameters are not provided
- Parameters:
positions (Array)
numbers (Array)
a1 (float)
a2 (float)
s8 (float)
k1 (float)
k3 (float)
s6 (float)
s5_smoothing_on (float)
s5_smoothing_off (float)
fill_value (int | None)
d3_params (D3Parameters | dict[str, Array] | None)
covalent_radii (Array | None)
r4r2 (Array | None)
c6_reference (Array | None)
coord_num_ref (Array | None)
batch_idx (Array | None)
cell (Array | None)
neighbor_matrix (Array | None)
neighbor_matrix_shifts (Array | None)
neighbor_list (Array | None)
neighbor_ptr (Array | None)
unit_shifts (Array | None)
compute_virial (bool)
num_systems (int | None)
- Return type:
tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]
Examples
Using neighbor matrix format:
>>> energy, forces, coord_num = dftd3( ... positions, numbers, ... neighbor_matrix=neighbor_matrix, ... a1=0.3981, a2=4.4211, s8=1.9889, ... d3_params=params, ... )
Using neighbor list format with PBC:
>>> energy, forces, coord_num, virial = dftd3( ... positions, numbers, ... neighbor_list=neighbor_list, ... neighbor_ptr=neighbor_ptr, ... a1=0.3981, a2=4.4211, s8=1.9889, ... d3_params=params, ... cell=cell, ... unit_shifts=unit_shifts, ... compute_virial=True, ... )
Data Structures#
This data structure is not necessarily required to use the kernels, however is provided
for convenience—the dataclass will validate shapes and keys for parameters
required by the kernels.
- class nvalchemiops.jax.interactions.dispersion.D3Parameters(rcov, r4r2, c6ab, cn_ref, interp_mesh=5)[source]#
DFT-D3 reference parameters for dispersion correction calculations.
This dataclass encapsulates all element-specific parameters required for DFT-D3 dispersion corrections. The main purpose for this structure is to provide validation, ensuring the correct shapes, dtypes, and keys are present and complete. These parameters are used by
dftd3().- Parameters:
rcov (jax.Array) – Covalent radii [max_Z+1] as float32. Units should be consistent with position coordinates. Index 0 is reserved for padding; valid atomic numbers are 1 to max_Z.
r4r2 (jax.Array) – <r⁴>/<r²> expectation values [max_Z+1] as float32. Dimensionless ratio used for computing C8 coefficients from C6 values.
c6ab (jax.Array) – C6 reference coefficients [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32. Units are energy x distance^6. Indexed by atomic numbers and coordination number reference indices.
cn_ref (jax.Array) – Coordination number reference grid [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32. Dimensionless CN values for Gaussian interpolation.
interp_mesh (int, optional) – Size of the coordination number interpolation mesh. Default: 5 (standard DFT-D3 uses a 5x5 grid)
- Raises:
ValueError – If parameter shapes are inconsistent or invalid
TypeError – If parameters are not jax.Array or have invalid dtypes
Notes
Parameters should use consistent units matching your coordinate system. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree x Bohr^6 for C6 coefficients).
Index 0 in all arrays is reserved for padding atoms (atomic number 0)
Valid atomic numbers range from 1 to max_z
The standard DFT-D3 implementation supports elements 1-94 (H to Pu)
Parameters should be float32 for efficiency
Examples
Create parameters from individual arrays:
>>> params = D3Parameters( ... rcov=jnp.array([...]), ... r4r2=jnp.array([...]), ... c6ab=jnp.array([...]), ... cn_ref=jnp.array([...]), ... )
Internal Implementation#
These are low-level implementation functions that wrap the Warp kernels for JAX.
For most use cases, prefer the high-level dftd3() wrapper above.
Neighbor Matrix Implementation#
- nvalchemiops.jax.interactions.dispersion._dftd3._dftd3_nm_impl(positions, numbers, neighbor_matrix, covalent_radii, r4r2, c6_reference, coord_num_ref, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, fill_value=None, batch_idx=None, cell=None, neighbor_matrix_shifts=None, compute_virial=False, num_systems=None)[source]#
Internal implementation for neighbor matrix format using jax_kernel wrappers.
- Parameters:
positions (Array)
numbers (Array)
neighbor_matrix (Array)
covalent_radii (Array)
r4r2 (Array)
c6_reference (Array)
coord_num_ref (Array)
a1 (float)
a2 (float)
s8 (float)
k1 (float)
k3 (float)
s6 (float)
s5_smoothing_on (float)
s5_smoothing_off (float)
fill_value (int | None)
batch_idx (Array | None)
cell (Array | None)
neighbor_matrix_shifts (Array | None)
compute_virial (bool)
num_systems (int | None)
- Return type:
tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]
Neighbor List Implementation#
- nvalchemiops.jax.interactions.dispersion._dftd3._dftd3_nl_impl(positions, numbers, idx_j, neighbor_ptr, covalent_radii, r4r2, c6_reference, coord_num_ref, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, batch_idx=None, cell=None, unit_shifts=None, compute_virial=False, num_systems=None)[source]#
Internal implementation for neighbor list format using jax_kernel wrappers.
- Parameters:
positions (Array)
numbers (Array)
idx_j (Array)
neighbor_ptr (Array)
covalent_radii (Array)
r4r2 (Array)
c6_reference (Array)
coord_num_ref (Array)
a1 (float)
a2 (float)
s8 (float)
k1 (float)
k3 (float)
s6 (float)
s5_smoothing_on (float)
s5_smoothing_off (float)
batch_idx (Array | None)
cell (Array | None)
unit_shifts (Array | None)
compute_virial (bool)
num_systems (int | None)
- Return type:
tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]