Source code for physicsnemo_curator.run.dask

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dask execution backend.

Uses ``dask.bag`` for parallel and distributed execution.
Supports local execution and can scale to clusters.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar

from physicsnemo_curator.run.base import (
    RunBackend,
    RunConfig,
    process_single_index_packed,
)

if TYPE_CHECKING:
    from physicsnemo_curator.core.base import Pipeline


[docs] class DaskBackend(RunBackend): """Execute pipeline items using Dask bags. Dask provides parallel execution that can scale from a single machine to a distributed cluster. This backend uses ``dask.bag`` for task-parallel execution. .. warning:: Stateful filters accumulate per-worker state that is **not** merged back. Design a post-hoc merge strategy if needed. Backend Options --------------- scheduler : str Dask scheduler ("synchronous", "threads", "processes", "distributed"). num_workers : int | None Number of workers (for local schedulers). client : distributed.Client | None Pre-configured Dask distributed client. """ name: ClassVar[str] = "dask" description: ClassVar[str] = "Dask bags for parallel/distributed execution" requires: ClassVar[tuple[str, ...]] = ("dask",)
[docs] def run( self, pipeline: Pipeline[Any], config: RunConfig, ) -> list[list[str]]: """Execute pipeline indices using Dask. Parameters ---------- pipeline : Pipeline The pipeline to execute. config : RunConfig Execution configuration. Returns ------- list[list[str]] Sink outputs, one list per index. Raises ------ ImportError If dask is not installed. """ try: import dask.bag as db except ImportError: msg = "The 'dask' backend requires dask. Install with: pip install 'physicsnemo-curator[dask]'" raise ImportError(msg) from None indices = config.indices if config.indices is not None else list(range(len(pipeline))) n_jobs = config.resolved_n_jobs # Set up progress bar if requested if config.progress: try: from dask.diagnostics import ProgressBar pbar: Any = ProgressBar() pbar.register() except ImportError: pass # Create dask bag from index-pipeline pairs npartitions = min(n_jobs, len(indices)) bag = db.from_sequence( [(pipeline, i) for i in indices], npartitions=npartitions, ) # Extract compute options compute_kwargs = {k: v for k, v in config.backend_options.items() if k in ("scheduler", "num_workers")} results: list[list[str]] = bag.map(process_single_index_packed).compute(**compute_kwargs) return results