Source code for nvalchemi.dynamics.hooks.profiling

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Per-stage wall-clock profiling for dynamics simulations.

Provides :class:`ProfilerHook`, a single hook that registers at multiple
stages and records the elapsed time between consecutive stages at each
step.  Supports NVTX range annotations for Nsight Systems, CSV logging,
and formatted console output via ``loguru``.
"""

from __future__ import annotations

import csv
import io
import statistics
import time
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import torch
from loguru import logger

from nvalchemi.dynamics.base import HookStageEnum

if TYPE_CHECKING:
    from nvalchemi.data import Batch
    from nvalchemi.dynamics.base import BaseDynamics

try:
    import nvtx
except ImportError:
    nvtx = None

__all__ = ["ProfilerHook"]

# HookStageEnum members in execution order.  The numeric values of the
# enum are monotonically increasing with execution order.
_STAGE_ORDER: list[HookStageEnum] = sorted(HookStageEnum, key=lambda s: s.value)


[docs] class ProfilerHook: """Per-stage timing hook for dynamics simulations. A single ``ProfilerHook`` instance registers itself at every requested stage. On each call it records a timestamp; when the last profiled stage in a step fires, it computes the elapsed time between consecutive stages and (optionally) writes to CSV / console. The hook uses ``stages`` (plural) so that :meth:`~nvalchemi.dynamics.base.BaseDynamics.register_hook` registers it at all listed stages in one call. Parameters ---------- stages : set[HookStageEnum] | {"all", "step", "detailed"} Which stages to instrument. * ``"all"`` (default): every stage except ``ON_CONVERGE``. * ``"step"``: ``BEFORE_STEP`` and ``AFTER_STEP`` only. * ``"detailed"``: all stages from ``BEFORE_STEP`` through ``AFTER_STEP`` (excluding ``ON_CONVERGE``). * A custom ``set[HookStageEnum]`` for fine-grained control. frequency : int, optional Profile every ``frequency`` steps. Default ``1``. enable_nvtx : bool, optional Emit NVTX push/pop ranges for Nsight Systems. Default ``True``. timer_backend : {"cuda_event", "perf_counter", "auto"}, optional Timing backend. ``"auto"`` selects ``cuda_event`` on GPU devices and ``perf_counter`` on CPU. Default ``"auto"``. log_path : str | Path | None, optional Path to a CSV file for persistent timing logs. Each row records the rank, step, stage transition, wall-clock offset, and delta. Default ``None`` (no file). show_console : bool, optional Print a formatted timing table via ``loguru`` at each profiled step. Default ``False``. console_frequency : int, optional When ``show_console`` is ``True``, print every ``console_frequency`` profiled steps. Default ``1``. Attributes ---------- stages : list[HookStageEnum] Profiled stages in execution order (used by ``register_hook``). frequency : int Execution frequency in steps. timings : dict[HookStageEnum, list[float]] Accumulated per-transition timing data (seconds). Examples -------- >>> from nvalchemi.dynamics.hooks import ProfilerHook >>> profiler = ProfilerHook() >>> dynamics = DemoDynamics(model=model, n_steps=100, dt=0.5, hooks=[profiler]) >>> dynamics.run(batch) >>> print(profiler.summary()) With CSV logging and console output: >>> profiler = ProfilerHook( ... "detailed", ... log_path="profiler.csv", ... show_console=True, ... console_frequency=10, ... ) >>> dynamics = DemoDynamics(model=model, n_steps=1000, dt=0.5, hooks=[profiler]) >>> dynamics.run(batch) """
[docs] def __init__( self, stages: set[HookStageEnum] | Literal["all", "step", "detailed"] = "all", *, frequency: int = 1, enable_nvtx: bool = True, timer_backend: Literal["cuda_event", "perf_counter", "auto"] = "auto", log_path: str | Path | None = None, show_console: bool = False, console_frequency: int = 1, ) -> None: # Init file handle early so __del__ is safe on validation errors. self._csv_file: io.TextIOWrapper | None = None self._csv_writer: csv.DictWriter | None = None S = HookStageEnum if isinstance(stages, str): if stages == "all": resolved = {s for s in S if s != S.ON_CONVERGE} elif stages == "step": resolved = {S.BEFORE_STEP, S.AFTER_STEP} elif stages == "detailed": resolved = { S.BEFORE_STEP, S.BEFORE_PRE_UPDATE, S.AFTER_PRE_UPDATE, S.BEFORE_COMPUTE, S.AFTER_COMPUTE, S.BEFORE_POST_UPDATE, S.AFTER_POST_UPDATE, S.AFTER_STEP, } else: raise ValueError( f"Unknown stages preset {stages!r}. " f"Use 'all', 'step', 'detailed', or a set of HookStageEnum." ) else: resolved = set(stages) if len(resolved) < 2: raise ValueError( "At least two stages are required to measure timing deltas." ) # Sorted by execution order — used by register_hook. self.stages: list[HookStageEnum] = [s for s in _STAGE_ORDER if s in resolved] self.frequency = frequency self.enable_nvtx = enable_nvtx self.timer_backend = timer_backend self.log_path = Path(log_path) if log_path is not None else None self.show_console = show_console self.console_frequency = console_frequency # Per-step scratch — separate dicts for type safety. self._current_step: int = -1 self._step_cuda_events: dict[HookStageEnum, torch.cuda.Event] = {} self._step_cpu_timestamps: dict[HookStageEnum, int] = {} # Accumulated timing: transition endpoint -> list of delta_s. self.timings: dict[HookStageEnum, list[float]] = {s: [] for s in self.stages} self._t0_ns: int = time.perf_counter_ns() self._backend_resolved: str | None = None self._steps_recorded: int = 0
# ------------------------------------------------------------------ # Hook entry point # ------------------------------------------------------------------ @torch.compiler.disable def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None: """Record a timestamp for the current stage. Parameters ---------- batch : Batch The current batch of atomic data. dynamics : BaseDynamics The dynamics engine instance. """ stage: HookStageEnum = dynamics.current_hook_stage # type: ignore[assignment] step = dynamics.step_count # New step: flush the previous one, then reset scratch. if step != self._current_step: if self._current_step >= 0: self._flush_step(dynamics.global_rank) self._current_step = step self._step_cuda_events.clear() self._step_cpu_timestamps.clear() # NVTX annotation. if self.enable_nvtx and nvtx is not None: idx = self.stages.index(stage) if idx > 0: nvtx.pop_range() nvtx.push_range(f"{stage.name}/{step}") # Timestamp. dev = batch.device if isinstance(dev, str): dev = torch.device(dev) if self._backend_resolved is None: self._backend_resolved = self._resolve_backend(dev) if self._backend_resolved == "cuda_event": event = torch.cuda.Event(enable_timing=True) event.record() self._step_cuda_events[stage] = event else: self._step_cpu_timestamps[stage] = time.perf_counter_ns() # If this is the last profiled stage in the step, flush now. if stage == self.stages[-1]: self._flush_step(dynamics.global_rank) self._current_step = -1 self._step_cuda_events.clear() self._step_cpu_timestamps.clear() # ------------------------------------------------------------------ # Backend resolution # ------------------------------------------------------------------ def _resolve_backend(self, device: torch.device) -> str: """Resolve the timing backend based on configuration and device.""" if self.timer_backend != "auto": return self.timer_backend if device.type == "cuda": return "cuda_event" return "perf_counter" # ------------------------------------------------------------------ # Step flush — compute deltas, log # ------------------------------------------------------------------ def _flush_step(self, rank: int) -> None: """Compute per-transition deltas for the current step and log.""" use_cuda = self._backend_resolved == "cuda_event" if use_cuda: ordered = [s for s in self.stages if s in self._step_cuda_events] else: ordered = [s for s in self.stages if s in self._step_cpu_timestamps] if len(ordered) < 2: return if use_cuda: torch.cuda.synchronize() deltas: dict[HookStageEnum, float] = {} for i in range(1, len(ordered)): prev_stage, curr_stage = ordered[i - 1], ordered[i] if use_cuda: prev_ev = self._step_cuda_events[prev_stage] curr_ev = self._step_cuda_events[curr_stage] delta_s = prev_ev.elapsed_time(curr_ev) / 1000.0 else: prev_ts = self._step_cpu_timestamps[prev_stage] curr_ts = self._step_cpu_timestamps[curr_stage] delta_s = (curr_ts - prev_ts) / 1e9 deltas[curr_stage] = delta_s self.timings[curr_stage].append(delta_s) t_since_init_s = (time.perf_counter_ns() - self._t0_ns) / 1e9 self._steps_recorded += 1 if self.log_path is not None: self._write_csv(rank, self._current_step, t_since_init_s, ordered, deltas) if self.show_console and (self._steps_recorded % self.console_frequency == 0): self._print_console( rank, self._current_step, t_since_init_s, ordered, deltas ) # Close NVTX range for the last stage in this step. if self.enable_nvtx and nvtx is not None: nvtx.pop_range() # ------------------------------------------------------------------ # CSV output # ------------------------------------------------------------------ def _write_csv( self, rank: int, step: int, t_since_init: float, ordered: list[HookStageEnum], deltas: dict[HookStageEnum, float], ) -> None: """Append one row per transition to the CSV log.""" rows = [] for i, stage in enumerate(ordered[1:], start=1): rows.append( { "rank": rank, "step": step, "stage": f"{ordered[i - 1].name}->{stage.name}", "t_since_init_s": f"{t_since_init:.6f}", "delta_s": f"{deltas[stage]:.6f}", } ) if self._csv_writer is None: log_path = self.log_path if log_path is None: return fh = open(log_path, "w", newline="") # noqa: SIM115 self._csv_file = fh self._csv_writer = csv.DictWriter( fh, fieldnames=["rank", "step", "stage", "t_since_init_s", "delta_s"], ) self._csv_writer.writeheader() self._csv_writer.writerows(rows) if self._csv_file is not None: self._csv_file.flush() # ------------------------------------------------------------------ # Console output # ------------------------------------------------------------------ def _print_console( self, rank: int, step: int, t_since_init: float, ordered: list[HookStageEnum], deltas: dict[HookStageEnum, float], ) -> None: """Print a formatted timing table for the current step.""" lines = [f"[Profiler] rank={rank} step={step} t={t_since_init:.3f}s"] for i, stage in enumerate(ordered[1:], start=1): prev_name = ordered[i - 1].name lines.append( f" {prev_name} -> {stage.name}: {deltas[stage] * 1000:.3f} ms" ) logger.info("\n".join(lines)) # ------------------------------------------------------------------ # Summary / reset / close # ------------------------------------------------------------------ def summary(self) -> dict[str, dict[str, float]]: """Return per-transition timing statistics. Returns ------- dict[str, dict[str, float]] Mapping from ``"PREV_STAGE->STAGE"`` label to a stats dict with keys ``mean_s``, ``std_s``, ``min_s``, ``max_s``, ``total_s``, ``n_samples``. """ result: dict[str, dict[str, float]] = {} for idx, stage in enumerate(self.stages): samples = self.timings[stage] if not samples: continue prev_name = self.stages[idx - 1].name label = f"{prev_name}->{stage.name}" n = len(samples) result[label] = { "mean_s": statistics.mean(samples), "std_s": statistics.stdev(samples) if n > 1 else 0.0, "min_s": min(samples), "max_s": max(samples), "total_s": sum(samples), "n_samples": float(n), } return result def reset(self) -> None: """Clear all accumulated timing data.""" for stage in self.timings: self.timings[stage].clear() self._step_cuda_events.clear() self._step_cpu_timestamps.clear() self._current_step = -1 self._backend_resolved = None self._t0_ns = time.perf_counter_ns() self._steps_recorded = 0 def close(self) -> None: """Flush and close the CSV log file, if open.""" if self._csv_file is not None: self._csv_file.close() self._csv_file = None self._csv_writer = None def __del__(self) -> None: self.close()