Source code for nv_dfm_core.targets.flare._flare_app

#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from pathlib import Path
from typing import Any

from nvflare.fuel.f3.cellnet.fqcn import FQCN  # pyright: ignore[reportMissingImports]
from nvflare.job_config.api import FedJob  # pyright: ignore[reportMissingImports]

from nv_dfm_core.api import PreparedPipeline
from nv_dfm_core.exec import Frame

from ._controller import Controller
from ._executor import Executor
from ._flare_options import FlareOptions


[docs] class FlareApp: def __init__( self, pipeline: PreparedPipeline, input_params: list[tuple[Frame, dict[str, Any]]], options: FlareOptions, force_modgen: bool = False, ): self._pipeline = pipeline self._job: FedJob | None = None self._input_params = input_params self._force_modgen = force_modgen self._options = options def _prepare(self): # Create the Flare job configuration object rnd_id = str(uuid.uuid4()) pipe_name = ( f"{self._pipeline.pipeline_name}" if self._pipeline.pipeline_name else "pipeline" ) job_name = f"{self._pipeline.federation_module}-{self._pipeline.homesite}-{pipe_name}-{rnd_id}" clients = self._pipeline.get_participating_sites() clients.remove(FQCN.ROOT_SERVER) job = FedJob( name=job_name, min_clients=len(clients), mandatory_clients=clients, ) # Send DfmExecutor to each client and Controller to root server bound_net_irs = self._pipeline.bind_net_irs(input_params=self._input_params) # handle the server # NOTE: the server_netir should not be False ever, but some tests currently rely on this so I keep it in. server_netir = ( bound_net_irs.pop(FQCN.ROOT_SERVER).model_dump() if FQCN.ROOT_SERVER in bound_net_irs else False ) controller = Controller( submitted_api_version=self._pipeline.api_version, federation_name=self._pipeline.federation_module, homesite=self._pipeline.homesite, bound_net_ir=server_netir, options=self._options, force_modgen=self._force_modgen, ) job.to_server(controller) # and the clients for bound_net_ir in bound_net_irs.values(): assert bound_net_ir.site != FQCN.ROOT_SERVER, ( "Root server net IR should have been handled above" ) executor = Executor( submitted_api_version=self._pipeline.api_version, federation_name=self._pipeline.federation_module, homesite=self._pipeline.homesite, bound_net_ir=bound_net_ir.model_dump(), force_modgen=self._force_modgen, ) job.to(obj=executor, target=bound_net_ir.site) self._job = job def simulate(self, workspace: str | Path | None = None) -> str: if not workspace: workspace = "/tmp/dfm_workspace" workspace = str(workspace) if not self._job: self._prepare() assert self._job self._job.simulator_run(workspace, n_clients=1) return workspace
[docs] def export(self, workspace: Path) -> Path: """ Exports the job to the given workspace. Returns the path to the job directory inside the workspace. """ if not self._job: self._prepare() assert self._job self._job.export_job(str(workspace)) return workspace.joinpath(self._job.name)