Source code for nv_dfm_core.targets.flare._flare_router
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyright: reportMissingTypeStubs=false
from logging import Logger
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import (
Shareable,
)
from nvflare.apis.signal import Signal
from nvflare.apis.utils.reliable_message import ReliableMessage
from nvflare.fuel.f3.cellnet.fqcn import FQCN
from typing_extensions import override
from nv_dfm_core.exec import Router, TokenPackage
from nv_dfm_core.targets.flare._app_io_manager import AppIOManager
from ._defs import Constant
[docs]
class FlareRouter(Router):
def __init__(
self,
fl_ctx: FLContext,
abort_signal: Signal,
logger: Logger,
app_io_manager: AppIOManager | None = None, # for the server only
client_names: list[str]
| None = None, # for the server controller, just to check that recipients are valid
):
super().__init__()
self._fl_ctx: FLContext = fl_ctx
self._abort_signal: Signal = abort_signal
self._client_names: list[str] | None = client_names
self._app_io_manager: AppIOManager | None = app_io_manager
def _send_flare_message(self, recipient: str, token_package: TokenPackage):
request = Shareable()
request[Constant.MSG_KEY_TOKEN_PACKAGE_DICT] = token_package.model_dump()
_ = ReliableMessage.send_request(
target=recipient,
topic=Constant.TOPIC_SEND_TO_PLACE,
request=request,
fl_ctx=self._fl_ctx,
per_msg_timeout=10,
tx_timeout=50,
abort_signal=self._abort_signal,
)
@override
def _send_other_job_remote_token_package_sync(self, token_package: TokenPackage):
raise RuntimeError(
f"Target job {token_package.target_job} is not the current job {self.job_id}. Cross-job communication is not yet supported in the Flare target."
)
@override
def _send_this_job_remote_token_package_sync(self, token_package: TokenPackage):
assert (
token_package.target_site != self.this_site
) # this package is certainly not for us
assert token_package.target_job == self.job_id
# if we are not on the server, we need to send the message via the server
if self.this_site != FQCN.ROOT_SERVER:
self._send_flare_message(
recipient=FQCN.ROOT_SERVER, token_package=token_package
)
# we are on the server, message is for the homesite
elif token_package.target_site == self.homesite:
assert self._app_io_manager is not None
# target is the app
self._app_io_manager.receive_token_package(token_package)
# we are on the server, target is a client
else:
if (
self._client_names is not None
and token_package.target_site not in self._client_names
):
raise RuntimeError(f"Invalid target site: {token_package.target_site}")
self._send_flare_message(
recipient=token_package.target_site, token_package=token_package
)
@override
def _send_this_job_yield_token_package_sync(self, token_package: TokenPackage):
if self.this_site != FQCN.ROOT_SERVER:
self._send_flare_message(
recipient=FQCN.ROOT_SERVER, token_package=token_package
)
else:
assert self._app_io_manager is not None
self._app_io_manager.receive_token_package(token_package)