Source code for nv_dfm_core.api._prepared_pipeline
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# Just a mock for now
from logging import Logger
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from nv_dfm_core.exec import Frame
from nv_dfm_core.gen.modgen.ir import BoundNetIR, NetIR, YieldPlace
else:
BoundNetIR = Any
NetIR = Any
Frame = Any
YieldPlace = Any
[docs]
class PreparedPipeline:
"""A pipeline that has been prepared and optimized for execution in a federation.
PreparedPipeline contains the intermediate representation (IR) for each
participating site, along with metadata needed for execution.
"""
def __init__(
self,
api_version: str,
federation_module_name: str,
homesite: str,
pipeline_name: str | None,
net_irs: list[NetIR],
yield_places: list[YieldPlace],
):
self._api_version: str = api_version
self._federation_module_name: str = federation_module_name
self._homesite: str = homesite
self._name: str | None = pipeline_name
self._net_irs: dict[str, NetIR] = {net_ir.site: net_ir for net_ir in net_irs}
self._yield_places: list[YieldPlace] = yield_places
@property
def api_version(self) -> str:
return self._api_version
@api_version.setter
def api_version(self, value: str) -> None:
self._api_version = value
@property
def federation_module(self) -> str:
return self._federation_module_name
@property
def homesite(self) -> str:
return self._homesite
@property
def pipeline_name(self) -> str | None:
return self._name
def net_irs(self) -> dict[str, NetIR]:
return self._net_irs
def has_param_places(self) -> bool:
return any(
net_ir.find_places(kind="data", origin="external")
for net_ir in self._net_irs.values()
)
[docs]
def check_callbacks(self, callback_places: list[str], logger: Logger):
"""Check and warn if callback registration doesn't match yield places in the pipeline."""
yield_places = set([yield_place.place for yield_place in self._yield_places])
cb_places = set(callback_places)
if not yield_places.issubset(cb_places):
logger.warning(
f"No callbacks were registered for yield places {yield_places - cb_places}."
)
if not cb_places.issubset(yield_places):
logger.warning(
f"Some callbacks were registered that don't have corresponding yields: {cb_places - yield_places}."
)
def get_participating_sites(self) -> list[str]:
return list(self._net_irs.keys())
[docs]
def bind_net_ir(
self,
site: str,
input_params: list[tuple[Frame, dict[str, Any]]],
) -> BoundNetIR:
"""Bind a specific site's NetIR with input parameters to create a BoundNetIR ready for execution."""
from nv_dfm_core.gen.modgen.ir import BoundNetIR
netir = self._net_irs[site]
return BoundNetIR.bind_netir(netir, input_params)
[docs]
def bind_net_irs(
self, input_params: list[tuple[Frame, dict[str, Any]]]
) -> dict[str, BoundNetIR]:
"""Bind all site NetIRs with input parameters to create BoundNetIRs for each participating site."""
from nv_dfm_core.gen.modgen.ir import BoundNetIR
return {
site: BoundNetIR.bind_netir(net_ir, input_params)
for site, net_ir in self._net_irs.items()
}