Source code for nvalchemiops.dynamics.optimizers.fire2

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
FIRE2 Optimizer Kernels
=======================

GPU-accelerated Warp kernels for the FIRE2 (Fast Inertial Relaxation
Engine v2) geometry optimizer.

This module provides three highly-fused kernels that implement a complete
FIRE2 step in only **3 kernel launches**, minimizing Python-side and
launch overhead.

FIRE2 ALGORITHM (Guenole et al., 2020)
=======================================

Given positions *r*, velocities *v*, and forces *f*:

1. Half-step velocity update:  v += f * dt
2. Compute power:  P = sum(v . f)  per system
3. Adaptive parameter update:
   - If P > 0: increment counter, optionally grow dt, shrink alpha
   - If P <= 0: reset counter, shrink dt, reset alpha
4. Velocity mixing:
   v = (1 - alpha) * v + alpha * sqrt(v.v / f.f) * f
5. Compute step:  step = v * dt
6. Uphill correction:
   if P <= 0:  step = -0.5 * dt * v_mixed;  v = 0
7. Step clamping + position update + coupled dt scaling

KERNEL STRUCTURE
================

Kernel 1 (_fire2_reduce_only):
    Runs-based triple inner-product reduction (vf, v.v, f.f) with
    deferred half-step computed in registers only (no velocity write).

Kernel 2 (_fire2_fused_mix_maxnorm):
    Fuses per-system parameter update, deferred half-step, velocity
    mixing, and runs-based max-norm reduction into a single launch.
    Each thread redundantly computes the parameter update for its
    segment from shared read-only inputs, avoiding inter-thread
    synchronization.

Kernel 3 (_fire2_clamp_apply_recompute):
    Recomputes step from mixed velocities, applies step clamping,
    position update, deferred velocity zeroing for uphill systems,
    and coupled dt scaling.

REFERENCES
==========

- Guenole et al. (2020). Comp. Mat. Sci. 175, 109584 (FIRE2)
- Bitzek et al. (2006). Phys. Rev. Lett. 97, 170201 (FIRE)
"""

from __future__ import annotations

import os
from typing import Any

import warp as wp

from nvalchemiops.dynamics.utils.kernel_functions import compute_vf_vv_ff
from nvalchemiops.segment_ops import compute_ept

# =============================================================================
# Kernel 1: Triple inner-product reduction (deferred half-step)
# =============================================================================

# Tile block size for cooperative reductions
TILE_DIM = int(os.getenv("NVALCHEMIOPS_DYNAMICS_TILE_DIM", 256))


@wp.kernel(enable_backward=False)
def _fire2_reduce_only_kernel(
    velocities: wp.array(dtype=Any),
    forces: wp.array(dtype=Any),
    dt: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    vf: wp.array(dtype=Any),
    v_sumsq: wp.array(dtype=Any),
    f_sumsq: wp.array(dtype=Any),
    N: wp.int32,
    elems_per_thread: wp.int32,
):
    """Triple inner-product reduction with deferred velocity half-step.

    Computes three inner products per segment without modifying velocities:
    - ``vf[s] = sum(dot(v_upd[i], f[i]) for i where batch_idx[i] == s)``
    - ``v_sumsq[s] = sum(dot(v_upd[i], v_upd[i]) for i where batch_idx[i] == s)``
    - ``f_sumsq[s] = sum(dot(f[i], f[i]) for i where batch_idx[i] == s)``

    where ``v_upd[i] = velocities[i] + forces[i] * dt[batch_idx[i]]``.

    The half-step velocity update is computed in registers only and NOT written
    back to the velocities array. This deferred write is performed by the
    subsequent fused mixing kernel, which algebraically combines the half-step
    with the velocity mixing operation.

    Launch Grid
    -----------
    dim = ceil(N / elems_per_thread)

    Parameters
    ----------
    velocities : wp.array, shape (N,), dtype vec3f/vec3d
        Atomic velocities, read-only (not modified by this kernel).
    forces : wp.array, shape (N,), dtype vec3f/vec3d
        Forces on atoms.
    dt : wp.array, shape (M,), dtype float32/float64
        Per-system timestep (scalar dtype matching vector precision).
    batch_idx : wp.array, shape (N,), dtype int32
        Sorted system index per atom in [0, M).
    vf : wp.array, shape (M,), dtype float32/float64
        OUTPUT: v_upd·f per segment. Zeroed internally before each use.
    v_sumsq : wp.array, shape (M,), dtype float32/float64
        OUTPUT: v_upd·v_upd per segment. Zeroed internally before each use.
    f_sumsq : wp.array, shape (M,), dtype float32/float64
        OUTPUT: f·f per segment. Zeroed internally before each use.
    N : int32
        Total number of atoms.
    elems_per_thread : int32
        Elements processed per thread (auto-tuned based on array size and SM count).

    Notes
    -----
    - batch_idx must be sorted in non-decreasing order for correctness
    - Uses run-length encoding to minimize atomic operations
    - Part of the FIRE2 3-kernel optimization strategy
    - The deferred half-step approach avoids an intermediate velocity write
    """
    t = wp.tid()
    start = t * elems_per_thread
    if start >= N:
        return
    end = wp.min(start + elems_per_thread, N)

    # First element -- compute v_upd in register, do NOT write back
    s_cur = batch_idx[start]
    v_upd = velocities[start] + forces[start] * dt[s_cur]
    acc_vf, acc_vv, acc_ff = compute_vf_vv_ff(v_upd, forces[start])

    for i in range(start + 1, end):
        s = batch_idx[i]
        v_upd = velocities[i] + forces[i] * dt[s]
        val_vf, val_vv, val_ff = compute_vf_vv_ff(v_upd, forces[i])
        if s == s_cur:
            acc_vf = acc_vf + val_vf
            acc_vv = acc_vv + val_vv
            acc_ff = acc_ff + val_ff
        else:
            wp.atomic_add(vf, s_cur, acc_vf)
            wp.atomic_add(v_sumsq, s_cur, acc_vv)
            wp.atomic_add(f_sumsq, s_cur, acc_ff)
            s_cur = s
            acc_vf = val_vf
            acc_vv = val_vv
            acc_ff = val_ff

    wp.atomic_add(vf, s_cur, acc_vf)
    wp.atomic_add(v_sumsq, s_cur, acc_vv)
    wp.atomic_add(f_sumsq, s_cur, acc_ff)


@wp.kernel(enable_backward=False)
def _fire2_reduce_only_tiled_kernel(
    velocities: wp.array(dtype=Any),
    forces: wp.array(dtype=Any),
    dt: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    vf: wp.array(dtype=Any),
    v_sumsq: wp.array(dtype=Any),
    f_sumsq: wp.array(dtype=Any),
):
    """Triple inner-product reduction with tile reductions (per-atom processing).

    Computes three inner products per system using block-level tile reductions:
    - vf[s] = sum(dot(v_upd[i], f[i]) for i where batch_idx[i] == s)
    - v_sumsq[s] = sum(dot(v_upd[i], v_upd[i]) for i where batch_idx[i] == s)
    - f_sumsq[s] = sum(dot(f[i], f[i]) for i where batch_idx[i] == s)

    where v_upd[i] = velocities[i] + forces[i] * dt[batch_idx[i]].

    Launch Grid: dim = [N_atoms], block_dim = TILE_DIM

    Notes
    -----
    - Simpler per-atom processing (no RLE complexity)
    - Uses wp.tile() and wp.tile_sum() for cooperative reduction
    - Reduces atomics from N to N/TILE_DIM per system
    """
    atom_idx = wp.tid()
    system_id = batch_idx[atom_idx]

    # Compute deferred half-step in register only
    v_upd = velocities[atom_idx] + forces[atom_idx] * dt[system_id]

    # Compute local contributions
    local_vf, local_vv, local_ff = compute_vf_vv_ff(v_upd, forces[atom_idx])

    # Convert to tiles for block-level reduction
    t_vf = wp.tile(local_vf)
    t_vv = wp.tile(local_vv)
    t_ff = wp.tile(local_ff)

    # Cooperative sum within block
    s_vf = wp.tile_sum(t_vf)
    s_vv = wp.tile_sum(t_vv)
    s_ff = wp.tile_sum(t_ff)

    # Extract scalar values from tile sums
    sum_vf = s_vf[0]
    sum_vv = s_vv[0]
    sum_ff = s_ff[0]

    # Only first thread in block writes (3 atomics per block)
    if atom_idx % TILE_DIM == 0:
        wp.atomic_add(vf, system_id, sum_vf)
        wp.atomic_add(v_sumsq, system_id, sum_vv)
        wp.atomic_add(f_sumsq, system_id, sum_ff)


# =============================================================================
# Kernel 2: Fused param update + deferred halfstep + mix + max-norm
# =============================================================================


@wp.kernel(enable_backward=False)
def _fire2_fused_mix_maxnorm_kernel(
    velocities: wp.array(dtype=Any),
    forces: wp.array(dtype=Any),
    dt: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    vf: wp.array(dtype=Any),
    v_sumsq: wp.array(dtype=Any),
    f_sumsq: wp.array(dtype=Any),
    alpha: wp.array(dtype=Any),
    nsteps_inc: wp.array(dtype=wp.int32),
    max_norm: wp.array(dtype=Any),
    N: wp.int32,
    elems_per_thread: wp.int32,
    delaystep: wp.int32,
    dtgrow: Any,
    dtshrink: Any,
    alphashrink: Any,
    alpha0: Any,
    tmax: Any,
    tmin: Any,
):
    """Fused adaptive parameter update, deferred half-step, velocity mixing, and max-norm reduction.

    This kernel performs four operations in a single launch:

    1. **Adaptive parameter update** (per-segment, redundantly computed):
       - If ``vf[s] > 0`` (downhill): increment counter, optionally grow dt, shrink alpha
       - If ``vf[s] <= 0`` (uphill): reset counter, shrink dt, reset alpha

    2. **Deferred half-step + velocity mixing** (algebraically combined):
       ``v[i] = mix_a * v[i] + (mix_a * dt_old + mix_b) * f[i]``
       where ``mix_a = 1 - alpha``, ``mix_b = alpha * sqrt(v·v / f·f)``

    3. **State updates** (first atom per segment writes):
       Updates ``alpha[s]``, ``dt[s]``, and ``nsteps_inc[s]``

    4. **Max-norm reduction** (run-length encoded):
       Computes ``max_norm[s] = max(||step[i]|| for i where batch_idx[i] == s)``
       where step depends on uphill/downhill status

    Each thread redundantly computes the per-system parameter update from shared
    read-only inputs (vf, v_sumsq, f_sumsq), avoiding inter-thread synchronization.

    Launch Grid
    -----------
    dim = ceil(N / elems_per_thread)

    Parameters
    ----------
    velocities : wp.array, shape (N,), dtype vec3f/vec3d
        Atomic velocities, modified in-place. Must hold pre-halfstep values
        (kernel 1 did not modify them).
    forces : wp.array, shape (N,), dtype vec3f/vec3d
        Forces on atoms (read-only).
    dt : wp.array, shape (M,), dtype float32/float64
        Per-system timestep. Modified in-place by first atom per segment.
    batch_idx : wp.array, shape (N,), dtype int32
        Sorted system index per atom in [0, M).
    vf : wp.array, shape (M,), dtype float32/float64
        v·f inner product per segment from kernel 1 (read-only).
    v_sumsq : wp.array, shape (M,), dtype float32/float64
        v·v inner product per segment from kernel 1 (read-only).
    f_sumsq : wp.array, shape (M,), dtype float32/float64
        f·f inner product per segment from kernel 1 (read-only).
    alpha : wp.array, shape (M,), dtype float32/float64
        FIRE2 mixing parameter. Modified in-place by first atom per segment.
    nsteps_inc : wp.array, shape (M,), dtype int32
        Consecutive positive-power step counter. Modified by first atom per segment.
    max_norm : wp.array, shape (M,), dtype float32/float64
        OUTPUT: Maximum step norm per segment. Zeroed internally before each use.
    N : int32
        Total number of atoms.
    elems_per_thread : int32
        Elements processed per thread (auto-tuned based on array size).
    delaystep : int32
        Minimum consecutive positive steps before dt growth.
    dtgrow : float32/float64
        Timestep growth factor (typically 1.05).
    dtshrink : float32/float64
        Timestep shrink factor (typically 0.75).
    alphashrink : float32/float64
        Alpha decay factor (typically 0.985).
    alpha0 : float32/float64
        Alpha reset value (typically 0.09).
    tmax : float32/float64
        Maximum allowed timestep.
    tmin : float32/float64
        Minimum allowed timestep.

    Notes
    -----
    - batch_idx must be sorted in non-decreasing order
    - Only the first atom in each segment writes to alpha, dt, nsteps_inc
    - Parameter updates are computed redundantly by each thread to avoid synchronization
    - The algebraic combination of half-step and mixing eliminates one velocity write
    - For uphill systems (vf <= 0), step norm uses factor -0.5 for the correction
    """
    t = wp.tid()
    start = t * elems_per_thread
    if start >= N:
        return
    end = wp.min(start + elems_per_thread, N)

    s_cur = batch_idx[start]

    # --- Redundant param-update computation for the first segment ---
    _vf = vf[s_cur]
    _vv = v_sumsq[s_cur]
    _ff = f_sumsq[s_cur]
    _a = alpha[s_cur]
    _dt = dt[s_cur]
    dt_old = _dt  # pre-update dt for the deferred half-step

    zero = type(_dt)(0.0)
    one = type(_dt)(1.0)
    w_inc = _vf > zero

    if w_inc:
        _nsi = nsteps_inc[s_cur] + 1
        if _nsi > delaystep:
            _dt = wp.min(dtgrow * _dt, tmax)
            _a = alphashrink * _a
    else:
        _nsi = 0
        _a = alpha0
        _dt = wp.max(dtshrink * _dt, tmin)

    # First atom per segment writes updated params
    if start == 0 or batch_idx[start - 1] != s_cur:
        alpha[s_cur] = _a
        dt[s_cur] = _dt
        nsteps_inc[s_cur] = _nsi

    if _ff > zero:
        ratio = wp.sqrt(_vv / _ff)
    else:
        ratio = zero
    mix_a = one - _a
    mix_b = _a * ratio
    w_dec = not w_inc

    # --- Process first element: deferred halfstep + mix (algebraic combo) ---
    f_coeff = mix_a * dt_old + mix_b
    velocities[start] = mix_a * velocities[start] + f_coeff * forces[start]
    if w_dec:
        max_val = wp.length(type(_dt)(-0.5) * _dt * velocities[start])
    else:
        max_val = wp.length(_dt * velocities[start])

    for i in range(start + 1, end):
        s = batch_idx[i]
        if s != s_cur:
            # Flush max_norm for previous segment
            wp.atomic_max(max_norm, s_cur, max_val)
            s_cur = s

            # --- Redundant param-update computation for new segment ---
            _vf = vf[s]
            _vv = v_sumsq[s]
            _ff = f_sumsq[s]
            _a = alpha[s]
            _dt = dt[s]
            dt_old = _dt

            w_inc = _vf > zero
            if w_inc:
                _nsi = nsteps_inc[s] + 1
                if _nsi > delaystep:
                    _dt = wp.min(dtgrow * _dt, tmax)
                    _a = alphashrink * _a
            else:
                _nsi = 0
                _a = alpha0
                _dt = wp.max(dtshrink * _dt, tmin)

            if batch_idx[i - 1] != s:
                alpha[s] = _a
                dt[s] = _dt
                nsteps_inc[s] = _nsi

            if _ff > zero:
                ratio = wp.sqrt(_vv / _ff)
            else:
                ratio = zero
            mix_a = one - _a
            mix_b = _a * ratio
            w_dec = not w_inc
            f_coeff = mix_a * dt_old + mix_b
            max_val = type(_dt)(0.0)

        # Deferred halfstep + mix (algebraic combo)
        velocities[i] = mix_a * velocities[i] + f_coeff * forces[i]
        if w_dec:
            norm = wp.length(type(_dt)(-0.5) * _dt * velocities[i])
        else:
            norm = wp.length(_dt * velocities[i])
        max_val = wp.max(max_val, norm)

    wp.atomic_max(max_norm, s_cur, max_val)


# =============================================================================
# Kernel 3: Step recompute + clamping + position update + velocity zeroing
# =============================================================================


@wp.kernel(enable_backward=False)
def _fire2_clamp_apply_recompute_kernel(
    positions: wp.array(dtype=Any),
    velocities: wp.array(dtype=Any),
    dt: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    max_norm: wp.array(dtype=Any),
    vf: wp.array(dtype=Any),
    maxstep: Any,
):
    """Step recomputation, clamping, position update, velocity zeroing, and coupled dt scaling.

    This kernel performs the final operations of the FIRE2 step:

    1. **Step recomputation** from mixed velocities (avoids storing step buffer)
    2. **Uphill correction**: For ``vf[s] <= 0``, applies ``step = -0.5 * dt * v``
    3. **Step clamping**: Scales step by ``min(1.0, maxstep / max_norm[s])``
    4. **Position update**: ``positions[i] += clamped_step``
    5. **Velocity zeroing**: Sets ``velocities[i] = 0`` for uphill systems
    6. **Coupled dt scaling**: Scales ``dt[s]`` by the same clamping factor

    The algorithm:
    ```
    local_dt = dt[s]  (snapshot before any thread modifies it)
    inv = min(1.0, maxstep / max_norm[s])
    if vf[s] <= 0:  # uphill
        step = -0.5 * local_dt * v[i]
        v[i] = 0
    else:  # downhill
        step = local_dt * v[i]
    positions[i] += step * inv
    if first_atom_in_segment:
        dt[s] = local_dt * inv
    ```

    Launch Grid
    -----------
    dim = N (total atoms)

    Parameters
    ----------
    positions : wp.array, shape (N,), dtype vec3f/vec3d
        Atomic positions, modified in-place.
    velocities : wp.array, shape (N,), dtype vec3f/vec3d
        Atomic velocities, modified in-place (zeroed for uphill systems).
    dt : wp.array, shape (M,), dtype float32/float64
        Per-system timestep, modified in-place by first atom per segment
        (clamped proportionally to step scaling).
    batch_idx : wp.array, shape (N,), dtype int32
        Sorted system index per atom in [0, M).
    max_norm : wp.array, shape (M,), dtype float32/float64
        Maximum step norm per segment from kernel 2.
    vf : wp.array, shape (M,), dtype float32/float64
        v·f inner product per segment from kernel 1. System is uphill if vf[s] <= 0.
    maxstep : float32/float64
        Maximum allowed step size (FIRE2 hyperparameter).

    Notes
    -----
    - Each thread reads dt[s] before any thread writes to avoid race conditions
    - Only the first atom in each segment (batch_idx[i-1] != batch_idx[i]) writes dt[s]
    - Velocity zeroing for uphill systems is deferred to this kernel for efficiency
    - Coupled dt scaling ensures consistency between step size and timestep
    - The -0.5 factor for uphill correction is part of the FIRE2 algorithm
    """
    tid = wp.tid()
    s = batch_idx[tid]
    # Snapshot dt before any thread writes to it (race-condition guard)
    local_dt = dt[s]
    mn = max_norm[s]
    inv = wp.min(type(mn)(1.0), maxstep / mn)

    if vf[s] <= type(mn)(0.0):
        local_step = type(mn)(-0.5) * local_dt * velocities[tid]
        velocities[tid] = type(velocities[tid])()
    else:
        local_step = local_dt * velocities[tid]

    positions[tid] = positions[tid] + local_step * inv
    # Only first atom in segment updates dt (idx is sorted)
    if tid == 0 or batch_idx[tid - 1] != s:
        dt[s] = local_dt * inv


# =============================================================================
# Overloads
# =============================================================================

_T = [wp.float32, wp.float64]
_V = [wp.vec3f, wp.vec3d]

_fire2_reduce_only_overloads = {}
_fire2_reduce_only_tiled_overloads = {}
_fire2_fused_mix_maxnorm_overloads = {}
_fire2_clamp_apply_recompute_overloads = {}

for _t, _v in zip(_T, _V):
    _fire2_reduce_only_overloads[_v] = wp.overload(
        _fire2_reduce_only_kernel,
        [
            wp.array(dtype=_v),  # velocities
            wp.array(dtype=_v),  # forces
            wp.array(dtype=_t),  # dt
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=_t),  # vf
            wp.array(dtype=_t),  # v_sumsq
            wp.array(dtype=_t),  # f_sumsq
            wp.int32,  # N
            wp.int32,  # elems_per_thread
        ],
    )

    _fire2_reduce_only_tiled_overloads[_v] = wp.overload(
        _fire2_reduce_only_tiled_kernel,
        [
            wp.array(dtype=_v),  # velocities
            wp.array(dtype=_v),  # forces
            wp.array(dtype=_t),  # dt
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=_t),  # vf
            wp.array(dtype=_t),  # v_sumsq
            wp.array(dtype=_t),  # f_sumsq
        ],
    )

    _fire2_fused_mix_maxnorm_overloads[_v] = wp.overload(
        _fire2_fused_mix_maxnorm_kernel,
        [
            wp.array(dtype=_v),  # velocities
            wp.array(dtype=_v),  # forces
            wp.array(dtype=_t),  # dt
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=_t),  # vf
            wp.array(dtype=_t),  # v_sumsq
            wp.array(dtype=_t),  # f_sumsq
            wp.array(dtype=_t),  # alpha
            wp.array(dtype=wp.int32),  # nsteps_inc
            wp.array(dtype=_t),  # max_norm
            wp.int32,  # N
            wp.int32,  # elems_per_thread
            wp.int32,  # delaystep
            _t,  # dtgrow
            _t,  # dtshrink
            _t,  # alphashrink
            _t,  # alpha0
            _t,  # tmax
            _t,  # tmin
        ],
    )

    _fire2_clamp_apply_recompute_overloads[_v] = wp.overload(
        _fire2_clamp_apply_recompute_kernel,
        [
            wp.array(dtype=_v),  # positions
            wp.array(dtype=_v),  # velocities
            wp.array(dtype=_t),  # dt
            wp.array(dtype=wp.int32),  # batch_idx
            wp.array(dtype=_t),  # max_norm
            wp.array(dtype=_t),  # vf (v.f inner product)
            _t,  # maxstep
        ],
    )


# =============================================================================
# Public API
# =============================================================================


[docs] def fire2_step( # Per-atom arrays (N,) positions: wp.array, velocities: wp.array, forces: wp.array, batch_idx: wp.array, # Per-system state (M,) alpha: wp.array, dt: wp.array, nsteps_inc: wp.array, # Scratch buffers (M,) vf: wp.array, v_sumsq: wp.array, f_sumsq: wp.array, max_norm: wp.array, # Hyperparameters (Python scalars) delaystep: int = 60, dtgrow: float = 1.05, dtshrink: float = 0.75, alphashrink: float = 0.985, alpha0: float = 0.09, tmax: float = 0.08, tmin: float = 0.005, maxstep: float = 0.1, ) -> None: """Complete FIRE2 optimization step. Modifies *positions*, *velocities*, *alpha*, *dt*, and *nsteps_inc* in-place. Parameters ---------- positions : wp.array, shape (N,), dtype vec3f/vec3d Atomic positions. velocities : wp.array, shape (N,), dtype vec3f/vec3d Atomic velocities. forces : wp.array, shape (N,), dtype vec3f/vec3d Forces on atoms (read-only). batch_idx : wp.array, shape (N,), dtype int32 Sorted system index per atom (required). alpha : wp.array, shape (M,), dtype float* FIRE2 mixing parameter. dt : wp.array, shape (M,), dtype float* Per-system timestep. nsteps_inc : wp.array, shape (M,), dtype int32 Consecutive positive-power step counter. vf, v_sumsq, f_sumsq, max_norm : wp.array, shape (M,), dtype float* Scratch buffers for reductions. Zeroed internally before each use. delaystep : int Minimum positive steps before dt growth. dtgrow, dtshrink : float Timestep growth/shrink factors. alphashrink : float Alpha decay factor. alpha0 : float Alpha reset value. tmax, tmin : float Timestep bounds. maxstep : float Maximum step magnitude per system. Notes ----- - ``batch_idx`` must be sorted; segment reductions assume contiguous atom ranges per system. Examples -------- >>> fire2_step(positions, velocities, forces, batch_idx, ... alpha, dt, nsteps_inc, ... vf, v_sumsq, f_sumsq, max_norm) """ # --- Input validation --- N = positions.shape[0] if batch_idx is None: raise ValueError("batch_idx is required for fire2_step") if velocities.shape[0] != N: raise ValueError( f"velocities length {velocities.shape[0]} != positions length {N}" ) if forces.shape[0] != N: raise ValueError(f"forces length {forces.shape[0]} != positions length {N}") if batch_idx.shape[0] != N: raise ValueError( f"batch_idx length {batch_idx.shape[0]} != positions length {N}" ) M = alpha.shape[0] if dt.shape[0] != M: raise ValueError(f"dt length {dt.shape[0]} != alpha length {M}") if nsteps_inc.shape[0] != M: raise ValueError(f"nsteps_inc length {nsteps_inc.shape[0]} != alpha length {M}") vec_dtype = positions.dtype device = positions.device vf.zero_() v_sumsq.zero_() f_sumsq.zero_() max_norm.zero_() sm = max(device.sm_count, 1) # Kernel 1: reduce only (no velocity write, deferred to fused kernel) ept1 = compute_ept(N, sm, True) dim1 = (N + ept1 - 1) // ept1 wp.launch( _fire2_reduce_only_overloads[vec_dtype], dim=dim1, inputs=[velocities, forces, dt, batch_idx, vf, v_sumsq, f_sumsq, N, ept1], device=device, ) # Kernel 2: param update + deferred halfstep + mix + maxnorm ept2 = compute_ept(N, sm, True) dim2 = (N + ept2 - 1) // ept2 wp.launch( _fire2_fused_mix_maxnorm_overloads[vec_dtype], dim=dim2, inputs=[ velocities, forces, dt, batch_idx, vf, v_sumsq, f_sumsq, alpha, nsteps_inc, max_norm, N, ept2, delaystep, dtgrow, dtshrink, alphashrink, alpha0, tmax, tmin, ], device=device, ) # Kernel 3: recompute step + clamp + position update + velocity zeroing wp.launch( _fire2_clamp_apply_recompute_overloads[vec_dtype], dim=N, inputs=[ positions, velocities, dt, batch_idx, max_norm, vf, # vf holds v.f; uphill if <= 0 maxstep, ], device=device, )