Source code for nvalchemiops.interactions.electrostatics.k_vectors

# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import math

import torch

# Mathematical constants
PI = math.pi
TWOPI = 2.0 * PI


def _generate_miller_indices(
    cell: torch.Tensor,
    k_cutoff: float | torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate Miller indices for Ewald summation.

    Parameters
    ----------
    cell : torch.Tensor, shape (N, 3, 3)
        Unit cell matrices with lattice vectors as rows.
    k_cutoff : float | torch.Tensor
        Maximum magnitude of k-vectors to include in reciprocal summation.

    Notes
    -----
    if cell represents a single system, return max_h, max_k, max_l
    computed by taking the maximum reciprocal cell_lengths over the entire batch of systems.
    """
    cell_lengths = (torch.norm(cell, dim=-1).max(dim=0).values) / (
        2 * torch.pi
    )  # Length of each reciprocal vector
    return torch.ceil(k_cutoff * cell_lengths).long()


[docs] def generate_k_vectors_ewald_summation( cell: torch.Tensor, k_cutoff: float | torch.Tensor, ) -> torch.Tensor: """Generate reciprocal lattice vectors for Ewald summation (half-space). Creates k-vectors within the specified cutoff for the reciprocal space summation in the Ewald method. Uses half-space optimization to reduce computational cost by approximately 2x. Half-Space Optimization ----------------------- This function generates k-vectors in the positive half-space only, exploiting the symmetry S(-k) = S*(k) where S(k) is the structure factor. For each pair of k-vectors (k, -k), only one is included. The half-space condition selects k-vectors where: - h > 0, OR - (h == 0 AND k > 0), OR - (h == 0 AND k == 0 AND l > 0) The kernels in ewald_kernels.py compensate by doubling the Green's function (using :math:`8\\pi` instead of :math:`4\\pi`), so energies, forces, and charge gradients are computed correctly. Mathematical Background ----------------------- For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are: .. math:: \\mathbf{a}^* &= \\frac{2\\pi (\\mathbf{b} \\times \\mathbf{c})}{V} \\mathbf{b}^* &= \\frac{2\\pi (\\mathbf{c} \\times \\mathbf{a})}{V} \\mathbf{c}^* &= \\frac{2\\pi (\\mathbf{a} \\times \\mathbf{b})}{V} where :math:`V = \\mathbf{a} \\cdot (\\mathbf{b} \\times \\mathbf{c})` is the cell volume. In matrix form: :math:`\\text{reciprocal_matrix} = 2\\pi \\cdot (\\text{cell}^T)^{-1}` Each k-vector is: :math:`\\mathbf{k} = h \\mathbf{a}^* + k \\mathbf{b}^* + l \\mathbf{c}^*` where (h, k, l) are Miller indices (integers). Parameters ---------- cell : torch.Tensor Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch. k_cutoff : float or torch.Tensor Maximum magnitude of k-vectors to include (:math:`|\\mathbf{k}| \\leq k_{\\text{cutoff}}`). Typical values: 8-12 :math:`\\text{\\AA}^{-1}` for molecular systems. Higher values increase accuracy but also computational cost. Returns ------- torch.Tensor Reciprocal lattice vectors within the cutoff. Shape (K, 3) for single system or (B, K, 3) for batch. Excludes k=0 and includes only half-space vectors. Examples -------- Single system with explicit k_cutoff:: >>> cell = torch.eye(3, dtype=torch.float64) * 10.0 >>> k_vectors = generate_k_vectors_ewald_summation(cell, k_cutoff=8.0) >>> k_vectors.shape torch.Size([...]) # Number depends on cell size and cutoff With automatic parameter estimation:: >>> from nvalchemiops.interactions.electrostatics import estimate_ewald_parameters >>> params = estimate_ewald_parameters(positions, cell) >>> k_vectors = generate_k_vectors_ewald_summation(cell, params.reciprocal_space_cutoff) Notes ----- - The k=0 vector is always excluded (causes division by zero in Green's function). - For batch mode, the same set of Miller indices is used for all systems but transformed using each system's reciprocal cell. - The number of k-vectors K scales as O(k_cutoff³ · V) where V is the cell volume. See Also -------- ewald_reciprocal_space : Uses these k-vectors for reciprocal space energy. estimate_ewald_parameters : Automatic parameter estimation including k_cutoff. """ if cell.ndim == 2: cell = cell.unsqueeze(0) device = cell.device dtype = cell.dtype max_h, max_k, max_l = 2 * _generate_miller_indices(cell, k_cutoff) + 1 # Generate all combinations of Miller indices h_range = torch.fft.fftfreq(max_h, device=device, dtype=dtype) * max_h k_range = torch.fft.fftfreq(max_k, device=device, dtype=dtype) * max_k l_range = torch.fft.fftfreq(max_l, device=device, dtype=dtype) * max_l h_grid, k_grid, l_grid = torch.meshgrid(h_range, k_range, l_range, indexing="ij") miller_indices = torch.stack( [h_grid.flatten(), k_grid.flatten(), l_grid.flatten()], dim=1 ) # Apply half-space filter: keep only one of each +-k pair # Condition: h > 0 OR (h == 0 AND k > 0) OR (h == 0 AND k == 0 AND l > 0) h = miller_indices[:, 0] k = miller_indices[:, 1] m = miller_indices[:, 2] # Using 'm' instead of 'l' to avoid E741 ambiguity halfspace_mask = (h > 0) | ((h == 0) & (k > 0)) | ((h == 0) & (k == 0) & (m > 0)) miller_indices = miller_indices[halfspace_mask] # Compute reciprocal lattice vectors (2π times reciprocal of direct lattice) reciprocal_cell = ( TWOPI * torch.linalg.inv_ex(cell.transpose(1, 2))[0] ) # Transpose for column vectors k_vectors = miller_indices.to(reciprocal_cell.dtype) @ reciprocal_cell return k_vectors.squeeze(0)
[docs] def generate_k_vectors_pme( cell: torch.Tensor, mesh_dimensions: tuple[int, int, int], reciprocal_cell: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Generate reciprocal lattice vectors for Particle Mesh Ewald (PME). Creates k-vectors on a regular grid compatible with FFT-based reciprocal space calculations in PME. Uses rfft conventions (half-size in z-dimension) to exploit Hermitian symmetry of real-valued charge densities. Notes ----- For a direct lattice defined by basis vectors {a, b, c} (rows of cell matrix), the reciprocal lattice vectors are: .. math:: \\begin{aligned} \\mathbf{a}^* &= \\frac{2\\pi (\\mathbf{b} \\times \\mathbf{c})}{V} \\\\ \\mathbf{b}^* &= \\frac{2\\pi (\\mathbf{c} \\times \\mathbf{a})}{V} \\\\ \\mathbf{c}^* &= \\frac{2\\pi (\\mathbf{a} \\times \\mathbf{b})}{V} \\end{aligned} where :math:`V = \\mathbf{a} \\cdot (\\mathbf{b} \\times \\mathbf{c})` is the cell volume. In matrix form: .. math:: \\text{reciprocal_matrix} = 2\\pi \\cdot (\\text{cell}^T)^{-1} Each k-vector is then: .. math:: \\mathbf{k} = h \\mathbf{a}^* + k \\mathbf{b}^* + l \\mathbf{c}^* where (h, k, l) are Miller indices (integers). Parameters ---------- cell : torch.Tensor Unit cell matrix with lattice vectors as rows. Shape (3, 3) for single system or (B, 3, 3) for batch. mesh_dimensions : tuple[int, int, int] PME mesh grid dimensions (nx, ny, nz). Should typically be chosen such that mesh spacing is :math:`\\sim 1 \\text{\\AA}` or finer. Power-of-2 dimensions are optimal for FFT performance. reciprocal_cell : torch.Tensor, optional Precomputed reciprocal cell matrix (:math:`2\\pi \\cdot \\text{cell}^{-1}`). If provided, skips the inverse computation. Shape (3, 3) or (B, 3, 3). Returns ------- k_vectors : torch.Tensor, shape (nx, ny, nz//2+1, 3) Cartesian k-vectors at each grid point. Uses rfft convention where z-dimension is halved due to Hermitian symmetry. k_squared_safe : torch.Tensor, shape (nx, ny, nz//2+1) Squared magnitude :math:`|\\mathbf{k}|^2` for each k-vector, with k=0 set to a small positive value (1e-12) to avoid division by zero. Examples -------- Basic usage:: >>> cell = torch.eye(3, dtype=torch.float64) * 10.0 >>> mesh_dims = (32, 32, 32) >>> k_vectors, k_squared = generate_k_vectors_pme(cell, mesh_dims) >>> k_vectors.shape torch.Size([32, 32, 17, 3]) With precomputed reciprocal cell:: >>> reciprocal_cell = 2 * torch.pi * torch.linalg.inv(cell) >>> k_vectors, k_squared = generate_k_vectors_pme( ... cell, mesh_dims, reciprocal_cell=reciprocal_cell ... ) Notes ----- - The z-dimension output size is nz//2+1 due to rfft symmetry. - Miller indices follow torch.fft.fftfreq convention (0, 1, 2, ..., -2, -1). - k_squared_safe has k=0 replaced with 1e-12 to prevent division by zero in Green's function calculations. See Also -------- pme_reciprocal_space : Uses these k-vectors for PME reciprocal space energy. pme_green_structure_factor : Computes Green's function using k_squared. """ device = cell.device dtype = cell.dtype # Ensure cell has batch dimension cell_3d = cell if cell.dim() == 3 else cell.unsqueeze(0) # Compute reciprocal lattice vectors (2*pi times reciprocal of direct lattice) if reciprocal_cell is None: reciprocal_cell = TWOPI * torch.linalg.inv_ex(cell_3d)[0] # Generate all combinations of Miller indices mesh_grid_x, mesh_grid_y, mesh_grid_z = mesh_dimensions # Generate Miller indices (h, k, l) for each FFT grid point # fftfreq gives frequencies normalized to sampling rate # Multiplying by n gives actual Miller indices kx = torch.fft.fftfreq(mesh_grid_x, d=1.0, device=device, dtype=dtype) * mesh_grid_x ky = torch.fft.fftfreq(mesh_grid_y, d=1.0, device=device, dtype=dtype) * mesh_grid_y kz = ( torch.fft.rfftfreq(mesh_grid_z, d=1.0, device=device, dtype=dtype) * mesh_grid_z ) kx_grid, ky_grid, kz_grid = torch.meshgrid(kx, ky, kz, indexing="ij") # Stack into Miller indices array (nx, ny, nz/2+1, 3) k_grid = torch.stack([kx_grid, ky_grid, kz_grid], dim=-1) # Transform Miller indices to Cartesian k-vectors # k_cart = [h, k, l] @ reciprocal_matrix^T # where reciprocal_matrix has reciprocal lattice vectors as rows k_vectors = torch.einsum("ijkd,bcd->bijkc", k_grid, reciprocal_cell).squeeze(0) # Compute k^2 for Green's function k_squared = torch.sum(k_vectors**2, dim=-1) # Avoid division by zero at k=0 k_squared_safe = torch.where( k_squared > 1e-12, k_squared, torch.tensor(1e-12, device=device) ) return k_vectors, k_squared_safe