# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for pipeline execution backends.
This module defines the abstract interface that all execution backends
must implement, along with common utilities.
"""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
from physicsnemo_curator.core.base import Pipeline
from physicsnemo_curator.core.logging import DatabaseLogHandler
[docs]
@dataclass
class RunConfig:
"""Configuration for pipeline execution.
Parameters
----------
n_jobs : int
Number of parallel workers. ``1`` forces sequential execution.
``-1`` uses all available CPUs. Values ``<= 0`` follow the
convention ``cpu_count + 1 + n_jobs``.
use_tui : bool
Whether to show the full-screen Textual TUI for progress
(requires an interactive terminal). When ``False``, prints
simple timestamped log lines to the console instead.
indices : list[int] | None
Specific source indices to process. ``None`` processes all indices.
backend_options : dict[str, Any]
Additional backend-specific options.
"""
n_jobs: int = 1
use_tui: bool = True
indices: list[int] | None = None
backend_options: dict[str, Any] = field(default_factory=dict)
@property
def resolved_n_jobs(self) -> int:
"""Return the concrete positive worker count.
Returns
-------
int
Positive integer number of workers.
"""
if self.n_jobs > 0:
return self.n_jobs
cpu = os.cpu_count() or 1
resolved = cpu + 1 + self.n_jobs # -1 → cpu, -2 → cpu-1, …
return max(1, resolved)
[docs]
class RunBackend(ABC):
"""Abstract base class for pipeline execution backends.
Subclasses implement different parallelization strategies (threading,
multiprocessing, distributed computing, workflow orchestrators, etc.).
Class Attributes
----------------
name : str
Unique identifier for this backend (e.g., "sequential", "process_pool").
description : str
Human-readable description of the backend.
requires : tuple[str, ...]
Optional package dependencies required by this backend.
"""
name: ClassVar[str]
description: ClassVar[str]
requires: ClassVar[tuple[str, ...]] = ()
[docs]
@classmethod
def is_available(cls) -> bool:
"""Check if this backend's dependencies are installed.
Returns
-------
bool
True if all required packages are available.
"""
for package in cls.requires:
try:
__import__(package)
except ImportError:
return False
return True
[docs]
@abstractmethod
def run(
self,
pipeline: Pipeline[Any],
config: RunConfig,
) -> list[list[str]]:
"""Execute the pipeline over the configured indices.
Parameters
----------
pipeline : Pipeline
A fully-configured pipeline (source + filters + sink).
config : RunConfig
Execution configuration.
Returns
-------
list[list[str]]
Outer list is ordered by the input indices; each inner list
contains the file paths returned by the sink for that index.
"""
...
def _get_worker_id() -> str:
"""Return a unique worker identifier using PID.
For process-based backends each forked process has a distinct PID,
which is sufficient to produce unique per-worker shard files.
Returns
-------
str
A string like ``"12345"`` (pid).
"""
import os
return str(os.getpid())
def _flush_filters(pipeline: Pipeline[Any], index: int) -> None:
"""Flush stateful filters after processing an index.
For each filter that has a ``flush`` method and an ``_output_path``
attribute, this function resolves a worker-specific output path and
flushes the filter's accumulated state.
If the original path contains ``{worker_id}`` it is treated as a
template and the placeholder is substituted with the unique worker
identifier. Otherwise the path is rewritten as
``{stem}_worker_{worker_id}{suffix}``.
The worker ID is derived from the PID so that process-based backends
produce unique per-worker output files.
After flushing, any filter artifacts (reported via
:meth:`~Filter.artifacts`) are recorded in the pipeline store when
metrics tracking is enabled.
Parameters
----------
pipeline : Pipeline
The pipeline whose filters should be flushed.
index : int
The source index that was just processed (used for artifact
tracking).
"""
import pathlib
worker_id = _get_worker_id()
for i_f, f in enumerate(pipeline.filters):
if not (hasattr(f, "flush") and hasattr(f, "_output_path")):
continue
# Store the original template path once (first call in this process).
if not hasattr(f, "_output_path_template"):
f._output_path_template = f._output_path # noqa: SLF001 # ty: ignore[invalid-assignment]
template_str = str(f._output_path_template) # noqa: SLF001 # ty: ignore[unresolved-attribute]
if "{worker_id}" in template_str:
worker_path = pathlib.Path(template_str.format(worker_id=worker_id))
else:
p = pathlib.Path(template_str)
worker_path = p.parent / f"{p.stem}_worker_{worker_id}{p.suffix}"
f._output_path = worker_path # noqa: SLF001 # ty: ignore[invalid-assignment]
f.flush() # ty: ignore[call-non-callable]
# Record filter artifacts if metrics tracking is enabled
if pipeline.track_metrics:
artifact_paths = f.artifacts()
if artifact_paths:
store = pipeline._get_store() # noqa: SLF001
store._resilient_write( # noqa: SLF001
"record_filter_artifacts",
store.record_filter_artifacts,
index,
type(f).name,
i_f,
artifact_paths,
)
# Module-level state for worker logging (setup once per process)
_worker_log_handler: DatabaseLogHandler | None = None
def _ensure_worker_logging(pipeline: Pipeline[Any]) -> DatabaseLogHandler | None:
"""Ensure database logging is set up in worker processes.
Sets up a DatabaseLogHandler that writes logs to the pipeline store.
Called once per worker process on first index processing.
Parameters
----------
pipeline : Pipeline
The pipeline being executed.
Returns
-------
DatabaseLogHandler | None
The handler, or None if metrics tracking is disabled.
"""
global _worker_log_handler # noqa: PLW0603
if _worker_log_handler is not None:
return _worker_log_handler
# Only set up if metrics tracking is enabled
if not pipeline.track_metrics:
return None
# Check if we're in a worker process (not main)
import multiprocessing
if multiprocessing.current_process().name == "MainProcess":
return None
try:
from physicsnemo_curator.core.logging import setup_worker_logging
store = pipeline._get_store() # noqa: SLF001
_worker_log_handler = setup_worker_logging(store)
return _worker_log_handler
except Exception: # noqa: BLE001
# Don't crash if logging setup fails
return None
[docs]
def process_single_index(pipeline: Pipeline[Any], index: int) -> list[str]:
"""Process a single pipeline index.
This is a module-level function to support pickling for multiprocess
backends. After processing, any stateful filters with ``flush``
methods are automatically flushed to shard files.
Parameters
----------
pipeline : Pipeline
The pipeline to execute.
index : int
The index to process.
Returns
-------
list[str]
File paths written by the sink.
"""
# Set up database logging in worker processes (once per process)
handler = _ensure_worker_logging(pipeline)
if handler is not None:
handler.set_current_index(index)
try:
result = pipeline[index]
_flush_filters(pipeline, index)
return result
finally:
# Flush logs after each index to ensure they're captured
if handler is not None:
handler.flush()
handler.set_current_index(None)
[docs]
def process_single_index_packed(args: tuple[Pipeline[Any], int]) -> list[str]:
"""Process a single pipeline index (packed arguments for map functions).
Parameters
----------
args : tuple[Pipeline, int]
A ``(pipeline, index)`` pair.
Returns
-------
list[str]
File paths written by the sink.
"""
pipeline, index = args
# Delegate to main function for logging setup
return process_single_index(pipeline, index)
[docs]
def intersect_partitions(
source_groups: list[list[int]] | None,
sink_groups: list[list[int]] | None,
) -> list[list[int]] | None:
"""Intersect source and sink partition constraints.
Both the source and sink may independently declare that certain
indices MUST be processed by the same worker. This function
computes the finest partition that satisfies both constraints,
or raises :class:`ValueError` if the constraints are incompatible.
Parameters
----------
source_groups : list[list[int]] | None
Groups from :meth:`Source.partition_indices`, or ``None``.
sink_groups : list[list[int]] | None
Groups from :meth:`Sink.partition_indices`, or ``None``.
Returns
-------
list[list[int]] | None
Merged groups satisfying both constraints, or ``None`` if
neither source nor sink requires partitioning.
Raises
------
ValueError
If the source and sink constraints are incompatible (one
requires indices together that the other requires apart).
"""
if source_groups is None and sink_groups is None:
return None
if source_groups is None:
return sink_groups
if sink_groups is None:
return source_groups
# Build index → group_id mappings.
source_map: dict[int, int] = {}
for gid, group in enumerate(source_groups):
for idx in group:
source_map[idx] = gid
sink_map: dict[int, int] = {}
for gid, group in enumerate(sink_groups):
for idx in group:
sink_map[idx] = gid
# Group by (source_group_id, sink_group_id) pair.
from collections import defaultdict
pair_groups: dict[tuple[int, int], list[int]] = defaultdict(list)
all_indices = set(source_map.keys()) | set(sink_map.keys())
for idx in all_indices:
s_gid = source_map.get(idx, -1)
k_gid = sink_map.get(idx, -1)
pair_groups[(s_gid, k_gid)].append(idx)
# Validate: no original group was split.
# For each source group, all its indices must map to the same
# intersection group. If they don't, the constraints conflict.
intersection_groups = list(pair_groups.values())
# Build reverse: idx → intersection group id
idx_to_intersection: dict[int, int] = {}
for ig_id, ig in enumerate(intersection_groups):
for idx in ig:
idx_to_intersection[idx] = ig_id
# Check source groups are not split.
for s_gid, s_group in enumerate(source_groups):
ig_ids = {idx_to_intersection[idx] for idx in s_group}
if len(ig_ids) > 1:
# Find conflicting sink groups
conflicting_sinks = {sink_map.get(idx, -1) for idx in s_group}
msg = (
f"Incompatible partition constraints: source requires indices "
f"{s_group} to be processed together (source group {s_gid}), "
f"but they span {len(conflicting_sinks)} different sink groups. "
f"Adjust sink chunk_size so chunk boundaries align with source "
f"file boundaries."
)
raise ValueError(msg)
# Check sink groups are not split.
for k_gid, k_group in enumerate(sink_groups):
ig_ids = {idx_to_intersection[idx] for idx in k_group}
if len(ig_ids) > 1:
# Find conflicting source groups
conflicting_sources = {source_map.get(idx, -1) for idx in k_group}
msg = (
f"Incompatible partition constraints: sink requires indices "
f"{k_group} to be processed together (sink group {k_gid}), "
f"but they span {len(conflicting_sources)} different source groups. "
f"Adjust sink chunk_size so chunk boundaries align with source "
f"file boundaries."
)
raise ValueError(msg)
# Sort groups by their minimum index for deterministic ordering.
intersection_groups.sort(key=lambda g: min(g))
# Sort indices within each group.
for g in intersection_groups:
g.sort()
return intersection_groups
[docs]
def batch_groups(groups: list[list[int]], n_workers: int) -> list[list[int]]:
"""Merge partition groups into at most *n_workers* batches.
When there are more groups than workers, groups are distributed
across workers using a greedy bin-packing strategy (assign each
group to the lightest batch) to balance load.
Each batch is a flat list of indices preserving the constraint
that indices from the same original group are always together.
Parameters
----------
groups : list[list[int]]
Partition groups (from :func:`intersect_partitions`).
n_workers : int
Maximum number of worker batches.
Returns
-------
list[list[int]]
At most *n_workers* batches, each a list of indices.
"""
if len(groups) <= n_workers:
return groups
import heapq
# Greedy: assign largest groups first to lightest batch.
# Sort groups descending by size for best packing.
sorted_groups = sorted(groups, key=len, reverse=True)
# Min-heap of (batch_size, batch_index)
batches: list[list[int]] = [[] for _ in range(n_workers)]
heap: list[tuple[int, int]] = [(0, i) for i in range(n_workers)]
heapq.heapify(heap)
for group in sorted_groups:
size, batch_idx = heapq.heappop(heap)
batches[batch_idx].extend(group)
heapq.heappush(heap, (size + len(group), batch_idx))
# Remove empty batches (if n_workers > n_groups, already handled above).
return [b for b in batches if b]
[docs]
def process_index_group(pipeline: Pipeline[Any], indices: list[int]) -> dict[int, list[str]]:
"""Process a group of pipeline indices sequentially.
Used when the sink provides :meth:`partition_indices` to batch
related indices onto the same worker (e.g. for chunk-aligned
parallel writes).
Parameters
----------
pipeline : Pipeline
The pipeline to execute.
indices : list[int]
The indices to process (in order).
Returns
-------
dict[int, list[str]]
Mapping of index to sink output paths.
"""
# Set up database logging in worker processes (once per process)
handler = _ensure_worker_logging(pipeline)
results: dict[int, list[str]] = {}
try:
for idx in indices:
if handler is not None:
handler.set_current_index(idx)
results[idx] = pipeline[idx]
_flush_filters(pipeline, idx)
finally:
# Flush logs after processing the group
if handler is not None:
handler.flush()
handler.set_current_index(None)
return results
[docs]
def make_progress_bar(total: int, *, enabled: bool, desc: str = "run_pipeline") -> Any:
"""Return a tqdm progress bar or None.
Parameters
----------
total : int
Number of items.
enabled : bool
Whether to attempt tqdm import.
desc : str
Description for the progress bar.
Returns
-------
Any
A tqdm progress bar, or None if disabled or unavailable.
"""
if not enabled:
return None
try:
from tqdm.auto import tqdm
return tqdm(total=total, desc=desc, unit="item")
except ImportError:
import warnings
warnings.warn(
"progress=True was requested but tqdm is not installed. Install tqdm for progress bars: pip install tqdm",
stacklevel=2,
)
return None
_MAX_WORKER_BARS = 8
"""Maximum number of per-worker progress bars to display."""
[docs]
class WorkerProgressDisplay:
"""Multi-line progress display showing per-worker activity.
Renders an overall progress bar plus optionally one bar per active worker
(up to :data:`_MAX_WORKER_BARS`). Falls back gracefully when
*tqdm* is not installed or progress is disabled.
Parameters
----------
total : int
Total number of items to process.
n_workers : int
Number of parallel workers.
enabled : bool
Whether to show progress at all.
desc : str
Description label for the overall bar.
show_worker_bars : bool
Whether to show per-worker progress bars. Defaults to False
to avoid console conflicts with multiple processes.
"""
def __init__(
self,
total: int,
n_workers: int,
*,
enabled: bool = True,
desc: str = "run_pipeline",
show_worker_bars: bool = False,
) -> None:
self._enabled = enabled
self._show_worker_bars = show_worker_bars
self._n_display = min(n_workers, _MAX_WORKER_BARS) if show_worker_bars else 0
self._main_bar: Any = None
self._worker_bars: list[Any] = []
self._tqdm_cls: Any = None
if not enabled:
return
try:
from tqdm.auto import tqdm
self._tqdm_cls = tqdm
except ImportError:
import warnings
warnings.warn(
"progress=True was requested but tqdm is not installed. "
"Install tqdm for progress bars: pip install tqdm",
stacklevel=3,
)
return
# Position 0: overall bar
self._main_bar = tqdm(
total=total,
desc=desc,
unit="item",
position=0,
leave=True,
)
# Positions 1..n_display: per-worker bars (only if requested)
if show_worker_bars:
for w in range(self._n_display):
bar = tqdm(
total=0,
desc=f" Worker {w}",
bar_format=" {desc}",
position=w + 1,
leave=False,
)
self._worker_bars.append(bar)
@property
def active(self) -> bool:
"""Return whether the display is active."""
return self._main_bar is not None
[docs]
def worker_start(self, worker_id: int, index: int) -> None:
"""Mark a worker as starting to process an index.
Parameters
----------
worker_id : int
Zero-based worker identifier.
index : int
The source index being processed.
"""
if worker_id < self._n_display and self._worker_bars:
bar = self._worker_bars[worker_id]
bar.set_description_str(f" Worker {worker_id}: index {index}")
bar.refresh()
[docs]
def worker_done(self, worker_id: int) -> None:
"""Mark a worker as idle (does NOT update main bar - use complete_item).
Parameters
----------
worker_id : int
Zero-based worker identifier.
"""
if worker_id < self._n_display and self._worker_bars:
bar = self._worker_bars[worker_id]
bar.set_description_str(f" Worker {worker_id}: idle")
bar.refresh()
[docs]
def complete_item(self) -> None:
"""Increment the overall bar without worker tracking.
Use this for backends where individual worker identity is not
available.
"""
if self._main_bar is not None:
self._main_bar.update(1)
[docs]
def close(self) -> None:
"""Close all bars and clean up terminal lines."""
for bar in reversed(self._worker_bars):
bar.close()
self._worker_bars.clear()
if self._main_bar is not None:
self._main_bar.close()
self._main_bar = None