Note
Go to the end to download the full example code.
Variable-Cell FIRE2 Optimization with LJ Potential#
This example demonstrates joint optimization of atomic positions and cell parameters using the FIRE2 optimizer (Guenole et al., 2020) with manual pack/unpack of extended arrays.
Compared to FIRE (07_fire_variable_cell.py), FIRE2:
Assumes unit mass (no
masses/pack_masses_with_cellneeded)Requires
batch_idx(all zeros for single system on extended arrays)Has simpler state:
alpha,dt,nsteps_incHyperparameters are Python scalars
The workflow is otherwise the same:
align_cell() - Transform cell to upper-triangular form
LJ energy/forces/virial - Compute interatomic interactions
Virial -> Stress -> Cell Force - Convert virial to cell driving force
pack_*_with_cell() - Combine atomic + cell DOFs into extended arrays
fire2_step() - FIRE2 optimization on extended arrays
unpack_positions_with_cell() - Extract optimized geometry
We optimize an FCC argon crystal under external pressure.
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
import warp as wp
from _dynamics_utils import (
EPSILON_AR,
SIGMA_AR,
MDSystem,
create_fcc_lattice,
pressure_ev_per_a3_to_gpa,
pressure_gpa_to_ev_per_a3,
virial_to_stress,
)
from nvalchemiops.dynamics.optimizers import fire2_step
from nvalchemiops.dynamics.utils import (
align_cell,
compute_cell_volume,
pack_forces_with_cell,
pack_positions_with_cell,
stress_to_cell_force,
unpack_positions_with_cell,
wrap_positions_to_cell,
)
# LJ cutoff for argon
CUTOFF = 2.5 * SIGMA_AR # ~8.5 Å
# ==============================================================================
# Main Example
# ==============================================================================
wp.init()
device = "cuda:0" if wp.is_cuda_available() else "cpu"
print(f"Using device: {device}")
Using device: cuda:0
Create Initial System#
n_cells = 3 # 3x3x3 = 108 atoms
a_initial = 5.5 # Å (slightly expanded from equilibrium ~5.26 Å)
positions_np, cell_np = create_fcc_lattice(n_cells, a_initial)
num_atoms = len(positions_np)
# Target external pressure (positive = compression)
target_pressure_gpa = 0.01
target_pressure = pressure_gpa_to_ev_per_a3(target_pressure_gpa)
print(f"System: {num_atoms} atoms in {n_cells}\u00b3 FCC lattice")
print(f"Initial lattice constant: {a_initial:.3f} \u00c5")
print(f"Initial density: {num_atoms / np.linalg.det(cell_np):.4f} atoms/\u00c5\u00b3")
print(f"Target external pressure: {target_pressure_gpa:.3f} GPa")
print(f"LJ parameters: \u03b5 = {EPSILON_AR:.4f} eV, \u03c3 = {SIGMA_AR:.2f} \u00c5")
System: 108 atoms in 3³ FCC lattice
Initial lattice constant: 5.500 Å
Initial density: 0.0240 atoms/ų
Target external pressure: 0.010 GPa
LJ parameters: ε = 0.0104 eV, σ = 3.40 Å
Initialize System#
# Create Warp arrays
positions = wp.array(positions_np, dtype=wp.vec3d, device=device)
cell = wp.array(cell_np.reshape(1, 3, 3), dtype=wp.mat33d, device=device)
# Create MD system for force computation
md_system = MDSystem(
positions=positions_np,
cell=cell_np,
epsilon=EPSILON_AR,
sigma=SIGMA_AR,
cutoff=CUTOFF,
skin=0.5,
switch_width=1.0, # Smooth cutoff for optimization
device=device,
)
Initialized MD system with 108 atoms
Cell: 16.50 x 16.50 x 16.50 Å
Cutoff: 8.50 Å (+ 0.50 Å skin)
LJ: ε = 0.0104 eV, σ = 3.40 Å
Device: cuda:0, dtype: <class 'numpy.float64'>
Units: x [Å], t [fs], E [eV], m [eV·fs²/Ų] (from amu), v [Å/fs]
Step 1: Align Cell#
--- Step 1: Align cell to upper-triangular form ---
Aligned cell:
[[16.5 0. 0. ]
[ 0. 16.5 0. ]
[ 0. 0. 16.5]]
Step 2: Pack Extended Arrays#
print("\n--- Step 2: Pack into extended arrays ---")
# Pack positions into extended arrays (atoms + 2 cell DOFs)
# FIRE2 does NOT use masses, so no pack_masses_with_cell needed.
N_ext = num_atoms + 2
ext_positions = wp.empty(N_ext, dtype=wp.vec3d, device=device)
pack_positions_with_cell(positions, cell, extended=ext_positions, device=device)
ext_velocities = wp.zeros(N_ext, dtype=wp.vec3d, device=device)
ext_forces = wp.empty(N_ext, dtype=wp.vec3d, device=device)
print(
f"Extended array size: {ext_positions.shape[0]} ({num_atoms} atoms + 2 cell DOFs)"
)
--- Step 2: Pack into extended arrays ---
Extended array size: 110 (108 atoms + 2 cell DOFs)
Step 3: FIRE2 Parameters#
# Per-system state arrays (shape (1,) for single system)
alpha = wp.array([0.09], dtype=wp.float64, device=device)
dt = wp.array([0.005], dtype=wp.float64, device=device)
nsteps_inc = wp.zeros(1, dtype=wp.int32, device=device)
# Scratch buffers (shape (1,) for single system)
vf = wp.zeros(1, dtype=wp.float64, device=device)
v_sumsq = wp.zeros(1, dtype=wp.float64, device=device)
f_sumsq = wp.zeros(1, dtype=wp.float64, device=device)
max_norm_buf = wp.zeros(1, dtype=wp.float64, device=device)
# batch_idx for extended arrays (all zeros = single system)
ext_batch_idx = wp.zeros(N_ext, dtype=wp.int32, device=device)
# Scratch arrays for unpack/stress/volume
pos_scratch = wp.empty(num_atoms, dtype=wp.vec3d, device=device)
cell_scratch = wp.empty(1, dtype=wp.mat33d, device=device)
cell_force_scratch = wp.empty(1, dtype=wp.mat33d, device=device)
volume_scratch = wp.empty(1, dtype=wp.float64, device=device)
Step 4: Optimization Loop#
max_steps = 1000
force_tol = 1e-4 # Convergence: max force/stress component
log_interval = 100
check_interval = 50
# History for plotting
energy_hist = []
max_force_hist = []
volume_hist = []
pressure_hist = []
lattice_const_hist = []
print("\n--- Step 4: Variable-cell FIRE2 optimization ---")
print(f"Force tolerance: {force_tol:.1e}")
print("FIRE2 defaults: delaystep=60, dtgrow=1.05, alpha0=0.09, maxstep=0.1")
print("=" * 90)
print(
f"{'Step':>6} {'Energy':>12} {'max|F|':>10} {'Volume':>10} "
f"{'ΔP (GPa)':>10} {'a (Å)':>10}"
)
print("=" * 90)
converged = False
for step in range(max_steps):
# Unpack current state
pos_current, cell_current = unpack_positions_with_cell(
ext_positions,
positions=pos_scratch,
cell=cell_scratch,
num_atoms=num_atoms,
device=device,
)
# Update MD system with current geometry
wp.copy(md_system.wp_positions, pos_current)
md_system.update_cell(cell_current)
# Wrap positions into cell (important for PBC consistency)
wrap_positions_to_cell(
positions=md_system.wp_positions,
cells=md_system.wp_cell,
cells_inv=md_system.wp_cell_inv,
device=device,
)
# Compute LJ forces and virial
energies, forces, virial = md_system.compute_forces_virial()
# Convert virial to stress with external pressure contribution
stress = virial_to_stress(virial, md_system.wp_cell, target_pressure, device)
# Convert stress to cell force (for optimization)
compute_cell_volume(md_system.wp_cell, volumes=volume_scratch, device=device)
stress_to_cell_force(
stress,
md_system.wp_cell,
volume=volume_scratch,
cell_force=cell_force_scratch,
keep_aligned=True,
device=device,
)
# Pack forces into extended array
pack_forces_with_cell(
forces, cell_force_scratch, extended=ext_forces, device=device
)
# Re-pack positions (after wrapping)
pack_positions_with_cell(
md_system.wp_positions,
md_system.wp_cell,
extended=ext_positions,
device=device,
)
# FIRE2 step on extended arrays
fire2_step(
positions=ext_positions,
velocities=ext_velocities,
forces=ext_forces,
batch_idx=ext_batch_idx,
alpha=alpha,
dt=dt,
nsteps_inc=nsteps_inc,
vf=vf,
v_sumsq=v_sumsq,
f_sumsq=f_sumsq,
max_norm=max_norm_buf,
)
# Check convergence and log only at intervals
if step % check_interval == 0 or step == max_steps - 1:
wp.synchronize()
ext_forces_np = ext_forces.numpy()
max_force = np.max(np.abs(ext_forces_np))
total_energy = float(energies.numpy().sum())
compute_cell_volume(md_system.wp_cell, volumes=volume_scratch, device=device)
volume = float(volume_scratch.numpy()[0])
# Compute deviation from target pressure (trace of stress / 3)
stress_np = stress.numpy()[0]
stress_trace = (stress_np[0, 0] + stress_np[1, 1] + stress_np[2, 2]) / 3
pressure_deviation_gpa = -pressure_ev_per_a3_to_gpa(stress_trace)
# Effective lattice constant (cube root of volume per atom * 4 for FCC)
lattice_const = (volume / num_atoms * 4) ** (1 / 3)
energy_hist.append(total_energy)
max_force_hist.append(max_force)
volume_hist.append(volume)
pressure_hist.append(pressure_deviation_gpa)
lattice_const_hist.append(lattice_const)
if step % log_interval == 0 or step == max_steps - 1:
print(
f"{step:>6d} {total_energy:>12.6f} {max_force:>10.2e} {volume:>10.2f} "
f"{pressure_deviation_gpa:>10.4f} {lattice_const:>10.4f}"
)
if max_force < force_tol:
print(f"\nConverged at step {step} (max|F| = {max_force:.2e})")
converged = True
break
--- Step 4: Variable-cell FIRE2 optimization ---
Force tolerance: 1.0e-04
FIRE2 defaults: delaystep=60, dtgrow=1.05, alpha0=0.09, maxstep=0.1
==========================================================================================
Step Energy max|F| Volume ΔP (GPa) a (Å)
==========================================================================================
0 -6.418945 3.32e-01 4492.12 -0.1756 5.5000
100 -6.459539 3.03e-01 4388.37 -0.1653 5.4573
200 -6.675806 1.74e-02 3597.40 -0.0111 5.1075
300 -6.728477 7.06e-03 3455.31 -0.0032 5.0394
400 -6.730184 4.71e-03 3404.94 -0.0019 5.0147
500 -6.729981 6.42e-03 3376.39 -0.0034 5.0007
600 -6.730163 6.33e-03 3353.31 -0.0035 4.9893
700 -6.730086 6.41e-03 3327.17 -0.0035 4.9763
800 -6.730273 6.22e-03 3309.31 -0.0035 4.9673
900 -6.730219 5.28e-03 3290.54 -0.0027 4.9579
999 -6.730160 5.56e-03 3275.04 -0.0030 4.9501
Final Results#
wp.synchronize()
final_pos, final_cell = unpack_positions_with_cell(
ext_positions,
positions=pos_scratch,
cell=cell_scratch,
num_atoms=num_atoms,
device=device,
)
wp.synchronize()
final_cell_np = final_cell.numpy()[0]
final_volume = np.linalg.det(final_cell_np)
final_density = num_atoms / final_volume
final_a = (final_volume / num_atoms * 4) ** (1 / 3)
print("\n" + "=" * 60)
print("FINAL RESULTS")
print("=" * 60)
print(f"Final cell:\n{final_cell_np}")
print(f"\nFinal volume: {final_volume:.2f} \u00c5\u00b3")
print(f"Final density: {final_density:.6f} atoms/\u00c5\u00b3")
print(f"Effective lattice constant: {final_a:.4f} \u00c5")
print(f"Target pressure: {target_pressure_gpa:.4f} GPa")
============================================================
FINAL RESULTS
============================================================
Final cell:
[[14.82897735 0. 0. ]
[ 0.03530234 14.98425518 0. ]
[ 0.03182169 -0.01607221 14.73677679]]
Final volume: 3274.53 ų
Final density: 0.032982 atoms/ų
Effective lattice constant: 4.9499 Å
Target pressure: 0.0100 GPa
Plot Convergence#
fig, axes = plt.subplots(2, 2, figsize=(10, 8), constrained_layout=True)
steps = np.arange(len(energy_hist)) * check_interval
# Energy
axes[0, 0].plot(steps, energy_hist, "b-", lw=1.5)
axes[0, 0].set_xlabel("Step")
axes[0, 0].set_ylabel("Energy (eV)")
axes[0, 0].set_title("Total Energy")
# Force convergence
axes[0, 1].semilogy(steps, max_force_hist, "r-", lw=1.5)
axes[0, 1].axhline(force_tol, color="k", ls="--", lw=1, label="tolerance")
axes[0, 1].set_xlabel("Step")
axes[0, 1].set_ylabel("max|F|")
axes[0, 1].set_title("Force Convergence")
axes[0, 1].legend()
# Volume
axes[1, 0].plot(steps, volume_hist, "g-", lw=1.5)
axes[1, 0].set_xlabel("Step")
axes[1, 0].set_ylabel("Volume (\u00c5\u00b3)")
axes[1, 0].set_title("Cell Volume")
# Lattice constant
axes[1, 1].plot(steps, lattice_const_hist, "m-", lw=1.5)
axes[1, 1].axhline(5.26, color="k", ls="--", lw=1, label="~equilibrium (5.26 \u00c5)")
axes[1, 1].set_xlabel("Step")
axes[1, 1].set_ylabel("Lattice constant (\u00c5)")
axes[1, 1].set_title("Effective Lattice Constant")
axes[1, 1].legend()
fig.suptitle(
f"Variable-Cell FIRE2: LJ Argon at P = {target_pressure_gpa:.3f} GPa", fontsize=14
)
plt.show()

Total running time of the script: (0 minutes 1.201 seconds)