nvalchemiops.torch: Dynamics Optimizers#
The dynamics module provides PyTorch bindings for GPU-accelerated geometry optimization algorithms.
Tip
For the underlying framework-agnostic Warp kernels and full MD integrators, see nvalchemiops.dynamics: Molecular Dynamics.
PyTorch Adapters for nvalchemiops#
Thin wrappers that convert PyTorch tensors to Warp arrays, call the Warp-first core API, and manage scratch buffer allocation via PyTorch’s CUDA caching allocator.
Submodules#
- fire2
PyTorch adapter for the FIRE2 optimizer (coordinate-only and variable-cell).
FIRE2 Optimizer#
PyTorch adapter for the FIRE2 (Fast Inertial Relaxation Engine v2) geometry optimizer. These functions accept PyTorch tensors, allocate scratch buffers via PyTorch’s CUDA caching allocator, and call the pure-Warp FIRE2 kernels.
Coordinate-Only Optimization#
- nvalchemiops.torch.fire2.fire2_step_coord(positions, velocities, forces, batch_idx, alpha, dt, nsteps_inc, *, vf=None, v_sumsq=None, f_sumsq=None, max_norm=None, delaystep=60, dtgrow=1.05, dtshrink=0.75, alphashrink=0.985, alpha0=0.09, tmax=0.08, tmin=0.005, maxstep=0.1)[source]#
FIRE2 coordinate-only optimization step.
Converts PyTorch tensors to Warp arrays (zero-copy) and delegates to the pure-Warp
fire2_step().Modifies positions, velocities, alpha, dt, and nsteps_inc in-place.
- Parameters:
positions (Tensor, shape (N, 3), dtype float32/float64) – Atomic positions.
velocities (Tensor, shape (N, 3), dtype float32/float64) – Atomic velocities.
forces (Tensor, shape (N, 3), dtype float32/float64) – Forces on atoms (read-only).
batch_idx (Tensor, shape (N,), dtype int32) – Sorted system index per atom. Must be non-decreasing; segmented reductions rely on contiguous atom ranges.
alpha (Tensor, shape (M,), dtype float32/float64) – FIRE2 mixing parameter (one per system).
dt (Tensor, shape (M,), dtype float32/float64) – Per-system timestep.
nsteps_inc (Tensor, shape (M,), dtype int32) – Consecutive positive-power step counter.
vf (Tensor, shape (M,), optional) –
Scratch buffers for per-system reductions. Allocated and zeroed if
None; zeroed in-place if provided. Pre-allocate and pass them in tight loops to avoid repeated allocation:M = alpha.shape[0] vf = torch.empty(M, dtype=positions.dtype, device=positions.device) v_sumsq = torch.empty_like(vf) f_sumsq = torch.empty_like(vf) max_norm = torch.empty_like(vf)
v_sumsq (Tensor, shape (M,), optional) –
Scratch buffers for per-system reductions. Allocated and zeroed if
None; zeroed in-place if provided. Pre-allocate and pass them in tight loops to avoid repeated allocation:M = alpha.shape[0] vf = torch.empty(M, dtype=positions.dtype, device=positions.device) v_sumsq = torch.empty_like(vf) f_sumsq = torch.empty_like(vf) max_norm = torch.empty_like(vf)
f_sumsq (Tensor, shape (M,), optional) –
Scratch buffers for per-system reductions. Allocated and zeroed if
None; zeroed in-place if provided. Pre-allocate and pass them in tight loops to avoid repeated allocation:M = alpha.shape[0] vf = torch.empty(M, dtype=positions.dtype, device=positions.device) v_sumsq = torch.empty_like(vf) f_sumsq = torch.empty_like(vf) max_norm = torch.empty_like(vf)
max_norm (Tensor, shape (M,), optional) –
Scratch buffers for per-system reductions. Allocated and zeroed if
None; zeroed in-place if provided. Pre-allocate and pass them in tight loops to avoid repeated allocation:M = alpha.shape[0] vf = torch.empty(M, dtype=positions.dtype, device=positions.device) v_sumsq = torch.empty_like(vf) f_sumsq = torch.empty_like(vf) max_norm = torch.empty_like(vf)
delaystep (int) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.dtgrow (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.dtshrink (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.alphashrink (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.alpha0 (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.tmax (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.tmin (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.maxstep (float) – FIRE2 hyperparameters. See
fire2_step()for defaults and descriptions.
- Return type:
None
Notes
For variable-cell optimization (coordinates + cell DOFs), use
fire2_step_coord_cell()instead.Examples
Minimal single-step call:
>>> fire2_step_coord( ... positions, velocities, forces, ... batch_idx, alpha, dt, nsteps_inc, ... )
Tight optimization loop with pre-allocated scratch buffers:
>>> M = alpha.shape[0] >>> vf = torch.empty(M, dtype=positions.dtype, device=positions.device) >>> v_sumsq = torch.empty_like(vf) >>> f_sumsq = torch.empty_like(vf) >>> max_norm = torch.empty_like(vf) >>> for step in range(num_steps): ... fire2_step_coord( ... positions, velocities, forces, ... batch_idx, alpha, dt, nsteps_inc, ... vf=vf, v_sumsq=v_sumsq, ... f_sumsq=f_sumsq, max_norm=max_norm, ... )
Variable-Cell Optimization#
For optimizing both atomic coordinates and simulation cell parameters simultaneously.
- nvalchemiops.torch.fire2.fire2_step_coord_cell(positions, velocities, forces, cell, cell_velocities, cell_force, batch_idx, alpha, dt, nsteps_inc, *, atom_ptr=None, ext_atom_ptr=None, ext_positions=None, ext_velocities=None, ext_forces=None, ext_batch_idx=None, vf=None, v_sumsq=None, f_sumsq=None, max_norm=None, delaystep=60, dtgrow=1.05, dtshrink=0.75, alphashrink=0.985, alpha0=0.09, tmax=0.08, tmin=0.005, maxstep=0.1)[source]#
FIRE2 variable-cell optimization step.
Performs a FIRE2 step on both atomic coordinates and cell degrees of freedom. Internally packs atomic + cell DOFs into extended arrays using an interleaved layout (each system’s atoms followed by its 2 cell vec3s), runs the 3-kernel FIRE2 algorithm, and unpacks results back.
The cell must be pre-aligned to upper-triangular form via
align_cell()before the first call.Modifies positions, velocities, cell, cell_velocities, alpha, dt, and nsteps_inc in-place.
- Parameters:
positions (Tensor, shape (N, 3), dtype float32/float64) – Atomic positions.
velocities (Tensor, shape (N, 3), dtype float32/float64) – Atomic velocities.
forces (Tensor, shape (N, 3), dtype float32/float64) – Forces on atoms (read-only).
cell (Tensor, shape (M, 3, 3), dtype float32/float64) – Cell matrices (upper-triangular from
align_cell()).cell_velocities (Tensor, shape (M, 3, 3), dtype float32/float64) – Cell velocity matrices.
cell_force (Tensor, shape (M, 3, 3), dtype float32/float64) – Cell force matrices from
stress_to_cell_force()(read-only).batch_idx (Tensor, shape (N,), dtype int32) – Sorted system index per atom.
alpha (Tensor, shape (M,), dtype float32/float64) – FIRE2 mixing parameter.
dt (Tensor, shape (M,), dtype float32/float64) – Per-system timestep.
nsteps_inc (Tensor, shape (M,), dtype int32) – Consecutive positive-power counter.
atom_ptr (Tensor, shape (M+1,), dtype int32, optional) – CSR-style atom pointers derived from batch_idx. If
None, computed internally each call viabatch_idx_to_atom_ptr(). Pre-compute once and pass in tight loops to avoid repeated allocation. See Notes for how to compute.ext_atom_ptr (Tensor, shape (M+1,), dtype int32, optional) – Extended atom pointers (accounts for 2 cell DOFs per system). If
None, computed from atom_ptr each call viaextend_atom_ptr(). See Notes for how to compute.ext_positions (Tensor, shape (N+2M, 3), optional) – Pre-allocated extended working arrays. Allocated if
None; contents are overwritten each call. Providing them avoids repeated allocation in tight loops.ext_velocities (Tensor, shape (N+2M, 3), optional) – Pre-allocated extended working arrays. Allocated if
None; contents are overwritten each call. Providing them avoids repeated allocation in tight loops.ext_forces (Tensor, shape (N+2M, 3), optional) – Pre-allocated extended working arrays. Allocated if
None; contents are overwritten each call. Providing them avoids repeated allocation in tight loops.ext_batch_idx (Tensor, shape (N+2M,), dtype int32, optional) – Pre-computed extended batch index (sorted, matching interleaved pack layout). If
None, computed from ext_atom_ptr each call viaatom_ptr_to_batch_idx(). If provided, assumed correct and reused without recomputation. See Notes for how to compute.vf (Tensor, shape (M,), optional) – Scratch buffers for reductions. Allocated and zeroed if
None; zeroed in-place if provided.v_sumsq (Tensor, shape (M,), optional) – Scratch buffers for reductions. Allocated and zeroed if
None; zeroed in-place if provided.f_sumsq (Tensor, shape (M,), optional) – Scratch buffers for reductions. Allocated and zeroed if
None; zeroed in-place if provided.max_norm (Tensor, shape (M,), optional) – Scratch buffers for reductions. Allocated and zeroed if
None; zeroed in-place if provided.delaystep (int) – FIRE2 hyperparameters.
dtgrow (float) – FIRE2 hyperparameters.
dtshrink (float) – FIRE2 hyperparameters.
alphashrink (float) – FIRE2 hyperparameters.
alpha0 (float) – FIRE2 hyperparameters.
tmax (float) – FIRE2 hyperparameters.
tmin (float) – FIRE2 hyperparameters.
maxstep (float) – FIRE2 hyperparameters.
- Return type:
None
Notes
Pre-computing static metadata for tight loops
When batch_idx does not change between steps (fixed system sizes), atom_ptr, ext_atom_ptr, and ext_batch_idx are constant and can be pre-computed once to eliminate per-step allocation and kernel launches:
import warp as wp from nvalchemiops.batch_utils import ( atom_ptr_to_batch_idx, batch_idx_to_atom_ptr, ) from nvalchemiops.dynamics.utils.cell_filter import extend_atom_ptr N, M = positions.shape[0], alpha.shape[0] N_ext = N + 2 * M device = positions.device # 1) atom_ptr from batch_idx (CSR pointers into atom array) atom_ptr = torch.zeros(M + 1, dtype=torch.int32, device=device) atom_counts = torch.zeros(M, dtype=torch.int32, device=device) batch_idx_to_atom_ptr( wp.from_torch(batch_idx, dtype=wp.int32), wp.from_torch(atom_counts, dtype=wp.int32), wp.from_torch(atom_ptr, dtype=wp.int32), ) # 2) ext_atom_ptr (CSR pointers into extended array, # each system's range grows by 2 for the cell DOFs) ext_atom_ptr = torch.zeros(M + 1, dtype=torch.int32, device=device) extend_atom_ptr( wp.from_torch(atom_ptr, dtype=wp.int32), wp.from_torch(ext_atom_ptr, dtype=wp.int32), ) # 3) ext_batch_idx (sorted system index for extended array) ext_batch_idx = torch.empty(N_ext, dtype=torch.int32, device=device) atom_ptr_to_batch_idx( wp.from_torch(ext_atom_ptr, dtype=wp.int32), wp.from_torch(ext_batch_idx, dtype=wp.int32), )
Then pass all three on every step:
fire2_step_coord_cell( ..., atom_ptr=atom_ptr, ext_atom_ptr=ext_atom_ptr, ext_batch_idx=ext_batch_idx, )
Extended array layout (interleaved)
The packing places each system’s cell DOFs immediately after its atoms:
[sys0_atom0, ..., sys0_atomK, sys0_cell_row0, sys0_cell_row1, sys1_atom0, ..., sys1_atomJ, sys1_cell_row0, sys1_cell_row1, ...]
This ensures that ext_batch_idx is sorted (all DOFs for system 0 precede all DOFs for system 1, etc.), which is required by
fire2_step’s segmented reductions.Examples
Minimal single-step call (all buffers allocated internally):
>>> fire2_step_coord_cell( ... positions, velocities, forces, ... cell, cell_velocities, cell_force, ... batch_idx, alpha, dt, nsteps_inc, ... )
Tight optimization loop with pre-allocated buffers:
>>> # Pre-compute static metadata once >>> atom_ptr = ... # see Notes >>> ext_atom_ptr = ... >>> ext_batch_idx = ... >>> N_ext = positions.shape[0] + 2 * alpha.shape[0] >>> ext_pos = torch.empty(N_ext, 3, dtype=positions.dtype, ... device=positions.device) >>> ext_vel = torch.empty_like(ext_pos) >>> ext_forces = torch.empty_like(ext_pos) >>> M = alpha.shape[0] >>> vf = torch.empty(M, dtype=positions.dtype, device=positions.device) >>> v_sumsq = torch.empty_like(vf) >>> f_sumsq = torch.empty_like(vf) >>> max_norm = torch.empty_like(vf) >>> for step in range(num_steps): ... fire2_step_coord_cell( ... positions, velocities, forces, ... cell, cell_velocities, cell_force, ... batch_idx, alpha, dt, nsteps_inc, ... atom_ptr=atom_ptr, ... ext_atom_ptr=ext_atom_ptr, ... ext_positions=ext_pos, ... ext_velocities=ext_vel, ... ext_forces=ext_forces, ... ext_batch_idx=ext_batch_idx, ... vf=vf, v_sumsq=v_sumsq, ... f_sumsq=f_sumsq, max_norm=max_norm, ... )
Extended Array Interface#
For advanced use cases where you manage packed extended arrays directly.
- nvalchemiops.torch.fire2.fire2_step_extended(ext_positions, ext_velocities, ext_forces, ext_batch_idx, alpha, dt, nsteps_inc, *, vf=None, v_sumsq=None, f_sumsq=None, max_norm=None, delaystep=60, dtgrow=1.05, dtshrink=0.75, alphashrink=0.985, alpha0=0.09, tmax=0.08, tmin=0.005, maxstep=0.1)[source]#
Run FIRE2 directly on pre-packed extended arrays (no pack/unpack).
This is a lower-level API for callers that maintain persistent extended arrays (positions + cell DOFs interleaved). The caller is responsible for packing data into the extended layout before the first call and unpacking results after the last call (or as needed).
This eliminates the per-step pack/unpack overhead that
fire2_step_coord_cellincurs.- Parameters:
ext_positions (torch.Tensor, shape (N_ext, 3)) – Extended position array (atoms + cell DOFs interleaved).
ext_velocities (torch.Tensor, shape (N_ext, 3)) – Extended velocity array.
ext_forces (torch.Tensor, shape (N_ext, 3)) – Extended force array.
ext_batch_idx (torch.Tensor, shape (N_ext,), dtype=int32) – System index for each element in the extended arrays.
alpha (torch.Tensor, shape (M,)) – FIRE2 mixing parameter per system.
dt (torch.Tensor, shape (M,)) – Timestep per system.
nsteps_inc (torch.Tensor, shape (M,), dtype=int32) – Consecutive positive-power step counter per system.
vf (torch.Tensor or None) – Per-system scratch buffers, shape (M,). Allocated internally if None.
v_sumsq (torch.Tensor or None) – Per-system scratch buffers, shape (M,). Allocated internally if None.
f_sumsq (torch.Tensor or None) – Per-system scratch buffers, shape (M,). Allocated internally if None.
max_norm (torch.Tensor or None) – Per-system scratch buffers, shape (M,). Allocated internally if None.
delaystep (int) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.dtgrow (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.dtshrink (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.alphashrink (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.alpha0 (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.tmax (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.tmin (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.maxstep (float) – FIRE2 hyperparameters. See
fire2_step_coord_cellfor details.
- Return type:
None
Notes
Modifies
ext_positions,ext_velocities,alpha,dt, andnsteps_incin-place.