Source code for nvalchemi.dynamics.hooks.profiling

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Per-stage wall-clock profiling for dynamics workflows.

Provides :class:`ProfilerHook`, a single hook that registers at multiple
stages and records the elapsed time between consecutive stages at each
step.  Supports NVTX range annotations for Nsight Systems, CSV logging,
and formatted console output via ``loguru``.

The hook supports dynamics and custom workflows via plum dispatch,
automatically detecting the stage type and annotating NVTX ranges with
the appropriate domain (``dynamics`` or ``custom``).
"""

from __future__ import annotations

import csv
import io
import statistics
import time
from enum import Enum
from pathlib import Path
from typing import Literal

import torch
from loguru import logger
from plum import dispatch

from nvalchemi.data import Batch
from nvalchemi.dynamics.base import DynamicsStage
from nvalchemi.hooks._context import HookContext

try:
    import nvtx
except ImportError:
    nvtx = None

__all__ = ["ProfilerHook"]


def _sort_stages(stages: set[Enum]) -> list[Enum]:
    """Sort stage enum members by their integer value."""
    return sorted(stages, key=lambda s: s.value)


[docs] class ProfilerHook: """Per-stage timing hook for dynamics workflows. A single ``ProfilerHook`` instance registers itself at every requested stage. On each call it records a timestamp; when the last profiled stage in a step fires, it computes the elapsed time between consecutive stages and (optionally) writes to CSV / console. The hook uses ``stages`` (plural) so that :meth:`~nvalchemi.dynamics.base.BaseDynamics.register_hook` registers it at all listed stages in one call. The hook supports :class:`DynamicsStage` and custom enum types via plum dispatch, automatically annotating NVTX ranges with the appropriate domain (``dynamics`` or ``custom``). Parameters ---------- profiled_stages : set[Enum] | {"all", "step", "detailed"} Which stages to instrument. * ``"all"`` (default): every :class:`DynamicsStage` except ``ON_CONVERGE``. * ``"step"``: ``BEFORE_STEP`` and ``AFTER_STEP`` only. * ``"detailed"``: all stages from ``BEFORE_STEP`` through ``AFTER_STEP`` (excluding ``ON_CONVERGE``). * A custom ``set[Enum]`` for fine-grained control. frequency : int, optional Profile every ``frequency`` steps. Default ``1``. enable_nvtx : bool, optional Emit NVTX push/pop ranges for Nsight Systems. Default ``True``. timer_backend : {"cuda_event", "perf_counter", "auto"}, optional Timing backend. ``"auto"`` selects ``cuda_event`` on GPU devices and ``perf_counter`` on CPU. Default ``"auto"``. log_path : str | Path | None, optional Path to a CSV file for persistent timing logs. Each row records the rank, step, stage transition, wall-clock offset, and delta. Default ``None`` (no file). show_console : bool, optional Print a formatted timing table via ``loguru`` at each profiled step. Default ``False``. console_frequency : int, optional When ``show_console`` is ``True``, print every ``console_frequency`` profiled steps. Default ``1``. Attributes ---------- _profiled_stages : list[Enum] Profiled stages in execution order (private). frequency : int Execution frequency in steps. timings : dict[Enum, list[float]] Accumulated per-transition timing data (seconds). Examples -------- >>> from nvalchemi.dynamics.hooks import ProfilerHook >>> profiler = ProfilerHook() >>> dynamics = DemoDynamics(model=model, n_steps=100, dt=0.5, hooks=[profiler]) >>> dynamics.run(batch) >>> print(profiler.summary()) With CSV logging and console output: >>> profiler = ProfilerHook( ... "detailed", ... log_path="profiler.csv", ... show_console=True, ... console_frequency=10, ... ) >>> dynamics = DemoDynamics(model=model, n_steps=1000, dt=0.5, hooks=[profiler]) >>> dynamics.run(batch) """
[docs] def __init__( self, profiled_stages: set[Enum] | Literal["all", "step", "detailed"] = "all", *, frequency: int = 1, enable_nvtx: bool = True, timer_backend: Literal["cuda_event", "perf_counter", "auto"] = "auto", log_path: str | Path | None = None, show_console: bool = False, console_frequency: int = 1, stage: Enum = DynamicsStage.BEFORE_STEP, ) -> None: # Init file handle early so __del__ is safe on validation errors. self._csv_file: io.TextIOWrapper | None = None self._csv_writer: csv.DictWriter | None = None if isinstance(profiled_stages, str): if profiled_stages == "all": resolved = {s for s in DynamicsStage if s != DynamicsStage.ON_CONVERGE} elif profiled_stages == "step": resolved = {DynamicsStage.BEFORE_STEP, DynamicsStage.AFTER_STEP} elif profiled_stages == "detailed": resolved = { DynamicsStage.BEFORE_STEP, DynamicsStage.BEFORE_PRE_UPDATE, DynamicsStage.AFTER_PRE_UPDATE, DynamicsStage.BEFORE_COMPUTE, DynamicsStage.AFTER_COMPUTE, DynamicsStage.BEFORE_POST_UPDATE, DynamicsStage.AFTER_POST_UPDATE, DynamicsStage.AFTER_STEP, } else: raise ValueError( f"Unknown stages preset {profiled_stages!r}. " f"Use 'all', 'step', 'detailed', or a set of Enum." ) else: resolved = set(profiled_stages) if len(resolved) < 2: raise ValueError( "At least two stages are required to measure timing deltas." ) # Primary stage for protocol compliance self.stage = stage # Sorted by execution order — private profiled stages list. self._profiled_stages: list[Enum] = _sort_stages(resolved) self.frequency = frequency self.enable_nvtx = enable_nvtx self.timer_backend = timer_backend self.log_path = Path(log_path) if log_path is not None else None self.show_console = show_console self.console_frequency = console_frequency # Per-step scratch — separate dicts for type safety. self._current_step: int = -1 self._step_cuda_events: dict[Enum, torch.cuda.Event] = {} self._step_cpu_timestamps: dict[Enum, int] = {} # Accumulated timing: transition endpoint -> list of delta_s. self.timings: dict[Enum, list[float]] = {s: [] for s in self._profiled_stages} self._t0_ns: int = time.perf_counter_ns() self._backend_resolved: str | None = None self._steps_recorded: int = 0
# ------------------------------------------------------------------ # Hook entry point # ------------------------------------------------------------------ def _runs_on_stage(self, stage: Enum) -> bool: """Check if this hook should run on the given stage. Parameters ---------- stage : Enum The stage to check. Returns ------- bool True if this hook runs on the given stage. """ return stage in set(self._profiled_stages) @torch.compiler.disable def _record( self, batch: Batch, current_stage: Enum, step_count: int, global_rank: int, domain: str = "dynamics", ) -> None: """Record a timestamp for the current stage. Parameters ---------- batch : Batch The current batch of atomic data. current_stage : Enum The current dynamics stage being executed. step_count : int The current step number. global_rank : int The distributed rank of this process. domain : str, optional The domain for NVTX annotation (e.g., "dynamics", "custom"). Default ``"dynamics"``. """ # New step: flush the previous one, then reset scratch. if step_count != self._current_step: if self._current_step >= 0: self._flush_step(global_rank) self._current_step = step_count self._step_cuda_events.clear() self._step_cpu_timestamps.clear() # NVTX annotation. if self.enable_nvtx and nvtx is not None: idx = self._profiled_stages.index(current_stage) if idx > 0: nvtx.pop_range() nvtx.push_range(f"{domain}/{current_stage.name}/{step_count}") # Timestamp. dev = batch.device if isinstance(dev, str): dev = torch.device(dev) if self._backend_resolved is None: self._backend_resolved = self._resolve_backend(dev) if self._backend_resolved == "cuda_event": event = torch.cuda.Event(enable_timing=True) event.record() self._step_cuda_events[current_stage] = event else: self._step_cpu_timestamps[current_stage] = time.perf_counter_ns() # If this is the last profiled stage in the step, flush now. if current_stage == self._profiled_stages[-1]: self._flush_step(global_rank) self._current_step = -1 self._step_cuda_events.clear() self._step_cpu_timestamps.clear() @dispatch def __call__(self, ctx: HookContext, stage: DynamicsStage) -> None: # noqa: F811 """Record timing for a dynamics stage.""" self._record( ctx.batch, stage, ctx.step_count, ctx.global_rank or 0, domain="dynamics" ) @dispatch def __call__(self, ctx: HookContext, stage: Enum) -> None: # noqa: F811 """Record timing for a generic stage.""" self._record( ctx.batch, stage, ctx.step_count, ctx.global_rank or 0, domain="custom" ) # ------------------------------------------------------------------ # Backend resolution # ------------------------------------------------------------------ def _resolve_backend(self, device: torch.device) -> str: """Resolve the timing backend based on configuration and device.""" if self.timer_backend != "auto": return self.timer_backend if device.type == "cuda": return "cuda_event" return "perf_counter" # ------------------------------------------------------------------ # Step flush — compute deltas, log # ------------------------------------------------------------------ def _flush_step(self, rank: int) -> None: """Compute per-transition deltas for the current step and log.""" use_cuda = self._backend_resolved == "cuda_event" if use_cuda: ordered = [s for s in self._profiled_stages if s in self._step_cuda_events] else: ordered = [ s for s in self._profiled_stages if s in self._step_cpu_timestamps ] if len(ordered) < 2: return if use_cuda: torch.cuda.synchronize() deltas: dict[Enum, float] = {} for i in range(1, len(ordered)): prev_stage, curr_stage = ordered[i - 1], ordered[i] if use_cuda: prev_ev = self._step_cuda_events[prev_stage] curr_ev = self._step_cuda_events[curr_stage] delta_s = prev_ev.elapsed_time(curr_ev) / 1000.0 else: prev_ts = self._step_cpu_timestamps[prev_stage] curr_ts = self._step_cpu_timestamps[curr_stage] delta_s = (curr_ts - prev_ts) / 1e9 deltas[curr_stage] = delta_s self.timings[curr_stage].append(delta_s) t_since_init_s = (time.perf_counter_ns() - self._t0_ns) / 1e9 self._steps_recorded += 1 if self.log_path is not None: self._write_csv(rank, self._current_step, t_since_init_s, ordered, deltas) if self.show_console and (self._steps_recorded % self.console_frequency == 0): self._print_console( rank, self._current_step, t_since_init_s, ordered, deltas ) # Close NVTX range for the last stage in this step. if self.enable_nvtx and nvtx is not None: nvtx.pop_range() # ------------------------------------------------------------------ # CSV output # ------------------------------------------------------------------ def _write_csv( self, rank: int, step: int, t_since_init: float, ordered: list[Enum], deltas: dict[Enum, float], ) -> None: """Append one row per transition to the CSV log.""" rows = [] for i, stage in enumerate(ordered[1:], start=1): rows.append( { "rank": rank, "step": step, "stage": f"{ordered[i - 1].name}->{stage.name}", "t_since_init_s": f"{t_since_init:.6f}", "delta_s": f"{deltas[stage]:.6f}", } ) if self._csv_writer is None: log_path = self.log_path if log_path is None: return fh = open(log_path, "w", newline="") # noqa: SIM115 self._csv_file = fh self._csv_writer = csv.DictWriter( fh, fieldnames=["rank", "step", "stage", "t_since_init_s", "delta_s"], ) self._csv_writer.writeheader() self._csv_writer.writerows(rows) if self._csv_file is not None: self._csv_file.flush() # ------------------------------------------------------------------ # Console output # ------------------------------------------------------------------ def _print_console( self, rank: int, step: int, t_since_init: float, ordered: list[Enum], deltas: dict[Enum, float], ) -> None: """Print a formatted timing table for the current step.""" lines = [f"[Profiler] rank={rank} step={step} t={t_since_init:.3f}s"] for i, stage in enumerate(ordered[1:], start=1): prev_name = ordered[i - 1].name lines.append( f" {prev_name} -> {stage.name}: {deltas[stage] * 1000:.3f} ms" ) logger.info("\n".join(lines)) # ------------------------------------------------------------------ # Summary / reset / close # ------------------------------------------------------------------ def summary(self) -> dict[str, dict[str, float]]: """Return per-transition timing statistics. Returns ------- dict[str, dict[str, float]] Mapping from ``"PREV_STAGE->STAGE"`` label to a stats dict with keys ``mean_s``, ``std_s``, ``min_s``, ``max_s``, ``total_s``, ``n_samples``. """ result: dict[str, dict[str, float]] = {} for idx, stage in enumerate(self._profiled_stages): samples = self.timings[stage] if not samples: continue prev_name = self._profiled_stages[idx - 1].name label = f"{prev_name}->{stage.name}" n = len(samples) result[label] = { "mean_s": statistics.mean(samples), "std_s": statistics.stdev(samples) if n > 1 else 0.0, "min_s": min(samples), "max_s": max(samples), "total_s": sum(samples), "n_samples": float(n), } return result def reset(self) -> None: """Clear all accumulated timing data.""" for stage in self.timings: self.timings[stage].clear() self._step_cuda_events.clear() self._step_cpu_timestamps.clear() self._current_step = -1 self._backend_resolved = None self._t0_ns = time.perf_counter_ns() self._steps_recorded = 0 def close(self) -> None: """Flush and close the CSV log file, if open.""" if self._csv_file is not None: self._csv_file.close() self._csv_file = None self._csv_writer = None def __del__(self) -> None: self.close()