Source code for physicsnemo_curator.run.base

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base classes for pipeline execution backends.

This module defines the abstract interface that all execution backends
must implement, along with common utilities.
"""

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar

if TYPE_CHECKING:
    from physicsnemo_curator.core.base import Pipeline


[docs] @dataclass class RunConfig: """Configuration for pipeline execution. Parameters ---------- n_jobs : int Number of parallel workers. ``1`` forces sequential execution. ``-1`` uses all available CPUs. Values ``<= 0`` follow the convention ``cpu_count + 1 + n_jobs``. progress : bool Whether to show a progress indicator (if supported by backend). indices : list[int] | None Specific source indices to process. ``None`` processes all indices. backend_options : dict[str, Any] Additional backend-specific options. """ n_jobs: int = 1 progress: bool = True indices: list[int] | None = None backend_options: dict[str, Any] = field(default_factory=dict) @property def resolved_n_jobs(self) -> int: """Return the concrete positive worker count. Returns ------- int Positive integer number of workers. """ if self.n_jobs > 0: return self.n_jobs cpu = os.cpu_count() or 1 resolved = cpu + 1 + self.n_jobs # -1 → cpu, -2 → cpu-1, … return max(1, resolved)
[docs] class RunBackend(ABC): """Abstract base class for pipeline execution backends. Subclasses implement different parallelization strategies (threading, multiprocessing, distributed computing, workflow orchestrators, etc.). Class Attributes ---------------- name : str Unique identifier for this backend (e.g., "sequential", "thread_pool"). description : str Human-readable description of the backend. requires : tuple[str, ...] Optional package dependencies required by this backend. """ name: ClassVar[str] description: ClassVar[str] requires: ClassVar[tuple[str, ...]] = ()
[docs] @classmethod def is_available(cls) -> bool: """Check if this backend's dependencies are installed. Returns ------- bool True if all required packages are available. """ for package in cls.requires: try: __import__(package) except ImportError: return False return True
[docs] @abstractmethod def run( self, pipeline: Pipeline[Any], config: RunConfig, ) -> list[list[str]]: """Execute the pipeline over the configured indices. Parameters ---------- pipeline : Pipeline A fully-configured pipeline (source + filters + sink). config : RunConfig Execution configuration. Returns ------- list[list[str]] Outer list is ordered by the input indices; each inner list contains the file paths returned by the sink for that index. """ ...
def _flush_filters(pipeline: Pipeline[Any], index: int) -> None: """Flush stateful filters after processing an index. For each filter that has a ``flush`` method and an ``_output_path`` attribute, this function temporarily swaps the output path to a worker-specific path (``{stem}_worker_{pid}{suffix}``) before flushing, then restores the original path. Using worker PID (instead of index) ensures all indices processed by the same worker append to a single file, reducing output file count when running with parallel workers. After flushing, any filter artifacts (reported via :meth:`~Filter.artifacts`) are recorded in the pipeline store when metrics tracking is enabled. Parameters ---------- pipeline : Pipeline The pipeline whose filters should be flushed. index : int The source index that was just processed (used for artifact tracking). """ import os import pathlib pid = os.getpid() for i_f, f in enumerate(pipeline.filters): if not (hasattr(f, "flush") and hasattr(f, "_output_path")): continue original = f._output_path # noqa: SLF001 p = pathlib.Path(str(original)) worker_path = p.parent / f"{p.stem}_worker_{pid}{p.suffix}" f._output_path = worker_path # noqa: SLF001 # ty: ignore[invalid-assignment] try: f.flush() # ty: ignore[call-non-callable] finally: f._output_path = original # noqa: SLF001 # ty: ignore[invalid-assignment] # Record filter artifacts if metrics tracking is enabled if pipeline.track_metrics: artifact_paths = f.artifacts() if artifact_paths: store = pipeline._get_store() # noqa: SLF001 store.record_filter_artifacts(index, type(f).name, i_f, artifact_paths)
[docs] def process_single_index(pipeline: Pipeline[Any], index: int) -> list[str]: """Process a single pipeline index. This is a module-level function to support pickling for multiprocess backends. After processing, any stateful filters with ``flush`` methods are automatically flushed to shard files. Parameters ---------- pipeline : Pipeline The pipeline to execute. index : int The index to process. Returns ------- list[str] File paths written by the sink. """ result = pipeline[index] _flush_filters(pipeline, index) return result
[docs] def process_single_index_packed(args: tuple[Pipeline[Any], int]) -> list[str]: """Process a single pipeline index (packed arguments for map functions). Parameters ---------- args : tuple[Pipeline, int] A ``(pipeline, index)`` pair. Returns ------- list[str] File paths written by the sink. """ pipeline, index = args result = pipeline[index] _flush_filters(pipeline, index) return result
[docs] def make_progress_bar(total: int, *, enabled: bool, desc: str = "run_pipeline") -> Any: """Return a tqdm progress bar or None. Parameters ---------- total : int Number of items. enabled : bool Whether to attempt tqdm import. desc : str Description for the progress bar. Returns ------- Any A tqdm progress bar, or None if disabled or unavailable. """ if not enabled: return None try: from tqdm.auto import tqdm return tqdm(total=total, desc=desc, unit="item") except ImportError: import warnings warnings.warn( "progress=True was requested but tqdm is not installed. Install tqdm for progress bars: pip install tqdm", stacklevel=2, ) return None
_MAX_WORKER_BARS = 8 """Maximum number of per-worker progress bars to display."""
[docs] class WorkerProgressDisplay: """Multi-line progress display showing per-worker activity. Renders an overall progress bar plus one bar per active worker (up to :data:`_MAX_WORKER_BARS`). Falls back gracefully when *tqdm* is not installed or progress is disabled. Parameters ---------- total : int Total number of items to process. n_workers : int Number of parallel workers. enabled : bool Whether to show progress at all. desc : str Description label for the overall bar. """ def __init__( self, total: int, n_workers: int, *, enabled: bool = True, desc: str = "run_pipeline", ) -> None: self._enabled = enabled self._n_display = min(n_workers, _MAX_WORKER_BARS) self._main_bar: Any = None self._worker_bars: list[Any] = [] self._tqdm_cls: Any = None if not enabled: return try: from tqdm.auto import tqdm self._tqdm_cls = tqdm except ImportError: import warnings warnings.warn( "progress=True was requested but tqdm is not installed. " "Install tqdm for progress bars: pip install tqdm", stacklevel=3, ) return # Position 0: overall bar self._main_bar = tqdm( total=total, desc=desc, unit="item", position=0, leave=True, ) # Positions 1..n_display: per-worker bars for w in range(self._n_display): bar = tqdm( total=0, desc=f" Worker {w}", bar_format=" {desc}", position=w + 1, leave=False, ) self._worker_bars.append(bar) @property def active(self) -> bool: """Return whether the display is active.""" return self._main_bar is not None
[docs] def worker_start(self, worker_id: int, index: int) -> None: """Mark a worker as starting to process an index. Parameters ---------- worker_id : int Zero-based worker identifier. index : int The source index being processed. """ if worker_id < self._n_display and self._worker_bars: bar = self._worker_bars[worker_id] bar.set_description_str(f" Worker {worker_id}: index {index}") bar.refresh()
[docs] def worker_done(self, worker_id: int) -> None: """Mark a worker as idle and update the overall bar. Parameters ---------- worker_id : int Zero-based worker identifier. """ if self._main_bar is not None: self._main_bar.update(1) if worker_id < self._n_display and self._worker_bars: bar = self._worker_bars[worker_id] bar.set_description_str(f" Worker {worker_id}: idle") bar.refresh()
[docs] def complete_item(self) -> None: """Increment the overall bar without worker tracking. Use this for backends where individual worker identity is not available. """ if self._main_bar is not None: self._main_bar.update(1)
[docs] def close(self) -> None: """Close all bars and clean up terminal lines.""" for bar in reversed(self._worker_bars): bar.close() self._worker_bars.clear() if self._main_bar is not None: self._main_bar.close() self._main_bar = None