Source code for nvalchemiops.dynamics.utils.constraints

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SHAKE and RATTLE Constraint Algorithms
======================================

GPU-accelerated Warp kernels for holonomic bond constraints in molecular dynamics.

This module provides both mutating (in-place) and non-mutating versions
of each kernel for gradient tracking compatibility.

MATHEMATICAL FORMULATION
========================

SHAKE Algorithm (Position Constraints)
--------------------------------------

For each constrained bond (i, j) with target distance d_ij:

1. Compute constraint violation: σ = |r_ij|² - d_ij²
2. Compute Lagrange multiplier: λ = σ / (2 * (1/m_i + 1/m_j) * r_ij · r_ij_old)
3. Update positions:
   - r_i += λ * r_ij_old / m_i
   - r_j -= λ * r_ij_old / m_j

Iterate until all constraints are satisfied within tolerance.

RATTLE Algorithm (Velocity Constraints)
---------------------------------------

After SHAKE, velocities must also satisfy constraints:

For each constrained bond (i, j):

1. Compute velocity constraint violation: κ = v_ij · r_ij
2. Compute Lagrange multiplier: μ = κ / ((1/m_i + 1/m_j) * |r_ij|²)
3. Update velocities:
   - v_i -= μ * r_ij / m_i
   - v_j += μ * r_ij / m_j

USAGE
=====

SHAKE is applied after position updates but before force calculation:

    # Position update (unconstrained)
    velocity_verlet_position_update(positions, velocities, forces, masses, dt)

    # Apply SHAKE to fix bond lengths
    shake_converged = shake_constraints(
        positions, positions_old, masses,
        bond_pairs, bond_lengths, tolerance=1e-6, max_iter=100
    )

    # Compute forces at constrained positions
    forces = compute_forces(positions)

    # Velocity update
    velocity_verlet_velocity_finalize(velocities, forces, masses, dt)

    # Apply RATTLE to fix velocity constraints
    rattle_constraints(positions, velocities, masses, bond_pairs, bond_lengths)

REFERENCES
==========

- Ryckaert et al. (1977). J. Comput. Phys. 23, 327 (SHAKE)
- Andersen (1983). J. Comput. Phys. 52, 24 (RATTLE)
- Allen & Tildesley (1987). Computer Simulation of Liquids
"""

from __future__ import annotations

import os
from typing import Any

import warp as wp

__all__ = [
    # SHAKE - Mutating
    "shake_constraints",
    "shake_iteration",
    # SHAKE - Non-mutating
    "shake_constraints_out",
    "shake_iteration_out",
    # RATTLE - Mutating
    "rattle_constraints",
    "rattle_iteration",
    # RATTLE - Non-mutating
    "rattle_constraints_out",
    "rattle_iteration_out",
]


# ==============================================================================
# SHAKE Kernels
# ==============================================================================

# Tile block size for cooperative reductions
TILE_DIM = int(os.getenv("NVALCHEMIOPS_DYNAMICS_TILE_DIM", 256))


@wp.kernel
def _shake_iteration_kernel(
    positions: wp.array(dtype=Any),
    positions_old: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    bond_lengths_sq: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single SHAKE iteration for one bond (in-place).

    Updates positions to satisfy bond length constraint.

    Launch Grid
    -----------
    dim = [num_bonds]

    Parameters
    ----------
    positions : wp.array(dtype=wp.vec3f or wp.vec3d)
        Current positions. MODIFIED in-place.
    positions_old : wp.array(dtype=wp.vec3f or wp.vec3d)
        Positions before unconstrained update.
    masses : wp.array
        Atomic masses.
    bond_atom_i : wp.array(dtype=wp.int32)
        First atom index for each bond.
    bond_atom_j : wp.array(dtype=wp.int32)
        Second atom index for each bond.
    bond_lengths_sq : wp.array
        Target bond length squared for each bond.
    max_error : wp.array(dtype=wp.float64)
        Maximum constraint error (atomic max).

    Notes
    -----
    This kernel uses atomic operations on positions, which may cause
    race conditions when atoms participate in multiple bonds.
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    r_i_old = positions_old[i]
    r_j_old = positions_old[j]
    m_i = masses[i]
    m_j = masses[j]
    d_sq = bond_lengths_sq[bond_idx]

    # Current bond vector
    r_ij = r_i - r_j

    # Old bond vector
    r_ij_old = r_i_old - r_j_old

    # Current distance squared
    r_sq = wp.dot(r_ij, r_ij)

    # Constraint violation
    sigma = r_sq - d_sq

    # Track maximum error
    wp.atomic_max(max_error, 0, wp.abs(wp.float64(sigma)))

    # Dot product r_ij · r_ij_old
    dot = wp.dot(r_ij, r_ij_old)

    # Inverse masses
    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    # Lagrange multiplier
    denom = type(dot)(2.0) * (inv_m_i + inv_m_j) * dot
    if wp.abs(denom) > type(denom)(1e-30):
        lam = sigma / denom

        # Position corrections
        corr_i = lam * r_ij_old * inv_m_i
        corr_j = lam * r_ij_old * inv_m_j

        # Apply corrections atomically
        wp.atomic_sub(positions, i, corr_i)
        wp.atomic_add(positions, j, corr_j)


@wp.kernel
def _shake_iteration_tiled_kernel(
    positions: wp.array(dtype=Any),
    positions_old: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    bond_lengths_sq: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single SHAKE iteration with tile reductions for error tracking.

    Updates positions to satisfy bond length constraint using atomic operations
    for position corrections. Uses tile reductions for max error computation.

    Launch Grid: dim = [num_bonds], block_dim = TILE_DIM

    Notes
    -----
    - Atomic operations on positions still required due to data dependencies
    - Tile reductions reduce atomic contention for max_error tracking
    - Data dependencies (atoms in multiple bonds) limit parallelization
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    r_i_old = positions_old[i]
    r_j_old = positions_old[j]
    m_i = masses[i]
    m_j = masses[j]
    d_sq = bond_lengths_sq[bond_idx]

    # Current bond vector
    r_ij = r_i - r_j

    # Old bond vector
    r_ij_old = r_i_old - r_j_old

    # Current distance squared
    r_sq = wp.dot(r_ij, r_ij)

    # Constraint violation
    sigma = r_sq - d_sq

    # Compute local error
    local_error = wp.abs(wp.float64(sigma))

    # Tile reduction for max error
    t_error = wp.tile(local_error)
    max_tile_error = wp.tile_reduce(wp.max, t_error)
    block_max_error = max_tile_error[0]

    # Only first thread in block updates max error
    if bond_idx % TILE_DIM == 0:
        wp.atomic_max(max_error, 0, block_max_error)

    # Dot product r_ij · r_ij_old
    dot = wp.dot(r_ij, r_ij_old)

    # Inverse masses
    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    # Lagrange multiplier
    denom = type(dot)(2.0) * (inv_m_i + inv_m_j) * dot
    if wp.abs(denom) > type(denom)(1e-30):
        lam = sigma / denom

        # Position corrections
        corr_i = lam * r_ij_old * inv_m_i
        corr_j = lam * r_ij_old * inv_m_j

        # Apply corrections atomically (required due to data dependencies)
        wp.atomic_sub(positions, i, corr_i)
        wp.atomic_add(positions, j, corr_j)


@wp.kernel
def _shake_iteration_out_kernel(
    positions: wp.array(dtype=Any),
    positions_old: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    bond_lengths_sq: wp.array(dtype=Any),
    position_corrections: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single SHAKE iteration computing corrections (non-mutating).

    Computes position corrections without applying them.

    Launch Grid
    -----------
    dim = [num_bonds]

    Parameters
    ----------
    positions : wp.array(dtype=wp.vec3f or wp.vec3d)
        Current positions.
    positions_old : wp.array(dtype=wp.vec3f or wp.vec3d)
        Positions before unconstrained update.
    masses : wp.array
        Atomic masses.
    bond_atom_i : wp.array(dtype=wp.int32)
        First atom index for each bond.
    bond_atom_j : wp.array(dtype=wp.int32)
        Second atom index for each bond.
    bond_lengths_sq : wp.array
        Target bond length squared.
    position_corrections : wp.array(dtype=wp.vec3f or wp.vec3d)
        Output corrections for each atom. Shape (N,).
    max_error : wp.array(dtype=wp.float64)
        Maximum constraint error.
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    r_i_old = positions_old[i]
    r_j_old = positions_old[j]
    m_i = masses[i]
    m_j = masses[j]
    d_sq = bond_lengths_sq[bond_idx]

    r_ij = r_i - r_j
    r_ij_old = r_i_old - r_j_old

    r_sq = wp.dot(r_ij, r_ij)
    sigma = r_sq - d_sq

    wp.atomic_max(max_error, 0, wp.abs(wp.float64(sigma)))

    dot = wp.dot(r_ij, r_ij_old)

    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    denom = type(dot)(2.0) * (inv_m_i + inv_m_j) * dot
    if wp.abs(denom) > type(denom)(1e-30):
        lam = sigma / denom

        corr_i = -lam * r_ij_old * inv_m_i
        corr_j = lam * r_ij_old * inv_m_j

        wp.atomic_add(position_corrections, i, corr_i)
        wp.atomic_add(position_corrections, j, corr_j)


@wp.kernel
def _shake_iteration_out_tiled_kernel(
    positions: wp.array(dtype=Any),
    positions_old: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    bond_lengths_sq: wp.array(dtype=Any),
    position_corrections: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single SHAKE iteration with tile reductions (non-mutating).

    Launch Grid: dim = [num_bonds], block_dim = TILE_DIM
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    r_i_old = positions_old[i]
    r_j_old = positions_old[j]
    m_i = masses[i]
    m_j = masses[j]
    d_sq = bond_lengths_sq[bond_idx]

    r_ij = r_i - r_j
    r_ij_old = r_i_old - r_j_old

    r_sq = wp.dot(r_ij, r_ij)
    sigma = r_sq - d_sq

    # Compute local error
    local_error = wp.abs(wp.float64(sigma))

    # Tile reduction for max error
    t_error = wp.tile(local_error)
    max_tile_error = wp.tile_reduce(wp.max, t_error)
    block_max_error = max_tile_error[0]

    # Only first thread in block updates max error
    if bond_idx % TILE_DIM == 0:
        wp.atomic_max(max_error, 0, block_max_error)

    dot = wp.dot(r_ij, r_ij_old)

    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    denom = type(dot)(2.0) * (inv_m_i + inv_m_j) * dot
    if wp.abs(denom) > type(denom)(1e-30):
        lam = sigma / denom

        corr_i = -lam * r_ij_old * inv_m_i
        corr_j = lam * r_ij_old * inv_m_j

        # Atomic operations still required due to data dependencies
        wp.atomic_add(position_corrections, i, corr_i)
        wp.atomic_add(position_corrections, j, corr_j)


# ==============================================================================
# RATTLE Kernels
# ==============================================================================


@wp.kernel
def _rattle_iteration_kernel(
    positions: wp.array(dtype=Any),
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    max_error: wp.array(dtype=wp.float64),
):
    """Single RATTLE iteration for velocity constraints (in-place).

    Ensures velocities are perpendicular to bond vectors.

    Launch Grid
    -----------
    dim = [num_bonds]

    Parameters
    ----------
    positions : wp.array(dtype=wp.vec3f or wp.vec3d)
        Current (constrained) positions.
    velocities : wp.array(dtype=wp.vec3f or wp.vec3d)
        Velocities. MODIFIED in-place.
    masses : wp.array
        Atomic masses.
    bond_atom_i : wp.array(dtype=wp.int32)
        First atom index for each bond.
    bond_atom_j : wp.array(dtype=wp.int32)
        Second atom index for each bond.
    max_error : wp.array(dtype=wp.float64)
        Maximum constraint error.
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    v_i = velocities[i]
    v_j = velocities[j]
    m_i = masses[i]
    m_j = masses[j]

    # Bond vector
    r_ij = r_i - r_j

    # Relative velocity
    v_ij = v_i - v_j

    # Velocity constraint: v_ij · r_ij = 0
    kappa = wp.dot(v_ij, r_ij)

    # Track maximum error
    wp.atomic_max(max_error, 0, wp.abs(wp.float64(kappa)))

    # Bond length squared
    r_sq = wp.dot(r_ij, r_ij)

    # Inverse masses
    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    # Lagrange multiplier
    denom = (inv_m_i + inv_m_j) * r_sq
    if wp.abs(denom) > type(denom)(1e-30):
        mu = kappa / denom

        # Velocity corrections
        corr_i = -mu * r_ij * inv_m_i
        corr_j = mu * r_ij * inv_m_j

        # Apply corrections atomically
        wp.atomic_sub(velocities, i, corr_i)
        wp.atomic_add(velocities, j, corr_j)


@wp.kernel
def _rattle_iteration_out_kernel(
    positions: wp.array(dtype=Any),
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    velocity_corrections: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single RATTLE iteration computing corrections (non-mutating).

    Launch Grid
    -----------
    dim = [num_bonds]
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    v_i = velocities[i]
    v_j = velocities[j]
    m_i = masses[i]
    m_j = masses[j]

    r_ij = r_i - r_j
    v_ij = v_i - v_j

    kappa = wp.dot(v_ij, r_ij)

    wp.atomic_max(max_error, 0, wp.abs(wp.float64(kappa)))

    r_sq = wp.dot(r_ij, r_ij)

    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    denom = (inv_m_i + inv_m_j) * r_sq
    if wp.abs(denom) > type(denom)(1e-30):
        mu = kappa / denom

        corr_i = -mu * r_ij * inv_m_i
        corr_j = mu * r_ij * inv_m_j

        wp.atomic_add(velocity_corrections, i, corr_i)
        wp.atomic_add(velocity_corrections, j, corr_j)


@wp.kernel
def _rattle_iteration_tiled_kernel(
    positions: wp.array(dtype=Any),
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    max_error: wp.array(dtype=wp.float64),
):
    """Single RATTLE iteration with tile reductions for error tracking.

    Updates velocities to satisfy velocity constraints using atomic operations
    for velocity corrections. Uses tile reductions for max error computation.

    Launch Grid: dim = [num_bonds], block_dim = TILE_DIM

    Notes
    -----
    - Atomic operations on velocities still required due to data dependencies
    - Tile reductions reduce atomic contention for max_error tracking
    - Data dependencies (atoms in multiple bonds) limit parallelization
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    v_i = velocities[i]
    v_j = velocities[j]
    m_i = masses[i]
    m_j = masses[j]

    # Bond vector
    r_ij = r_i - r_j

    # Relative velocity
    v_ij = v_i - v_j

    # Velocity constraint: v_ij · r_ij = 0
    kappa = wp.dot(v_ij, r_ij)

    # Compute local error
    local_error = wp.abs(wp.float64(kappa))

    # Tile reduction for max error
    t_error = wp.tile(local_error)
    max_tile_error = wp.tile_reduce(wp.max, t_error)
    block_max_error = max_tile_error[0]

    # Only first thread in block updates max error
    if bond_idx % TILE_DIM == 0:
        wp.atomic_max(max_error, 0, block_max_error)

    # Bond length squared
    r_sq = wp.dot(r_ij, r_ij)

    # Inverse masses
    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    # Lagrange multiplier
    denom = (inv_m_i + inv_m_j) * r_sq
    if wp.abs(denom) > type(denom)(1e-30):
        mu = kappa / denom

        # Velocity corrections
        corr_i = -mu * r_ij * inv_m_i
        corr_j = mu * r_ij * inv_m_j

        # Apply corrections atomically (required due to data dependencies)
        wp.atomic_sub(velocities, i, corr_i)
        wp.atomic_add(velocities, j, corr_j)


@wp.kernel
def _rattle_iteration_out_tiled_kernel(
    positions: wp.array(dtype=Any),
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    bond_atom_i: wp.array(dtype=wp.int32),
    bond_atom_j: wp.array(dtype=wp.int32),
    velocity_corrections: wp.array(dtype=Any),
    max_error: wp.array(dtype=wp.float64),
):
    """Single RATTLE iteration with tile reductions (non-mutating).

    Launch Grid: dim = [num_bonds], block_dim = TILE_DIM
    """
    bond_idx = wp.tid()

    i = bond_atom_i[bond_idx]
    j = bond_atom_j[bond_idx]

    r_i = positions[i]
    r_j = positions[j]
    v_i = velocities[i]
    v_j = velocities[j]
    m_i = masses[i]
    m_j = masses[j]

    r_ij = r_i - r_j
    v_ij = v_i - v_j

    kappa = wp.dot(v_ij, r_ij)

    # Compute local error
    local_error = wp.abs(wp.float64(kappa))

    # Tile reduction for max error
    t_error = wp.tile(local_error)
    max_tile_error = wp.tile_reduce(wp.max, t_error)
    block_max_error = max_tile_error[0]

    # Only first thread in block updates max error
    if bond_idx % TILE_DIM == 0:
        wp.atomic_max(max_error, 0, block_max_error)

    r_sq = wp.dot(r_ij, r_ij)

    inv_m_i = type(m_i)(1.0) / m_i
    inv_m_j = type(m_j)(1.0) / m_j

    denom = (inv_m_i + inv_m_j) * r_sq
    if wp.abs(denom) > type(denom)(1e-30):
        mu = kappa / denom

        corr_i = -mu * r_ij * inv_m_i
        corr_j = mu * r_ij * inv_m_j

        # Atomic operations still required due to data dependencies
        wp.atomic_add(velocity_corrections, i, corr_i)
        wp.atomic_add(velocity_corrections, j, corr_j)


# ==============================================================================
# Kernel Overloads for Explicit Typing
# ==============================================================================

_T = [wp.float32, wp.float64]  # Scalar types
_V = [wp.vec3f, wp.vec3d]  # Vector types

_shake_iteration_kernel_overload = {}
_shake_iteration_out_kernel_overload = {}
_shake_iteration_tiled_kernel_overload = {}
_shake_iteration_out_tiled_kernel_overload = {}
_rattle_iteration_kernel_overload = {}
_rattle_iteration_out_kernel_overload = {}
_rattle_iteration_tiled_kernel_overload = {}
_rattle_iteration_out_tiled_kernel_overload = {}

for t, v in zip(_T, _V):
    _shake_iteration_kernel_overload[v] = wp.overload(
        _shake_iteration_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=wp.float64),
        ],
    )
    _shake_iteration_out_kernel_overload[v] = wp.overload(
        _shake_iteration_out_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=v),
            wp.array(dtype=wp.float64),
        ],
    )
    _rattle_iteration_kernel_overload[v] = wp.overload(
        _rattle_iteration_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.float64),
        ],
    )
    _rattle_iteration_out_kernel_overload[v] = wp.overload(
        _rattle_iteration_out_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=v),
            wp.array(dtype=wp.float64),
        ],
    )
    _shake_iteration_tiled_kernel_overload[v] = wp.overload(
        _shake_iteration_tiled_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=wp.float64),
        ],
    )
    _shake_iteration_out_tiled_kernel_overload[v] = wp.overload(
        _shake_iteration_out_tiled_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=v),
            wp.array(dtype=wp.float64),
        ],
    )
    _rattle_iteration_tiled_kernel_overload[v] = wp.overload(
        _rattle_iteration_tiled_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.float64),
        ],
    )
    _rattle_iteration_out_tiled_kernel_overload[v] = wp.overload(
        _rattle_iteration_out_tiled_kernel,
        [
            wp.array(dtype=v),
            wp.array(dtype=v),
            wp.array(dtype=t),
            wp.array(dtype=wp.int32),
            wp.array(dtype=wp.int32),
            wp.array(dtype=v),
            wp.array(dtype=wp.float64),
        ],
    )


# ==============================================================================
# SHAKE Functional Interface
# ==============================================================================


[docs] def shake_iteration( positions: wp.array, positions_old: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, bond_lengths_sq: wp.array, max_error: wp.array, device: str = None, ) -> wp.array: """ Perform single SHAKE iteration (in-place). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current positions. Shape (N,). MODIFIED in-place. positions_old : wp.array(dtype=wp.vec3f or wp.vec3d) Positions before unconstrained update. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). bond_lengths_sq : wp.array Target bond length squared. Shape (M,). max_error : wp.array(dtype=wp.float64) Output array to store max error. Shape (1,). Zeroed internally before each use. device : str, optional Warp device. Returns ------- wp.array(dtype=wp.float64) Maximum constraint error :math:`|r^2_{ij} - d^2_{ij}|`. Shape (1,). """ if device is None: device = positions.device max_error.zero_() num_bonds = bond_atom_i.shape[0] vec_dtype = positions.dtype wp.launch( _shake_iteration_tiled_kernel_overload[vec_dtype], dim=num_bonds, inputs=[ positions, positions_old, masses, bond_atom_i, bond_atom_j, bond_lengths_sq, max_error, ], device=device, block_dim=TILE_DIM, ) return max_error
[docs] def shake_constraints( positions: wp.array, positions_old: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, bond_lengths_sq: wp.array, max_error: wp.array, num_iter: int = 10, device: str = None, ) -> wp.array: """ Apply SHAKE constraints for a fixed number of iterations (in-place). This function runs a fixed number of SHAKE iterations without convergence checking during the loop. The final error is returned as a wp.array. The caller can check convergence by inspecting the error value. Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current positions. Shape (N,). MODIFIED in-place. positions_old : wp.array(dtype=wp.vec3f or wp.vec3d) Positions before unconstrained update. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). bond_lengths_sq : wp.array Target bond length squared. Shape (M,). max_error : wp.array(dtype=wp.float64) Scratch array for max error tracking. Shape (1,). Caller must pre-allocate. Zeroed internally between iterations. num_iter : int, optional Number of iterations to run. Default 10. Typical values: 3-20 depending on constraint stiffness. device : str, optional Warp device. Returns ------- wp.array(dtype=wp.float64) Final constraint error :math:`|r^2_{ij} - d^2_{ij}|`. Shape (1,). Example ------- >>> # After unconstrained position update >>> max_error = wp.empty(1, dtype=wp.float64, device=device) >>> final_error = shake_constraints( ... positions, positions_old, masses, ... bond_i, bond_j, bond_lengths_sq, max_error, ... num_iter=10 ... ) """ if device is None: device = positions.device for _ in range(num_iter): max_error = shake_iteration( positions, positions_old, masses, bond_atom_i, bond_atom_j, bond_lengths_sq, max_error, device, ) return max_error
[docs] def shake_iteration_out( positions: wp.array, positions_old: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, bond_lengths_sq: wp.array, position_corrections: wp.array, max_error: wp.array, device: str = None, ) -> tuple[wp.array, wp.array]: """ Compute SHAKE corrections without applying (non-mutating). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current positions. Shape (N,). positions_old : wp.array(dtype=wp.vec3f or wp.vec3d) Positions before unconstrained update. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). bond_lengths_sq : wp.array Target bond length squared. Shape (M,). position_corrections : wp.array Output corrections. Shape (N,). Zeroed internally before each use. max_error : wp.array(dtype=wp.float64) Output max error. Shape (1,). Zeroed internally before each use. device : str, optional Warp device. Returns ------- tuple[wp.array, wp.array] (position_corrections, max_error) max_error is shape (1,) """ if device is None: device = positions.device position_corrections.zero_() max_error.zero_() num_bonds = bond_atom_i.shape[0] vec_dtype = positions.dtype wp.launch( _shake_iteration_out_tiled_kernel_overload[vec_dtype], dim=num_bonds, inputs=[ positions, positions_old, masses, bond_atom_i, bond_atom_j, bond_lengths_sq, position_corrections, max_error, ], device=device, block_dim=TILE_DIM, ) return position_corrections, max_error
[docs] def shake_constraints_out( positions: wp.array, positions_old: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, bond_lengths_sq: wp.array, positions_out: wp.array, max_error: wp.array, num_iter: int = 10, device: str = None, ) -> tuple[wp.array, wp.array]: """ Apply SHAKE constraints (non-mutating). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current positions. Shape (N,). positions_old : wp.array(dtype=wp.vec3f or wp.vec3d) Positions before unconstrained update. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). bond_lengths_sq : wp.array Target bond length squared. Shape (M,). positions_out : wp.array Output positions. Shape (N,). Caller must pre-allocate. Caller must copy input positions into this array before calling (e.g., via ``wp.copy(positions_out, positions)``). max_error : wp.array(dtype=wp.float64) Scratch array for max error tracking. Shape (1,). Caller must pre-allocate. Zeroed internally between iterations. num_iter : int, optional Number of iterations to run. Default 10. device : str, optional Warp device. Returns ------- tuple[wp.array, wp.array] (positions_out, final_error) final_error is shape (1,) """ if device is None: device = positions.device error = shake_constraints( positions_out, positions_old, masses, bond_atom_i, bond_atom_j, bond_lengths_sq, max_error, num_iter, device, ) return positions_out, error
# ============================================================================== # RATTLE Functional Interface # ==============================================================================
[docs] def rattle_iteration( positions: wp.array, velocities: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, max_error: wp.array, device: str = None, ) -> wp.array: r""" Perform single RATTLE iteration (in-place). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current (constrained) positions. Shape (N,). velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Velocities. Shape (N,). MODIFIED in-place. masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). max_error : wp.array(dtype=wp.float64) Output array to store max error. Shape (1,). Zeroed internally before each use. device : str, optional Warp device. Returns ------- wp.array(dtype=wp.float64) Maximum velocity constraint error :math:`|v_{ij} \cdot r_{ij}|`. Shape (1,). """ if device is None: device = positions.device max_error.zero_() num_bonds = bond_atom_i.shape[0] vec_dtype = positions.dtype wp.launch( _rattle_iteration_tiled_kernel_overload[vec_dtype], dim=num_bonds, inputs=[positions, velocities, masses, bond_atom_i, bond_atom_j, max_error], device=device, block_dim=TILE_DIM, ) return max_error
[docs] def rattle_constraints( positions: wp.array, velocities: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, max_error: wp.array, num_iter: int = 10, device: str = None, ) -> wp.array: r""" Apply RATTLE velocity constraints for a fixed number of iterations (in-place). This function runs a fixed number of RATTLE iterations without convergence checking during the loop. The final error is returned as a wp.array. The caller can check convergence by inspecting the error value. Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current (constrained) positions. Shape (N,). velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Velocities. Shape (N,). MODIFIED in-place. masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). max_error : wp.array(dtype=wp.float64) Scratch array for max error tracking. Shape (1,). Caller must pre-allocate. Zeroed internally between iterations. num_iter : int, optional Number of iterations to run. Default 10. device : str, optional Warp device. Returns ------- wp.array(dtype=wp.float64) Final constraint error :math:`|v_{ij} \cdot r_{ij}|`. Shape (1,). """ if device is None: device = positions.device for _ in range(num_iter): max_error = rattle_iteration( positions, velocities, masses, bond_atom_i, bond_atom_j, max_error, device ) return max_error
[docs] def rattle_iteration_out( positions: wp.array, velocities: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, velocity_corrections: wp.array, max_error: wp.array, device: str = None, ) -> tuple[wp.array, wp.array]: """ Compute RATTLE corrections without applying (non-mutating). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current (constrained) positions. Shape (N,). velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Velocities. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). velocity_corrections : wp.array Output corrections. Shape (N,). Zeroed internally before each use. max_error : wp.array(dtype=wp.float64) Output max error. Shape (1,). Zeroed internally before each use. device : str, optional Warp device. Returns ------- tuple[wp.array, wp.array] (velocity_corrections, max_error) max_error is shape (1,) """ if device is None: device = positions.device velocity_corrections.zero_() max_error.zero_() num_bonds = bond_atom_i.shape[0] vec_dtype = positions.dtype wp.launch( _rattle_iteration_out_tiled_kernel_overload[vec_dtype], dim=num_bonds, inputs=[ positions, velocities, masses, bond_atom_i, bond_atom_j, velocity_corrections, max_error, ], device=device, block_dim=TILE_DIM, ) return velocity_corrections, max_error
[docs] def rattle_constraints_out( positions: wp.array, velocities: wp.array, masses: wp.array, bond_atom_i: wp.array, bond_atom_j: wp.array, velocities_out: wp.array, max_error: wp.array, num_iter: int = 10, device: str = None, ) -> tuple[wp.array, wp.array]: """ Apply RATTLE velocity constraints (non-mutating). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Current (constrained) positions. Shape (N,). velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Velocities. Shape (N,). masses : wp.array Atomic masses. Shape (N,). bond_atom_i : wp.array(dtype=wp.int32) First atom index for each bond. Shape (M,). bond_atom_j : wp.array(dtype=wp.int32) Second atom index for each bond. Shape (M,). velocities_out : wp.array Output velocities. Shape (N,). Caller must pre-allocate. Caller must copy input velocities into this array before calling (e.g., via ``wp.copy(velocities_out, velocities)``). max_error : wp.array(dtype=wp.float64) Scratch array for max error tracking. Shape (1,). Caller must pre-allocate. Zeroed internally between iterations. num_iter : int, optional Number of iterations to run. Default 10. device : str, optional Warp device. Returns ------- tuple[wp.array, wp.array] (velocities_out, final_error) final_error is shape (1,) """ if device is None: device = positions.device error = rattle_constraints( positions, velocities_out, masses, bond_atom_i, bond_atom_j, max_error, num_iter, device, ) return velocities_out, error