Source code for nv_ingest_api.util.multi_processing.mp_pool_singleton

# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging
import math
import multiprocessing as mp
import os
from threading import Lock
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)


[docs] class SimpleFuture: """ A simplified future object that uses a multiprocessing Pipe to receive its result. When the result() method is called, it blocks until the worker sends a tuple (result, error) over the pipe. """ def __init__(self, parent_conn: mp.connection.Connection) -> None: """ Parameters ---------- parent_conn : mp.connection.Connection The parent end of the multiprocessing Pipe used to receive the result. """ self._parent_conn: mp.connection.Connection = parent_conn
[docs] def result(self) -> Any: """ Retrieve the result from the future, blocking until it is available. Returns ------- Any The result returned by the worker function. Raises ------ Exception If the worker function raised an exception, it is re-raised here. """ result, error = self._parent_conn.recv() if error is not None: raise error return result
[docs] class ProcessWorkerPoolSingleton: """ A singleton process worker pool using a dual-queue implementation. Instead of a global result queue, each submitted task gets its own Pipe. The submit_task() method returns a SimpleFuture, whose result() call blocks until the task completes. """ _instance: Optional["ProcessWorkerPoolSingleton"] = None _lock: Lock = Lock() _total_workers: int = 0 def __new__(cls) -> "ProcessWorkerPoolSingleton": """ Create or return the singleton instance of ProcessWorkerPoolSingleton. Returns ------- ProcessWorkerPoolSingleton The singleton instance. """ logger.debug("Creating ProcessWorkerPoolSingleton instance...") with cls._lock: if cls._instance is None: max_worker_limit: int = int(os.environ.get("MAX_INGEST_PROCESS_WORKERS", -1)) instance = super().__new__(cls) # Determine available CPU count using affinity if possible available: Optional[int] = ( len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() ) # Use 40% of available CPUs, ensuring at least one worker max_workers: int = math.floor(max(1, available * 0.4)) if (max_worker_limit > 0) and (max_workers > max_worker_limit): max_workers = max_worker_limit logger.debug("Creating ProcessWorkerPoolSingleton instance with max workers: %d", max_workers) instance._initialize(max_workers) logger.debug("ProcessWorkerPoolSingleton instance created: %s", instance) cls._instance = instance else: logger.debug("ProcessWorkerPoolSingleton instance already exists: %s", cls._instance) return cls._instance def _initialize(self, total_max_workers: int) -> None: """ Initialize the worker pool with the specified number of worker processes. Parameters ---------- total_max_workers : int The total number of worker processes to start. """ self._total_workers = total_max_workers self._context: mp.context.ForkContext = mp.get_context("fork") # Bounded task queue: maximum tasks queued = 2 * total_max_workers. self._task_queue: mp.Queue = self._context.Queue(maxsize=2 * total_max_workers) self._next_task_id: int = 0 self._processes: list[mp.Process] = [] logger.debug( "Initializing ProcessWorkerPoolSingleton with %d workers and queue size %d.", total_max_workers, 2 * total_max_workers, ) for i in range(total_max_workers): p: mp.Process = self._context.Process(target=self._worker, args=(self._task_queue,)) p.start() self._processes.append(p) logger.debug("Started worker process %d/%d: PID %d", i + 1, total_max_workers, p.pid) logger.debug("Initialized with max workers: %d", total_max_workers) @staticmethod def _worker(task_queue: mp.Queue) -> None: """ Worker process that continuously processes tasks from the task queue. Parameters ---------- task_queue : mp.Queue The queue from which tasks are retrieved. """ logger.debug("Worker process started: PID %d", os.getpid()) while True: task = task_queue.get() if task is None: # Stop signal received; exit the loop. logger.debug("Worker process %d received stop signal.", os.getpid()) break # Unpack task: (task_id, process_fn, args, child_conn) task_id, process_fn, args, child_conn = task try: result = process_fn(*args) child_conn.send((result, None)) except Exception as e: logger.error("Task %d error in worker %d: %s", task_id, os.getpid(), e) child_conn.send((None, e)) finally: child_conn.close()
[docs] def submit_task(self, process_fn: Callable, *args: Any) -> SimpleFuture: """ Submits a task to the worker pool for asynchronous execution. If a single tuple is passed as the only argument, it is unpacked. Parameters ---------- process_fn : Callable The function to be executed asynchronously. *args : Any The arguments to pass to the process function. If a single argument is a tuple, it will be unpacked as the function arguments. Returns ------- SimpleFuture A future object that can be used to retrieve the result of the task. """ # Unpack tuple if a single tuple argument is provided. if len(args) == 1 and isinstance(args[0], tuple): args = args[0] parent_conn, child_conn = mp.Pipe(duplex=False) task_id: int = self._next_task_id self._next_task_id += 1 self._task_queue.put((task_id, process_fn, args, child_conn)) return SimpleFuture(parent_conn)
[docs] def close(self) -> None: """ Closes the worker pool and terminates all worker processes. Sends a stop signal to each worker and waits for them to terminate. """ logger.debug("Closing ProcessWorkerPoolSingleton...") # Send a stop signal (None) for each worker. for _ in range(self._total_workers): self._task_queue.put(None) logger.debug("Sent stop signal to worker.") # Wait for all processes to finish. for i, p in enumerate(self._processes): p.join() logger.debug("Worker process %d/%d joined: PID %d", i + 1, self._total_workers, p.pid) logger.debug("ProcessWorkerPoolSingleton closed.")