Source code for nv_dfm_core.targets.local._job_execution

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
from logging import Logger
from typing import TYPE_CHECKING

from nv_dfm_core.exec import TokenPackage

from ._context import get_spawn_context

if TYPE_CHECKING:
    from ._federation_runner import FederationRunner
    from ._job_handle import JobHandle
    from ._job_runner import JobSubmission
else:
    FederationRunner = object
    JobHandle = object  # type: ignore
    JobSubmission = object

# Use explicit spawn context for all multiprocessing objects.
_spawn_ctx = get_spawn_context()


[docs] class JobExecution: """ A JobExecution object has a JobRunner process for every site to execute a pipeline. We reuse executions in a pool. We do this because we want to a) start all the background processes in a pool and b) we cannot send (normal) queues to processes after they have been created. Therefore, we pre-create JobExecution objects which are all essentially full federations with a process for every existing site. For those sites we can create all the comm queues at startup. """ def __init__( self, id: int, federation: "FederationRunner", sites: list[str], inter_job_queue: multiprocessing.Queue, logger: Logger, ): self.id: int = id self._federation: "FederationRunner" = federation self._inter_job_queue: multiprocessing.Queue = inter_job_queue self._logger: Logger = logger self._ctx = _spawn_ctx # create the input queues for each site self._channels: dict[str, multiprocessing.Queue[TokenPackage]] = { site: self._ctx.Queue() for site in sites } self.yield_queue: multiprocessing.Queue[TokenPackage] = self._ctx.Queue() # Create worker processes from ._job_runner import JobRunner self._workers: dict[str, JobRunner] = {} for site in sites: worker_logger = self._logger.getChild("JobRunner_" + site) self._workers[site] = JobRunner( site=site, channels=self._channels, yield_queue=self.yield_queue, inter_job_queue=self._inter_job_queue, # pyright: ignore[reportUnknownMemberType] logger=worker_logger, ) # start the workers for worker in self._workers.values(): worker.start() self._logger.info( f"JobRunner {self.id} started {len(self._workers)} worker processes" ) # the job handle this execution is assigned to self._job_handle: "JobHandle | None" = None
[docs] def is_alive(self) -> bool: """Returns True if the execution is alive.""" return all(worker.is_alive() for worker in self._workers.values())
[docs] def get_dead_sites(self) -> list[str]: """Returns a list of sites whose worker processes have died.""" return [site for site, worker in self._workers.items() if not worker.is_alive()]
def abort_job(self): if self._job_handle is not None: for worker in self._workers.values(): worker.abort_netrunner_event.set()
[docs] def shutdown(self): """Shuts down the execution. Note that workers are generally waiting for new data on the command queue and only have a timeout of 0.5 seconds on the command queue. So we give them 5 seconds to finish.""" for worker in self._workers.values(): # try to shut down gracefully worker.abort_netrunner_event.set() worker.shutdown_event.set() for worker in self._workers.values(): # NetRunner shutdown can take up to 5 seconds per site, so we give it some time. worker.join(timeout=5.0) for worker in self._workers.values(): if worker.is_alive(): self._logger.warning( f"JobRunner {self.id} shutting down worker {worker.pid} but it is still alive. Terminating." ) worker.terminate() worker.join(timeout=1.0) # Close all multiprocessing queues to avoid leaked semaphores for q in self._channels.values(): try: q.close() q.join_thread() except Exception: pass try: self.yield_queue.close() self.yield_queue.join_thread() except Exception: pass for worker in self._workers.values(): try: worker.command_queue.close() worker.command_queue.join_thread() worker.ack_command_queue.close() worker.ack_command_queue.join_thread() except Exception: pass
[docs] def job_id(self) -> str | None: """Returns the job ID if this execution is assigned to a job. Otherwise, returns None.""" if self._job_handle is None: return None return self._job_handle.job_id
[docs] def handle(self): """Returns the job handle if this execution is assigned. Otherwise, returns None.""" return self._job_handle
[docs] def receive_inter_job_token(self, token: TokenPackage): """Called by the FederationRunner background thread that polls the inter job queue. Receives an inter-job token and forwards it to the correct queue.""" assert token.target_job == self.job_id(), ( f"Token target job {token.target_job} does not match execution job {self.job_id()}" ) self._channels[token.target_site].put(token)
[docs] def handle_was_assigned(self, handle: "JobHandle"): """Assign a job handle to this execution.""" if self._job_handle is not None: raise RuntimeError( f"JobExecution {self.id} is already assigned to job {self._job_handle.job_id}" ) self._job_handle = handle
[docs] def handle_was_released(self): """Called when the job execution completes.""" self._job_handle = None
[docs] def submit( self, site: str, js: JobSubmission, ): """Called by the job object when it is started.""" if not self.is_alive(): dead_sites = self.get_dead_sites() raise RuntimeError( f"JobRunner {self.id} is not alive. Workers for sites {dead_sites} have died." ) worker = self._workers[site] worker.command_queue.put(js) ack = worker.ack_command_queue.get() assert ack == js.job_id, f"JobRunner {self.id} received ack for wrong job {ack}"