Source code for tensorrt_llm.executor.utils
import asyncio
import concurrent.futures
import os
from concurrent.futures import ProcessPoolExecutor
from queue import Empty, Queue
from typing import Any, Callable, List, NamedTuple, Optional
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.bindings.executor import Response
from tensorrt_llm.llmapi.utils import print_colored_debug
from tensorrt_llm.logger import logger
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
                                  RemoteMpiCommSessionClient)
from ..llmapi.utils import print_colored_debug
PERIODICAL_RESP_IN_AWAIT = os.getenv(
    "TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT") == "1"
def get_spawn_proxy_process_ipc_addr_env() -> str | None:
    ''' Get the IPC address for the spawn proxy process dynamically. '''
    return os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR")
def get_spawn_proxy_process_env() -> bool:
    ''' Get the environment variable for the spawn proxy process dynamically. '''
    return os.getenv("TLLM_SPAWN_PROXY_PROCESS") == "1"
if PERIODICAL_RESP_IN_AWAIT:
    logger.info("Using periodical responses in await_responses")
def create_mpi_comm_session(
        n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession:
    assert mpi_rank(
    ) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}"
    if get_spawn_proxy_process_env():
        assert get_spawn_proxy_process_ipc_addr_env(
        ), "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set."
        print_colored_debug(
            f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n",
            "yellow")
        hmac_key = os.getenv("TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY")
        # Convert the hex string to bytes
        if hmac_key is not None:
            hmac_key = bytes.fromhex(hmac_key)
        return RemoteMpiCommSessionClient(
            addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key)
    else:
        print_colored_debug(
            f"Using MpiCommSession to bind to external MPI processes\n",
            "yellow")
        return MpiCommSession(n_workers=n_workers)
def has_event_loop() -> bool:
    try:
        asyncio.get_running_loop()
    except RuntimeError:
        return False
    return True
[docs]
class RequestError(RuntimeError):
    ''' The error raised when the request is failed. ''' 
class ProcessPoolExecutorSession(MpiSession):
    # This process pool is introduced for better recoverable exceptions handling.
    # It replaces MpiPoolExecutor for single-gpu case.
    def __init__(self, n_workers: int, **kwargs):
        self.n_workers = n_workers
        self.mpi_pool = ProcessPoolExecutor(max_workers=self.n_workers,
                                            **kwargs)
    def submit(self, task: Callable, *args,
               **kwargs) -> List[concurrent.futures.Future]:
        return [
            self.mpi_pool.submit(task, *args, **kwargs)
            for i in range(self.n_workers)
        ]
    def submit_sync(self, task: Callable, *args, **kwargs) -> List[Any]:
        futures = [
            self.mpi_pool.submit(task, *args, **kwargs)
            for i in range(self.n_workers)
        ]
        return [future.result() for future in futures]
    def shutdown(self):
        self.mpi_pool.shutdown(wait=True)
class ErrorResponse(NamedTuple):
    client_id: int
    error_msg: str
    request_id: int
class IntraProcessQueue:
    ''' A Queue-like container for IPC within the same process. '''
    def __init__(self):
        self.queue = Queue()
    def put(self, obj: Any):
        self.queue.put(obj)
    def get(self, timeout=None) -> Any:
        return self.queue.get(timeout=timeout)
    def close(self):
        pass
    def poll(self, timeout=None) -> bool:
        try:
            # Try to get an item from the queue without blocking
            item = self.queue.get(timeout=timeout)
            # If successful, put the item back to not alter the state
            self.queue.put(item)
            return True
        except Empty:
            # If the queue thread is empty, return False
            return False
class WorkerCommIpcAddrs(NamedTuple):
    ''' IPC addresses (str) and HMAC keys (bytes) for communication with the worker processes. '''
    request_queue_addr: tuple[str, Optional[bytes]]
    request_error_queue_addr: tuple[str, Optional[bytes]]
    result_queue_addr: tuple[str, Optional[bytes]]
    stats_queue_addr: tuple[str, Optional[bytes]]
    kv_cache_events_queue_addr: tuple[str, Optional[bytes]]
def is_llm_response(instance):
    from tensorrt_llm._torch.pyexecutor.llm_request import \
        LlmResponse as PyLlmResponse
    from .result import ResponseWrapper
    return isinstance(instance, (Response, PyLlmResponse, ResponseWrapper))