nvalchemiops.jax.interactions.dispersion: Dispersion Corrections#

The dispersion module provides JAX bindings for the GPU-accelerated implementations of dispersion interactions.

JAX dispersion interactions API.

This module provides JAX bindings for dispersion corrections (DFT-D3).

Tip

For the underlying framework-agnostic Warp kernels, see nvalchemiops.interactions.dispersion: Dispersion Corrections.

High-Level Interface#

DFT-D3(BJ) Dispersion Corrections#

The DFT-D3 implementation supports two neighbor representation formats:

  • Neighbor matrix (dense): [num_atoms, max_neighbors] with padding

  • Neighbor list (sparse CSR): Compressed sparse row format with idx_j and neighbor_ptr

Both formats produce identical results and support all features including periodic boundary conditions, batching, and smooth cutoff functions. The high-level wrapper automatically dispatches to the appropriate kernels based on which format is provided. The method should be jax.jit compatible.

nvalchemiops.jax.interactions.dispersion.dftd3(positions, numbers, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, fill_value=None, d3_params=None, covalent_radii=None, r4r2=None, c6_reference=None, coord_num_ref=None, batch_idx=None, cell=None, neighbor_matrix=None, neighbor_matrix_shifts=None, neighbor_list=None, neighbor_ptr=None, unit_shifts=None, compute_virial=False, num_systems=None)[source]#

Compute DFT-D3(BJ) dispersion energy and forces using Warp with JAX arrays.

DFT-D3 parameters must be explicitly provided using one of three methods:

  1. D3Parameters dataclass: Supply a D3Parameters instance (recommended). Individual parameters can override dataclass values if both are provided.

  2. Explicit parameters: Supply all four parameters individually: covalent_radii, r4r2, c6_reference, and coord_num_ref.

  3. Dictionary: Provide a d3_params dictionary with keys: "rcov", "r4r2", "c6ab", and "cn_ref". Individual parameters can override dictionary values if both are provided.

See examples/dispersion/utils.py for parameter generation utilities.

Parameters:
  • positions (jax.Array) – Atomic coordinates [num_atoms, 3] as float32 or float64, in consistent distance units (conventionally Bohr when using standard D3 parameters)

  • numbers (jax.Array) – Atomic numbers [num_atoms] as int32

  • a1 (float) – Becke-Johnson damping parameter 1 (functional-dependent, dimensionless)

  • a2 (float) – Becke-Johnson damping parameter 2 (functional-dependent), in same units as positions

  • s8 (float) – C8 term scaling factor (functional-dependent, dimensionless)

  • k1 (float, optional) – CN counting function steepness parameter, in inverse distance units (typically 16.0 1/Bohr for atomic units). Default: 16.0

  • k3 (float, optional) – CN interpolation Gaussian width parameter (typically -4.0, dimensionless). Default: -4.0

  • s6 (float, optional) – C6 term scaling factor (typically 1.0, dimensionless). Default: 1.0

  • s5_smoothing_on (float, optional) – Distance where S5 switching begins, in same units as positions. Default: 1e10

  • s5_smoothing_off (float, optional) – Distance where S5 switching completes, in same units as positions. Default: 1e10 (effectively no cutoff)

  • fill_value (int | None, optional) – Value indicating padding in neighbor_matrix. If None, defaults to num_atoms. Default: None

  • d3_params (D3Parameters | dict[str, jax.Array] | None, optional) – DFT-D3 parameters provided as either: - D3Parameters dataclass instance (recommended) - Dictionary with keys: “rcov”, “r4r2”, “c6ab”, “cn_ref” Individual parameters below can override values from d3_params.

  • covalent_radii (jax.Array | None, optional) – Covalent radii [max_Z+1] as float32, indexed by atomic number, in same units as positions. If provided, overrides the value in d3_params.

  • r4r2 (jax.Array | None, optional) – <r4>/<r2> expectation values [max_Z+1] as float32 for C8 computation (dimensionless). If provided, overrides the value in d3_params.

  • c6_reference (jax.Array | None, optional) – C6 reference values [max_Z+1, max_Z+1, 5, 5] as float32 in energy × distance^6 units. If provided, overrides the value in d3_params.

  • coord_num_ref (jax.Array | None, optional) – CN reference grid [max_Z+1, max_Z+1, 5, 5] as float32 (dimensionless). If provided, overrides the value in d3_params.

  • batch_idx (jax.Array or None, optional) – Batch indices [num_atoms] as int32. If None, all atoms are assumed to be in a single system (batch 0). Default: None

  • cell (jax.Array or None, optional) – Unit cell lattice vectors [num_systems, 3, 3] for PBC, in same dtype and units as positions. Convention: cell[s, i, :] is i-th lattice vector for system s. If None, non-periodic calculation. Default: None

  • neighbor_matrix (jax.Array | None, optional) – Neighbor indices [num_atoms, max_neighbors] as int32. Each row i contains indices of atom i’s neighbors, padded with fill_value for unused slots. Mutually exclusive with neighbor_list. Default: None

  • neighbor_matrix_shifts (jax.Array or None, optional) – Integer unit cell shifts [num_atoms, max_neighbors, 3] as int32 for PBC with neighbor_matrix format. If None, non-periodic calculation. Mutually exclusive with unit_shifts. Default: None

  • neighbor_list (jax.Array or None, optional) – Neighbor pairs [2, num_pairs] as int32 in COO format, where row 0 contains source atom indices and row 1 contains target atom indices. Alternative to neighbor_matrix for sparse neighbor representations. Mutually exclusive with neighbor_matrix. Must be used together with neighbor_ptr. Default: None

  • neighbor_ptr (jax.Array or None, optional) – CSR row pointers [num_atoms+1] as int32. Required when using neighbor_list. Indicates that neighbor_list[1, :] contains destination atoms in CSR format. Default: None

  • unit_shifts (jax.Array or None, optional) – Integer unit cell shifts [num_pairs, 3] as int32 for PBC with neighbor_list format. If None, non-periodic calculation. Mutually exclusive with neighbor_matrix_shifts. Default: None

  • compute_virial (bool, optional) – If True, compute and return virial tensor. Default: False

  • num_systems (int, optional) – Number of systems in batch. If None, inferred from cell or from batch_idx (introduces device synchronization overhead). Default: None

Returns:

  • energy (jax.Array) – Total dispersion energy [num_systems] as float32. Units are energy (Hartree when using standard D3 parameters).

  • forces (jax.Array) – Atomic forces [num_atoms, 3] as float32. Units are energy/distance (Hartree/Bohr when using standard D3 parameters).

  • coord_num (jax.Array) – Coordination numbers [num_atoms] as float32 (dimensionless)

  • virial (jax.Array, optional) – Virial tensor [num_systems, 3, 3] as float32. Only returned if compute_virial=True.

Return type:

tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]

Notes

  • Unit consistency: All inputs must use consistent units. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree for energy).

  • Float32 or float64 precision for positions and cell; outputs always float32

  • Neighbor formats: Supports both neighbor_matrix (dense) and neighbor_list (sparse) formats. Choose neighbor_list for sparse systems or when memory efficiency is important.

  • Padding atoms indicated by numbers[i] == 0

  • Requires symmetric neighbor representation (each pair appears twice)

  • Two-body only: Computes pairwise C6 and C8 dispersion terms; three-body Axilrod-Teller-Muto (ATM/C9) terms are not included

  • Virial computation requires periodic boundary conditions.

Raises:
  • ValueError – If neighbor format is invalid or PBC requirements are not met

  • RuntimeError – If DFT-D3 parameters are not provided

Parameters:
Return type:

tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]

Examples

Using neighbor matrix format:

>>> energy, forces, coord_num = dftd3(
...     positions, numbers,
...     neighbor_matrix=neighbor_matrix,
...     a1=0.3981, a2=4.4211, s8=1.9889,
...     d3_params=params,
... )

Using neighbor list format with PBC:

>>> energy, forces, coord_num, virial = dftd3(
...     positions, numbers,
...     neighbor_list=neighbor_list,
...     neighbor_ptr=neighbor_ptr,
...     a1=0.3981, a2=4.4211, s8=1.9889,
...     d3_params=params,
...     cell=cell,
...     unit_shifts=unit_shifts,
...     compute_virial=True,
... )

Data Structures#

This data structure is not necessarily required to use the kernels, however is provided for convenience—the dataclass will validate shapes and keys for parameters required by the kernels.

class nvalchemiops.jax.interactions.dispersion.D3Parameters(rcov, r4r2, c6ab, cn_ref, interp_mesh=5)[source]#

DFT-D3 reference parameters for dispersion correction calculations.

This dataclass encapsulates all element-specific parameters required for DFT-D3 dispersion corrections. The main purpose for this structure is to provide validation, ensuring the correct shapes, dtypes, and keys are present and complete. These parameters are used by dftd3().

Parameters:
  • rcov (jax.Array) – Covalent radii [max_Z+1] as float32. Units should be consistent with position coordinates. Index 0 is reserved for padding; valid atomic numbers are 1 to max_Z.

  • r4r2 (jax.Array) – <r⁴>/<r²> expectation values [max_Z+1] as float32. Dimensionless ratio used for computing C8 coefficients from C6 values.

  • c6ab (jax.Array) – C6 reference coefficients [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32. Units are energy x distance^6. Indexed by atomic numbers and coordination number reference indices.

  • cn_ref (jax.Array) – Coordination number reference grid [max_Z+1, max_Z+1, interp_mesh, interp_mesh] as float32. Dimensionless CN values for Gaussian interpolation.

  • interp_mesh (int, optional) – Size of the coordination number interpolation mesh. Default: 5 (standard DFT-D3 uses a 5x5 grid)

Raises:
  • ValueError – If parameter shapes are inconsistent or invalid

  • TypeError – If parameters are not jax.Array or have invalid dtypes

Notes

  • Parameters should use consistent units matching your coordinate system. Standard D3 parameters from the Grimme group use atomic units (Bohr for distances, Hartree x Bohr^6 for C6 coefficients).

  • Index 0 in all arrays is reserved for padding atoms (atomic number 0)

  • Valid atomic numbers range from 1 to max_z

  • The standard DFT-D3 implementation supports elements 1-94 (H to Pu)

  • Parameters should be float32 for efficiency

Examples

Create parameters from individual arrays:

>>> params = D3Parameters(
...     rcov=jnp.array([...]),
...     r4r2=jnp.array([...]),
...     c6ab=jnp.array([...]),
...     cn_ref=jnp.array([...]),
... )
c6ab: Array#
cn_ref: Array#
interp_mesh: int = 5#
property max_z: int#

Maximum atomic number supported by these parameters.

r4r2: Array#
rcov: Array#

Internal Implementation#

These are low-level implementation functions that wrap the Warp kernels for JAX. For most use cases, prefer the high-level dftd3() wrapper above.

Neighbor Matrix Implementation#

nvalchemiops.jax.interactions.dispersion._dftd3._dftd3_nm_impl(positions, numbers, neighbor_matrix, covalent_radii, r4r2, c6_reference, coord_num_ref, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, fill_value=None, batch_idx=None, cell=None, neighbor_matrix_shifts=None, compute_virial=False, num_systems=None)[source]#

Internal implementation for neighbor matrix format using jax_kernel wrappers.

Parameters:
Return type:

tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]

Neighbor List Implementation#

nvalchemiops.jax.interactions.dispersion._dftd3._dftd3_nl_impl(positions, numbers, idx_j, neighbor_ptr, covalent_radii, r4r2, c6_reference, coord_num_ref, a1, a2, s8, k1=16.0, k3=-4.0, s6=1.0, s5_smoothing_on=1e10, s5_smoothing_off=1e10, batch_idx=None, cell=None, unit_shifts=None, compute_virial=False, num_systems=None)[source]#

Internal implementation for neighbor list format using jax_kernel wrappers.

Parameters:
Return type:

tuple[Array, Array, Array] | tuple[Array, Array, Array, Array]