Source code for physicsnemo_curator.run

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline execution with pluggable backends.

This module provides :func:`run_pipeline`, which processes every index
of a :class:`~physicsnemo_curator.core.base.Pipeline` — optionally in
parallel — and returns the collected sink outputs.

Available Backends
------------------
* ``"sequential"`` — simple ``for``-loop (default when *n_jobs=1*).
* ``"thread_pool"`` — :class:`concurrent.futures.ThreadPoolExecutor`
  (good for I/O-bound tasks).
* ``"process_pool"`` — :class:`concurrent.futures.ProcessPoolExecutor`
  (true parallelism for CPU-bound tasks).
* ``"loky"`` — ``joblib.Parallel`` with the ``loky`` backend
  (requires ``joblib``).
* ``"dask"`` — ``dask.bag`` for parallel/distributed execution
  (requires ``dask``).
* ``"prefect"`` — Prefect workflow orchestration with observability
  (requires ``prefect``).
* ``"auto"`` — picks the best available backend automatically.

Custom Backends
---------------
You can register custom backends using :func:`register_backend`::

    from physicsnemo_curator.run import register_backend, RunBackend, RunConfig

    class MyBackend(RunBackend):
        name = "my_backend"
        description = "My custom execution backend"

        def run(self, pipeline, config):
            # Custom execution logic
            ...

    register_backend(MyBackend)

.. warning::

   Stateful filters (e.g. those with ``flush()`` methods) automatically
   write per-index shard files when using parallel backends.  Call
   :func:`gather_pipeline` after :func:`run_pipeline` to merge all
   shards into a single output file.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from physicsnemo_curator.run.base import RunBackend, RunConfig
from physicsnemo_curator.run.dask import DaskBackend
from physicsnemo_curator.run.loky import LokyBackend
from physicsnemo_curator.run.prefect import PrefectBackend
from physicsnemo_curator.run.process_pool import ProcessPoolBackend
from physicsnemo_curator.run.sequential import SequentialBackend
from physicsnemo_curator.run.thread_pool import ThreadPoolBackend

if TYPE_CHECKING:
    from collections.abc import Iterable

    from physicsnemo_curator.core.base import Pipeline

# ---------------------------------------------------------------------------
# Backend Registry
# ---------------------------------------------------------------------------

_BACKENDS: dict[str, type[RunBackend]] = {}


[docs] def register_backend(backend_cls: type[RunBackend]) -> None: """Register a custom execution backend. Parameters ---------- backend_cls : type[RunBackend] The backend class to register. Must have a ``name`` class attribute. Raises ------ ValueError If a backend with the same name is already registered. Examples -------- >>> from physicsnemo_curator.run import register_backend, RunBackend >>> class MyBackend(RunBackend): ... name = "my_backend" ... description = "Custom backend" ... def run(self, pipeline, config): ... return [] >>> register_backend(MyBackend) """ name = backend_cls.name if name in _BACKENDS: msg = f"Backend {name!r} is already registered." raise ValueError(msg) _BACKENDS[name] = backend_cls
[docs] def get_backend(name: str) -> RunBackend: """Get an instance of a registered backend by name. Parameters ---------- name : str The backend name. Returns ------- RunBackend An instance of the requested backend. Raises ------ ValueError If the backend is not registered. """ if name not in _BACKENDS: available = ", ".join(sorted(_BACKENDS.keys())) msg = f"Unknown backend {name!r}. Available: {available}" raise ValueError(msg) return _BACKENDS[name]()
[docs] def list_backends() -> dict[str, dict[str, Any]]: """List all registered backends and their availability. Returns ------- dict[str, dict[str, Any]] Dictionary mapping backend names to info dicts containing: - description: Human-readable description - available: Whether dependencies are installed - requires: Tuple of required packages """ return { name: { "description": cls.description, "available": cls.is_available(), "requires": cls.requires, } for name, cls in _BACKENDS.items() }
def _pick_auto_backend() -> str: """Select the best available parallel backend. Returns ------- str Name of the best available backend. """ # Priority order for parallel execution priority = ["dask", "loky", "process_pool"] for name in priority: if name in _BACKENDS and _BACKENDS[name].is_available(): return name return "process_pool" # Register built-in backends register_backend(SequentialBackend) register_backend(ThreadPoolBackend) register_backend(ProcessPoolBackend) register_backend(LokyBackend) register_backend(DaskBackend) register_backend(PrefectBackend) # --------------------------------------------------------------------------- # Public API # ---------------------------------------------------------------------------
[docs] def run_pipeline( pipeline: Pipeline[Any], *, n_jobs: int = 1, backend: str = "auto", indices: Iterable[int] | None = None, progress: bool = True, **backend_kwargs: Any, ) -> list[list[str]]: """Execute a pipeline over all (or selected) source indices. This is the primary entry-point for batch or parallel pipeline execution. It dispatches to the chosen backend and collects results. Parameters ---------- pipeline : Pipeline A fully-configured pipeline (source + filters + sink). n_jobs : int Number of parallel workers. ``1`` forces sequential execution. ``-1`` uses all available CPUs. backend : str Execution backend. One of ``"auto"``, ``"sequential"``, ``"thread_pool"``, ``"process_pool"``, ``"loky"``, ``"dask"``, ``"prefect"``, or any custom registered backend. indices : Iterable[int] | None Specific source indices to process. ``None`` (default) processes all indices ``range(len(pipeline))``. progress : bool Show a progress indicator if the chosen backend supports it. **backend_kwargs : Any Extra keyword arguments forwarded to the backend. Returns ------- list[list[str]] Outer list is ordered by the input indices; each inner list contains the file paths returned by the sink for that index. Raises ------ ValueError If *backend* is not a recognised name. RuntimeError If the pipeline has no sink attached. ImportError If the selected backend's optional dependency is missing. Notes ----- When using parallel backends, each worker operates on an independent copy of the pipeline. Stateful filters accumulate per-worker state that is **not** merged back. Use sequential execution when filter side-effects must be aggregated, or implement a custom reduce step. Examples -------- Sequential execution (default): >>> results = run_pipeline(pipeline) Parallel with 4 processes: >>> results = run_pipeline(pipeline, n_jobs=4, backend="process_pool") Using Prefect with retries: >>> results = run_pipeline( ... pipeline, ... n_jobs=4, ... backend="prefect", ... retries=3, ... retry_delay_seconds=10, ... ) Process only a subset of indices: >>> results = run_pipeline(pipeline, indices=[0, 5, 10]) """ # Validate backend if backend != "auto" and backend not in _BACKENDS: available = ", ".join(sorted(_BACKENDS.keys())) msg = f"Unknown backend {backend!r}. Available: {available}, auto" raise ValueError(msg) # Validate pipeline if pipeline.sink is None: msg = "Pipeline has no sink. Call .write(sink) before run_pipeline()." raise RuntimeError(msg) # Build configuration idx_list: list[int] | None = list(indices) if indices is not None else None config = RunConfig( n_jobs=n_jobs, progress=progress, indices=idx_list, backend_options=backend_kwargs, ) # Handle empty indices if idx_list is not None and not idx_list: return [] # Resolve backend resolved_n_jobs = config.resolved_n_jobs if resolved_n_jobs == 1 or backend == "sequential": effective_backend = "sequential" elif backend == "auto": effective_backend = _pick_auto_backend() else: effective_backend = backend # Check backend availability backend_cls = _BACKENDS[effective_backend] if not backend_cls.is_available(): requires = ", ".join(backend_cls.requires) msg = f"Backend {effective_backend!r} requires: {requires}" raise ImportError(msg) # Execute runner = backend_cls() return runner.run(pipeline, config)
[docs] def gather_pipeline(pipeline: Pipeline[Any]) -> list[str]: """Merge per-index shard files produced by stateful filters. When :func:`run_pipeline` runs with a parallel backend, each worker flushes stateful filters to shard files named ``{stem}_shard_{index:06d}{suffix}``. This function discovers those shards, calls the filter's :meth:`merge` method to combine them into a single output file, and removes the shard files. Call this **after** :func:`run_pipeline` completes. Parameters ---------- pipeline : Pipeline The same pipeline that was passed to :func:`run_pipeline`. Returns ------- list[str] Paths to the merged output files (one per stateful filter). Examples -------- >>> results = run_pipeline(pipeline, n_jobs=4, backend="process_pool") >>> merged = gather_pipeline(pipeline) >>> print(merged) ['outputs/mean_stats.parquet'] """ import pathlib merged_paths: list[str] = [] for f in pipeline.filters: if not (hasattr(f, "flush") and hasattr(f, "_output_path") and hasattr(f, "merge")): continue output_path = pathlib.Path(str(f._output_path)) # noqa: SLF001 shard_pattern = f"{output_path.stem}_shard_*{output_path.suffix}" shard_files = sorted(str(p) for p in output_path.parent.glob(shard_pattern)) if not shard_files: continue merged = f.merge(shard_files, str(output_path)) # ty: ignore[call-non-callable] merged_paths.append(merged) # Clean up shard files for shard in shard_files: pathlib.Path(shard).unlink(missing_ok=True) return merged_paths
__all__ = [ "DaskBackend", "LokyBackend", "PrefectBackend", "ProcessPoolBackend", "RunBackend", "RunConfig", "SequentialBackend", "ThreadPoolBackend", "gather_pipeline", "get_backend", "list_backends", "register_backend", "run_pipeline", ]