Source code for physicsnemo_curator.run.thread_pool

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Thread pool execution backend.

Uses :class:`concurrent.futures.ThreadPoolExecutor` for parallel execution.
Suitable for I/O-bound workloads where the GIL is not a bottleneck.
"""

from __future__ import annotations

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, Any, ClassVar

from physicsnemo_curator.run.base import (
    RunBackend,
    RunConfig,
    WorkerProgressDisplay,
    process_single_index,
)

if TYPE_CHECKING:
    from physicsnemo_curator.core.base import Pipeline


[docs] class ThreadPoolBackend(RunBackend): """Execute pipeline items using a thread pool. This backend uses Python's :class:`concurrent.futures.ThreadPoolExecutor`. It's suitable for I/O-bound workloads but may not provide speedup for CPU-bound tasks due to the GIL. Backend Options --------------- max_workers : int | None Maximum number of threads. Defaults to ``config.resolved_n_jobs``. thread_name_prefix : str Prefix for thread names. """ name: ClassVar[str] = "thread_pool" description: ClassVar[str] = "Thread pool executor (good for I/O-bound tasks)" requires: ClassVar[tuple[str, ...]] = ()
[docs] def run( self, pipeline: Pipeline[Any], config: RunConfig, ) -> list[list[str]]: """Execute pipeline indices using a thread pool. Parameters ---------- pipeline : Pipeline The pipeline to execute. config : RunConfig Execution configuration. Returns ------- list[list[str]] Sink outputs, one list per index. """ indices = config.indices if config.indices is not None else list(range(len(pipeline))) n_jobs = config.resolved_n_jobs # Extract ThreadPoolExecutor-specific options executor_kwargs = { k: v for k, v in config.backend_options.items() if k in ("max_workers", "thread_name_prefix", "initializer", "initargs") } if "max_workers" not in executor_kwargs: executor_kwargs["max_workers"] = n_jobs display = WorkerProgressDisplay( total=len(indices), n_workers=n_jobs, enabled=config.progress, ) # Map worker threads to display slots _slot_lock = threading.Lock() _thread_slots: dict[int, int] = {} _next_slot = [0] def _get_slot() -> int: tid = threading.get_ident() with _slot_lock: if tid not in _thread_slots: _thread_slots[tid] = _next_slot[0] _next_slot[0] += 1 return _thread_slots[tid] def _process_tracked(idx: int) -> list[str]: slot = _get_slot() display.worker_start(slot, idx) result = process_single_index(pipeline, idx) display.worker_done(slot) return result # Use as_completed so the main bar updates as soon as each # future finishes rather than waiting for ordered completion. result_map: dict[int, list[str]] = {} try: with ThreadPoolExecutor(**executor_kwargs) as executor: future_to_idx = {executor.submit(_process_tracked, idx): idx for idx in indices} for future in as_completed(future_to_idx): idx = future_to_idx[future] result_map[idx] = future.result() finally: display.close() # Return results in original index order return [result_map[idx] for idx in indices]