FIRE2 Geometry Optimization (LJ Cluster)#

This example demonstrates geometry optimization with the FIRE2 optimizer (Guenole et al., 2020) using:

  • The package LJ implementation (neighbor-list accelerated)

  • The package FIRE2 kernels (nvalchemiops.dynamics.optimizers.fire2_step())

  • The shared example utilities in examples.dynamics._dynamics_utils

Compared to FIRE (06_fire_optimization.py), FIRE2:

  • Assumes unit mass (no masses parameter)

  • Requires batch_idx even for single-system mode

  • Uses fewer per-system state arrays (alpha, dt, nsteps_inc)

  • Hyperparameters are Python scalars (delaystep, dtgrow, etc.)

We optimize the same small Lennard-Jones cluster and plot convergence.

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import torch
import warp as wp
from _dynamics_utils import MDSystem, create_random_cluster

from nvalchemiops.dynamics.optimizers import fire2_step

wp.init()

device = "cuda:0" if wp.is_cuda_available() else "cpu"
print(f"Using device: {device}")
Using device: cuda:0

Create a Lennard-Jones Cluster#

num_atoms = 32
epsilon = 0.0104  # eV (argon-like)
sigma = 3.40  # Å
cutoff = 2.5 * sigma
skin = 0.5
box_L = 80.0  # Å (large to avoid self-interaction across PBC)

cell = np.eye(3, dtype=np.float64) * box_L
initial_positions = create_random_cluster(
    num_atoms=num_atoms,
    radius=12.0,
    min_dist=0.9 * sigma,
    center=np.array([0.5 * box_L, 0.5 * box_L, 0.5 * box_L]),
    seed=42,
)

system = MDSystem(
    positions=initial_positions,
    cell=cell,
    masses=np.full(num_atoms, 39.948, dtype=np.float64),  # amu (argon)
    epsilon=epsilon,
    sigma=sigma,
    cutoff=cutoff,
    skin=skin,
    switch_width=0.0,
    device=device,
    dtype=np.float64,
)
Initialized MD system with 32 atoms
  Cell: 80.00 x 80.00 x 80.00 Å
  Cutoff: 8.50 Å (+ 0.50 Å skin)
  LJ: ε = 0.0104 eV, σ = 3.40 Å
  Device: cuda:0, dtype: <class 'numpy.float64'>
  Units: x [Å], t [fs], E [eV], m [eV·fs²/Ų] (from amu), v [Å/fs]

FIRE2 Optimization Loop#

FIRE2 uses a simpler state than FIRE: just alpha, dt, nsteps_inc per system, plus 4 scratch buffers. Hyperparameters are Python scalars. batch_idx is always required (all zeros for single system).

max_steps = 3000
force_tolerance = 1e-3  # eV/Å (max force)

wp_dtype = system.wp_dtype

# Per-system state arrays (shape (1,) for single system)
alpha = wp.array([0.09], dtype=wp_dtype, device=device)
dt = wp.array([0.005], dtype=wp_dtype, device=device)
nsteps_inc = wp.zeros(1, dtype=wp.int32, device=device)

# Scratch buffers (shape (1,) for single system)
vf = wp.zeros(1, dtype=wp_dtype, device=device)
v_sumsq = wp.zeros(1, dtype=wp_dtype, device=device)
f_sumsq = wp.zeros(1, dtype=wp_dtype, device=device)
max_norm = wp.zeros(1, dtype=wp_dtype, device=device)

# batch_idx: all zeros for single system
batch_idx = wp.zeros(num_atoms, dtype=wp.int32, device=device)

# Velocities (FIRE2 uses unit mass, so velocities are just momenta)
velocities = wp.zeros(num_atoms, dtype=system.wp_vec_dtype, device=device)

# History for plotting
energy_hist: list[float] = []
maxf_hist: list[float] = []
dt_hist: list[float] = []
alpha_hist: list[float] = []

print("\n" + "=" * 95)
print("FIRE2 GEOMETRY OPTIMIZATION (LJ cluster)")
print("=" * 95)
print(f"  atoms: {num_atoms}, cutoff={cutoff:.2f} Å, box={box_L:.1f} Å")
print(f"  max_steps={max_steps}, force_tol={force_tolerance:.2e} eV/Å")
print("  FIRE2 defaults: delaystep=60, dtgrow=1.05, alpha0=0.09, maxstep=0.1")

log_interval = 100
check_interval = 50

for step in range(max_steps):
    # Compute forces at current positions
    energies = system.compute_forces()

    # FIRE2 step: updates positions, velocities, alpha, dt, nsteps_inc in-place
    fire2_step(
        positions=system.wp_positions,
        velocities=velocities,
        forces=system.wp_forces,
        batch_idx=batch_idx,
        alpha=alpha,
        dt=dt,
        nsteps_inc=nsteps_inc,
        vf=vf,
        v_sumsq=v_sumsq,
        f_sumsq=f_sumsq,
        max_norm=max_norm,
    )

    # Logging / stopping criteria (host read only at intervals)
    if step % check_interval == 0 or step == max_steps - 1:
        pe = float(energies.numpy().sum())
        fmax = float(
            torch.linalg.norm(wp.to_torch(system.wp_forces), dim=1).max().item()
        )

        energy_hist.append(pe)
        maxf_hist.append(fmax)
        dt_hist.append(float(dt.numpy()[0]))
        alpha_hist.append(float(alpha.numpy()[0]))

        if step % log_interval == 0 or step == max_steps - 1:
            print(
                f"step={step:5d}  PE={pe:12.6f} eV  max|F|={fmax:10.3e} eV/Å  "
                f"dt={dt_hist[-1]:8.5f}  alpha={alpha_hist[-1]:7.4f}  "
                f"n+={int(nsteps_inc.numpy()[0]):3d}"
            )

        if fmax < force_tolerance:
            print(f"\nConverged at step {step} (max|F|={fmax:.3e} eV/Å).")
            break
===============================================================================================
FIRE2 GEOMETRY OPTIMIZATION (LJ cluster)
===============================================================================================
  atoms: 32, cutoff=8.50 Å, box=80.0 Å
  max_steps=3000, force_tol=1.00e-03 eV/Å
  FIRE2 defaults: delaystep=60, dtgrow=1.05, alpha0=0.09, maxstep=0.1
step=    0  PE=   -0.022508 eV  max|F|= 3.541e-01 eV/Å  dt= 0.00500  alpha= 0.0900  n+=  1
step=  100  PE=   -0.169683 eV  max|F|= 1.031e-01 eV/Å  dt= 0.03696  alpha= 0.0484  n+=101
step=  200  PE=   -0.291762 eV  max|F|= 2.022e-02 eV/Å  dt= 0.08000  alpha= 0.0107  n+=201
step=  300  PE=   -0.328476 eV  max|F|= 9.843e-03 eV/Å  dt= 0.06000  alpha= 0.0900  n+= 52
step=  400  PE=   -0.368472 eV  max|F|= 9.107e-03 eV/Å  dt= 0.08000  alpha= 0.0224  n+=152
step=  500  PE=   -0.400717 eV  max|F|= 2.077e-02 eV/Å  dt= 0.08000  alpha= 0.0049  n+=252
step=  600  PE=   -0.443818 eV  max|F|= 1.520e-02 eV/Å  dt= 0.08000  alpha= 0.0011  n+=352
step=  700  PE=   -0.459287 eV  max|F|= 5.966e-03 eV/Å  dt= 0.08000  alpha= 0.0598  n+= 87
step=  800  PE=   -0.485679 eV  max|F|= 9.677e-03 eV/Å  dt= 0.08000  alpha= 0.0132  n+=187
step=  900  PE=   -0.519129 eV  max|F|= 3.104e-02 eV/Å  dt= 0.08000  alpha= 0.0029  n+=287
step= 1000  PE=   -0.544707 eV  max|F|= 1.331e-02 eV/Å  dt= 0.06000  alpha= 0.0900  n+= 39
step= 1100  PE=   -0.572975 eV  max|F|= 8.447e-03 eV/Å  dt= 0.08000  alpha= 0.0273  n+=139
step= 1200  PE=   -0.613114 eV  max|F|= 1.308e-02 eV/Å  dt= 0.08000  alpha= 0.0060  n+=239
step= 1300  PE=   -0.626956 eV  max|F|= 4.091e-03 eV/Å  dt= 0.07293  alpha= 0.0847  n+= 64
step= 1400  PE=   -0.644412 eV  max|F|= 5.236e-03 eV/Å  dt= 0.08000  alpha= 0.0187  n+=164
step= 1500  PE=   -0.676661 eV  max|F|= 1.813e-02 eV/Å  dt= 0.08000  alpha= 0.0041  n+=264
step= 1600  PE=   -0.714668 eV  max|F|= 2.611e-02 eV/Å  dt= 0.08000  alpha= 0.0009  n+=364
step= 1700  PE=   -0.738105 eV  max|F|= 7.718e-03 eV/Å  dt= 0.08000  alpha= 0.0563  n+= 91
step= 1800  PE=   -0.772027 eV  max|F|= 6.722e-03 eV/Å  dt= 0.08000  alpha= 0.0124  n+=191
step= 1900  PE=   -0.797784 eV  max|F|= 1.953e-02 eV/Å  dt= 0.08000  alpha= 0.0027  n+=291
step= 2000  PE=   -0.813366 eV  max|F|= 9.130e-03 eV/Å  dt= 0.06000  alpha= 0.0900  n+= 45
step= 2100  PE=   -0.834695 eV  max|F|= 6.301e-03 eV/Å  dt= 0.08000  alpha= 0.0249  n+=145
step= 2200  PE=   -0.867851 eV  max|F|= 1.347e-02 eV/Å  dt= 0.08000  alpha= 0.0055  n+=245
step= 2300  PE=   -0.886634 eV  max|F|= 3.405e-02 eV/Å  dt= 0.08000  alpha= 0.0012  n+=345
step= 2400  PE=   -0.908705 eV  max|F|= 9.988e-03 eV/Å  dt= 0.06000  alpha= 0.0900  n+= 56
step= 2500  PE=   -0.927332 eV  max|F|= 6.808e-03 eV/Å  dt= 0.08000  alpha= 0.0211  n+=156
step= 2600  PE=   -0.955947 eV  max|F|= 1.362e-02 eV/Å  dt= 0.08000  alpha= 0.0047  n+=256
step= 2700  PE=   -0.978268 eV  max|F|= 3.703e-02 eV/Å  dt= 0.08000  alpha= 0.0010  n+=356
step= 2800  PE=   -0.995579 eV  max|F|= 8.869e-03 eV/Å  dt= 0.06000  alpha= 0.0900  n+= 57
step= 2900  PE=   -1.016717 eV  max|F|= 4.816e-03 eV/Å  dt= 0.08000  alpha= 0.0208  n+=157
step= 2999  PE=   -1.033258 eV  max|F|= 6.567e-03 eV/Å  dt= 0.08000  alpha= 0.0047  n+=256

Plot convergence

steps = np.arange(len(energy_hist))

fig, ax = plt.subplots(2, 1, figsize=(7.0, 5.5), sharex=True, constrained_layout=True)
ax[0].plot(steps, energy_hist, lw=1.5)
ax[0].set_ylabel("Potential Energy (eV)")
ax[0].set_title("FIRE2 Optimization Convergence")

ax[1].semilogy(steps, maxf_hist, lw=1.5)
ax[1].axhline(force_tolerance, color="k", ls="--", lw=1.0, label="tolerance")
ax[1].set_xlabel("Log point index")
ax[1].set_ylabel(r"max$|F|$ (eV/$\AA$)")
ax[1].legend(frameon=False, loc="best")
FIRE2 Optimization Convergence
<matplotlib.legend.Legend object at 0x767809cf2350>

Visualize initial vs final geometry (XY projection)

pos0 = initial_positions
pos1 = wp.to_torch(system.wp_positions).cpu().numpy()

fig2, ax2 = plt.subplots(
    1, 2, figsize=(8.0, 3.5), sharex=True, sharey=True, constrained_layout=True
)
ax2[0].scatter(pos0[:, 0], pos0[:, 1], s=20)
ax2[0].set_title("Initial (XY)")
ax2[0].set_xlabel("x (Å)")
ax2[0].set_ylabel("y (Å)")
ax2[1].scatter(pos1[:, 0], pos1[:, 1], s=20)
ax2[1].set_title("Optimized (XY)")
ax2[1].set_xlabel("x (Å)")

plt.show()
Initial (XY), Optimized (XY)

Total running time of the script: (0 minutes 2.040 seconds)

Gallery generated by Sphinx-Gallery