# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for pipeline components and the Pipeline builder."""
from __future__ import annotations
import logging
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
import pathlib
from collections.abc import Generator, Iterator
from physicsnemo_curator.core.pipeline_store import PipelineMetrics, PipelineStore
logger = logging.getLogger(__name__)
# Sentinel for required parameters (no default).
_REQUIRED = object()
REQUIRED: Any = _REQUIRED
"""Sentinel value indicating a :class:`Param` has no default and must be provided."""
[docs]
@dataclass(frozen=True)
class Param:
"""Descriptor for a configurable parameter on a pipeline component.
Parameters
----------
name : str
Parameter name (should match the ``__init__`` keyword argument).
description : str
Human-readable help text shown in the interactive CLI.
type : type
Expected Python type (``str``, ``int``, ``float``, ``pathlib.Path``, …).
default : Any
Default value. Use :data:`REQUIRED` (the default) to indicate the
parameter must be supplied by the user.
choices : list[str] | None
If not *None*, the CLI will present a selection prompt instead of
free-text input.
"""
name: str
description: str
type: type = str # ty: ignore[invalid-type-form]
default: Any = REQUIRED
choices: list[str] | None = None
@property
def required(self) -> bool:
"""Return ``True`` if this parameter has no default value."""
return self.default is _REQUIRED
# ---------------------------------------------------------------------------
# Source
# ---------------------------------------------------------------------------
[docs]
class Source[T](ABC):
"""Abstract data source that yields items of type *T*.
A source represents a collection of data items (e.g. files on disk).
Each item is accessed by integer index and may yield one or more *T*
objects (generator semantics allow a single source item to expand into
multiple outputs).
Subclasses must set the class-level :attr:`name` and :attr:`description`
attributes and implement :meth:`params`, :meth:`__len__`, and
:meth:`__getitem__`.
Examples
--------
>>> pipeline = MySource(path="/data").filter(MyFilter()).write(MySink())
>>> pipeline[0] # process first source item lazily
"""
name: ClassVar[str]
"""Human-readable display name for the interactive CLI."""
description: ClassVar[str]
"""Short description shown in the interactive CLI."""
[docs]
@classmethod
@abstractmethod
def params(cls) -> list[Param]:
"""Declare the configurable parameters for this source.
Returns
-------
list[Param]
Ordered list of parameter descriptors.
"""
...
@abstractmethod
def __len__(self) -> int:
"""Return the number of items available in this source."""
...
@abstractmethod
def __getitem__(self, index: int) -> Generator[T]:
"""Yield one or more *T* items for the given *index*.
Parameters
----------
index : int
Zero-based index into the source's item collection.
Yields
------
T
Data item(s) produced from the source at *index*.
"""
...
# -- Convenience builder methods -----------------------------------------
[docs]
def filter(self, f: Filter[T]) -> Pipeline[T]:
"""Create a :class:`Pipeline` with this source and a single filter.
Parameters
----------
f : Filter[T]
The filter to append.
Returns
-------
Pipeline[T]
A new pipeline containing this source and the given filter.
"""
return Pipeline(source=self, filters=[f])
[docs]
def write(self, s: Sink[T]) -> Pipeline[T]:
"""Create a :class:`Pipeline` with this source and a sink (no filters).
If the sink exposes a ``set_source`` method, the source is
automatically injected so the sink can resolve naming
placeholders (e.g. ``{relpath}``, ``{stem}``) from the source.
Parameters
----------
s : Sink[T]
The sink to attach.
Returns
-------
Pipeline[T]
A new pipeline containing this source and the given sink.
"""
if hasattr(s, "set_source"):
s.set_source(self) # ty: ignore[call-non-callable]
return Pipeline(source=self, sink=s)
# ---------------------------------------------------------------------------
# Filter
# ---------------------------------------------------------------------------
[docs]
class Filter[T](ABC):
"""Abstract filter/transform that processes a stream of *T* items.
Filters receive a generator of items and yield zero or more items per
input (full generator semantics — can expand, contract, or pass through).
Subclasses must set :attr:`name` and :attr:`description` and implement
:meth:`params` and :meth:`__call__`.
"""
name: ClassVar[str]
"""Human-readable display name for the interactive CLI."""
description: ClassVar[str]
"""Short description shown in the interactive CLI."""
[docs]
@classmethod
@abstractmethod
def params(cls) -> list[Param]:
"""Declare the configurable parameters for this filter.
Returns
-------
list[Param]
Ordered list of parameter descriptors.
"""
...
@abstractmethod
def __call__(self, items: Generator[T]) -> Generator[T]:
"""Process a stream of items, yielding transformed results.
Parameters
----------
items : Generator[T]
Incoming stream of data items.
Yields
------
T
Transformed data item(s).
"""
...
[docs]
def artifacts(self) -> list[str]:
"""Return paths of files produced by this filter since the last call.
Stateful filters that write side-effect files (statistics, logs,
etc.) should override this to report the paths written during the
most recent :meth:`flush` or :meth:`__call__` cycle. The
framework calls this after each index to record filter artifacts
in the pipeline store.
The default implementation returns an empty list, which is
correct for stateless (pass-through) filters.
Returns
-------
list[str]
Paths of files written, or ``[]`` if none.
"""
return []
# ---------------------------------------------------------------------------
# Sink
# ---------------------------------------------------------------------------
[docs]
class Sink[T](ABC):
"""Abstract sink that persists items and returns output file paths.
The sink consumes a generator of items and writes each one to storage,
returning the file paths of the written outputs.
Subclasses must set :attr:`name` and :attr:`description` and implement
:meth:`params` and :meth:`__call__`.
"""
name: ClassVar[str]
"""Human-readable display name for the interactive CLI."""
description: ClassVar[str]
"""Short description shown in the interactive CLI."""
[docs]
@classmethod
@abstractmethod
def params(cls) -> list[Param]:
"""Declare the configurable parameters for this sink.
Returns
-------
list[Param]
Ordered list of parameter descriptors.
"""
...
@abstractmethod
def __call__(self, items: Iterator[T], index: int) -> list[str]:
"""Consume items and persist them to storage.
Parameters
----------
items : Iterator[T]
Stream of data items to write.
index : int
Source index being processed (useful for naming output files).
Returns
-------
list[str]
Paths of the files written.
"""
...
# ---------------------------------------------------------------------------
# Pipeline
# ---------------------------------------------------------------------------
[docs]
@dataclass
class Pipeline[T]:
"""Lazy pipeline that chains a source through filters into a sink.
The pipeline is built incrementally using the :meth:`filter` and
:meth:`write` builder methods. Execution is deferred until the
pipeline is indexed with ``pipeline[i]``, which processes only
the *i*-th source item.
Parameters
----------
source : Source[T]
The data source.
filters : list[Filter[T]]
Ordered list of filters to apply.
sink : Sink[T] | None
Optional sink for writing output.
Examples
--------
>>> pipeline = (
... MySource(path="/data")
... .filter(FilterA())
... .filter(FilterB())
... .write(MySink(output="/out"))
... )
>>> pipeline[0] # lazily process source item 0
['/out/item_0']
"""
source: Source[T]
filters: list[Filter[T]] = field(default_factory=list)
sink: Sink[T] | None = None
track_metrics: bool = True
track_memory: bool = True
track_gpu: bool = False
db_dir: pathlib.Path | None = None
_store: PipelineStore | None = field(default=None, init=False, repr=False, compare=False)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False, compare=False)
[docs]
def filter(self, f: Filter[T]) -> Pipeline[T]:
"""Return a new pipeline with an additional filter appended.
Parameters
----------
f : Filter[T]
The filter to append.
Returns
-------
Pipeline[T]
A new pipeline instance (the original is unchanged).
"""
return Pipeline(
source=self.source,
filters=[*self.filters, f],
sink=self.sink,
track_metrics=self.track_metrics,
track_memory=self.track_memory,
track_gpu=self.track_gpu,
db_dir=self.db_dir,
)
[docs]
def write(self, s: Sink[T]) -> Pipeline[T]:
"""Return a new pipeline with the given sink attached.
If the sink exposes a ``set_source`` method, the pipeline's
source is automatically injected so the sink can resolve
naming placeholders (e.g. ``{relpath}``, ``{stem}``).
Parameters
----------
s : Sink[T]
The sink to attach.
Returns
-------
Pipeline[T]
A new pipeline instance (the original is unchanged).
"""
if hasattr(s, "set_source"):
s.set_source(self.source) # ty: ignore[call-non-callable]
return Pipeline(
source=self.source,
filters=list(self.filters),
sink=s,
track_metrics=self.track_metrics,
track_memory=self.track_memory,
track_gpu=self.track_gpu,
db_dir=self.db_dir,
)
def __len__(self) -> int:
"""Return the number of items in the source."""
return len(self.source)
def __getitem__(self, index: int) -> list[str]:
"""Lazily process the *index*-th source item through the full chain.
When :attr:`track_metrics` is ``True``, each stage is wrapped with
:class:`~physicsnemo_curator.core.pipeline_store._TimedGenerator` for
per-stage timing, memory tracking via ``tracemalloc``, and optional
GPU memory tracking. Results and errors are recorded in the
:class:`~physicsnemo_curator.core.pipeline_store.PipelineStore`.
Parameters
----------
index : int
Zero-based index into the source.
Returns
-------
list[str]
File paths produced by the sink.
Raises
------
RuntimeError
If no sink has been attached to the pipeline.
IndexError
If *index* is out of range.
"""
if self.sink is None:
msg = "Pipeline has no sink. Call .write(sink) before indexing."
raise RuntimeError(msg)
n = len(self.source)
if index < 0:
index += n
if index < 0 or index >= n:
msg = f"Index {index} out of range for source with {n} items."
raise IndexError(msg)
# Fast path: no instrumentation
if not self.track_metrics:
stream: Generator[T] = self.source[index]
for f in self.filters:
stream = f(stream)
return self.sink(stream, index)
# Instrumented path
return self._getitem_instrumented(index)
def _getitem_instrumented(self, index: int) -> list[str]:
"""Execute index with full metrics instrumentation.
Parameters
----------
index : int
Validated, non-negative index into the source.
Returns
-------
list[str]
File paths produced by the sink.
"""
import os
import socket
import time
import tracemalloc
from physicsnemo_curator.core.pipeline_store import (
StageMetrics,
_get_worker_id,
_TimedGenerator,
)
store = self._get_store()
# --- Worker registration ---
worker_id = _get_worker_id()
store.register_worker(worker_id, os.getpid(), socket.gethostname())
store.worker_start_index(worker_id, index)
# Checkpoint hit — return cached paths
cached = store.is_completed(index)
if cached is not None:
logger.debug("Checkpoint hit for index %d — returning cached paths", index)
store.worker_finish_index(worker_id)
return cached
# --- GPU baseline ---
gpu_baseline: int | None = None
if self.track_gpu:
gpu_baseline = Pipeline._gpu_setup()
# --- Memory tracking ---
was_tracing = tracemalloc.is_tracing()
if self.track_memory:
if not was_tracing:
tracemalloc.start()
tracemalloc.reset_peak()
overall_start = time.perf_counter_ns()
started_tracemalloc = self.track_memory and not was_tracing
try:
# 1. Wrap source generator with timing
source_gen = self.source[index]
timed_source: _TimedGenerator[T] = _TimedGenerator(source_gen)
# 2. Chain through filters, wrapping each output
filter_wrappers: list[_TimedGenerator[T]] = []
current_stream: _TimedGenerator[T] = timed_source
for f in self.filters:
raw_output = f(current_stream) # type: ignore[arg-type] # ty: ignore[invalid-argument-type]
wrapped: _TimedGenerator[T] = _TimedGenerator(raw_output)
filter_wrappers.append(wrapped)
current_stream = wrapped
# 3. Run the sink (forces full chain evaluation)
assert self.sink is not None # guaranteed by caller
result = self.sink(current_stream, index)
overall_elapsed = time.perf_counter_ns() - overall_start
# 4. Compute per-stage times using chain subtraction
stage_metrics: list[StageMetrics] = []
source_time = timed_source.elapsed_ns
stage_metrics.append(StageMetrics(name="source", wall_time_ns=source_time))
prev_elapsed = source_time
for i_f, fw in enumerate(filter_wrappers):
filter_own_time = max(0, fw.elapsed_ns - prev_elapsed)
fname = type(self.filters[i_f]).name
stage_metrics.append(StageMetrics(name=fname, wall_time_ns=filter_own_time))
prev_elapsed = fw.elapsed_ns
last_elapsed = filter_wrappers[-1].elapsed_ns if filter_wrappers else source_time
sink_own_time = max(0, overall_elapsed - last_elapsed)
stage_metrics.append(StageMetrics(name="sink", wall_time_ns=sink_own_time))
# 5. Memory measurement
peak_memory: int = 0
if self.track_memory:
_, peak_memory = tracemalloc.get_traced_memory()
# 6. GPU measurement
gpu_delta: int | None = None
if self.track_gpu and gpu_baseline is not None:
gpu_delta = Pipeline._gpu_measure(gpu_baseline)
# 7. Record success
store.record_success(index, result, overall_elapsed, peak_memory, gpu_delta, stage_metrics)
return result
except Exception as exc:
elapsed = time.perf_counter_ns() - overall_start
store.record_error(index, str(exc), elapsed)
raise
finally:
store.worker_finish_index(worker_id)
if started_tracemalloc:
tracemalloc.stop()
def _get_store(self) -> PipelineStore:
"""Lazily create and return the pipeline store.
Thread-safe via a lock to prevent multiple stores from being
created when threads race on ``_get_store()``.
The database path is resolved in priority order:
1. ``db_dir`` field (explicit per-pipeline override)
2. :func:`~physicsnemo_curator.core.cache.default_cache_dir` which
honours the ``PSNC_CACHE_DIR`` environment variable, then
``$XDG_CACHE_HOME/psnc/``, then ``~/.cache/psnc/``
Returns
-------
PipelineStore
The SQLite-backed pipeline store for this pipeline.
"""
if self._store is not None:
return self._store
with self._lock:
# Double-check after acquiring lock
if self._store is not None:
return self._store
import pathlib
from physicsnemo_curator.core.cache import default_cache_dir
from physicsnemo_curator.core.pipeline_store import (
PipelineStore,
_config_hash,
_pipeline_config,
)
config = _pipeline_config(self)
hash_ = _config_hash(config)
if self.db_dir is not None:
db_path = pathlib.Path(self.db_dir) / f"{hash_[:16]}.db"
else:
db_path = default_cache_dir() / f"{hash_[:16]}.db"
self._store = PipelineStore(db_path=db_path, pipeline_config=config, config_hash=hash_)
return self._store
@staticmethod
def _gpu_setup() -> int | None:
"""Reset GPU peak stats and return baseline memory.
Returns
-------
int | None
Baseline GPU memory in bytes, or ``None`` if unavailable.
"""
try:
import torch
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
return torch.cuda.memory_allocated()
except ImportError:
pass
return None
@staticmethod
def _gpu_measure(baseline: int) -> int:
"""Measure peak GPU memory delta from baseline.
Parameters
----------
baseline : int
GPU memory at start of ``__getitem__``.
Returns
-------
int
Peak GPU memory minus baseline (bytes).
"""
import torch
return torch.cuda.max_memory_allocated() - baseline
def __getstate__(self) -> dict[str, Any]:
"""Return picklable state, dropping the non-serializable store and lock.
Returns
-------
dict[str, Any]
Instance state with ``_store`` and ``_lock`` excluded.
"""
state = self.__dict__.copy()
state["_store"] = None
state.pop("_lock", None)
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""Restore state from pickle, ensuring ``_store`` is ``None``.
Parameters
----------
state : dict[str, Any]
Pickled state dictionary.
"""
state["_store"] = None
state["_lock"] = threading.Lock()
self.__dict__.update(state)
# -- Query API (delegates to store) ----------------------------------------
def _require_metrics(self) -> PipelineStore:
"""Return the store or raise if metrics are disabled.
Returns
-------
PipelineStore
The pipeline store.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
if not self.track_metrics:
msg = "Pipeline metrics are disabled (track_metrics=False)"
raise RuntimeError(msg)
return self._get_store()
@property
def completed_indices(self) -> set[int]:
"""Return the set of successfully completed indices.
Returns
-------
set[int]
Indices with recorded successful completions.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().completed_indices()
@property
def db_path(self) -> pathlib.Path | None:
"""Return the resolved database path, or ``None`` if metrics are disabled.
Returns
-------
pathlib.Path or None
Absolute path to the SQLite database file, or ``None`` when
``track_metrics`` is ``False``.
"""
if not self.track_metrics:
return None
return self._get_store()._db_path
@property
def failed_indices(self) -> dict[int, str]:
"""Return indices that failed with their error messages.
Returns
-------
dict[int, str]
Mapping from index to error message string.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().failed_indices()
@property
def metrics(self) -> PipelineMetrics:
"""Return aggregated metrics from the store.
Returns
-------
PipelineMetrics
Aggregated metrics across all completed indices.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().metrics()
@property
def active_workers(self) -> list[dict[str, Any]]:
"""Return all workers registered for this pipeline run.
Returns
-------
list[dict[str, Any]]
List of worker dictionaries with keys: ``worker_id``, ``pid``,
``hostname``, ``started_at``, ``last_heartbeat``, ``current_index``.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().active_workers()
[docs]
def remaining_indices(self) -> list[int]:
"""Return indices not yet completed or failed.
Returns
-------
list[int]
Sorted list of indices still needing processing.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
store = self._require_metrics()
return store.remaining_indices(len(self.source))
[docs]
def summary(self) -> dict[str, Any]:
"""Return a summary of the store state.
Returns
-------
dict[str, Any]
Dictionary with ``total``, ``completed``, ``failed``,
``remaining``, ``config_hash``, ``db_path``, ``total_elapsed_s``.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
store = self._require_metrics()
return store.summary(len(self.source))
[docs]
def reset(self) -> None:
"""Clear all records for this pipeline run.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
self._require_metrics().reset()
[docs]
def reset_index(self, index: int) -> None:
"""Remove records for a single index.
Parameters
----------
index : int
Source index to remove.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
self._require_metrics().reset_index(index)
[docs]
def index_for_path(self, path: str) -> int | None:
"""Find which source index produced a given output file.
Parameters
----------
path : str
Output file path to look up.
Returns
-------
int | None
Source index that produced the file, or ``None`` if not found.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().index_for_path(path)
[docs]
def output_paths_for_index(self, index: int) -> list[str]:
"""Return the output file paths produced by a given source index.
Parameters
----------
index : int
Source index to query.
Returns
-------
list[str]
Output file paths ordered by sequence, or empty list if none.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().output_paths_for_index(index)
[docs]
def filter_artifacts_for_index(self, index: int) -> dict[str, list[str]]:
"""Return filter artifact paths for a given source index.
Parameters
----------
index : int
Source index to query.
Returns
-------
dict[str, list[str]]
Mapping of filter name to list of artifact paths.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().filter_artifacts_for_index(index)
[docs]
def all_filter_artifacts(self) -> dict[str, list[str]]:
"""Return all filter artifact paths grouped by filter name.
Returns
-------
dict[str, list[str]]
Mapping of filter name to list of all artifact paths.
Raises
------
RuntimeError
If ``track_metrics`` is ``False``.
"""
return self._require_metrics().all_filter_artifacts()
[docs]
def save(self, path: str | pathlib.Path) -> None:
"""Save this pipeline's configuration to a YAML or JSON file.
The file format is determined by the extension:
``.yaml`` / ``.yml`` → YAML, ``.json`` → JSON.
Parameters
----------
path : str | pathlib.Path
Destination file path.
Raises
------
ValueError
If the file extension is not supported.
See Also
--------
Pipeline.load : Restore a pipeline from a saved file.
"""
from physicsnemo_curator.core.serialization import save_pipeline
save_pipeline(self, path)
[docs]
@classmethod
def load(cls, path: str | pathlib.Path) -> Pipeline[Any]:
"""Load a pipeline from a YAML or JSON configuration file.
Parameters
----------
path : str | pathlib.Path
Path to the pipeline configuration file.
Returns
-------
Pipeline
Fully constructed pipeline ready for execution.
Raises
------
FileNotFoundError
If the file does not exist.
ValueError
If the file extension is not supported.
See Also
--------
Pipeline.save : Save a pipeline configuration.
"""
from physicsnemo_curator.core.serialization import load_pipeline
return load_pipeline(path)