Source code for tensorrt_llm.executor.utils

import asyncio
import concurrent.futures
import os
from concurrent.futures import ProcessPoolExecutor
from queue import Empty, Queue
from typing import Any, Callable, List, NamedTuple, Optional

from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.llmapi.utils import print_colored_debug
from tensorrt_llm.logger import logger

from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
                                  RemoteMpiCommSessionClient)
from ..llmapi.utils import print_colored_debug

PERIODICAL_RESP_IN_AWAIT = os.getenv(
    "TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT") == "1"


def get_spawn_proxy_process_ipc_addr_env() -> str | None:
    ''' Get the IPC address for the spawn proxy process dynamically. '''
    return os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR")


def get_spawn_proxy_process_env() -> bool:
    ''' Get the environment variable for the spawn proxy process dynamically. '''
    return os.getenv("TLLM_SPAWN_PROXY_PROCESS") == "1"


if PERIODICAL_RESP_IN_AWAIT:
    logger.info("Using periodical responses in await_responses")


def create_mpi_comm_session(
        n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession:
    assert mpi_rank(
    ) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}"
    if get_spawn_proxy_process_env():
        assert get_spawn_proxy_process_ipc_addr_env(
        ), "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set."
        print_colored_debug(
            f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n",
            "yellow")
        hmac_key = os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY")
        # Convert the hex string to bytes
        if hmac_key is not None:
            hmac_key = bytes.fromhex(hmac_key)
        return RemoteMpiCommSessionClient(
            addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key)
    else:
        print_colored_debug(
            f"Using MpiCommSession to bind to external MPI processes\n",
            "yellow")
        return MpiCommSession(n_workers=n_workers)


def has_event_loop() -> bool:
    try:
        asyncio.get_running_loop()
    except RuntimeError:
        return False
    return True


[docs] class RequestError(RuntimeError): ''' The error raised when the request is failed. '''
class ProcessPoolExecutorSession(MpiSession): # This process pool is introduced for better recoverable exceptions handling. # It replaces MpiPoolExecutor for single-gpu case. def __init__(self, n_workers: int, **kwargs): self.n_workers = n_workers self.mpi_pool = ProcessPoolExecutor(max_workers=self.n_workers, **kwargs) def submit(self, task: Callable, *args, **kwargs) -> List[concurrent.futures.Future]: return [ self.mpi_pool.submit(task, *args, **kwargs) for i in range(self.n_workers) ] def submit_sync(self, task: Callable, *args, **kwargs) -> List[Any]: futures = [ self.mpi_pool.submit(task, *args, **kwargs) for i in range(self.n_workers) ] return [future.result() for future in futures] def shutdown(self): self.mpi_pool.shutdown(wait=True) class ErrorResponse(NamedTuple): client_id: int error_msg: str request_id: int class IntraProcessQueue: ''' A Queue-like container for IPC within the same process. ''' def __init__(self): self.queue = Queue() def put(self, obj: Any): self.queue.put(obj) def get(self, timeout=None) -> Any: return self.queue.get(timeout=timeout) def close(self): pass def poll(self, timeout=None) -> bool: try: # Try to get an item from the queue without blocking item = self.queue.get(timeout=timeout) # If successful, put the item back to not alter the state self.queue.put(item) return True except Empty: # If the queue thread is empty, return False return False class WorkerCommIpcAddrs(NamedTuple): ''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. ''' request_queue_addr: tuple[str, Optional[bytes]] request_error_queue_addr: tuple[str, Optional[bytes]] result_queue_addr: tuple[str, Optional[bytes]] stats_queue_addr: tuple[str, Optional[bytes]] kv_cache_events_queue_addr: tuple[str, Optional[bytes]]