Source code for physicsnemo_curator.core.pipeline_store

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified pipeline store with SQLite-backed metrics, checkpointing, and provenance.

Provides :class:`PipelineStore`, a single SQLite database that combines
checkpoint tracking (completed/failed indices), per-index and per-stage
wall-clock metrics, and pipeline provenance (config hashing).

Also contains the metrics dataclasses (:class:`StageMetrics`,
:class:`IndexMetrics`, :class:`PipelineMetrics`), the :class:`_TimedGenerator`
timing utility, and provenance helpers for serializing pipeline configuration.

Usage
-----
>>> from physicsnemo_curator.core.pipeline_store import PipelineStore
>>> config = _pipeline_config(pipeline)
>>> chash = _config_hash(config)
>>> store = PipelineStore(db_path=Path("run.db"), pipeline_config=config, config_hash=chash)
>>> store.is_completed(0)  # None — not yet completed
>>> store.record_success(0, ["/out/0.vtk"], wall_time_ns=1_000_000, ...)
"""

from __future__ import annotations

import csv
import hashlib
import inspect
import json
import logging
import pathlib
import sqlite3
import threading
import time
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any

from physicsnemo_curator.core.base import REQUIRED

if TYPE_CHECKING:
    from collections.abc import Iterator

    from physicsnemo_curator.core.base import Pipeline

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Metrics dataclasses
# ---------------------------------------------------------------------------


[docs] @dataclass class StageMetrics: """Metrics for a single pipeline stage (source, one filter, or sink). Parameters ---------- name : str Human-readable name of the stage (e.g. ``"source"``, ``"DoubleFilter"``, ``"sink"``). wall_time_ns : int Wall-clock time in nanoseconds spent in this stage. """ name: str wall_time_ns: int
[docs] def to_dict(self) -> dict[str, Any]: """Convert to a plain dictionary. Returns ------- dict[str, Any] Dictionary with ``"name"`` and ``"wall_time_ns"`` keys. """ return {"name": self.name, "wall_time_ns": self.wall_time_ns}
[docs] @dataclass class IndexMetrics: """Metrics for one ``__getitem__`` call (one source index). Parameters ---------- index : int The source index that was processed. stages : list[StageMetrics] Per-stage timing breakdown. wall_time_ns : int Total wall-clock time for this index in nanoseconds. peak_memory_bytes : int Peak Python memory usage during this index (from ``tracemalloc``). gpu_memory_bytes : int | None Peak GPU memory delta, or ``None`` if GPU tracking was disabled. """ index: int stages: list[StageMetrics] wall_time_ns: int peak_memory_bytes: int gpu_memory_bytes: int | None
[docs] def to_dict(self) -> dict[str, Any]: """Convert to a plain dictionary. Returns ------- dict[str, Any] Nested dictionary with all metric fields. """ return { "index": self.index, "stages": [s.to_dict() for s in self.stages], "wall_time_ns": self.wall_time_ns, "peak_memory_bytes": self.peak_memory_bytes, "gpu_memory_bytes": self.gpu_memory_bytes, }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> IndexMetrics: """Reconstruct from a dictionary (e.g. deserialized JSON). Parameters ---------- data : dict[str, Any] Dictionary as produced by :meth:`to_dict`. Returns ------- IndexMetrics Reconstructed metrics object. """ stages = [StageMetrics(**s) for s in data["stages"]] return cls( index=data["index"], stages=stages, wall_time_ns=data["wall_time_ns"], peak_memory_bytes=data["peak_memory_bytes"], gpu_memory_bytes=data.get("gpu_memory_bytes"), )
[docs] @dataclass class PipelineMetrics: """Aggregated metrics across all processed indices. Parameters ---------- indices : list[IndexMetrics] Per-index metrics, one entry per ``__getitem__`` call. """ indices: list[IndexMetrics] = field(default_factory=list) @property def total_wall_time_ns(self) -> int: """Total wall-clock time across all indices (nanoseconds). Returns ------- int Sum of per-index wall times. """ return sum(m.wall_time_ns for m in self.indices) @property def mean_index_time_ns(self) -> float: """Mean wall-clock time per index (nanoseconds). Returns ------- float Average per-index time, or ``0.0`` if no indices were processed. """ if not self.indices: return 0.0 return self.total_wall_time_ns / len(self.indices) @property def total_peak_memory_bytes(self) -> int: """Maximum peak memory observed across all indices (bytes). Returns ------- int Max of per-index peak memory values. """ if not self.indices: return 0 return max(m.peak_memory_bytes for m in self.indices)
[docs] def summary(self) -> dict[str, Any]: """Return a summary dictionary for programmatic use. Returns ------- dict[str, Any] Dictionary with total/mean wall time, peak memory, index count, and per-index breakdowns. """ return { "num_indices": len(self.indices), "total_wall_time_ns": self.total_wall_time_ns, "mean_index_time_ns": self.mean_index_time_ns, "total_peak_memory_bytes": self.total_peak_memory_bytes, "indices": [m.to_dict() for m in self.indices], }
[docs] def to_json(self, path: str | pathlib.Path) -> None: """Write metrics to a JSON file. Parameters ---------- path : str | pathlib.Path Output file path. """ data = self.summary() p = pathlib.Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(data, indent=2))
[docs] def to_csv(self, path: str | pathlib.Path) -> None: """Write per-index metrics to a CSV file. Each row represents one index. Stage timings are included as separate columns named ``stage_<name>_ns``. Parameters ---------- path : str | pathlib.Path Output file path. """ if not self.indices: pathlib.Path(path).write_text("") return # Collect all unique stage names across indices (preserving order) stage_names: list[str] = [] seen: set[str] = set() for idx_m in self.indices: for s in idx_m.stages: if s.name not in seen: stage_names.append(s.name) seen.add(s.name) fieldnames = [ "index", "wall_time_ns", "peak_memory_bytes", "gpu_memory_bytes", ] + [f"stage_{name}_ns" for name in stage_names] p = pathlib.Path(path) p.parent.mkdir(parents=True, exist_ok=True) with p.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for idx_m in self.indices: row: dict[str, Any] = { "index": idx_m.index, "wall_time_ns": idx_m.wall_time_ns, "peak_memory_bytes": idx_m.peak_memory_bytes, "gpu_memory_bytes": idx_m.gpu_memory_bytes if idx_m.gpu_memory_bytes is not None else "", } stage_map = {s.name: s.wall_time_ns for s in idx_m.stages} for sn in stage_names: row[f"stage_{sn}_ns"] = stage_map.get(sn, "") writer.writerow(row)
[docs] def to_console(self) -> None: """Print a formatted summary table to stdout. Outputs a human-readable table showing per-index and aggregate metrics. Uses only stdlib formatting (no external dependencies). """ if not self.indices: print("No profiling metrics collected.") return print("\n=== Pipeline Profiling Results ===\n") # Summary total_ms = self.total_wall_time_ns / 1e6 mean_ms = self.mean_index_time_ns / 1e6 peak_mb = self.total_peak_memory_bytes / (1024 * 1024) print(f" Indices processed : {len(self.indices)}") print(f" Total wall time : {total_ms:,.2f} ms") print(f" Mean per index : {mean_ms:,.2f} ms") print(f" Peak memory : {peak_mb:,.2f} MB") # Check for GPU gpu_indices = [m for m in self.indices if m.gpu_memory_bytes is not None] if gpu_indices: max_gpu = max(m.gpu_memory_bytes for m in gpu_indices) # type: ignore[arg-type] print(f" Peak GPU memory : {max_gpu / (1024 * 1024):,.2f} MB") # Per-index table print(f"\n{'Index':>7} {'Wall (ms)':>12} {'Memory (MB)':>13} {'GPU (MB)':>10}") print(" " + "-" * 46) for m in self.indices: wall = m.wall_time_ns / 1e6 mem = m.peak_memory_bytes / (1024 * 1024) gpu = f"{m.gpu_memory_bytes / (1024 * 1024):>10.2f}" if m.gpu_memory_bytes is not None else " N/A" print(f" {m.index:>5} {wall:>12.2f} {mem:>13.2f} {gpu}") # Per-stage averages if self.indices and self.indices[0].stages: print("\n Stage Averages:") stage_totals: dict[str, list[int]] = {} for idx_m in self.indices: for s in idx_m.stages: stage_totals.setdefault(s.name, []).append(s.wall_time_ns) for name, times in stage_totals.items(): avg_ms = (sum(times) / len(times)) / 1e6 print(f" {name:<30s} {avg_ms:>10.2f} ms (avg)") print()
# --------------------------------------------------------------------------- # _TimedGenerator # --------------------------------------------------------------------------- class _TimedGenerator[T]: """Generator wrapper that accumulates wall-clock time across ``__next__`` calls. This is used internally to attribute time to each pipeline stage. The wrapper preserves the full iterator protocol. Parameters ---------- inner : Iterator[T] The generator or iterator to wrap. """ def __init__(self, inner: Iterator[T]) -> None: """Initialize with the inner iterator.""" self._inner = inner self._elapsed_ns: int = 0 @property def elapsed_ns(self) -> int: """Total nanoseconds spent inside ``__next__`` of the inner iterator. Returns ------- int Accumulated wall-clock nanoseconds. """ return self._elapsed_ns def __iter__(self) -> _TimedGenerator[T]: """Return self (iterator protocol).""" return self def __next__(self) -> T: """Delegate to inner iterator, timing the call. Returns ------- T Next value from the inner iterator. Raises ------ StopIteration When the inner iterator is exhausted. """ start = time.perf_counter_ns() try: value = next(self._inner) except StopIteration: self._elapsed_ns += time.perf_counter_ns() - start raise self._elapsed_ns += time.perf_counter_ns() - start return value # --------------------------------------------------------------------------- # Provenance helpers # --------------------------------------------------------------------------- def _component_config(component: object) -> dict[str, Any]: """Serialize a pipeline component (source, filter, or sink) to a dict. Parameters ---------- component : object A pipeline component with ``name``, ``params()`` classmethod, and ``__init__`` arguments. Returns ------- dict[str, Any] Serialized configuration dictionary. """ cls = type(component) config: dict[str, Any] = { "class": cls.__name__, "module": cls.__module__, } if hasattr(cls, "name"): config["name"] = cls.name if hasattr(cls, "description"): config["description"] = cls.description # Extract current parameter values from __init__ signature sig = inspect.signature(cls.__init__) params: dict[str, Any] = {} for pname, param in sig.parameters.items(): if pname == "self": continue # Try to read the stored attribute (common convention: _<name>) for attr_name in (f"_{pname}", pname): if hasattr(component, attr_name): val = getattr(component, attr_name) params[pname] = _safe_serialize(val) break else: # Fall back to default if available if param.default is not inspect.Parameter.empty: params[pname] = _safe_serialize(param.default) config["params"] = params return config def _safe_serialize(value: object) -> object: """Convert a value to a JSON-safe type. Parameters ---------- value : object Any Python object. Returns ------- object JSON-safe representation. """ if value is REQUIRED: return "<REQUIRED>" if isinstance(value, (str, int, float, bool, type(None))): return value if isinstance(value, (list, tuple)): return [_safe_serialize(v) for v in value] if isinstance(value, dict): return {str(k): _safe_serialize(v) for k, v in value.items()} # pathlib.Path, other types -> string return str(value) def _pipeline_config(pipeline: Pipeline[Any]) -> dict[str, Any]: """Build the full pipeline configuration dictionary. Parameters ---------- pipeline : Pipeline The pipeline to serialize. Returns ------- dict[str, Any] Full pipeline configuration. """ config: dict[str, Any] = { "source": _component_config(pipeline.source), "filters": [_component_config(f) for f in pipeline.filters], } if pipeline.sink is not None: config["sink"] = _component_config(pipeline.sink) return config def _config_hash(config: dict[str, Any]) -> str: """Compute a stable SHA-256 hash of a pipeline configuration. Parameters ---------- config : dict[str, Any] Pipeline configuration dict. Returns ------- str Hex-encoded SHA-256 hash. """ blob = json.dumps(config, sort_keys=True, default=str).encode() return hashlib.sha256(blob).hexdigest() # --------------------------------------------------------------------------- # Worker identity # --------------------------------------------------------------------------- _worker_id_local = threading.local() def _get_worker_id() -> str: """Return a stable worker ID for the current thread. Each thread (and by extension each process in multi-process backends) gets a unique UUID4 identifier that is stable across multiple calls within the same thread. Returns ------- str UUID4 hex string identifying the current worker thread. """ wid: str | None = getattr(_worker_id_local, "worker_id", None) if wid is None: wid = uuid.uuid4().hex _worker_id_local.worker_id = wid return wid # --------------------------------------------------------------------------- # SQL schema # --------------------------------------------------------------------------- _SCHEMA_SQL = """\ CREATE TABLE IF NOT EXISTS pipeline_runs ( run_id INTEGER PRIMARY KEY AUTOINCREMENT, config_hash TEXT UNIQUE NOT NULL, config_json TEXT NOT NULL, started_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS index_results ( idx INTEGER NOT NULL, run_id INTEGER NOT NULL, status TEXT NOT NULL CHECK (status IN ('completed', 'error')), output_paths TEXT, completed_at TEXT NOT NULL, wall_time_ns INTEGER, peak_memory_bytes INTEGER, gpu_memory_bytes INTEGER, error TEXT, PRIMARY KEY (idx, run_id), FOREIGN KEY (run_id) REFERENCES pipeline_runs (run_id) ); CREATE TABLE IF NOT EXISTS stage_metrics ( idx INTEGER NOT NULL, run_id INTEGER NOT NULL, stage_order INTEGER NOT NULL, stage_name TEXT NOT NULL, wall_time_ns INTEGER NOT NULL, PRIMARY KEY (idx, run_id, stage_order), FOREIGN KEY (idx, run_id) REFERENCES index_results (idx, run_id) ); CREATE TABLE IF NOT EXISTS workers ( worker_id TEXT PRIMARY KEY, run_id INTEGER NOT NULL, pid INTEGER NOT NULL, hostname TEXT NOT NULL, started_at TEXT NOT NULL, last_heartbeat TEXT NOT NULL, current_index INTEGER, FOREIGN KEY (run_id) REFERENCES pipeline_runs (run_id) ); CREATE TABLE IF NOT EXISTS output_files ( path TEXT NOT NULL, idx INTEGER NOT NULL, run_id INTEGER NOT NULL, seq INTEGER NOT NULL, PRIMARY KEY (path, run_id), FOREIGN KEY (idx, run_id) REFERENCES index_results (idx, run_id) ); CREATE TABLE IF NOT EXISTS filter_artifacts ( path TEXT NOT NULL, idx INTEGER NOT NULL, run_id INTEGER NOT NULL, filter_name TEXT NOT NULL, filter_order INTEGER NOT NULL, PRIMARY KEY (path, idx, run_id), FOREIGN KEY (idx, run_id) REFERENCES index_results (idx, run_id) ); """ # --------------------------------------------------------------------------- # PipelineStore # ---------------------------------------------------------------------------
[docs] class PipelineStore: """SQLite-backed store combining checkpoint tracking, metrics, provenance, and worker progress. Manages a single database with six tables: ``pipeline_runs``, ``index_results``, ``stage_metrics``, ``output_files``, ``filter_artifacts``, and ``workers``. Supports checkpoint resumption via config hashing, per-index success/error recording, aggregated metrics queries, and live worker progress tracking. Parameters ---------- db_path : pathlib.Path Path to the SQLite database file. Created automatically if it does not exist. pipeline_config : dict Full pipeline configuration dictionary (from :func:`_pipeline_config`). config_hash : str SHA-256 hex hash of the pipeline configuration. Examples -------- >>> config = _pipeline_config(pipeline) >>> chash = _config_hash(config) >>> store = PipelineStore(Path("run.db"), config, chash) >>> store.is_completed(0) # None if not yet done >>> store.record_success(0, ["/out/0.vtk"], 1_000_000, 4096, None, []) """ def __init__(self, db_path: pathlib.Path, pipeline_config: dict, config_hash: str) -> None: """Initialize the pipeline store. Parameters ---------- db_path : pathlib.Path Path to the SQLite database file. pipeline_config : dict Full pipeline configuration dictionary. config_hash : str SHA-256 hex hash of the pipeline configuration. """ self._db_path = db_path self._pipeline_config = pipeline_config self._config_hash = config_hash self._db_path.parent.mkdir(parents=True, exist_ok=True) self._init_db()
[docs] @classmethod def from_db(cls, db_path: str | pathlib.Path) -> PipelineStore: """Open an existing pipeline database in read-only mode. This is the entry point for the dashboard and post-hoc analysis tools. It reads the ``pipeline_runs`` table to recover ``config_hash`` and ``pipeline_config``, so the caller does not need to know them. Parameters ---------- db_path : str or pathlib.Path Path to an existing ``.db`` file produced by a pipeline run. Returns ------- PipelineStore A store instance backed by the existing database. Raises ------ FileNotFoundError If *db_path* does not exist. ValueError If the database contains no pipeline run records. """ db_path = pathlib.Path(db_path) if not db_path.exists(): msg = f"Database file not found: {db_path}" raise FileNotFoundError(msg) conn = sqlite3.connect(str(db_path), timeout=30) try: row = conn.execute( "SELECT config_hash, config_json FROM pipeline_runs ORDER BY run_id DESC LIMIT 1", ).fetchone() finally: conn.close() if row is None: msg = f"No pipeline run records found in {db_path}" raise ValueError(msg) config_hash, config_json = row pipeline_config = json.loads(config_json) return cls(db_path, pipeline_config, config_hash)
def _connect(self) -> sqlite3.Connection: """Open a WAL-mode connection to the database. Retries the ``PRAGMA journal_mode=WAL`` statement with exponential backoff to handle concurrent process initialization on Windows, where the WAL mode switch requires a brief exclusive lock that ``busy_timeout`` alone does not reliably cover. Returns ------- sqlite3.Connection Database connection with WAL journal mode and busy timeout. """ import time conn = sqlite3.connect(str(self._db_path), timeout=30) conn.execute("PRAGMA busy_timeout=30000") max_retries = 10 delay = 0.05 # 50 ms initial backoff for attempt in range(max_retries): try: conn.execute("PRAGMA journal_mode=WAL") break except sqlite3.OperationalError: if attempt == max_retries - 1: raise time.sleep(delay) delay = min(delay * 2, 2.0) return conn def _init_db(self) -> None: """Create schema and register or resume a pipeline run by config hash. Uses INSERT OR IGNORE to handle concurrent init from multiple threads/processes safely (avoids TOCTOU race on the UNIQUE ``config_hash`` column). """ conn = self._connect() try: conn.executescript(_SCHEMA_SQL) # Atomically insert if not exists, then SELECT to get run_id. # This avoids the TOCTOU race where two threads both see no # existing row and both try to INSERT. now = datetime.now(tz=UTC).isoformat() conn.execute( "INSERT OR IGNORE INTO pipeline_runs (config_hash, config_json, started_at) VALUES (?, ?, ?)", (self._config_hash, json.dumps(self._pipeline_config, sort_keys=True, default=str), now), ) row = conn.execute( "SELECT run_id FROM pipeline_runs WHERE config_hash = ?", (self._config_hash,), ).fetchone() self._run_id: int = row[0] conn.commit() logger.info("Pipeline run_id=%d (config hash %s...)", self._run_id, self._config_hash[:12]) finally: conn.close()
[docs] def is_completed(self, index: int) -> list[str] | None: """Check if an index has been completed successfully. Parameters ---------- index : int Source index to check. Returns ------- list[str] | None Cached output paths if completed, ``None`` otherwise. """ conn = self._connect() try: row = conn.execute( "SELECT output_paths FROM index_results WHERE idx = ? AND run_id = ? AND status = 'completed'", (index, self._run_id), ).fetchone() if row is not None: return json.loads(row[0]) # type: ignore[no-any-return] return None finally: conn.close()
[docs] def record_success( self, index: int, output_paths: list[str], wall_time_ns: int, peak_memory_bytes: int, gpu_memory_bytes: int | None, stages: list[StageMetrics], ) -> None: """Record a successfully completed index with metrics. Parameters ---------- index : int Source index that completed. output_paths : list[str] File paths written by the sink. wall_time_ns : int Total wall-clock time in nanoseconds. peak_memory_bytes : int Peak memory usage in bytes. gpu_memory_bytes : int | None Peak GPU memory delta, or ``None``. stages : list[StageMetrics] Per-stage timing breakdown. """ now = datetime.now(tz=UTC).isoformat() conn = self._connect() try: conn.execute( "INSERT OR REPLACE INTO index_results " "(idx, run_id, status, output_paths, completed_at, wall_time_ns, " "peak_memory_bytes, gpu_memory_bytes, error) " "VALUES (?, ?, 'completed', ?, ?, ?, ?, ?, NULL)", ( index, self._run_id, json.dumps(output_paths), now, wall_time_ns, peak_memory_bytes, gpu_memory_bytes, ), ) # Delete any existing stage_metrics for this index/run before inserting conn.execute( "DELETE FROM stage_metrics WHERE idx = ? AND run_id = ?", (index, self._run_id), ) for order, stage in enumerate(stages): conn.execute( "INSERT INTO stage_metrics (idx, run_id, stage_order, stage_name, wall_time_ns) " "VALUES (?, ?, ?, ?, ?)", (index, self._run_id, order, stage.name, stage.wall_time_ns), ) # Populate the normalized output_files table for reverse lookup conn.execute( "DELETE FROM output_files WHERE idx = ? AND run_id = ?", (index, self._run_id), ) for seq, path in enumerate(output_paths): conn.execute( "INSERT OR REPLACE INTO output_files (path, idx, run_id, seq) VALUES (?, ?, ?, ?)", (path, index, self._run_id, seq), ) conn.commit() finally: conn.close()
[docs] def record_error(self, index: int, error: str, wall_time_ns: int) -> None: """Record a failed index execution. Parameters ---------- index : int Source index that failed. error : str Error message. wall_time_ns : int Wall-clock time before the error in nanoseconds. """ now = datetime.now(tz=UTC).isoformat() conn = self._connect() try: conn.execute( "INSERT OR REPLACE INTO index_results " "(idx, run_id, status, output_paths, completed_at, wall_time_ns, " "peak_memory_bytes, gpu_memory_bytes, error) " "VALUES (?, ?, 'error', NULL, ?, ?, NULL, NULL, ?)", (index, self._run_id, now, wall_time_ns, error), ) conn.commit() finally: conn.close()
[docs] def completed_indices(self) -> set[int]: """Return the set of successfully completed indices for this run. Returns ------- set[int] Indices with ``status='completed'``. """ conn = self._connect() try: rows = conn.execute( "SELECT idx FROM index_results WHERE run_id = ? AND status = 'completed'", (self._run_id,), ).fetchall() return {r[0] for r in rows} finally: conn.close()
[docs] def failed_indices(self) -> dict[int, str]: """Return indices that failed with their error messages. Returns ------- dict[int, str] Mapping from index to error message string. """ conn = self._connect() try: rows = conn.execute( "SELECT idx, error FROM index_results WHERE run_id = ? AND status = 'error'", (self._run_id,), ).fetchall() return {r[0]: r[1] for r in rows} finally: conn.close()
[docs] def remaining_indices(self, total: int) -> list[int]: """Return indices not yet completed or failed for this run. Parameters ---------- total : int Total number of source indices. Returns ------- list[int] Sorted list of indices still needing processing. """ done = self.completed_indices() | set(self.failed_indices().keys()) return sorted(i for i in range(total) if i not in done)
[docs] def summary(self, total: int) -> dict[str, Any]: """Return a summary of the store state. Parameters ---------- total : int Total number of source indices. Returns ------- dict[str, Any] Dictionary with keys: ``total``, ``completed``, ``failed``, ``remaining``, ``config_hash``, ``db_path``, ``total_elapsed_s``, ``workers``. """ conn = self._connect() try: completed = conn.execute( "SELECT COUNT(*) FROM index_results WHERE run_id = ? AND status = 'completed'", (self._run_id,), ).fetchone()[0] failed = conn.execute( "SELECT COUNT(*) FROM index_results WHERE run_id = ? AND status = 'error'", (self._run_id,), ).fetchone()[0] elapsed_row = conn.execute( "SELECT COALESCE(SUM(wall_time_ns), 0) FROM index_results WHERE run_id = ? AND status = 'completed'", (self._run_id,), ).fetchone() total_elapsed_ns: int = elapsed_row[0] worker_count = conn.execute( "SELECT COUNT(*) FROM workers WHERE run_id = ?", (self._run_id,), ).fetchone()[0] finally: conn.close() return { "total": total, "completed": completed, "failed": failed, "remaining": total - completed - failed, "config_hash": self._config_hash, "db_path": str(self._db_path), "total_elapsed_s": total_elapsed_ns / 1e9, "workers": worker_count, }
[docs] def reset(self) -> None: """Clear all records for this run and re-register. Deletes all index results, stage metrics, and pipeline run metadata from the database. The database file is kept and a fresh run is registered. """ conn = self._connect() try: conn.execute("DELETE FROM filter_artifacts WHERE run_id = ?", (self._run_id,)) conn.execute("DELETE FROM output_files WHERE run_id = ?", (self._run_id,)) conn.execute("DELETE FROM stage_metrics WHERE run_id = ?", (self._run_id,)) conn.execute("DELETE FROM index_results WHERE run_id = ?", (self._run_id,)) conn.execute("DELETE FROM workers WHERE run_id = ?", (self._run_id,)) conn.execute("DELETE FROM pipeline_runs WHERE run_id = ?", (self._run_id,)) conn.commit() finally: conn.close() # Re-initialize to register a fresh run self._init_db() logger.info("Pipeline store reset — all records cleared for run_id=%d", self._run_id)
[docs] def reset_index(self, index: int) -> None: """Remove records for a single index from this run. Parameters ---------- index : int Source index to remove. """ conn = self._connect() try: conn.execute( "DELETE FROM filter_artifacts WHERE idx = ? AND run_id = ?", (index, self._run_id), ) conn.execute( "DELETE FROM output_files WHERE idx = ? AND run_id = ?", (index, self._run_id), ) conn.execute( "DELETE FROM stage_metrics WHERE idx = ? AND run_id = ?", (index, self._run_id), ) conn.execute( "DELETE FROM index_results WHERE idx = ? AND run_id = ?", (index, self._run_id), ) conn.commit() finally: conn.close()
# -- Output file lookup ------------------------------------------------------
[docs] def index_for_path(self, path: str) -> int | None: """Find which source index produced a given output file. Parameters ---------- path : str Output file path to look up. Returns ------- int | None Source index that produced the file, or ``None`` if not found. """ conn = self._connect() try: row = conn.execute( "SELECT idx FROM output_files WHERE path = ? AND run_id = ?", (path, self._run_id), ).fetchone() return row[0] if row is not None else None finally: conn.close()
[docs] def output_paths_for_index(self, index: int) -> list[str]: """Return the output file paths produced by a given source index. Parameters ---------- index : int Source index to query. Returns ------- list[str] Output file paths ordered by sequence, or empty list if none. """ conn = self._connect() try: rows = conn.execute( "SELECT path FROM output_files WHERE idx = ? AND run_id = ? ORDER BY seq", (index, self._run_id), ).fetchall() return [r[0] for r in rows] finally: conn.close()
# -- Filter artifact tracking ------------------------------------------------
[docs] def record_filter_artifacts( self, index: int, filter_name: str, filter_order: int, paths: list[str], ) -> None: """Record file artifacts produced by a filter for a given index. Parameters ---------- index : int Source index that was processed. filter_name : str Human-readable name of the filter. filter_order : int Position of the filter in the pipeline (0-indexed). paths : list[str] File paths produced by the filter for this index. """ if not paths: return conn = self._connect() try: for path in paths: conn.execute( "INSERT OR REPLACE INTO filter_artifacts " "(path, idx, run_id, filter_name, filter_order) " "VALUES (?, ?, ?, ?, ?)", (path, index, self._run_id, filter_name, filter_order), ) conn.commit() finally: conn.close()
[docs] def filter_artifacts_for_index(self, index: int) -> dict[str, list[str]]: """Return filter artifact paths for a given source index. Parameters ---------- index : int Source index to query. Returns ------- dict[str, list[str]] Mapping of filter name to list of artifact paths. """ conn = self._connect() try: rows = conn.execute( "SELECT filter_name, path FROM filter_artifacts " "WHERE idx = ? AND run_id = ? ORDER BY filter_order, path", (index, self._run_id), ).fetchall() result: dict[str, list[str]] = {} for name, path in rows: result.setdefault(name, []).append(path) return result finally: conn.close()
[docs] def all_filter_artifacts(self) -> dict[str, list[str]]: """Return all filter artifact paths grouped by filter name. Returns ------- dict[str, list[str]] Mapping of filter name to list of all artifact paths. """ conn = self._connect() try: rows = conn.execute( "SELECT filter_name, path FROM filter_artifacts WHERE run_id = ? ORDER BY filter_order, idx, path", (self._run_id,), ).fetchall() result: dict[str, list[str]] = {} for name, path in rows: result.setdefault(name, []).append(path) return result finally: conn.close()
# -- Worker progress tracking ------------------------------------------------
[docs] def register_worker(self, worker_id: str, pid: int, hostname: str) -> None: """Register a worker or update its heartbeat if already known. Parameters ---------- worker_id : str Unique identifier for this worker (UUID hex). pid : int OS process ID of the worker. hostname : str Hostname of the machine running the worker. """ now = datetime.now(tz=UTC).isoformat() conn = self._connect() try: conn.execute( "INSERT OR IGNORE INTO workers " "(worker_id, run_id, pid, hostname, started_at, last_heartbeat, current_index) " "VALUES (?, ?, ?, ?, ?, ?, NULL)", (worker_id, self._run_id, pid, hostname, now, now), ) conn.execute( "UPDATE workers SET last_heartbeat = ? WHERE worker_id = ?", (now, worker_id), ) conn.commit() finally: conn.close()
[docs] def worker_start_index(self, worker_id: str, index: int) -> None: """Record that a worker is starting to process an index. Parameters ---------- worker_id : str Unique identifier for this worker. index : int Source index being processed. """ now = datetime.now(tz=UTC).isoformat() conn = self._connect() try: conn.execute( "UPDATE workers SET current_index = ?, last_heartbeat = ? WHERE worker_id = ?", (index, now, worker_id), ) conn.commit() finally: conn.close()
[docs] def worker_finish_index(self, worker_id: str) -> None: """Record that a worker has finished processing its current index. Parameters ---------- worker_id : str Unique identifier for this worker. """ now = datetime.now(tz=UTC).isoformat() conn = self._connect() try: conn.execute( "UPDATE workers SET current_index = NULL, last_heartbeat = ? WHERE worker_id = ?", (now, worker_id), ) conn.commit() finally: conn.close()
[docs] def active_workers(self) -> list[dict[str, Any]]: """Return all workers registered for this pipeline run. Returns ------- list[dict[str, Any]] List of worker dictionaries with keys: ``worker_id``, ``pid``, ``hostname``, ``started_at``, ``last_heartbeat``, ``current_index``. """ conn = self._connect() try: rows = conn.execute( "SELECT worker_id, pid, hostname, started_at, last_heartbeat, current_index " "FROM workers WHERE run_id = ? ORDER BY started_at", (self._run_id,), ).fetchall() return [ { "worker_id": r[0], "pid": r[1], "hostname": r[2], "started_at": r[3], "last_heartbeat": r[4], "current_index": r[5], } for r in rows ] finally: conn.close()
[docs] def metrics(self) -> PipelineMetrics: """Build aggregated metrics from the database. Returns ------- PipelineMetrics Aggregated metrics across all completed indices in this run. """ conn = self._connect() try: rows = conn.execute( "SELECT idx, wall_time_ns, peak_memory_bytes, gpu_memory_bytes " "FROM index_results WHERE run_id = ? AND status = 'completed' " "ORDER BY idx", (self._run_id,), ).fetchall() index_metrics_list: list[IndexMetrics] = [] for row in rows: idx, wall_ns, peak_mem, gpu_mem = row # Fetch stage metrics for this index stage_rows = conn.execute( "SELECT stage_name, wall_time_ns FROM stage_metrics " "WHERE idx = ? AND run_id = ? ORDER BY stage_order", (idx, self._run_id), ).fetchall() stages = [StageMetrics(name=sr[0], wall_time_ns=sr[1]) for sr in stage_rows] index_metrics_list.append( IndexMetrics( index=idx, stages=stages, wall_time_ns=wall_ns, peak_memory_bytes=peak_mem, gpu_memory_bytes=gpu_mem, ) ) return PipelineMetrics(indices=index_metrics_list) finally: conn.close()
[docs] def index_metrics(self, index: int) -> IndexMetrics | None: """Retrieve metrics for a single index. Parameters ---------- index : int Source index to query. Returns ------- IndexMetrics | None Metrics for the index, or ``None`` if not found. """ conn = self._connect() try: row = conn.execute( "SELECT wall_time_ns, peak_memory_bytes, gpu_memory_bytes " "FROM index_results WHERE idx = ? AND run_id = ? AND status = 'completed'", (index, self._run_id), ).fetchone() if row is None: return None wall_ns, peak_mem, gpu_mem = row stage_rows = conn.execute( "SELECT stage_name, wall_time_ns FROM stage_metrics WHERE idx = ? AND run_id = ? ORDER BY stage_order", (index, self._run_id), ).fetchall() stages = [StageMetrics(name=sr[0], wall_time_ns=sr[1]) for sr in stage_rows] return IndexMetrics( index=index, stages=stages, wall_time_ns=wall_ns, peak_memory_bytes=peak_mem, gpu_memory_bytes=gpu_mem, ) finally: conn.close()