# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Logging hook for recording per-sample simulation observables.
Provides :class:`LoggingHook`, which computes and logs per-graph scalar
statistics (energies, temperatures, max forces, etc.) at a configurable
frequency. Each graph in the batch is written as an individual row,
together with the current step and status (stage) information.
I/O is offloaded to a background thread via :class:`ThreadPoolExecutor`
so that file writes do not block the simulation loop. When the batch
resides on CUDA, scalar computation and the D2H transfer run on a
dedicated side stream (created in :meth:`__enter__`) so that they
overlap with ongoing compute on the default stream. The transfer uses
``non_blocking=True``; the worker thread calls
``stream.synchronize()`` before reading the resulting CPU tensor.
The hook implements the context manager protocol::
with LoggingHook(backend="csv", log_path="log.csv") as hook:
dynamics.register_hook(hook)
dynamics.run(batch)
# all pending writes flushed, files closed
"""
from __future__ import annotations
import contextlib
import csv
import io
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Literal, get_args
import torch
from tensordict import TensorDict
from nvalchemi.dynamics.hooks._base import _ObserverHook
from nvalchemi.dynamics.hooks._utils import (
scatter_reduce_per_graph,
temperature_per_graph,
)
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from nvalchemi.data import Batch
from nvalchemi.dynamics.base import BaseDynamics
# Boltzmann constant in eV/K — consistent with typical atomistic MD unit
# systems (positions in Å, masses in amu, velocities in Å/fs, energy in eV).
_KB_EV_PER_K: float = 8.617333262e-5
__all__ = ["LoggingHook"]
LogBackend = Literal["csv", "tensorboard", "custom"]
[docs]
class LoggingHook(_ObserverHook):
"""Log per-sample scalar observables from the simulation.
At each firing step, this hook computes per-graph scalars from the
:class:`~nvalchemi.data.Batch` and writes **one row per graph** to the
configured logging backend. Each row includes:
* **step** — the current ``dynamics.step_count``.
* **graph_idx** — the graph's index within the batch.
* **status** — the sample's status code (from ``batch.status``),
indicating which pipeline stage it belongs to. Always ``0`` for
single-stage dynamics.
* **energy** — per-graph potential energy (from ``batch.energies``).
* **fmax** — per-graph maximum atomic force norm.
* **temperature** — per-graph instantaneous kinetic temperature
(from ``batch.velocities`` and ``batch.atomic_masses`` via the
equipartition theorem), if velocities are present.
Users can extend or replace this set by providing a ``custom_scalars``
mapping of ``{name: callable}`` pairs, where each callable has
signature ``(batch, dynamics) -> Tensor`` of shape ``(B,)`` (one
value per graph) or a plain ``float`` (broadcast to all graphs).
**Asynchronous I/O.** Scalar computation and the GPU-to-CPU transfer
run on a dedicated CUDA side stream (set up in :meth:`__enter__`) so
they do not stall the default compute stream. The D2H copy uses
``non_blocking=True``; the single-worker
:class:`ThreadPoolExecutor` synchronizes the stream before reading
the CPU tensor and writing to the backend.
**Context manager.** Use ``with`` to guarantee the CUDA stream is
created, all pending writes are flushed, and file handles are
closed on exit::
with LoggingHook(backend="csv", log_path="out.csv") as hook:
dynamics.register_hook(hook)
dynamics.run(batch)
This hook is non-blocking: scalar computation and the D2H transfer
run on a dedicated CUDA side stream (when available), and all
``.item()`` calls and I/O happen in a background
:class:`~concurrent.futures.ThreadPoolExecutor` worker, so the
GPU pipeline is not stalled.
Parameters
----------
backend : {"csv", "tensorboard", "custom"}
Logging backend to use.
frequency : int, optional
Log every ``frequency`` steps. Default ``1``.
log_path : str | Path | None, optional
File path for file-based backends (``"csv"``,
``"tensorboard"``). Default ``None``.
custom_scalars : dict[str, Callable] | None, optional
Additional named scalars to compute and log. Each callable
receives ``(batch, dynamics)`` and returns either a ``(B,)``
tensor (per-graph values) or a ``float`` (broadcast to all
graphs). Name collisions override defaults. Default ``None``.
writer_fn : Callable[[int, list[dict[str, float]]], None] | None, optional
Custom writer function, required when ``backend="custom"``.
Receives ``(step_count, rows)``. Default ``None``.
Examples
--------
>>> from nvalchemi.dynamics.hooks import LoggingHook
>>> with LoggingHook(backend="csv", log_path="md_log.csv", frequency=100) as hook:
... dynamics = DemoDynamics(model=model, n_steps=10_000, dt=0.5, hooks=[hook])
... dynamics.run(batch)
Using custom scalars:
>>> def pressure(batch, dynamics):
... return compute_pressure(batch.stresses, batch.cell)
>>> hook = LoggingHook(
... backend="csv",
... log_path="md_log.csv",
... frequency=50,
... custom_scalars={"pressure": pressure},
... )
Notes
-----
* The default temperature calculation assumes an NVT-like system
with ``3N`` degrees of freedom (no constraint correction).
Override via ``custom_scalars`` if constraints remove DOFs.
* For distributed pipelines, each rank logs independently. Use
``log_path`` with rank-specific filenames to avoid file
contention.
"""
[docs]
def __init__(
self,
backend: LogBackend,
frequency: int = 1,
log_path: str | Path | None = None,
custom_scalars: (
dict[str, Callable[[Batch, BaseDynamics], float | torch.Tensor]] | None
) = None,
writer_fn: (Callable[[int, list[dict[str, float]]], None] | None) = None,
) -> None:
super().__init__(frequency=frequency)
self._csv_file: io.TextIOWrapper | None = None
self._csv_writer: csv.DictWriter | None = None
self._tb_writer = None
self._stream: torch.cuda.Stream | None = None
self._executor = ThreadPoolExecutor(max_workers=1)
if backend == "csv" and log_path is None:
raise ValueError("csv backend requires log_path")
if backend == "tensorboard" and log_path is None:
raise ValueError("tensorboard backend requires log_path")
if backend == "custom" and writer_fn is None:
raise ValueError("custom backend requires writer_fn")
if backend not in get_args(LogBackend):
raise ValueError(
f"LoggingHook only supports backends: "
f"{get_args(LogBackend)}, got {backend!r}."
)
self.backend = backend
self.log_path = log_path
self.custom_scalars = custom_scalars
self.writer_fn = writer_fn
# ------------------------------------------------------------------
# Context manager
# ------------------------------------------------------------------
def __enter__(self) -> LoggingHook:
if torch.cuda.is_available():
self._stream = torch.cuda.Stream()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001
self.close()
return None
def close(self) -> None:
"""Flush pending writes and release stream / file resources."""
self._executor.shutdown(wait=True)
self._executor = ThreadPoolExecutor(max_workers=1)
if self._stream is not None:
self._stream.synchronize()
self._stream = None
if self._csv_file is not None:
self._csv_file.close()
self._csv_file = None
self._csv_writer = None
if self._tb_writer is not None:
self._tb_writer.close()
self._tb_writer = None
def __del__(self) -> None:
self._executor.shutdown(wait=True)
# ------------------------------------------------------------------
# Hook entry point
# ------------------------------------------------------------------
@torch.compiler.disable
def __call__(self, batch: Batch, dynamics: BaseDynamics) -> None:
"""Compute per-graph scalars and dispatch to the logging backend.
When a CUDA stream has been set up (via :meth:`__enter__`), the
column computation and ``non_blocking`` D2H transfer run on that
side stream so they do not stall the default compute stream.
"""
device = batch.device
use_stream = self._stream is not None and device.type == "cuda"
if use_stream:
main_stream = torch.cuda.current_stream(device)
ctx = torch.cuda.stream(self._stream)
else:
ctx = contextlib.nullcontext()
with ctx:
if use_stream:
self._stream.wait_stream(main_stream)
td = self._compute_columns(batch, dynamics)
td = td.to("cpu", non_blocking=True)
self._executor.submit(self._dispatch, td, dynamics.step_count)
# ------------------------------------------------------------------
# Column computation (on-device)
# ------------------------------------------------------------------
def _compute_columns(
self,
batch: Batch,
dynamics: BaseDynamics,
) -> TensorDict:
"""Build a ``TensorDict(batch_size=[B])`` of per-graph scalars."""
dev = batch.device
num_graphs = batch.num_graphs
td = TensorDict(
step=torch.full(
(num_graphs,), dynamics.step_count, device=dev, dtype=torch.int64
),
graph_idx=torch.arange(num_graphs, device=dev, dtype=torch.int64),
batch_size=[num_graphs],
)
# Status (stage code) — 0 if not present
if hasattr(batch, "status") and batch.status is not None:
status = (
batch.status.squeeze(-1) if batch.status.dim() == 2 else batch.status
)
td.set("status", status.float())
else:
td.set("status", torch.zeros(num_graphs, device=dev))
if batch.energies is not None:
td.set("energy", batch.energies.squeeze(-1))
if batch.forces is not None:
norms = torch.linalg.vector_norm(batch.forces, dim=-1)
td.set(
"fmax",
scatter_reduce_per_graph(norms, batch.batch, num_graphs, reduce="amax"),
)
if getattr(batch, "velocities", None) is not None:
td.set(
"temperature",
temperature_per_graph(
batch.velocities,
batch.atomic_masses,
batch.batch,
num_graphs,
batch.num_nodes_per_graph,
),
)
if self.custom_scalars:
for name, fn in self.custom_scalars.items():
val = fn(batch, dynamics)
if isinstance(val, torch.Tensor):
td.set(name, val)
else:
td.set(name, torch.full((num_graphs,), val, device=dev))
return td
# ------------------------------------------------------------------
# Dispatch (runs on ThreadPoolExecutor worker)
# ------------------------------------------------------------------
def _dispatch(self, td: TensorDict, step: int) -> None:
"""Synchronize stream (if needed), convert to rows, write."""
if self._stream is not None:
self._stream.synchronize()
rows = _tensordict_to_rows(td)
match self.backend:
case "csv":
self._write_csv(rows)
case "tensorboard":
self._write_tensorboard(rows, step)
case "custom":
self.writer_fn(step, rows) # type: ignore[misc]
def _write_csv(self, rows: list[dict[str, float]]) -> None:
"""Convert row data and write out CSV"""
if self._csv_writer is None:
self._csv_file = open(self.log_path, "w", newline="") # noqa: SIM115
self._csv_writer = csv.DictWriter(
self._csv_file,
fieldnames=list(rows[0].keys()),
)
self._csv_writer.writeheader()
self._csv_writer.writerows(rows)
self._csv_file.flush() # type: ignore[union-attr]
def _write_tensorboard(
self,
rows: list[dict[str, float]],
step: int,
) -> None:
if self._tb_writer is None:
from torch.utils.tensorboard import SummaryWriter
self._tb_writer = SummaryWriter(log_dir=str(self.log_path))
for row in rows:
graph_idx = int(row.get("graph_idx", 0))
for key, value in row.items():
if key in ("step", "graph_idx"):
continue
tag = f"{key}/graph_{graph_idx}" if len(rows) > 1 else key
self._tb_writer.add_scalar(tag, value, step)
def _tensordict_to_rows(td: TensorDict) -> list[dict[str, float]]:
"""Convert a ``TensorDict(batch_size=[B])`` into a list of dicts."""
cols = {k: td[k].tolist() for k in td.keys()}
num_graphs = td.batch_size[0]
keys = list(cols.keys())
return [{k: float(cols[k][i]) for k in keys} for i in range(num_graphs)]