# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Cell Filter Utilities for Variable-Cell Optimization.
This module provides utilities for combining atomic and cell degrees of freedom
into extended arrays, enabling standard optimizers (FIRE, BFGS, etc.) to perform
variable-cell optimization without modification.
The approach follows the "filter" pattern:
- Atomic positions (3N DOFs) + cell parameters (6 DOFs) → extended positions (3N + 6)
- Atomic forces + stress tensor → extended forces (3N + 6)
- Standard optimizer operates on extended arrays
- Results are unpacked back to atomic positions and cell
Key features:
- Cell alignment to upper-triangular form for stability
- 6-DOF cell representation (upper-triangular: a, b*cos(γ), b*sin(γ), c1, c2, c3)
- Stress-to-cell-force conversion with proper volume scaling
- batch_idx/atom_ptr extension for batched systems
Usage workflow:
1. align_cell() - One-time preprocessing to put cell in standard form
2. extend_batch_idx() or extend_atom_ptr() - Update batching arrays for extended DOFs
3. pack_*() - Combine atomic + cell DOFs into extended arrays
4. Run optimizer step on extended arrays
5. unpack_*() - Extract atomic positions and cell from extended arrays
6. Compute forces/stress with your calculator
7. pack_forces_with_cell() - Combine forces and stress for next step
"""
from __future__ import annotations
from typing import Any
import warp as wp
from nvalchemiops.batch_utils import atom_ptr_to_batch_idx
__all__ = [
# Cell alignment
"align_cell",
# Batch index extension
"extend_batch_idx",
"extend_atom_ptr",
# Pack utilities
"pack_positions_with_cell",
"pack_velocities_with_cell",
"pack_forces_with_cell",
"pack_masses_with_cell",
# Unpack utilities
"unpack_positions_with_cell",
"unpack_velocities_with_cell",
# Stress conversion
"stress_to_cell_force",
]
# ==============================================================================
# Cell Alignment Kernel
# ==============================================================================
@wp.kernel
def _align_cell_kernel(
cell: wp.array(dtype=Any),
transform: wp.array(dtype=Any),
):
r"""Align cell to upper-triangular (right-handed) form.
Transforms the cell matrix to the standard upper-triangular form:
.. math::
\mathbf{H} = \begin{pmatrix}
a & 0 & 0 \\
b\cos\gamma & b\sin\gamma & 0 \\
c_1 & c_2 & c_3
\end{pmatrix}
where a, b, c are lattice vector lengths and γ is the angle between a and b.
This representation:
- Reduces rotational ambiguity (improves optimization stability)
- Has 6 independent parameters instead of 9
- Is the standard form expected by many MD codes
The transformation matrix is computed such that:
new_positions = old_positions @ transform
Parameters
----------
cell : wp.array, shape (B,), dtype=wp.mat33*
Cell matrices (in-place, will be overwritten with aligned cells).
transform : wp.array, shape (B,), dtype=wp.mat33*
Output transformation matrices for position update.
Should be initialized to identity matrices.
Launch Grid
-----------
dim = [num_systems]
Notes
-----
- Adapted from alchemistudio2 implementation.
- Handles negative volume cells by flipping sign.
- After this kernel, positions should be updated: pos = pos @ transform
"""
tid = wp.tid()
if tid >= cell.shape[0]:
return
_cell = cell[tid]
vol = wp.determinant(_cell)
# Handle zero volume (degenerate cell)
if vol == type(_cell[0, 0])(0.0):
return
# Ensure right-handed cell
if vol < type(_cell[0, 0])(0.0):
_cell = type(_cell[0, 0])(-1.0) * _cell
_one = type(_cell[0, 0])(1.0)
_zero = type(_cell[0, 0])(0.0)
# Compute lattice parameters
a = wp.length(_cell[0])
b = wp.length(_cell[1])
c = wp.length(_cell[2])
# Compute angles (cosines)
cos_alpha = wp.dot(_cell[1], _cell[2]) / (b * c) # angle between b and c
cos_beta = wp.dot(_cell[0], _cell[2]) / (a * c) # angle between a and c
cos_gamma = wp.dot(_cell[0], _cell[1]) / (a * b) # angle between a and b
sin_gamma = wp.sqrt(wp.max(_zero, _one - cos_gamma * cos_gamma))
# Compute c vector components in aligned frame
c1 = c * cos_beta
c2 = (c * (cos_alpha - cos_beta * cos_gamma)) / sin_gamma
c3 = wp.sqrt(wp.max(_zero, c * c - c1 * c1 - c2 * c2))
# Construct aligned cell (upper triangular)
cell_r = type(_cell)(
a,
_zero,
_zero,
b * cos_gamma,
b * sin_gamma,
_zero,
c1,
c2,
c3,
)
# Compute transformation matrix: r = cell_r_inv @ original_cell
cell_r_inv = wp.inverse(cell_r)
r = cell_r_inv * _cell
# Store results
cell[tid] = cell_r
transform[tid] = r
@wp.kernel
def _apply_transform_single_kernel(
positions: wp.array(dtype=Any),
transform: wp.array(dtype=Any),
):
"""Apply transformation matrix to positions for single-system cell alignment.
Computes: positions[i] = transform @ positions[i]
This is called after _align_cell_kernel to rotate atomic positions so they
maintain their fractional coordinates in the new aligned cell frame.
Parameters
----------
positions : wp.array, shape (N,), dtype=vec3f or vec3d
Atomic positions. Modified in-place.
transform : wp.array, shape (1,), dtype=mat33f or mat33d
Transformation matrix from _align_cell_kernel.
Launch Grid
-----------
dim = num_atoms
"""
idx = wp.tid()
r = positions[idx]
T = transform[0]
positions[idx] = wp.mul(T, r)
@wp.kernel
def _apply_transform_kernel(
positions: wp.array(dtype=Any),
transform: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
):
"""Apply transformation matrices to positions for batched cell alignment.
Computes: positions[i] = transform[sys] @ positions[i]
where sys = batch_idx[i].
This is called after _align_cell_kernel to rotate atomic positions so they
maintain their fractional coordinates in their respective aligned cell frames.
Parameters
----------
positions : wp.array, shape (total_atoms,), dtype=vec3f or vec3d
Concatenated atomic positions. Modified in-place.
transform : wp.array, shape (num_systems,), dtype=mat33f or mat33d
Per-system transformation matrices from _align_cell_kernel.
batch_idx : wp.array, shape (total_atoms,), dtype=int32
System index for each atom.
Launch Grid
-----------
dim = total_atoms
"""
idx = wp.tid()
sys = batch_idx[idx]
r = positions[idx]
T = transform[sys]
positions[idx] = wp.mul(T, r)
# ==============================================================================
# Pack/Unpack Kernels for Extended Arrays
# ==============================================================================
@wp.kernel
def _pack_positions_kernel(
positions: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
num_atoms: wp.int32,
):
"""Pack atomic positions and cell into extended position array (single system).
Combines N atomic positions with 6 cell parameters (stored as 2 vec3s) into
a single extended array of shape (N + 2,). The cell is assumed to be in
upper-triangular form from align_cell().
Cell packing format:
extended[N] = [H[0,0], H[1,0], H[2,0]] = [a, b*cos(γ), c1]
extended[N+1] = [H[1,1], H[2,1], H[2,2]] = [b*sin(γ), c2, c3]
Parameters
----------
positions : wp.array, shape (N,), dtype=vec3f or vec3d
Atomic positions.
cell : wp.array, shape (1,), dtype=mat33f or mat33d
Cell matrix (should be upper-triangular from align_cell).
extended : wp.array, shape (N + 2,), dtype=vec3f or vec3d
OUTPUT: Extended position array. Modified in-place.
num_atoms : wp.int32
Number of atoms (N).
Launch Grid
-----------
dim = num_atoms + 2
"""
idx = wp.tid()
if idx < num_atoms:
# Copy atomic positions
extended[idx] = positions[idx]
elif idx == num_atoms:
# First cell vec3: [a, b*cos(γ), c1] = [H[0,0], H[1,0], H[2,0]]
H = cell[0]
extended[idx] = type(positions[0])(H[0, 0], H[1, 0], H[2, 0])
elif idx == num_atoms + 1:
# Second cell vec3: [b*sin(γ), c2, c3] = [H[1,1], H[2,1], H[2,2]]
H = cell[0]
extended[idx] = type(positions[0])(H[1, 1], H[2, 1], H[2, 2])
else:
return
@wp.kernel
def _unpack_positions_kernel(
extended: wp.array(dtype=Any),
positions: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
num_atoms: wp.int32,
):
"""Unpack extended position array to atomic positions and cell (single system).
Extracts N atomic positions and reconstructs the upper-triangular cell matrix
from the extended array. This is the inverse of _pack_positions_kernel.
Parameters
----------
extended : wp.array, shape (N + 2,), dtype=vec3f or vec3d
Extended position array.
positions : wp.array, shape (N,), dtype=vec3f or vec3d
OUTPUT: Atomic positions. Modified in-place.
cell : wp.array, shape (1,), dtype=mat33f or mat33d
OUTPUT: Reconstructed upper-triangular cell matrix. Modified in-place.
num_atoms : wp.int32
Number of atoms (N).
Launch Grid
-----------
dim = num_atoms + 2
"""
idx = wp.tid()
if idx < num_atoms:
# Copy atomic positions
positions[idx] = extended[idx]
elif idx == num_atoms:
# Reconstruct cell from packed format
# Need both vec3s, so thread num_atoms does the full reconstruction
v1 = extended[num_atoms] # [a, b*cos(γ), c1]
v2 = extended[num_atoms + 1] # [b*sin(γ), c2, c3]
# Upper triangular cell:
# [a, 0, 0 ]
# [b*cos(γ), b*sin(γ), 0 ]
# [c1, c2, c3 ]
_zero = type(v1[0])(0.0)
cell[0] = type(cell[0])(
v1[0],
_zero,
_zero, # Row 0
v1[1],
v2[0],
_zero, # Row 1
v1[2],
v2[1],
v2[2], # Row 2
)
else:
return
@wp.kernel
def _pack_forces_kernel(
forces: wp.array(dtype=Any),
cell_force: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
num_atoms: wp.int32,
):
"""Pack atomic forces and cell force into extended force array (single system).
Combines N atomic forces with 6 cell force components (stored as 2 vec3s)
into a single extended array. Cell force is typically computed from
stress_to_cell_force().
Cell force packing format (same as positions):
extended[N] = [Fc[0,0], Fc[1,0], Fc[2,0]]
extended[N+1] = [Fc[1,1], Fc[2,1], Fc[2,2]]
Parameters
----------
forces : wp.array, shape (N,), dtype=vec3f or vec3d
Atomic forces.
cell_force : wp.array, shape (1,), dtype=mat33f or mat33d
Cell force matrix (from stress_to_cell_force).
extended : wp.array, shape (N + 2,), dtype=vec3f or vec3d
OUTPUT: Extended force array. Modified in-place.
num_atoms : wp.int32
Number of atoms (N).
Launch Grid
-----------
dim = num_atoms + 2
"""
idx = wp.tid()
if idx < num_atoms:
extended[idx] = forces[idx]
elif idx == num_atoms:
# First cell force vec3
Fc = cell_force[0]
extended[idx] = type(forces[0])(Fc[0, 0], Fc[1, 0], Fc[2, 0])
elif idx == num_atoms + 1:
# Second cell force vec3
Fc = cell_force[0]
extended[idx] = type(forces[0])(Fc[1, 1], Fc[2, 1], Fc[2, 2])
else:
return
@wp.kernel
def _pack_masses_kernel(
masses: wp.array(dtype=Any),
cell_mass: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
num_atoms: wp.int32,
):
"""Pack atomic masses and cell mass into extended mass array (single system).
Combines N atomic masses with 2 cell mass entries (for the 6 cell DOFs
represented as 2 vec3s). The cell mass controls the relative response
speed of cell parameters during optimization.
Parameters
----------
masses : wp.array, shape (N,), dtype=float32 or float64
Atomic masses.
cell_mass : wp.array, shape (1,), dtype=float32 or float64
Mass for cell DOFs (scalar, same value used for both cell vec3 entries).
extended : wp.array, shape (N + 2,), dtype=float32 or float64
OUTPUT: Extended mass array. Modified in-place.
num_atoms : wp.int32
Number of atoms (N).
Launch Grid
-----------
dim = num_atoms + 2
"""
idx = wp.tid()
if idx < num_atoms:
extended[idx] = masses[idx]
else:
# Cell DOFs get the cell mass
extended[idx] = cell_mass[0]
# ==============================================================================
# Batched Pack/Unpack Kernels (for use with atom_ptr + batch_idx)
# ==============================================================================
# Two-kernel pattern per pack/unpack operation:
# - Atom kernel (dim=N): each thread copies one atom position/force/velocity
# - Cell kernel (dim=M): each thread writes/reads 2 cell DOFs per system
@wp.kernel(enable_backward=False)
def _pack_atoms_batched_kernel(
src: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Copy one atom from src to its interleaved position in extended array.
Launch Grid: dim = N (total atoms).
"""
i = wp.tid()
s = batch_idx[i]
local_idx = i - atom_ptr[s]
extended[ext_atom_ptr[s] + local_idx] = src[i]
@wp.kernel(enable_backward=False)
def _pack_cell_dofs_kernel(
cells: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Write 2 cell DOF vec3s for one system into extended array.
Launch Grid: dim = M (num_systems).
"""
sys = wp.tid()
n_atoms_sys = atom_ptr[sys + 1] - atom_ptr[sys]
ext_start = ext_atom_ptr[sys]
H = cells[sys]
extended[ext_start + n_atoms_sys] = type(extended[0])(H[0, 0], H[1, 0], H[2, 0])
extended[ext_start + n_atoms_sys + 1] = type(extended[0])(H[1, 1], H[2, 1], H[2, 2])
@wp.kernel(enable_backward=False)
def _pack_cell_force_dofs_kernel(
cell_forces: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Write 2 cell force DOF vec3s for one system into extended array.
Launch Grid: dim = M (num_systems).
"""
sys = wp.tid()
n_atoms_sys = atom_ptr[sys + 1] - atom_ptr[sys]
ext_start = ext_atom_ptr[sys]
Fc = cell_forces[sys]
extended[ext_start + n_atoms_sys] = type(extended[0])(Fc[0, 0], Fc[1, 0], Fc[2, 0])
extended[ext_start + n_atoms_sys + 1] = type(extended[0])(
Fc[1, 1], Fc[2, 1], Fc[2, 2]
)
@wp.kernel(enable_backward=False)
def _unpack_atoms_batched_kernel(
extended: wp.array(dtype=Any),
dst: wp.array(dtype=Any),
batch_idx: wp.array(dtype=wp.int32),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Copy one atom from its interleaved position in extended array to dst.
Launch Grid: dim = N (total atoms).
"""
i = wp.tid()
s = batch_idx[i]
local_idx = i - atom_ptr[s]
dst[i] = extended[ext_atom_ptr[s] + local_idx]
@wp.kernel(enable_backward=False)
def _unpack_cell_dofs_kernel(
extended: wp.array(dtype=Any),
cells: wp.array(dtype=Any),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Read 2 cell DOF vec3s for one system from extended array and reconstruct cell.
Launch Grid: dim = M (num_systems).
"""
sys = wp.tid()
n_atoms_sys = atom_ptr[sys + 1] - atom_ptr[sys]
ext_start = ext_atom_ptr[sys]
v1 = extended[ext_start + n_atoms_sys]
v2 = extended[ext_start + n_atoms_sys + 1]
_zero = type(v1[0])(0.0)
cells[sys] = type(cells[0])(
v1[0],
_zero,
_zero,
v1[1],
v2[0],
_zero,
v1[2],
v2[1],
v2[2],
)
@wp.kernel
def _pack_masses_batched_kernel(
masses: wp.array(dtype=Any),
cell_masses: wp.array(dtype=Any),
extended: wp.array(dtype=Any),
atom_ptr: wp.array(dtype=wp.int32),
ext_atom_ptr: wp.array(dtype=wp.int32),
):
"""Pack atomic masses and cell masses into extended array for batched systems.
Each thread handles one complete system, copying its atomic masses and
appending the cell mass entries. The cell mass controls the relative
response speed of cell parameters during optimization.
Parameters
----------
masses : wp.array, shape (total_atoms,), dtype=float32 or float64
Concatenated atomic masses for all systems.
cell_masses : wp.array, shape (num_systems,), dtype=float32 or float64
Cell mass for each system.
extended : wp.array, shape (total_atoms + 2*num_systems,), dtype=float32 or float64
OUTPUT: Extended mass array. Modified in-place.
atom_ptr : wp.array, shape (num_systems + 1,), dtype=int32
CSR-style pointer for original masses.
ext_atom_ptr : wp.array, shape (num_systems + 1,), dtype=int32
CSR-style pointer for extended array (from extend_atom_ptr).
Launch Grid
-----------
dim = num_systems
"""
sys = wp.tid()
orig_start = atom_ptr[sys]
n_atoms_sys = atom_ptr[sys + 1] - orig_start
ext_start = ext_atom_ptr[sys]
# Copy atomic masses (serial within thread)
for i in range(n_atoms_sys):
extended[ext_start + i] = masses[orig_start + i]
# Cell DOFs get the cell mass for this system
extended[ext_start + n_atoms_sys] = cell_masses[sys]
extended[ext_start + n_atoms_sys + 1] = cell_masses[sys]
# ==============================================================================
# Batch Index Extension Kernels
# ==============================================================================
@wp.kernel
def _extend_batch_idx_kernel(
batch_idx: wp.array(dtype=wp.int32),
extended_batch_idx: wp.array(dtype=wp.int32),
num_atoms: wp.int32,
num_systems: wp.int32,
):
"""Extend batch_idx to include cell DOFs for variable-cell optimization.
Atomic indices keep their original system assignment. Cell DOFs are appended
after all atoms, with 2 DOFs per system assigned to their respective systems.
Extended layout:
[atom_0_sys, atom_1_sys, ..., atom_N-1_sys, <- original atoms
sys_0, sys_0, <- system 0 cell DOFs
sys_1, sys_1, <- system 1 cell DOFs
...]
Parameters
----------
batch_idx : wp.array, shape (num_atoms,), dtype=int32
Original system index for each atom.
extended_batch_idx : wp.array, shape (num_atoms + 2*num_systems,), dtype=int32
OUTPUT: Extended batch index including cell DOFs. Modified in-place.
num_atoms : wp.int32
Total number of atoms across all systems (N).
num_systems : wp.int32
Number of systems (B).
Launch Grid
-----------
dim = num_atoms + 2 * num_systems
"""
idx = wp.tid()
if idx < num_atoms:
# Atomic positions keep their original batch_idx
extended_batch_idx[idx] = batch_idx[idx]
else:
# Cell DOFs: idx = num_atoms + 2*sys + offset (offset = 0 or 1)
cell_idx = idx - num_atoms
sys = cell_idx / 2
extended_batch_idx[idx] = sys
@wp.kernel
def _extend_atom_ptr_kernel(
atom_ptr: wp.array(dtype=wp.int32),
extended_atom_ptr: wp.array(dtype=wp.int32),
):
"""Extend atom_ptr to include cell DOFs for variable-cell optimization.
Each system's range is extended by 2 entries (for the 6 cell DOFs stored
as 2 vec3s). The offset increases by 2 for each system.
Transformation:
extended_atom_ptr[sys] = atom_ptr[sys] + 2 * sys
Example:
atom_ptr = [0, 50, 100] # 2 systems with 50 atoms each
ext_atom_ptr = [0, 52, 104] # 50+2=52, 100+4=104
Parameters
----------
atom_ptr : wp.array, shape (num_systems + 1,), dtype=int32
Original CSR-style atom pointers.
extended_atom_ptr : wp.array, shape (num_systems + 1,), dtype=int32
OUTPUT: Extended CSR-style pointers. Modified in-place.
Launch Grid
-----------
dim = num_systems + 1
"""
sys = wp.tid()
extended_atom_ptr[sys] = atom_ptr[sys] + 2 * sys
# ==============================================================================
# Stress to Cell Force Conversion Kernel
# ==============================================================================
@wp.kernel
def _stress_to_cell_force_kernel(
stress: wp.array(dtype=Any),
cell: wp.array(dtype=Any),
volume: wp.array(dtype=Any),
cell_force: wp.array(dtype=Any),
keep_aligned: wp.bool,
):
r"""Convert stress tensor to cell force for optimization.
The cell "force" is computed as:
.. math::
\mathbf{F}_{\text{cell}} = -V \cdot \boldsymbol{\sigma} \cdot (\mathbf{H}^{-1})^T
where V is cell volume, σ is the stress tensor, and H is the cell matrix.
For upper-triangular cells, this simplifies since H^{-1} is also upper-triangular.
Parameters
----------
stress : wp.array, shape (B,), dtype=wp.mat33*
Stress tensor in tension-positive (negative for compression)
convention, in energy/volume units. For zero-pressure relaxation
this is typically ``virial / V`` where virial = −Σ r⊗F from the
LJ kernel. For finite external pressure use
``P_ext − P_internal`` (see ``virial_to_stress``).
cell : wp.array, shape (B,), dtype=wp.mat33*
Cell matrices (should be upper-triangular from align_cell).
volume : wp.array, shape (B,), dtype=wp.float*
Cell volumes.
cell_force : wp.array, shape (B,), dtype=wp.mat33*
Output cell force matrices.
keep_aligned : wp.bool
If True, zero out upper-triangular off-diagonal elements [0,1], [0,2], [1,2]
to prevent the cell from rotating away from upper-triangular form.
Launch Grid
-----------
dim = [num_systems]
Notes
-----
- The stress follows a tension-positive sign convention: negative values
indicate compression, positive values indicate tension / expansion.
- The negative prefactor in the formula ensures correct equilibration:
negative stress (compression) produces a positive cell force that
expands the cell, while positive stress (tension) produces a negative
cell force that contracts the cell.
- When keep_aligned=True, the upper off-diagonal elements are zeroed to
maintain the upper-triangular cell representation from align_cell().
This is essential for stable variable-cell optimization.
"""
sys = wp.tid()
V = volume[sys]
S = stress[sys]
H = cell[sys]
# Compute H^{-1}
H_inv = wp.inverse(H)
# F_cell = -V * S @ H_inv^T
# Note: in warp, H_inv * x is matrix-vector, we need transpose
H_inv_T = wp.transpose(H_inv)
Fc = type(S[0, 0])(-1.0) * V * wp.mul(S, H_inv_T)
# Zero upper off-diagonal to keep cell aligned (upper-triangular)
if keep_aligned:
_zero = type(S[0, 0])(0.0)
Fc_aligned = type(Fc)(
Fc[0, 0],
_zero,
_zero, # Row 0: keep [0,0], zero [0,1] and [0,2]
Fc[1, 0],
Fc[1, 1],
_zero, # Row 1: keep [1,0] and [1,1], zero [1,2]
Fc[2, 0],
Fc[2, 1],
Fc[2, 2], # Row 2: keep all
)
cell_force[sys] = Fc_aligned
else:
cell_force[sys] = Fc
# ==============================================================================
# Kernel Overloads for Explicit Typing
# ==============================================================================
_T = [wp.float32, wp.float64] # Scalar types
_V = [wp.vec3f, wp.vec3d] # Vector types
_M = [wp.mat33f, wp.mat33d] # Matrix types
# Cell alignment kernel overloads
_align_cell_kernel_overload = {}
_apply_transform_single_kernel_overload = {}
_apply_transform_kernel_overload = {}
# Pack/unpack kernel overloads (single system)
_pack_positions_kernel_overload = {}
_unpack_positions_kernel_overload = {}
_pack_forces_kernel_overload = {}
_pack_masses_kernel_overload = {}
# Pack/unpack kernel overloads (batched with atom_ptr + batch_idx)
_pack_atoms_batched_kernel_overload = {}
_pack_cell_dofs_kernel_overload = {}
_pack_cell_force_dofs_kernel_overload = {}
_unpack_atoms_batched_kernel_overload = {}
_unpack_cell_dofs_kernel_overload = {}
_pack_masses_batched_kernel_overload = {}
# Stress to cell force kernel overloads
_stress_to_cell_force_kernel_overload = {}
for t, v, m in zip(_T, _V, _M):
# Cell alignment kernels
_align_cell_kernel_overload[m] = wp.overload(
_align_cell_kernel,
[wp.array(dtype=m), wp.array(dtype=m)],
)
_apply_transform_single_kernel_overload[v] = wp.overload(
_apply_transform_single_kernel,
[wp.array(dtype=v), wp.array(dtype=m)],
)
_apply_transform_kernel_overload[v] = wp.overload(
_apply_transform_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=wp.int32)],
)
# Pack/unpack kernels
_pack_positions_kernel_overload[v] = wp.overload(
_pack_positions_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=v), wp.int32],
)
_unpack_positions_kernel_overload[v] = wp.overload(
_unpack_positions_kernel,
[wp.array(dtype=v), wp.array(dtype=v), wp.array(dtype=m), wp.int32],
)
_pack_forces_kernel_overload[v] = wp.overload(
_pack_forces_kernel,
[wp.array(dtype=v), wp.array(dtype=m), wp.array(dtype=v), wp.int32],
)
_pack_masses_kernel_overload[t] = wp.overload(
_pack_masses_kernel,
[wp.array(dtype=t), wp.array(dtype=t), wp.array(dtype=t), wp.int32],
)
# Batched pack/unpack kernels (atom_ptr + batch_idx)
_i32 = wp.array(dtype=wp.int32)
_pack_atoms_batched_kernel_overload[v] = wp.overload(
_pack_atoms_batched_kernel,
[wp.array(dtype=v), wp.array(dtype=v), _i32, _i32, _i32],
)
_pack_cell_dofs_kernel_overload[v] = wp.overload(
_pack_cell_dofs_kernel,
[wp.array(dtype=m), wp.array(dtype=v), _i32, _i32],
)
_pack_cell_force_dofs_kernel_overload[v] = wp.overload(
_pack_cell_force_dofs_kernel,
[wp.array(dtype=m), wp.array(dtype=v), _i32, _i32],
)
_unpack_atoms_batched_kernel_overload[v] = wp.overload(
_unpack_atoms_batched_kernel,
[wp.array(dtype=v), wp.array(dtype=v), _i32, _i32, _i32],
)
_unpack_cell_dofs_kernel_overload[v] = wp.overload(
_unpack_cell_dofs_kernel,
[wp.array(dtype=v), wp.array(dtype=m), _i32, _i32],
)
_pack_masses_batched_kernel_overload[t] = wp.overload(
_pack_masses_batched_kernel,
[
wp.array(dtype=t),
wp.array(dtype=t),
wp.array(dtype=t),
wp.array(dtype=wp.int32),
wp.array(dtype=wp.int32),
],
)
# Stress to cell force kernel
_stress_to_cell_force_kernel_overload[m] = wp.overload(
_stress_to_cell_force_kernel,
[
wp.array(dtype=m),
wp.array(dtype=m),
wp.array(dtype=t),
wp.array(dtype=m),
wp.bool,
],
)
# ==============================================================================
# Functional Interfaces
# ==============================================================================
[docs]
def align_cell(
positions: wp.array,
cell: wp.array,
transform: wp.array,
batch_idx: wp.array = None,
device: str = None,
) -> tuple[wp.array, wp.array]:
"""
Align cell to upper-triangular form and transform positions accordingly.
This is a one-time preprocessing step before variable-cell optimization.
The cell is transformed to the standard upper-triangular form, and
positions are rotated to maintain their fractional coordinates.
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,). Modified in-place.
cell : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,). Modified in-place.
transform : wp.array(dtype=wp.mat33f or wp.mat33d)
Scratch array for rotation transform. Shape (B,).
Caller must pre-allocate.
batch_idx : wp.array(dtype=wp.int32), optional
System index for each atom. Shape (N,). If None, assumes single system.
device : str, optional
Warp device. If None, inferred from positions.
Returns
-------
tuple[wp.array, wp.array]
(positions, cell) - same arrays, modified in-place for convenience.
Example
-------
>>> # Before optimization loop
>>> transform = wp.zeros(1, dtype=wp.mat33d, device=device)
>>> positions, cell = align_cell(positions, cell, transform)
"""
if device is None:
device = positions.device
num_systems = cell.shape[0]
num_atoms = positions.shape[0]
mat_dtype = cell.dtype
vec_dtype = positions.dtype
# Align cell and compute transform
wp.launch(
_align_cell_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[cell, transform],
device=device,
)
# Apply transform to positions: pos_new = pos @ transform
if batch_idx is None:
# Single system
wp.launch(
_apply_transform_single_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, transform],
device=device,
)
else:
wp.launch(
_apply_transform_kernel_overload[vec_dtype],
dim=num_atoms,
inputs=[positions, transform, batch_idx],
device=device,
)
return positions, cell
[docs]
def extend_batch_idx(
batch_idx: wp.array,
num_atoms: int,
num_systems: int,
extended_batch_idx: wp.array,
device: str = None,
) -> wp.array:
"""
Extend batch_idx to include cell DOFs for variable-cell optimization.
For each system, 2 additional "atoms" (representing the 6 cell DOFs as 2 vec3s)
are appended. The extended batch_idx assigns these cell DOFs to their
respective systems.
Parameters
----------
batch_idx : wp.array(dtype=wp.int32)
Original batch index for atoms. Shape (N,).
num_atoms : int
Number of atoms (N).
num_systems : int
Number of systems (B).
extended_batch_idx : wp.array
Output extended batch index. Shape (N + 2*B,).
Caller must pre-allocate.
device : str, optional
Warp device.
Returns
-------
wp.array
Extended batch index. Shape (N + 2*B,).
Example
-------
>>> # Original: 100 atoms across 2 systems
>>> # Extended: 100 + 4 = 104 "atoms" (2 cell DOFs per system)
>>> ext_batch_idx = wp.zeros(104, dtype=wp.int32, device=device)
>>> extend_batch_idx(batch_idx, num_atoms=100, num_systems=2, extended_batch_idx=ext_batch_idx)
"""
if device is None:
device = batch_idx.device
extended_size = num_atoms + 2 * num_systems
wp.launch(
_extend_batch_idx_kernel,
dim=extended_size,
inputs=[batch_idx, extended_batch_idx, num_atoms, num_systems],
device=device,
)
return extended_batch_idx
[docs]
def extend_atom_ptr(
atom_ptr: wp.array,
extended_atom_ptr: wp.array,
device: str = None,
) -> wp.array:
"""
Extend atom_ptr to include cell DOFs for variable-cell optimization.
Each system gets 2 additional DOFs (representing 6 cell parameters as 2 vec3s),
so the CSR pointers are adjusted: extended_atom_ptr[sys] = atom_ptr[sys] + 2*sys.
Parameters
----------
atom_ptr : wp.array(dtype=wp.int32)
Original CSR pointers. Shape (B+1,).
extended_atom_ptr : wp.array
Output extended CSR pointers. Shape (B+1,).
Caller must pre-allocate.
device : str, optional
Warp device.
Returns
-------
wp.array
Extended CSR pointers. Shape (B+1,).
Example
-------
>>> # Original: atom_ptr = [0, 50, 100] (2 systems, 50 atoms each)
>>> # Extended: [0, 52, 104] (50+2=52, 100+4=104)
>>> ext_atom_ptr = wp.zeros(3, dtype=wp.int32, device=device)
>>> extend_atom_ptr(atom_ptr, ext_atom_ptr)
"""
if device is None:
device = atom_ptr.device
num_systems_plus_one = atom_ptr.shape[0]
wp.launch(
_extend_atom_ptr_kernel,
dim=num_systems_plus_one,
inputs=[atom_ptr, extended_atom_ptr],
device=device,
)
return extended_atom_ptr
[docs]
def pack_positions_with_cell(
positions: wp.array,
cell: wp.array,
extended: wp.array,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
batch_idx: wp.array = None,
) -> wp.array:
"""
Pack atomic positions and cell into extended position array.
Single-system mode (atom_ptr=None):
The extended array has shape (N + 2,) with dtype vec3*, where:
- First N entries: atomic positions
- Entry N: [a, b*cos(γ), c1] (first 3 cell parameters)
- Entry N+1: [b*sin(γ), c2, c3] (remaining 3 cell parameters)
Batched mode (atom_ptr provided):
Positions are concatenated across systems, cells have shape (B,).
The extended array interleaves each system's positions with its cell DOFs.
Both atom_ptr and ext_atom_ptr must be provided. Optionally pass
batch_idx to avoid recomputing it internally.
Parameters
----------
positions : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic positions. Shape (N,) for single system or (total_atoms,) for batched.
cell : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrix (should be upper-triangular from align_cell).
Shape (1,) for single system or (B,) for batched.
extended : wp.array
Output extended array. Caller must pre-allocate.
Shape (N+2,) for single, (N+2*B,) for batched.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
batch_idx : wp.array(dtype=wp.int32), optional
Sorted system index per atom. Shape (N,). Computed from atom_ptr
if not provided.
Returns
-------
wp.array
Extended position array.
"""
if device is None:
device = positions.device
vec_dtype = positions.dtype
if atom_ptr is None:
# Single system mode
num_atoms = positions.shape[0]
wp.launch(
_pack_positions_kernel_overload[vec_dtype],
dim=num_atoms + 2,
inputs=[positions, cell, extended, num_atoms],
device=device,
)
else:
# Batched mode
N = positions.shape[0]
M = atom_ptr.shape[0] - 1
if batch_idx is None:
batch_idx = wp.empty(N, dtype=wp.int32, device=device)
atom_ptr_to_batch_idx(atom_ptr, batch_idx)
wp.launch(
_pack_atoms_batched_kernel_overload[vec_dtype],
dim=N,
inputs=[positions, extended, batch_idx, atom_ptr, ext_atom_ptr],
device=device,
)
wp.launch(
_pack_cell_dofs_kernel_overload[vec_dtype],
dim=M,
inputs=[cell, extended, atom_ptr, ext_atom_ptr],
device=device,
)
return extended
[docs]
def unpack_positions_with_cell(
extended: wp.array,
positions: wp.array,
cell: wp.array,
num_atoms: int = None,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
batch_idx: wp.array = None,
) -> tuple[wp.array, wp.array]:
"""
Unpack extended position array to atomic positions and cell.
Single-system mode (atom_ptr=None):
Unpacks extended array of shape (N + 2,) to positions (N,) and cell (1,).
Requires num_atoms to be specified.
Batched mode (atom_ptr provided):
Unpacks extended array to concatenated positions (total_atoms,) and
cells (B,). Both atom_ptr and ext_atom_ptr must be provided.
Optionally pass batch_idx to avoid recomputing it internally.
Parameters
----------
extended : wp.array(dtype=wp.vec3f or wp.vec3d)
Extended position array.
positions : wp.array
Output atomic positions. Caller must pre-allocate.
cell : wp.array
Output cell matrix. Caller must pre-allocate.
num_atoms : int, optional
Number of atoms (N). Required for single-system mode.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
batch_idx : wp.array(dtype=wp.int32), optional
Sorted system index per atom. Shape (N,). Computed from atom_ptr
if not provided.
Returns
-------
tuple[wp.array, wp.array]
(positions, cell)
"""
if device is None:
device = extended.device
vec_dtype = extended.dtype
if atom_ptr is None:
# Single system mode
if num_atoms is None:
raise ValueError("num_atoms is required for single-system mode")
wp.launch(
_unpack_positions_kernel_overload[vec_dtype],
dim=num_atoms + 2,
inputs=[extended, positions, cell, num_atoms],
device=device,
)
else:
# Batched mode
N = positions.shape[0]
M = atom_ptr.shape[0] - 1
if batch_idx is None:
batch_idx = wp.empty(N, dtype=wp.int32, device=device)
atom_ptr_to_batch_idx(atom_ptr, batch_idx)
wp.launch(
_unpack_atoms_batched_kernel_overload[vec_dtype],
dim=N,
inputs=[extended, positions, batch_idx, atom_ptr, ext_atom_ptr],
device=device,
)
wp.launch(
_unpack_cell_dofs_kernel_overload[vec_dtype],
dim=M,
inputs=[extended, cell, atom_ptr, ext_atom_ptr],
device=device,
)
return positions, cell
[docs]
def pack_velocities_with_cell(
velocities: wp.array,
cell_velocity: wp.array,
extended: wp.array,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
batch_idx: wp.array = None,
) -> wp.array:
"""
Pack atomic velocities and cell velocity into extended velocity array.
Single-system mode (atom_ptr=None):
Extended array has shape (N + 2,).
Batched mode (atom_ptr provided):
Velocities are concatenated across systems, cell velocities have shape (B,).
Both atom_ptr and ext_atom_ptr must be provided. Optionally pass
batch_idx to avoid recomputing it internally.
Parameters
----------
velocities : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic velocities. Shape (N,) for single system or (total_atoms,) for batched.
cell_velocity : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell velocity matrix. Shape (1,) for single system or (B,) for batched.
extended : wp.array
Output extended array. Caller must pre-allocate.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
batch_idx : wp.array(dtype=wp.int32), optional
Sorted system index per atom. Computed from atom_ptr if not provided.
Returns
-------
wp.array
Extended velocity array.
"""
# Reuse pack_positions_with_cell - same packing format
return pack_positions_with_cell(
velocities, cell_velocity, extended, atom_ptr, ext_atom_ptr, device, batch_idx
)
[docs]
def unpack_velocities_with_cell(
extended: wp.array,
velocities: wp.array,
cell_velocity: wp.array,
num_atoms: int = None,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
batch_idx: wp.array = None,
) -> tuple[wp.array, wp.array]:
"""
Unpack extended velocity array to atomic velocities and cell velocity.
Single-system mode (atom_ptr=None):
Unpacks extended array of shape (N + 2,). Requires num_atoms.
Batched mode (atom_ptr provided):
Unpacks to concatenated velocities (total_atoms,) and cell velocities (B,).
Both atom_ptr and ext_atom_ptr must be provided. Optionally pass
batch_idx to avoid recomputing it internally.
Parameters
----------
extended : wp.array(dtype=wp.vec3f or wp.vec3d)
Extended velocity array.
velocities : wp.array
Output atomic velocities. Caller must pre-allocate.
cell_velocity : wp.array
Output cell velocity matrix. Caller must pre-allocate.
num_atoms : int, optional
Number of atoms (N). Required for single-system mode.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
batch_idx : wp.array(dtype=wp.int32), optional
Sorted system index per atom. Computed from atom_ptr if not provided.
Returns
-------
tuple[wp.array, wp.array]
(velocities, cell_velocity)
"""
return unpack_positions_with_cell(
extended,
velocities,
cell_velocity,
num_atoms,
atom_ptr,
ext_atom_ptr,
device,
batch_idx,
)
[docs]
def pack_forces_with_cell(
forces: wp.array,
cell_force: wp.array,
extended: wp.array,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
batch_idx: wp.array = None,
) -> wp.array:
"""
Pack atomic forces and cell force into extended force array.
Single-system mode (atom_ptr=None):
Extended array has shape (N + 2,).
Batched mode (atom_ptr provided):
Forces are concatenated across systems, cell forces have shape (B,).
Both atom_ptr and ext_atom_ptr must be provided. Optionally pass
batch_idx to avoid recomputing it internally.
Parameters
----------
forces : wp.array(dtype=wp.vec3f or wp.vec3d)
Atomic forces. Shape (N,) for single system or (total_atoms,) for batched.
cell_force : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell force matrix (from stress_to_cell_force).
Shape (1,) for single system or (B,) for batched.
extended : wp.array
Output extended array. Caller must pre-allocate.
Shape (N+2,) for single, (N+2*B,) for batched.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
batch_idx : wp.array(dtype=wp.int32), optional
Sorted system index per atom. Computed from atom_ptr if not provided.
Returns
-------
wp.array
Extended force array.
"""
if device is None:
device = forces.device
vec_dtype = forces.dtype
if atom_ptr is None:
# Single system mode
num_atoms = forces.shape[0]
wp.launch(
_pack_forces_kernel_overload[vec_dtype],
dim=num_atoms + 2,
inputs=[forces, cell_force, extended, num_atoms],
device=device,
)
else:
# Batched mode
N = forces.shape[0]
M = atom_ptr.shape[0] - 1
if batch_idx is None:
batch_idx = wp.empty(N, dtype=wp.int32, device=device)
atom_ptr_to_batch_idx(atom_ptr, batch_idx)
wp.launch(
_pack_atoms_batched_kernel_overload[vec_dtype],
dim=N,
inputs=[forces, extended, batch_idx, atom_ptr, ext_atom_ptr],
device=device,
)
wp.launch(
_pack_cell_force_dofs_kernel_overload[vec_dtype],
dim=M,
inputs=[cell_force, extended, atom_ptr, ext_atom_ptr],
device=device,
)
return extended
[docs]
def pack_masses_with_cell(
masses: wp.array,
cell_mass_arr: wp.array,
extended: wp.array,
atom_ptr: wp.array = None,
ext_atom_ptr: wp.array = None,
device: str = None,
) -> wp.array:
"""
Pack atomic masses and cell mass into extended mass array.
Single-system mode (atom_ptr=None):
Extended array has shape (N + 2,).
Batched mode (atom_ptr provided):
Masses are concatenated across systems. Cell mass is applied to all systems.
Both atom_ptr and ext_atom_ptr must be provided.
Parameters
----------
masses : wp.array(dtype=wp.float32 or wp.float64)
Atomic masses. Shape (N,) for single system or (total_atoms,) for batched.
cell_mass_arr : wp.array
Cell mass as a warp array. Shape (1,) for single system or (B,) for batched.
Caller must pre-allocate.
extended : wp.array
Output extended array. Caller must pre-allocate.
Shape (N+2,) for single, (N+2*B,) for batched.
atom_ptr : wp.array(dtype=wp.int32), optional
CSR-style atom pointers. Shape (B+1,). If provided, enables batched mode.
ext_atom_ptr : wp.array(dtype=wp.int32), optional
Extended atom pointers from extend_atom_ptr(). Shape (B+1,).
Required if atom_ptr is provided.
device : str, optional
Warp device.
Returns
-------
wp.array
Extended mass array.
"""
if device is None:
device = masses.device
scalar_dtype = masses.dtype
if atom_ptr is None:
# Single system mode
num_atoms = masses.shape[0]
wp.launch(
_pack_masses_kernel_overload[scalar_dtype],
dim=num_atoms + 2,
inputs=[masses, cell_mass_arr, extended, num_atoms],
device=device,
)
else:
# Batched mode with atom_ptr
num_systems = atom_ptr.shape[0] - 1
# Launch with num_systems threads (each handles one system)
wp.launch(
_pack_masses_batched_kernel_overload[scalar_dtype],
dim=num_systems,
inputs=[masses, cell_mass_arr, extended, atom_ptr, ext_atom_ptr],
device=device,
)
return extended
[docs]
def stress_to_cell_force(
stress: wp.array,
cell: wp.array,
volume: wp.array,
cell_force: wp.array,
keep_aligned: bool = True,
device: str = None,
) -> wp.array:
r"""
Convert stress tensor to cell force for optimization.
Computes: F_cell = -V * σ * (H^{-1})^T
This is the "force" on the cell that, when minimized, leads to
zero stress (pressure equilibration).
Parameters
----------
stress : wp.array(dtype=wp.mat33f or wp.mat33d)
Stress tensor. Shape (B,).
Convention: positive values indicate compression.
cell : wp.array(dtype=wp.mat33f or wp.mat33d)
Cell matrices. Shape (B,).
volume : wp.array
Cell volumes. Shape (B,). Caller must pre-compute via
``compute_cell_volume``.
cell_force : wp.array
Output cell force matrices. Shape (B,). Caller must pre-allocate.
keep_aligned : bool, default=True
If True, zero out upper-triangular off-diagonal elements [0,1], [0,2], [1,2]
of the cell force. This is **essential** to prevent the cell from rotating
away from the upper-triangular form established by `align_cell()`.
Only set to False if you know what you're doing.
device : str, optional
Warp device.
Returns
-------
wp.array
Cell force matrices. Shape (B,).
Notes
-----
The `keep_aligned=True` behavior zeros out forces on the upper off-diagonal
elements of the cell matrix:
.. code-block:: text
Cell force structure (keep_aligned=True):
[F00, 0, 0 ]
[F10, F11, 0 ]
[F20, F21, F22]
This prevents the optimizer from introducing rotations that would break
the upper-triangular cell representation from `align_cell()`.
"""
if device is None:
device = stress.device
num_systems = stress.shape[0]
mat_dtype = stress.dtype
wp.launch(
_stress_to_cell_force_kernel_overload[mat_dtype],
dim=num_systems,
inputs=[stress, cell, volume, cell_force, keep_aligned],
device=device,
)
return cell_force