Source code for nvalchemiops.dynamics.integrators.nose_hoover

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Nosé-Hoover Chain (NHC) Thermostat for NVT Ensemble.

This module implements the Nosé-Hoover chain thermostat following the
Martyna-Tobias-Klein (MTK) equations of motion with time-reversible
integration based on Tuckerman et al.

References
----------
- Martyna, Tobias, Klein. J Chem Phys, 101, 4177 (1994)
- Tuckerman et al. J Phys A: Math Gen, 39, 5629 (2006)

The Nosé-Hoover chain equations of motion:
    ṙᵢ = vᵢ
    v̇ᵢ = Fᵢ/mᵢ - η̇₁·vᵢ
    η̇₁ = (2·KE - Ndof·kT) / Q₁
    η̇ₖ = (Qₖ₋₁·η̇²ₖ₋₁ - kT) / Qₖ   for k > 1

Where:
    η   : thermostat chain positions (unitless)
    η̇   : thermostat chain velocities (1/time)
    Q   : thermostat chain masses (energy·time²)
    Ndof: degrees of freedom (typically 3N - 3)
    kT  : target temperature in energy units (k_B = 1)

BATCH MODE
==========

All functions in this module support three execution modes:

**Single System Mode**::

    # Simple position and velocity updates
    nhc_velocity_half_step(velocities, forces, masses, dt)
    nhc_position_update(positions, velocities, dt)

**Batch Mode with batch_idx** (atomic operations)::

    batch_idx = wp.array([0]*N0 + [1]*N1 + [2]*N2, dtype=wp.int32, device="cuda:0")
    dt = wp.array([dt0, dt1, dt2], dtype=wp.float64, device="cuda:0")

    nhc_velocity_half_step(velocities, forces, masses, dt, batch_idx=batch_idx)
    nhc_position_update(positions, velocities, dt, batch_idx=batch_idx)

**Batch Mode with atom_ptr** (sequential per-system)::

    atom_ptr = wp.array([0, N0, N0+N1, N0+N1+N2], dtype=wp.int32, device="cuda:0")

    nhc_velocity_half_step(velocities, forces, masses, dt, atom_ptr=atom_ptr)
    nhc_position_update(positions, velocities, dt, atom_ptr=atom_ptr)
"""

from __future__ import annotations

import os
from typing import Any

import warp as wp

from nvalchemiops.dynamics.utils.launch_helpers import dispatch_family
from nvalchemiops.dynamics.utils.shared_kernels import (
    position_update_families,
    velocity_kick_families,
)
from nvalchemiops.warp_dispatch import validate_out_array

__all__ = [
    # Mutating (in-place) APIs
    "nhc_thermostat_chain_update",
    "nhc_velocity_half_step",
    "nhc_position_update",
    "nhc_compute_chain_energy",
    # Non-mutating (output) APIs
    "nhc_thermostat_chain_update_out",
    "nhc_velocity_half_step_out",
    "nhc_position_update_out",
    # Utility functions
    "nhc_compute_masses",
]


# ==============================================================================
# Constants
# ==============================================================================

# Maximum supported chain length (typically 3-5 in practice)
MAX_CHAIN_LENGTH = 8

# Yoshida-Suzuki Integration Weights
# These weights provide 4th-order accurate, time-reversible integration
# for the thermostat chain propagation.

# 3-step Yoshida-Suzuki weights
_YS3_W0 = 1.0 / (2.0 - 2.0 ** (1.0 / 3.0))
_YS3_W1 = 1.0 - 2.0 * _YS3_W0
YOSHIDA_SUZUKI_3 = [_YS3_W0, _YS3_W1, _YS3_W0]

# 5-step Yoshida-Suzuki weights (higher accuracy)
_YS5_W0 = 1.0 / (4.0 - 4.0 ** (1.0 / 3.0))
_YS5_W1 = _YS5_W0
_YS5_W2 = 1.0 - 4.0 * _YS5_W0
YOSHIDA_SUZUKI_5 = [_YS5_W0, _YS5_W1, _YS5_W2, _YS5_W1, _YS5_W0]


# ==============================================================================
# Diagnostic Kernels
# ==============================================================================

# Tile block size for cooperative reductions
TILE_DIM = int(os.getenv("NVALCHEMIOPS_DYNAMICS_TILE_DIM", 256))


@wp.kernel
def _compute_2ke_kernel(
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    ke2: wp.array(dtype=Any),
):
    """Compute 2*KE = sum(m * v^2) for thermostat forcing.

    Launch Grid
    -----------
    dim = [num_atoms]
    """
    atom_idx = wp.tid()

    v = velocities[atom_idx]
    m = masses[atom_idx]

    v_sq = wp.dot(v, v)

    wp.atomic_add(ke2, 0, type(ke2[0])(m * v_sq))


@wp.kernel
def _compute_2ke_tiled_kernel(
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    ke2: wp.array(dtype=Any),
):
    """Compute 2*KE with tile reductions (single system).

    Launch Grid: dim = [num_atoms], block_dim = TILE_DIM
    """
    atom_idx = wp.tid()

    v = velocities[atom_idx]
    m = masses[atom_idx]

    v_sq = wp.dot(v, v)
    local_2ke = type(ke2[0])(m * v_sq)

    # Convert to tile for block-level reduction
    t = wp.tile(local_2ke)

    # Cooperative sum within block
    s = wp.tile_sum(t)

    # Extract scalar from tile sum
    sum_2ke = s[0]

    # Only first thread in block writes
    if atom_idx % TILE_DIM == 0:
        wp.atomic_add(ke2, 0, sum_2ke)


@wp.kernel(enable_backward=False)
def _batch_compute_2ke_kernel(
    velocities: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    ke2: wp.array(dtype=Any),
):
    """Compute 2*KE per system for batched simulations.

    Launch Grid
    -----------
    dim = [num_atoms_total]
    """
    atom_idx = wp.tid()
    system_id = batch_idx[atom_idx]

    v = velocities[atom_idx]
    m = masses[atom_idx]

    v_sq = wp.dot(v, v)

    wp.atomic_add(ke2, system_id, type(ke2[system_id])(m * v_sq))


@wp.kernel
def _nhc_compute_masses_kernel(
    ndof: wp.array(dtype=wp.int32),
    target_temp: wp.array(dtype=Any),
    tau: wp.array(dtype=Any),
    masses: wp.array(dtype=Any),
):
    """Compute Nosé-Hoover chain masses (single system).

    Computes Q_k values for Nosé-Hoover chain:
        Q_0 = ndof * kT * tau^2
        Q_k = kT * tau^2 for k > 0

    Parameters
    ----------
    ndof : wp.array(dtype=wp.int32)
        Number of degrees of freedom. Shape (1,).
    target_temp : wp.array(dtype=Any)
        Target temperature (kT). Shape (1,).
    tau : wp.array(dtype=Any)
        Time constant. Shape (1,).
    masses : wp.array(dtype=Any)
        Chain masses output. Shape (chain_length,).

    Launch Grid
    -----------
    dim = [chain_length]
    """
    k = wp.tid()
    tau_val = tau[0]
    tau_sq = tau_val * tau_val
    if k == 0:
        masses[k] = type(tau_val)(ndof[0]) * target_temp[0] * tau_sq
    else:
        masses[k] = target_temp[0] * tau_sq


@wp.kernel
def _batch_nhc_compute_masses_kernel(
    ndof: wp.array(dtype=wp.int32),
    target_temp: wp.array(dtype=Any),
    tau: wp.array(dtype=Any),
    masses: wp.array2d(dtype=Any),
):
    """Compute Nosé-Hoover chain masses for batched simulations.

    Parameters
    ----------
    ndof : wp.array(dtype=wp.int32)
        Number of degrees of freedom per system. Shape (num_systems,).
    target_temp : wp.array(dtype=Any)
        Target temperature (kT) per system. Shape (num_systems,).
    tau : wp.array(dtype=Any)
        Time constant per system. Shape (num_systems,).
    masses : wp.array2d(dtype=Any)
        Chain masses output. Shape (num_systems, chain_length).

    Launch Grid
    -----------
    dim = [num_systems, chain_length]
    """
    sys_id, k = wp.tid()
    tau_val = tau[sys_id]
    tau_sq = tau_val * tau_val
    if k == 0:
        masses[sys_id, k] = type(tau_val)(ndof[sys_id]) * target_temp[sys_id] * tau_sq
    else:
        masses[sys_id, k] = target_temp[sys_id] * tau_sq


# ==============================================================================
# Chain Propagation Kernels (Pure Warp Implementation)
# ==============================================================================


@wp.kernel
def _nhc_chain_propagate_kernel(
    eta: wp.array2d(dtype=Any),
    eta_dot: wp.array2d(dtype=Any),
    eta_mass: wp.array2d(dtype=Any),
    ke2: wp.array(dtype=Any),
    target_temp: wp.array(dtype=Any),
    ndof: wp.array(dtype=Any),
    dt_chain: wp.array(dtype=Any),
    chain_length: int,
    vel_scale: wp.array(dtype=Any),
):
    """Propagate Nosé-Hoover chain for one Yoshida-Suzuki sub-step.

    This kernel implements the time-reversible Martyna-Tobias-Klein (MTK)
    integration scheme for Nosé-Hoover chains.

    Algorithm (for each system):
    1. Half-step position update: η_k += 0.5 * dt * η̇_k
    2. Backward sweep: Update η̇ from chain end to start with friction
    3. Compute velocity scale factor: exp(-0.5 * dt * η̇_0)
    4. Forward sweep: Update η̇ from start to chain end with new forces
    5. Half-step position update: η_k += 0.5 * dt * η̇_k

    Launch Grid
    -----------
    dim = [num_systems]

    Parameters
    ----------
    eta : wp.array2d(dtype=Any)
        Chain positions, shape (num_systems, chain_length). MODIFIED in-place.
    eta_dot : wp.array2d(dtype=Any)
        Chain velocities, shape (num_systems, chain_length). MODIFIED in-place.
    eta_mass : wp.array2d(dtype=Any)
        Chain masses, shape (num_systems, chain_length).
    ke2 : wp.array(dtype=Any)
        2*KE for each system, shape (num_systems,). MODIFIED to reflect scaled KE.
    target_temp : wp.array(dtype=wp.float64)
        Target temperature (kT), shape (num_systems,).
    ndof : wp.array(dtype=wp.float64)
        Degrees of freedom, shape (num_systems,).
    dt_chain : wp.array(dtype=wp.float64)
        Time step for this sub-step (weight * dt), shape (num_systems,).
    chain_length : int
        Number of thermostats in the chain.
    vel_scale : wp.array(dtype=wp.float64)
        Output velocity scale factors, shape (num_systems,). MODIFIED.
    """
    sys_id = wp.tid()

    kT = target_temp[sys_id]
    ndof_sys = ndof[sys_id]
    dt = dt_chain[sys_id]
    half_dt = type(dt)(0.5) * dt
    quarter_dt = type(dt)(0.25) * dt
    eighth_dt = type(dt)(0.125) * dt
    ke2_sys = ke2[sys_id]

    # Local copies for chain state (we'll write back at the end)
    # Using fixed-size local arrays for the chain
    eta_local = wp.vector(dtype=eta.dtype, length=MAX_CHAIN_LENGTH)
    eta_dot_local = wp.vector(dtype=eta_dot.dtype, length=MAX_CHAIN_LENGTH)
    eta_mass_local = wp.vector(dtype=eta_mass.dtype, length=MAX_CHAIN_LENGTH)

    # Load chain state
    for k in range(chain_length):
        eta_local[k] = eta[sys_id, k]
        eta_dot_local[k] = eta_dot[sys_id, k]
        eta_mass_local[k] = eta_mass[sys_id, k]

    # ========== Step 1: Half-step position update ==========
    for k in range(chain_length):
        eta_local[k] = eta_local[k] + half_dt * eta_dot_local[k]

    # ========== Step 2: Backward sweep (chain_length-1 down to 0) ==========

    # Update last thermostat (no friction from above)
    if chain_length > 1:
        G_last = (
            eta_mass_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            - kT
        ) / eta_mass_local[chain_length - 1]
        eta_dot_local[chain_length - 1] = (
            eta_dot_local[chain_length - 1] + quarter_dt * G_last
        )

    # Update intermediate thermostats (chain_length-2 down to 1)
    for k in range(chain_length - 2, 0, -1):
        G_k = (
            eta_mass_local[k - 1] * eta_dot_local[k - 1] * eta_dot_local[k - 1] - kT
        ) / eta_mass_local[k]
        # Apply friction from k+1
        scale = wp.exp(-eighth_dt * eta_dot_local[k + 1])
        eta_dot_local[k] = eta_dot_local[k] * scale
        eta_dot_local[k] = eta_dot_local[k] + quarter_dt * G_k
        eta_dot_local[k] = eta_dot_local[k] * scale

    # Update first thermostat (couples to particle KE)
    G_0 = (ke2_sys - ndof_sys * kT) / eta_mass_local[0]

    if chain_length > 1:
        scale = wp.exp(-eighth_dt * eta_dot_local[1])
        eta_dot_local[0] = eta_dot_local[0] * scale
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0
        eta_dot_local[0] = eta_dot_local[0] * scale
    else:
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0

    # ========== Step 3: Compute velocity scale factor ==========
    vel_scale_factor = wp.exp(-half_dt * eta_dot_local[0])
    vel_scale[sys_id] = vel_scale_factor

    # Update ke2 for the forward sweep
    ke2_sys = ke2_sys * vel_scale_factor * vel_scale_factor
    ke2[sys_id] = ke2_sys

    # ========== Step 4: Forward sweep (0 to chain_length-1) ==========

    # Update first thermostat with new force
    G_0_new = (ke2_sys - ndof_sys * kT) / eta_mass_local[0]

    if chain_length > 1:
        scale = wp.exp(-eighth_dt * eta_dot_local[1])
        eta_dot_local[0] = eta_dot_local[0] * scale
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0_new
        eta_dot_local[0] = eta_dot_local[0] * scale
    else:
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0_new

    # Update intermediate thermostats (1 to chain_length-2)
    for k in range(1, chain_length - 1):
        G_k = (
            eta_mass_local[k - 1] * eta_dot_local[k - 1] * eta_dot_local[k - 1] - kT
        ) / eta_mass_local[k]
        scale = wp.exp(-eighth_dt * eta_dot_local[k + 1])
        eta_dot_local[k] = eta_dot_local[k] * scale
        eta_dot_local[k] = eta_dot_local[k] + quarter_dt * G_k
        eta_dot_local[k] = eta_dot_local[k] * scale

    # Update last thermostat
    if chain_length > 1:
        G_last = (
            eta_mass_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            - kT
        ) / eta_mass_local[chain_length - 1]
        eta_dot_local[chain_length - 1] = (
            eta_dot_local[chain_length - 1] + quarter_dt * G_last
        )

    # ========== Step 5: Second half-step position update ==========
    for k in range(chain_length):
        eta_local[k] = eta_local[k] + half_dt * eta_dot_local[k]

    # ========== Write back chain state ==========
    for k in range(chain_length):
        eta[sys_id, k] = eta_local[k]
        eta_dot[sys_id, k] = eta_dot_local[k]


@wp.kernel
def _nhc_chain_propagate_single_kernel(
    eta: wp.array(dtype=Any),
    eta_dot: wp.array(dtype=Any),
    eta_mass: wp.array(dtype=Any),
    ke2: wp.array(dtype=Any),
    target_temp: wp.array(dtype=Any),
    ndof: wp.array(dtype=Any),
    dt_chain: wp.array(dtype=Any),
    chain_length: int,
    vel_scale: wp.array(dtype=Any),
):
    """Propagate Nosé-Hoover chain for single system (non-batched).

    Same algorithm as batched version but for 1D arrays.

    Parameters
    ----------
    eta : wp.array(dtype=Any)
        Chain positions, shape (chain_length,). MODIFIED in-place.
    eta_dot : wp.array(dtype=Any)
        Chain velocities, shape (chain_length,). MODIFIED in-place.
    eta_mass : wp.array(dtype=Any)
        Chain masses, shape (chain_length,).
    ke2 : wp.array(dtype=Any)
        2*KE for the system, shape (1,). MODIFIED to reflect scaled KE.
    target_temp : wp.array(dtype=Any)
        Target temperature (kT), shape (1,).
    ndof : wp.array(dtype=Any)
        Degrees of freedom, shape (1,).
    dt_chain : wp.array(dtype=Any)
        Time step for this sub-step (weight * dt), shape (1,).
    chain_length : int
        Number of thermostats in the chain.
    vel_scale : wp.array(dtype=Any)
        Output velocity scale factors, shape (1,). MODIFIED.

    Launch Grid
    -----------
    dim = [1]
    """
    kT = target_temp[0]
    ndof_sys = ndof[0]
    dt = dt_chain[0]
    half_dt = type(dt)(0.5) * dt
    quarter_dt = type(dt)(0.25) * dt
    eighth_dt = type(dt)(0.125) * dt
    ke2_sys = ke2[0]

    # Local copies for chain state
    eta_local = wp.vector(dtype=eta.dtype, length=MAX_CHAIN_LENGTH)
    eta_dot_local = wp.vector(dtype=eta_dot.dtype, length=MAX_CHAIN_LENGTH)
    eta_mass_local = wp.vector(dtype=eta_mass.dtype, length=MAX_CHAIN_LENGTH)

    # Load chain state
    for k in range(chain_length):
        eta_local[k] = eta[k]
        eta_dot_local[k] = eta_dot[k]
        eta_mass_local[k] = eta_mass[k]

    # ========== Step 1: Half-step position update ==========
    for k in range(chain_length):
        eta_local[k] = eta_local[k] + half_dt * eta_dot_local[k]

    # ========== Step 2: Backward sweep ==========
    if chain_length > 1:
        G_last = (
            eta_mass_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            - kT
        ) / eta_mass_local[chain_length - 1]
        eta_dot_local[chain_length - 1] = (
            eta_dot_local[chain_length - 1] + quarter_dt * G_last
        )

    for k in range(chain_length - 2, 0, -1):
        G_k = (
            eta_mass_local[k - 1] * eta_dot_local[k - 1] * eta_dot_local[k - 1] - kT
        ) / eta_mass_local[k]
        scale = wp.exp(-eighth_dt * eta_dot_local[k + 1])
        eta_dot_local[k] = eta_dot_local[k] * scale
        eta_dot_local[k] = eta_dot_local[k] + quarter_dt * G_k
        eta_dot_local[k] = eta_dot_local[k] * scale

    G_0 = (ke2_sys - ndof_sys * kT) / eta_mass_local[0]

    if chain_length > 1:
        scale = wp.exp(-eighth_dt * eta_dot_local[1])
        eta_dot_local[0] = eta_dot_local[0] * scale
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0
        eta_dot_local[0] = eta_dot_local[0] * scale
    else:
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0

    # ========== Step 3: Compute velocity scale factor ==========
    vel_scale_factor = wp.exp(-half_dt * eta_dot_local[0])
    vel_scale[0] = vel_scale_factor
    ke2_sys = ke2_sys * vel_scale_factor * vel_scale_factor
    ke2[0] = ke2_sys
    # ========== Step 4: Forward sweep ==========
    G_0_new = (ke2_sys - ndof_sys * kT) / eta_mass_local[0]
    if chain_length > 1:
        scale = wp.exp(-eighth_dt * eta_dot_local[1])
        eta_dot_local[0] = eta_dot_local[0] * scale
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0_new
        eta_dot_local[0] = eta_dot_local[0] * scale
    else:
        eta_dot_local[0] = eta_dot_local[0] + quarter_dt * G_0_new
    for k in range(1, chain_length - 1):
        G_k = (
            eta_mass_local[k - 1] * eta_dot_local[k - 1] * eta_dot_local[k - 1] - kT
        ) / eta_mass_local[k]
        scale = wp.exp(-eighth_dt * eta_dot_local[k + 1])
        eta_dot_local[k] = eta_dot_local[k] * scale
        eta_dot_local[k] = eta_dot_local[k] + quarter_dt * G_k
        eta_dot_local[k] = eta_dot_local[k] * scale
    if chain_length > 1:
        G_last = (
            eta_mass_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            * eta_dot_local[chain_length - 2]
            - kT
        ) / eta_mass_local[chain_length - 1]
        eta_dot_local[chain_length - 1] = (
            eta_dot_local[chain_length - 1] + quarter_dt * G_last
        )

    # ========== Step 5: Second half-step position update ==========
    for k in range(chain_length):
        eta_local[k] = eta_local[k] + half_dt * eta_dot_local[k]

    for k in range(chain_length):
        eta[k] = eta_local[k]
        eta_dot[k] = eta_dot_local[k]


# ==============================================================================
# Chain Energy Kernels
# ==============================================================================


@wp.kernel
def _nhc_compute_chain_energy_kernel(
    eta: wp.array(dtype=Any),
    eta_dot: wp.array(dtype=Any),
    eta_mass: wp.array(dtype=Any),
    target_temp: wp.array(dtype=Any),
    ndof: wp.array(dtype=Any),
    chain_length: int,
    ke_chain: wp.array(dtype=Any),
    pe_chain: wp.array(dtype=Any),
):
    """Compute NHC kinetic and potential energy for single system.

    KE_chain = sum_k 0.5 * Q_k * η̇_k²
    PE_chain = ndof * kT * η_0 + kT * sum_{k>0} η_k

    Parameters
    ----------
    eta : wp.array(dtype=Any)
        Chain positions, shape (chain_length,).
    eta_dot : wp.array(dtype=Any)
        Chain velocities, shape (chain_length,).
    eta_mass : wp.array(dtype=Any)
        Chain masses, shape (chain_length,).
    target_temp : wp.array(dtype=Any)
        Target temperature (kT), shape (1,).
    ndof : wp.array(dtype=Any)
        Degrees of freedom, shape (1,).
    chain_length : int
        Number of thermostats in the chain.
    ke_chain : wp.array(dtype=Any)
        Kinetic energy of the chain, shape (1,).
    pe_chain : wp.array(dtype=Any)
        Potential energy of the chain, shape (1,).

    Launch Grid
    -----------
    dim = [1]
    """
    kT = target_temp[0]
    ndof_sys = ndof[0]

    ke = type(eta[0])(0.0)
    pe = type(eta[0])(0.0)

    for k in range(chain_length):
        ke = ke + type(eta[0])(0.5) * eta_mass[k] * eta_dot[k] * eta_dot[k]
        if k == 0:
            pe = pe + type(eta[0])(ndof_sys) * type(eta[0])(kT) * eta[k]
        else:
            pe = pe + type(eta[0])(kT) * eta[k]

    ke_chain[0] = type(eta[0])(ke)
    pe_chain[0] = type(eta[0])(pe)


@wp.kernel
def _batch_nhc_compute_chain_energy_kernel(
    eta: wp.array2d(dtype=Any),
    eta_dot: wp.array2d(dtype=Any),
    eta_mass: wp.array2d(dtype=Any),
    target_temp: wp.array(dtype=Any),
    ndof: wp.array(dtype=Any),
    chain_length: int,
    ke_chain: wp.array(dtype=Any),
    pe_chain: wp.array(dtype=Any),
):
    """Compute NHC kinetic and potential energy for batched systems.

    Parameters
    ----------
    eta : wp.array2d(dtype=Any)
        Chain positions, shape (num_systems, chain_length).
    eta_dot : wp.array2d(dtype=Any)
        Chain velocities, shape (num_systems, chain_length).
    eta_mass : wp.array2d(dtype=Any)
        Chain masses, shape (num_systems, chain_length).
    target_temp : wp.array(dtype=Any)
        Target temperature (kT), shape (num_systems,).
    ndof : wp.array(dtype=Any)
        Degrees of freedom, shape (num_systems,).
    chain_length : int
        Number of thermostats in the chain.
    ke_chain : wp.array(dtype=Any)
        Kinetic energy of the chain, shape (num_systems,).
    pe_chain : wp.array(dtype=Any)
        Potential energy of the chain, shape (num_systems,).

    Launch Grid
    -----------
    dim = [num_systems]
    """
    sys_id = wp.tid()

    # Use eta[sys_id, 0] to get a scalar for type inference (not eta[0] which returns a row)
    kT = type(eta[sys_id, 0])(target_temp[sys_id])
    ndof_sys = type(eta[sys_id, 0])(ndof[sys_id])

    ke = type(kT)(0.0)
    pe = type(kT)(0.0)

    for k in range(chain_length):
        ke = (
            ke
            + type(kT)(0.5)
            * eta_mass[sys_id, k]
            * eta_dot[sys_id, k]
            * eta_dot[sys_id, k]
        )
        if k == 0:
            pe = pe + ndof_sys * kT * eta[sys_id, k]
        else:
            pe = pe + kT * eta[sys_id, k]

    ke_chain[sys_id] = ke
    pe_chain[sys_id] = pe


# ==============================================================================
# Velocity Scaling Kernels
# ==============================================================================


@wp.kernel
def _scale_velocities_kernel(
    velocities: wp.array(dtype=Any),
    scale_factor: wp.array(dtype=Any),
):
    """Scale all velocities by a single factor.

    Launch Grid
    -----------
    dim = [num_atoms]
    """
    atom_idx = wp.tid()
    v = velocities[atom_idx]
    s = type(v[0])(scale_factor[0])

    new_vx = type(v[0])(v[0] * s)
    new_vy = type(v[1])(v[1] * s)
    new_vz = type(v[2])(v[2] * s)

    velocities[atom_idx] = type(v)(new_vx, new_vy, new_vz)


@wp.kernel
def _batch_scale_velocities_kernel(
    velocities: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    scale_factors: wp.array(dtype=Any),
):
    """Scale velocities with per-system factors.

    Launch Grid
    -----------
    dim = [num_atoms_total]
    """
    atom_idx = wp.tid()
    system_id = batch_idx[atom_idx]
    v = velocities[atom_idx]
    s = type(v[0])(scale_factors[system_id])

    new_vx = type(v[0])(v[0] * s)
    new_vy = type(v[1])(v[1] * s)
    new_vz = type(v[2])(v[2] * s)

    velocities[atom_idx] = type(v)(new_vx, new_vy, new_vz)


@wp.kernel
def _scale_velocities_out_kernel(
    velocities: wp.array(dtype=Any),
    scale_factor: wp.array(dtype=Any),
    velocities_out: wp.array(dtype=Any),
):
    """Scale velocities to output array.

    Launch Grid
    -----------
    dim = [num_atoms]
    """
    atom_idx = wp.tid()
    v = velocities[atom_idx]
    s = type(v[0])(scale_factor[0])

    new_vx = type(v[0])(v[0] * s)
    new_vy = type(v[1])(v[1] * s)
    new_vz = type(v[2])(v[2] * s)

    velocities_out[atom_idx] = type(v)(new_vx, new_vy, new_vz)


@wp.kernel
def _batch_scale_velocities_out_kernel(
    velocities: wp.array(dtype=Any),
    batch_idx: wp.array(dtype=wp.int32),
    scale_factors: wp.array(dtype=Any),
    velocities_out: wp.array(dtype=Any),
):
    """Scale velocities with per-system factors to output array.

    Launch Grid
    -----------
    dim = [num_atoms_total]
    """
    atom_idx = wp.tid()
    system_id = batch_idx[atom_idx]
    v = velocities[atom_idx]
    s = type(v[0])(scale_factors[system_id])

    new_vx = type(v[0])(v[0] * s)
    new_vy = type(v[1])(v[1] * s)
    new_vz = type(v[2])(v[2] * s)

    velocities_out[atom_idx] = type(v)(new_vx, new_vy, new_vz)


# ==============================================================================
# Multiply Scale Factors Kernel
# ==============================================================================


@wp.kernel
def _multiply_scale_factors_kernel(
    total_scale: wp.array(dtype=Any),
    step_scale: wp.array(dtype=Any),
):
    """Multiply total scale factors by step scale factors.

    Launch Grid
    -----------
    dim = [num_systems]
    """
    sys_id = wp.tid()
    total_scale[sys_id] = total_scale[sys_id] * step_scale[sys_id]


# Overloads for _multiply_scale_factors_kernel
_multiply_scale_factors_kernel_overload = {}
for t in [wp.float32, wp.float64]:
    _multiply_scale_factors_kernel_overload[t] = wp.overload(
        _multiply_scale_factors_kernel,
        [wp.array(dtype=t), wp.array(dtype=t)],
    )


# ==============================================================================
# Kernel Overloads for Explicit Typing
# ==============================================================================
# These overloads provide explicit type annotations for each kernel, avoiding
# Warp's type inference issues that can occur with dtype=Any parameters.

_T = [wp.float32, wp.float64]  # Scalar types
_V = [wp.vec3f, wp.vec3d]  # Vector types

# Diagnostic kernel overloads
_compute_2ke_kernel_overload = {}
_compute_2ke_tiled_kernel_overload = {}
_batch_compute_2ke_kernel_overload = {}

# Velocity scaling kernel overloads
_scale_velocities_kernel_overload = {}
_batch_scale_velocities_kernel_overload = {}
_scale_velocities_out_kernel_overload = {}
_batch_scale_velocities_out_kernel_overload = {}

# Compute masses kernel overloads (keyed by scalar type)
_nhc_compute_masses_kernel_overload = {}
_batch_nhc_compute_masses_kernel_overload = {}

for t in _T:
    _nhc_compute_masses_kernel_overload[t] = wp.overload(
        _nhc_compute_masses_kernel,
        [
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=t),
            wp.array(dtype=t),
        ],
    )
    _batch_nhc_compute_masses_kernel_overload[t] = wp.overload(
        _batch_nhc_compute_masses_kernel,
        [
            wp.array(dtype=wp.int32),
            wp.array(dtype=t),
            wp.array(dtype=t),
            wp.array2d(dtype=t),
        ],
    )

# Create overloads for all combinations of velocity type (v) and output type (t_out)
# The ke2 output needs to match the chain state dtype, not the velocity dtype
for v in _V:
    # Determine the scalar type for masses from the vector type
    t_mass = wp.float32 if v == wp.vec3f else wp.float64
    for t_out in _T:
        # Key by (velocity_type, output_type)
        _compute_2ke_kernel_overload[(v, t_out)] = wp.overload(
            _compute_2ke_kernel,
            [wp.array(dtype=v), wp.array(dtype=t_mass), wp.array(dtype=t_out)],
        )
        _compute_2ke_tiled_kernel_overload[(v, t_out)] = wp.overload(
            _compute_2ke_tiled_kernel,
            [wp.array(dtype=v), wp.array(dtype=t_mass), wp.array(dtype=t_out)],
        )
        _batch_compute_2ke_kernel_overload[(v, t_out)] = wp.overload(
            _batch_compute_2ke_kernel,
            [
                wp.array(dtype=v),
                wp.array(dtype=t_mass),
                wp.array(dtype=wp.int32),
                wp.array(dtype=t_out),
            ],
        )

# Create velocity scaling kernel overloads for all combinations of velocity type and scale factor dtype
for v in _V:
    for t_scale in _T:
        # Key by (velocity_type, scale_factor_dtype) to support mixed dtypes
        _scale_velocities_kernel_overload[(v, t_scale)] = wp.overload(
            _scale_velocities_kernel,
            [wp.array(dtype=v), wp.array(dtype=t_scale)],
        )
        _batch_scale_velocities_kernel_overload[(v, t_scale)] = wp.overload(
            _batch_scale_velocities_kernel,
            [wp.array(dtype=v), wp.array(dtype=wp.int32), wp.array(dtype=t_scale)],
        )
        _scale_velocities_out_kernel_overload[(v, t_scale)] = wp.overload(
            _scale_velocities_out_kernel,
            [wp.array(dtype=v), wp.array(dtype=t_scale), wp.array(dtype=v)],
        )
        _batch_scale_velocities_out_kernel_overload[(v, t_scale)] = wp.overload(
            _batch_scale_velocities_out_kernel,
            [
                wp.array(dtype=v),
                wp.array(dtype=wp.int32),
                wp.array(dtype=t_scale),
                wp.array(dtype=v),
            ],
        )

# NHC chain propagation kernels - keyed by scalar type
_nhc_chain_propagate_single_kernel_overload = {}
for t in _T:
    _nhc_chain_propagate_single_kernel_overload[t] = wp.overload(
        _nhc_chain_propagate_single_kernel,
        [
            wp.array(dtype=t),  # eta
            wp.array(dtype=t),  # eta_dot
            wp.array(dtype=t),  # eta_mass
            wp.array(dtype=t),  # ke2
            wp.array(dtype=t),  # target_temp
            wp.array(dtype=t),  # ndof
            wp.array(dtype=t),  # dt_chain
            wp.int32,  # chain_length
            wp.array(dtype=t),  # vel_scale
        ],
    )

# NHC chain energy kernels - keyed by scalar type
_nhc_compute_chain_energy_kernel_overload = {}
_batch_nhc_compute_chain_energy_kernel_overload = {}
for t in _T:
    _nhc_compute_chain_energy_kernel_overload[t] = wp.overload(
        _nhc_compute_chain_energy_kernel,
        [
            wp.array(dtype=t),  # eta
            wp.array(dtype=t),  # eta_dot
            wp.array(dtype=t),  # eta_mass
            wp.array(dtype=t),  # target_temp
            wp.array(dtype=t),  # ndof
            wp.int32,  # chain_length
            wp.array(dtype=t),  # ke_chain
            wp.array(dtype=t),  # pe_chain
        ],
    )
    _batch_nhc_compute_chain_energy_kernel_overload[t] = wp.overload(
        _batch_nhc_compute_chain_energy_kernel,
        [
            wp.array2d(dtype=t),  # eta
            wp.array2d(dtype=t),  # eta_dot
            wp.array2d(dtype=t),  # eta_mass
            wp.array(dtype=t),  # target_temp
            wp.array(dtype=t),  # ndof
            wp.int32,  # chain_length
            wp.array(dtype=t),  # ke_chain
            wp.array(dtype=t),  # pe_chain
        ],
    )


# ==============================================================================
# Functional Interfaces
# ==============================================================================


[docs] def nhc_compute_masses( ndof: wp.array, target_temp: wp.array, tau: wp.array, chain_length: int, masses: wp.array, num_systems: int = 1, device: str = None, dtype=wp.float64, ) -> wp.array: """Compute Nosé-Hoover chain masses using GPU kernel. Computes Q_k values for Nosé-Hoover chain: Q_0 = ndof * kT * tau^2 Q_k = kT * tau^2 for k > 0 Parameters ---------- ndof : wp.array(dtype=wp.int32) Number of degrees of freedom per system. Shape (1,) for single system, (num_systems,) for batched. target_temp : wp.array Target temperature (kT) per system. Shape (1,) for single system, (num_systems,) for batched. tau : wp.array Time constant per system. Shape (1,) for single system, (num_systems,) for batched. chain_length : int Number of thermostats in the chain. masses : wp.array Chain masses output. Caller must pre-allocate. Shape (chain_length,) for single system, (num_systems, chain_length) for batched. num_systems : int, optional Number of systems for batched mode. Default: 1. device : str, optional Warp device. If None, inferred from masses. dtype : dtype, optional Data type for the masses. Default: wp.float64. Returns ------- wp.array Chain masses. Shape (chain_length,) for single system, (num_systems, chain_length) for batched. """ if device is None: device = masses.device is_batched = masses.ndim == 2 # Select overload based on dtype scalar_type = dtype if is_batched: wp.launch( _batch_nhc_compute_masses_kernel_overload[scalar_type], dim=(num_systems, chain_length), inputs=[ndof, target_temp, tau, masses], device=device, ) else: wp.launch( _nhc_compute_masses_kernel_overload[scalar_type], dim=chain_length, inputs=[ndof, target_temp, tau, masses], device=device, ) return masses
[docs] def nhc_thermostat_chain_update( velocities: wp.array, masses: wp.array, eta: wp.array, eta_dot: wp.array, eta_mass: wp.array, target_temp: wp.array, dt: wp.array, ndof: wp.array, ke2: wp.array, total_scale: wp.array, step_scale: wp.array, dt_chain: wp.array, nloops: int = 1, batch_idx: wp.array = None, num_systems: int = 1, device: str = None, ) -> None: """ Propagate Nosé-Hoover chain and scale velocities (in-place). Uses Yoshida-Suzuki factorization for time-reversible integration. All computations are performed on GPU using Warp kernels. Parameters ---------- velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). MODIFIED in-place. masses : wp.array(dtype=wp.float32 or wp.float64) Atomic masses. Shape (N,). eta : wp.array(dtype=wp.float64) Thermostat chain positions. Non-batched: Shape (chain_length,). Batched: Shape (num_systems, chain_length) as wp.array2d. MODIFIED in-place. eta_dot : wp.array(dtype=wp.float64) Thermostat chain velocities. Same shape as eta. MODIFIED in-place. eta_mass : wp.array(dtype=wp.float64) Thermostat chain masses. Same shape as eta. target_temp : wp.array(dtype=wp.float64) Target temperature (kT). Shape (1,) or (num_systems,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (num_systems,). ndof : wp.array(dtype=wp.float64) Degrees of freedom. Shape (1,) or (num_systems,). ke2 : wp.array Scratch array for 2*KE computation. Zeroed internally before each use. Shape (1,) for single system, (num_systems,) for batched. total_scale : wp.array Scratch array for accumulated velocity scale factor. Must be initialized to ones by caller (wp.ones). Shape (1,) for single system, (num_systems,) for batched. step_scale : wp.array Scratch array for per-step velocity scale factor. Shape (1,) for single system, (num_systems,) for batched. dt_chain : wp.array Scratch array for weighted time steps. Shape (1,) for single system, (num_systems,) for batched. nloops : int, optional Number of Yoshida-Suzuki integration sub-steps. Default: 1. Use nloops=3 or 5 for higher accuracy. batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. Required for batched mode. num_systems : int, optional Number of systems for batched mode. Default: 1. device : str, optional Warp device. If None, inferred from velocities. """ if device is None: device = velocities.device num_atoms = velocities.shape[0] is_batched = batch_idx is not None # Get Yoshida-Suzuki weights if nloops == 1: weights = [1.0] elif nloops == 3: weights = YOSHIDA_SUZUKI_3 elif nloops == 5: weights = YOSHIDA_SUZUKI_5 else: # Simple equal weights for other values weights = [1.0 / nloops] * nloops # Determine chain length if is_batched: chain_length = eta.shape[1] else: chain_length = eta.shape[0] if chain_length > MAX_CHAIN_LENGTH: raise ValueError( f"Chain length {chain_length} exceeds maximum {MAX_CHAIN_LENGTH}" ) # Compute 2*KE - ke2 is zeroed internally before each use vec_dtype = velocities.dtype chain_dtype = eta.dtype n_scale = num_systems if is_batched else 1 ke2.zero_() if is_batched: wp.launch( _batch_compute_2ke_kernel_overload[(vec_dtype, chain_dtype)], dim=num_atoms, inputs=[velocities, masses, batch_idx, ke2], device=device, ) else: wp.launch( _compute_2ke_kernel_overload[(vec_dtype, chain_dtype)], dim=num_atoms, inputs=[velocities, masses, ke2], device=device, ) # Run Yoshida-Suzuki sub-steps for w in weights: # Compute weighted time step: dt_chain = w * dt if is_batched: # For batched case, we need to scale each system's dt by the weight _compute_weighted_dt(dt, dt_chain, w, num_systems, device) else: _compute_weighted_dt(dt, dt_chain, w, 1, device) # Propagate chain if is_batched: wp.launch( _nhc_chain_propagate_kernel, dim=num_systems, inputs=[ eta, eta_dot, eta_mass, ke2, target_temp, ndof, dt_chain, chain_length, step_scale, ], device=device, ) else: wp.launch( _nhc_chain_propagate_single_kernel_overload[chain_dtype], dim=1, inputs=[ eta, eta_dot, eta_mass, ke2, target_temp, ndof, dt_chain, chain_length, step_scale, ], device=device, ) # Accumulate total scale factor wp.launch( _multiply_scale_factors_kernel_overload[chain_dtype], dim=n_scale, inputs=[total_scale, step_scale], device=device, ) # Scale velocities if is_batched: wp.launch( _batch_scale_velocities_kernel_overload[(vec_dtype, chain_dtype)], dim=num_atoms, inputs=[velocities, batch_idx, total_scale], device=device, ) else: wp.launch( _scale_velocities_kernel_overload[(vec_dtype, chain_dtype)], dim=num_atoms, inputs=[velocities, total_scale], device=device, )
@wp.kernel def _compute_weighted_dt_kernel( dt: wp.array(dtype=Any), dt_chain: wp.array(dtype=Any), weight: Any, ): """Compute weighted time step: dt_chain = weight * dt. Launch Grid ----------- dim = [num_systems] """ sys_id = wp.tid() dt_chain[sys_id] = type(dt_chain[sys_id])(dt[sys_id]) * type(dt_chain[sys_id])( weight ) # Overloads for _compute_weighted_dt_kernel - support all combinations of dt and dt_chain dtypes _compute_weighted_dt_kernel_overload = {} for t_in in _T: for t_out in _T: _compute_weighted_dt_kernel_overload[(t_in, t_out)] = wp.overload( _compute_weighted_dt_kernel, [ wp.array(dtype=t_in), wp.array(dtype=t_out), t_out, ], # weight uses output type ) def _compute_weighted_dt( dt: wp.array, dt_chain: wp.array, weight: float, num_systems: int, device: str, ): """Helper to compute weighted dt using appropriate kernel overload.""" dt_dtype = dt.dtype dt_chain_dtype = dt_chain.dtype weight_typed = dt_chain_dtype(weight) wp.launch( _compute_weighted_dt_kernel_overload[(dt_dtype, dt_chain_dtype)], dim=num_systems, inputs=[dt, dt_chain, weight_typed], device=device, )
[docs] def nhc_thermostat_chain_update_out( velocities: wp.array, masses: wp.array, eta: wp.array, eta_dot: wp.array, eta_mass: wp.array, target_temp: wp.array, dt: wp.array, ndof: wp.array, ke2: wp.array, total_scale: wp.array, step_scale: wp.array, dt_chain: wp.array, velocities_out: wp.array, eta_out: wp.array, eta_dot_out: wp.array, nloops: int = 1, batch_idx: wp.array = None, num_systems: int = 1, device: str = None, ) -> tuple[wp.array, wp.array, wp.array]: """ Propagate Nosé-Hoover chain and scale velocities (non-mutating). Parameters ---------- velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). masses : wp.array(dtype=wp.float32 or wp.float64) Atomic masses. Shape (N,). eta : wp.array(dtype=wp.float64) Thermostat chain positions. Shape (M,) or (B, M). eta_dot : wp.array(dtype=wp.float64) Thermostat chain velocities. Shape (M,) or (B, M). eta_mass : wp.array(dtype=wp.float64) Thermostat chain masses. Shape (M,) or (B, M). target_temp : wp.array(dtype=wp.float64) Target temperature (kT). Shape (1,) or (B,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (B,). ndof : wp.array(dtype=wp.float64) Degrees of freedom. Shape (1,) or (B,). ke2 : wp.array Scratch array for 2*KE computation. Zeroed internally before each use. Shape (1,) for single system, (num_systems,) for batched. total_scale : wp.array Scratch array for accumulated velocity scale factor. Must be initialized to ones by caller (wp.ones). Shape (1,) for single system, (num_systems,) for batched. step_scale : wp.array Scratch array for per-step velocity scale factor. Shape (1,) for single system, (num_systems,) for batched. dt_chain : wp.array Scratch array for weighted time steps. Shape (1,) for single system, (num_systems,) for batched. velocities_out : wp.array Output velocities. Must be pre-allocated with same shape/dtype/device as velocities. eta_out : wp.array Output eta. Must be pre-allocated with same shape/dtype/device as eta. eta_dot_out : wp.array Output eta_dot. Must be pre-allocated with same shape/dtype/device as eta_dot. nloops : int, optional Number of Yoshida-Suzuki integration sub-steps. Default: 1. batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. Required for batched mode. num_systems : int, optional Number of systems for batched mode. Default: 1. device : str, optional Warp device. If None, inferred from velocities. Returns ------- tuple[wp.array, wp.array, wp.array] (velocities_out, eta_out, eta_dot_out) """ if device is None: device = velocities.device validate_out_array(velocities_out, velocities, "velocities_out") validate_out_array(eta_out, eta, "eta_out") validate_out_array(eta_dot_out, eta_dot, "eta_dot_out") # Copy inputs to outputs wp.copy(velocities_out, velocities) wp.copy(eta_out, eta) wp.copy(eta_dot_out, eta_dot) # Run in-place update on copies nhc_thermostat_chain_update( velocities_out, masses, eta_out, eta_dot_out, eta_mass, target_temp, dt, ndof, ke2, total_scale, step_scale, dt_chain, nloops=nloops, batch_idx=batch_idx, num_systems=num_systems, device=device, ) return velocities_out, eta_out, eta_dot_out
[docs] def nhc_velocity_half_step( velocities: wp.array, forces: wp.array, masses: wp.array, dt: wp.array, batch_idx: wp.array = None, atom_ptr: wp.array = None, device: str = None, ) -> None: """ Half-step velocity update (in-place). v += 0.5 * (F/m) * dt Parameters ---------- velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). MODIFIED in-place. forces : wp.array(dtype=wp.vec3f or wp.vec3d) Forces on atoms. Shape (N,). masses : wp.array(dtype=wp.float32 or wp.float64) Atomic masses. Shape (N,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (B,). batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. For batched mode (atomic operations). atom_ptr : wp.array(dtype=wp.int32), optional CSR-style pointers. Shape (num_systems + 1,). For batched mode (sequential per-system). device : str, optional Warp device. If None, inferred from velocities. """ dispatch_family( velocity_kick_families, velocities, batch_idx=batch_idx, atom_ptr=atom_ptr, device=device, inputs_single=[velocities, forces, masses, dt, velocities], inputs_batch=[velocities, forces, masses, batch_idx, dt, velocities], inputs_ptr=[velocities, forces, masses, atom_ptr, dt, velocities], )
[docs] def nhc_velocity_half_step_out( velocities: wp.array, forces: wp.array, masses: wp.array, dt: wp.array, velocities_out: wp.array, batch_idx: wp.array = None, atom_ptr: wp.array = None, device: str = None, ) -> wp.array: """ Half-step velocity update (non-mutating). Parameters ---------- velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). forces : wp.array(dtype=wp.vec3f or wp.vec3d) Forces on atoms. Shape (N,). masses : wp.array(dtype=wp.float32 or wp.float64) Atomic masses. Shape (N,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (B,). velocities_out : wp.array Output velocities. Must be pre-allocated with same shape/dtype/device as velocities. batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. For batched mode (atomic operations). atom_ptr : wp.array(dtype=wp.int32), optional CSR-style pointers. Shape (num_systems + 1,). For batched mode (sequential per-system). device : str, optional Warp device. If None, inferred from velocities. Returns ------- wp.array Updated velocities. """ validate_out_array(velocities_out, velocities, "velocities_out") dispatch_family( velocity_kick_families, velocities, batch_idx=batch_idx, atom_ptr=atom_ptr, device=device, inputs_single=[velocities, forces, masses, dt, velocities_out], inputs_batch=[velocities, forces, masses, batch_idx, dt, velocities_out], inputs_ptr=[velocities, forces, masses, atom_ptr, dt, velocities_out], ) return velocities_out
[docs] def nhc_position_update( positions: wp.array, velocities: wp.array, dt: wp.array, batch_idx: wp.array = None, atom_ptr: wp.array = None, device: str = None, ) -> None: """ Full-step position update (in-place). r += v * dt Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic positions. Shape (N,). MODIFIED in-place. velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (B,). batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. For batched mode (atomic operations). atom_ptr : wp.array(dtype=wp.int32), optional CSR-style pointers. Shape (num_systems + 1,). For batched mode (sequential per-system). device : str, optional Warp device. If None, inferred from positions. """ dispatch_family( position_update_families, positions, batch_idx=batch_idx, atom_ptr=atom_ptr, device=device, inputs_single=[positions, velocities, dt, positions], inputs_batch=[positions, velocities, batch_idx, dt, positions], inputs_ptr=[positions, velocities, atom_ptr, dt, positions], )
[docs] def nhc_position_update_out( positions: wp.array, velocities: wp.array, dt: wp.array, positions_out: wp.array, batch_idx: wp.array = None, atom_ptr: wp.array = None, device: str = None, ) -> wp.array: """ Full-step position update (non-mutating). Parameters ---------- positions : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic positions. Shape (N,). velocities : wp.array(dtype=wp.vec3f or wp.vec3d) Atomic velocities. Shape (N,). dt : wp.array(dtype=wp.float32 or wp.float64) Time step. Shape (1,) or (B,). positions_out : wp.array Output positions. Must be pre-allocated with same shape/dtype/device as positions. batch_idx : wp.array(dtype=wp.int32), optional System index for each atom. For batched mode (atomic operations). atom_ptr : wp.array(dtype=wp.int32), optional CSR-style pointers. Shape (num_systems + 1,). For batched mode (sequential per-system). device : str, optional Warp device. If None, inferred from positions. Returns ------- wp.array Updated positions. """ validate_out_array(positions_out, positions, "positions_out") dispatch_family( position_update_families, positions, batch_idx=batch_idx, atom_ptr=atom_ptr, device=device, inputs_single=[positions, velocities, dt, positions_out], inputs_batch=[positions, velocities, batch_idx, dt, positions_out], inputs_ptr=[positions, velocities, atom_ptr, dt, positions_out], ) return positions_out
[docs] def nhc_compute_chain_energy( eta: wp.array, eta_dot: wp.array, eta_mass: wp.array, target_temp: wp.array, ndof: wp.array, ke_chain: wp.array, pe_chain: wp.array, batch_idx: wp.array = None, num_systems: int = 1, device: str = None, ) -> tuple[wp.array, wp.array]: """ Compute Nosé-Hoover chain kinetic and potential energy. For conservation checks, the extended system Hamiltonian is: H_ext = KE_particles + PE + KE_chain + PE_chain where: KE_chain = sum_k 0.5 * Q_k * eta_dot_k^2 PE_chain = ndof * kT * eta_0 + kT * sum_{k>0} eta_k Parameters ---------- eta : wp.array(dtype=wp.float64) Thermostat chain positions. Shape (M,) or (B, M). eta_dot : wp.array(dtype=wp.float64) Thermostat chain velocities. Shape (M,) or (B, M). eta_mass : wp.array(dtype=wp.float64) Thermostat chain masses. Shape (M,) or (B, M). target_temp : wp.array(dtype=wp.float64) Target temperature (kT). Shape (1,) or (B,). ndof : wp.array(dtype=wp.float64) Degrees of freedom. Shape (1,) or (B,). ke_chain : wp.array Output kinetic energy of the chain. Shape (1,) for single system, (num_systems,) for batched. pe_chain : wp.array Output potential energy of the chain. Shape (1,) for single system, (num_systems,) for batched. batch_idx : wp.array(dtype=wp.int32), optional Not used directly, but included for API consistency. num_systems : int, optional Number of systems for batched mode. Default: 1. device : str, optional Warp device. If None, inferred from eta. Returns ------- tuple[wp.array, wp.array] (ke_chain, pe_chain) each with shape (1,) or (B,). """ if device is None: device = eta.device is_batched = num_systems > 1 or (batch_idx is not None) # Determine chain length if is_batched: chain_length = eta.shape[1] else: chain_length = eta.shape[0] chain_dtype = eta.dtype if is_batched: wp.launch( _batch_nhc_compute_chain_energy_kernel_overload[chain_dtype], dim=num_systems, inputs=[ eta, eta_dot, eta_mass, target_temp, ndof, chain_length, ke_chain, pe_chain, ], device=device, ) else: wp.launch( _nhc_compute_chain_energy_kernel_overload[chain_dtype], dim=1, inputs=[ eta, eta_dot, eta_mass, target_temp, ndof, chain_length, ke_chain, pe_chain, ], device=device, ) return ke_chain, pe_chain