import abc
import itertools
import os
import socket
import sys
import threading
import time
import traceback
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeVar
import zmq
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.logger import logger
from .._utils import global_mpi_rank, mpi_barrier, mpi_rank
from .utils import logger_debug, print_colored
if ENABLE_MULTI_DEVICE:
import mpi4py
from mpi4py.futures import MPICommExecutor, MPIPoolExecutor
from tensorrt_llm._utils import global_mpi_size, mpi_world_size
T = TypeVar("T")
class MPINodeState:
''' MPINodeState acts as a central global state shares between tasks on MPI node.
An example:
def task():
if MPINodeState.state is None:
MPINodeState.state = 0
MPINodeState.state += 1
return MPINodeState.state
n_workers = 4
with MPIPoolExecutor(max_workers=n_workers) as executor:
for i in range(2):
futures = [executor.submit(task) for i in range(n_workers)]
This should produce the following output:
- [1, 1, 1, 1]
- [2, 2, 2, 2]
'''
state = None
# Global MPICommExecutor instance to be reused across multiple MpiCommSession instances
# This is necessary because MPICommExecutor can only be created once per MPI process
_global_comm_executor = None
_global_mpi_pool = None
@staticmethod
def is_initialized() -> bool:
return MPINodeState.state is not None
def external_mpi_comm_available(model_world_size: int) -> bool:
''' Check if the current process is launched by mpirun and does not use MPIPoolExecutor to spawn processes.
e.g. mpirun -np 4 python script.py
'''
if ENABLE_MULTI_DEVICE:
return (get_mpi_world_size() == model_world_size
and model_world_size > 1) or (global_mpi_size()
> get_mpi_world_size())
else:
return False
def need_spawn_mpi_workers(model_world_size: int) -> bool:
''' Check if the current process needs to spawn MPI workers. '''
if ENABLE_MULTI_DEVICE:
return get_mpi_world_size() == 1 and model_world_size > 1
else:
return False
def set_mpi_session_cpp(comm):
if ENABLE_MULTI_DEVICE:
comm_fortran = comm.py2f()
from tensorrt_llm.bindings import MpiComm
MpiComm.set_raw_mpi_session_by_fortran_handle(comm_fortran)
class MpiSession(abc.ABC):
@abc.abstractmethod
def submit(self, task: Callable[..., T], *args,
**kwargs) -> List[Future[T]]:
raise NotImplementedError()
@abc.abstractmethod
def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
raise NotImplementedError()
@abc.abstractmethod
def shutdown(self, wait=True):
raise NotImplementedError()
@abc.abstractmethod
def abort(self):
raise NotImplementedError()
def is_comm_session(self) -> bool:
return isinstance(self, (MpiCommSession, RemoteMpiCommSessionClient))
def _abort_on_timeout(self, fut: Future, timeout: float, reason=None):
try:
fut.result(timeout=timeout)
except TimeoutError:
logger.critical("MpiSession shutdown timeout, aborting...")
if reason is not None:
logger.info(f"Reason to shutdown: {repr(reason)}")
self.abort()
def shutdown_abort(self, grace: float = 60, reason=None):
if sys.is_finalizing():
# cannot start thread at interpreter shutdown
# simply don't wait to avoid hang
return self.shutdown(wait=False)
fut = Future()
killer = threading.Thread(group=None,
target=self._abort_on_timeout,
name="MpiSessionTimeoutKiller",
args=(fut, grace, reason))
killer.start()
self.shutdown()
fut.set_result(None)
killer.join()
class MpiPoolSession(MpiSession):
def __init__(self, n_workers: int):
self.n_workers = n_workers
self.mpi_pool: Optional[MPIPoolExecutor] = None
self._start_mpi_pool()
if ENABLE_MULTI_DEVICE:
self.comm = mpi4py.MPI.COMM_WORLD
def get_comm(self):
return self.comm
def submit(self, task: Callable[..., T], *args,
**kwargs) -> List[Future[T]]:
return [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers)
]
def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
futures = [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers)
]
return [future.result() for future in futures]
def shutdown(self, wait=True):
if self.mpi_pool is not None:
self.mpi_pool.shutdown(wait=wait)
self.mpi_pool = None
def abort(self):
self.get_comm().Abort(1)
def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers,
path=sys.path)
def __del__(self):
self.shutdown_abort()
def __reduce__(self):
raise TypeError('cannot pickle MPI session')
[docs]
class MpiCommSession(MpiSession):
[docs]
def __init__(self, comm=None, n_workers: int = 1):
self.comm = comm
self.n_workers = n_workers
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.mpi_pool: Optional[MPIPoolExecutor] = None
self.owns_mpi_pool = False # Track if this instance owns the mpi_pool
if n_workers <= 0:
raise ValueError(
f'n_workers must be non-negative, but got {n_workers}')
if ENABLE_MULTI_DEVICE:
if not self.comm:
self.comm = mpi4py.MPI.COMM_WORLD
if self.comm.Get_rank() != 0:
raise RuntimeError(
f'only rank 0 can start multi-node session, got {self.comm.Get_rank()}'
)
if self.comm.Get_size() != n_workers:
raise ValueError(
f'n_workers must be equal to the number of processes in MPI, got {n_workers} vs {get_mpi_world_size()}'
)
self._start_mpi_pool()
[docs]
def get_comm(self):
return self.comm
[docs]
def submit(self, task: Callable[..., T], *args,
**kwargs) -> List[Future[T]]:
''' Submit a task to MPI workers.
Args:
task: The task to be submitted.
args: Positional arguments for the task.
kwargs: Keyword arguments for the task.
'''
assert self.mpi_pool is not None, 'MPI session not started'
worker_futures = [
self.mpi_pool.submit(task, *args, **kwargs)
for i in range(self.n_workers - 1)
]
rank0_future = self.thread_pool.submit(task, *args, **kwargs)
return [rank0_future] + worker_futures
[docs]
def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:
futures = self.submit(task, *args, **kwargs)
return [future.result() for future in futures]
[docs]
def shutdown(self, wait=True):
# Only shutdown the mpi_pool if this instance created it
# For shared global mpi_pool, we don't shut it down
if self.mpi_pool is not None and self.owns_mpi_pool:
self.mpi_pool.shutdown(wait=wait)
self.mpi_pool = None
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=wait)
self.thread_pool = None
[docs]
def abort(self):
self.get_comm().Abort(1)
def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'
self.thread_pool = ThreadPoolExecutor(max_workers=2)
# Use global MPICommExecutor if using COMM_WORLD
# This is necessary because MPICommExecutor can only be created once per MPI process
logger_debug(
f"_start_mpi_pool: ENABLE_MULTI_DEVICE={ENABLE_MULTI_DEVICE}, self.comm={self.comm}\n",
"grey")
if ENABLE_MULTI_DEVICE:
logger_debug(
f"_start_mpi_pool: Checking if self.comm == mpi4py.MPI.COMM_WORLD: {self.comm == mpi4py.MPI.COMM_WORLD}\n",
"grey")
if ENABLE_MULTI_DEVICE and self.comm == mpi4py.MPI.COMM_WORLD:
if MPINodeState._global_comm_executor is None:
logger_debug("Creating global MPICommExecutor for COMM_WORLD\n",
"yellow")
MPINodeState._global_comm_executor = MPICommExecutor(self.comm)
MPINodeState._global_mpi_pool = MPINodeState._global_comm_executor.__enter__(
)
else:
logger_debug("Reusing global MPICommExecutor for COMM_WORLD\n",
"yellow")
self.mpi_pool = MPINodeState._global_mpi_pool
self.owns_mpi_pool = False
else:
logger_debug(
f"_start_mpi_pool: Creating new MPICommExecutor (not COMM_WORLD or ENABLE_MULTI_DEVICE=False)\n",
"grey")
# For non-COMM_WORLD communicators, create a new executor
comm_executor = MPICommExecutor(self.comm)
self.mpi_pool = comm_executor.__enter__()
self.owns_mpi_pool = True
def __del__(self):
self.shutdown_abort()
def __reduce__(self):
raise TypeError('cannot pickle MPI session')
class RemoteTask(NamedTuple):
task: Callable[..., T]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
sync: bool = False # if True, the result will be sent back to the client
class RemoteMpiCommSessionClient(MpiSession):
'''
RemoteMpiCommSessionClient is a variant of MpiCommSession that is used to connect to a remote MPI pool.
Note: This class uses a global singleton pattern because ZeroMQ PAIR sockets only support
one connection at a time. Multiple LLM instances will reuse the same client connection.
'''
_global_instance = None
_global_instance_lock = threading.Lock()
def __new__(cls, addr: str, hmac_key: Optional[bytes] = None):
# Implement singleton pattern to reuse the same client connection
# for multiple LLM instances, since PAIR sockets only support one connection
with cls._global_instance_lock:
if cls._global_instance is None or cls._global_instance.addr != addr:
logger_debug(
f"Creating new global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
instance = super().__new__(cls)
cls._global_instance = instance
instance._initialized = False
else:
logger_debug(
f"Reusing existing global RemoteMpiCommSessionClient for {addr}\n",
"yellow")
return cls._global_instance
def __init__(self, addr: str, hmac_key: Optional[bytes] = None):
# Only initialize once
if self._initialized:
return
# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
logger_debug(f"RemoteMpiCommSessionClient connecting to {addr}\n",
"yellow")
self.queue = ZeroMqQueue((addr, hmac_key),
is_server=False,
socket_type=zmq.PAIR,
use_hmac_encryption=bool(hmac_key))
self._is_shutdown = False
self._initialized = True
def submit(self,
task: Callable[..., T],
*args,
sync: bool = False,
**kwargs) -> list:
''' Submit a task to the remote MPI pool. '''
if self._is_shutdown:
logger_debug("RemoteMpiCommSessionClient is already shut down\n",
"yellow")
return []
logger_debug(
f"RemoteMpiCommSessionClient [rank{global_mpi_rank()}] sending task {task} to {self.addr}\n",
"yellow")
self.queue.put(RemoteTask(task, args, kwargs, sync=sync))
return []
SYNC_IDLE_INTERVAL = 8
def submit_sync(self, task, *args, **kwargs) -> List[T]:
''' Submit a task to the remote MPI pool and wait for task completion. '''
self.submit(task, *args, sync=True, **kwargs)
while not ((res := self.poll()) or self._is_shutdown):
logger_debug(f"Waiting for task completion... {res}\n", "grey")
time.sleep(self.SYNC_IDLE_INTERVAL)
logger_debug(
f"rank{global_mpi_rank()} RemoteMpiCommSessionClient.send_sync received results: {res}\n",
"green")
if not res:
raise RuntimeError(
"RemoteMpiCommSessionClient received unexpected response")
return res
def poll(self) -> bool:
''' Poll the queue for a response.
Returns:
True if a response is received, False otherwise.
'''
if self._is_shutdown:
return False
response = self.queue.poll(0.1)
if response:
return self.queue.get() # should get a True if success
return False
def abort(self):
self.shutdown()
def shutdown(self, wait=True):
# NOTE: We do NOT close the queue or mark as shutdown for the singleton instance.
# The RemoteMpiCommSessionClient is a global singleton that's reused across multiple
# LLM instances. Marking it as shutdown would prevent subsequent LLM instances from
# using it. The connection stays open for the entire lifetime of the mgmn setup.
logger_debug(
f"RemoteMpiCommSessionClient.shutdown() called (no-op for singleton)\n",
"grey")
def shutdown_abort(self, grace: float = 60, reason=None):
self.shutdown()
class RemoteMpiCommSessionServer():
'''
RemoteMpiCommSessionServer is a variant of MpiCommSession that is used to create a remote MPI pool.
'''
def __init__(self,
n_workers: int = 0,
addr: str = f'tcp://127.0.0.1:*',
hmac_key: Optional[bytes] = None,
comm=None,
is_comm: bool = False):
# FIXME: this is a hack to avoid circular import, resolve later
from tensorrt_llm.executor.ipc import ZeroMqQueue
self.addr = addr
self.queue = ZeroMqQueue((addr, hmac_key),
is_server=True,
socket_type=zmq.PAIR,
use_hmac_encryption=bool(hmac_key))
self.comm = comm
self.results = [] # the results may arrive in any order
if self.comm is not None:
self.session = MpiCommSession(n_workers=self.comm.Get_size(),
comm=self.comm)
else:
self.session = MpiCommSession(
n_workers=n_workers) if is_comm else MpiPoolSession(
n_workers=n_workers)
@staticmethod
def task_wrapper(task: Callable[..., T], *args, **kwargs) -> T:
logger_debug(
f"MpiCommSession rank{mpi_rank()} with world_size {mpi_world_size()}\n",
"green")
logger_debug(
f"MpiCommSession rank{mpi_rank()} start task [{task}] with args: {args} and kwargs: {kwargs}\n",
"green")
# wait for all ranks to start the task
mpi_barrier()
try:
return task(*args, **kwargs)
except Exception as e:
print_colored(
f"MpiCommSession rank{mpi_rank()} task [{task}] failed with exception: {e}\n",
"red")
traceback.print_exc()
raise e
finally:
logger_debug(
f"MpiCommSession rank{mpi_rank()} task [{task}] finished\n",
"green")
mpi_barrier()
def serve(self):
logger_debug(f"RemoteMpiCommSessionServer listening on {self.addr}\n",
"yellow")
pending_futures = []
while True:
# Wait for any pending futures from previous tasks to complete
# This ensures all ranks are ready before accepting the next task
if pending_futures:
logger_debug(
f"RemoteMpiCommSessionServer waiting for {len(pending_futures)} pending futures to complete\n",
"grey")
for future in pending_futures:
try:
future.result() # Wait for completion
except Exception as e:
print_colored(
f"RemoteMpiCommSessionServer future failed with exception: {e}\n",
"red")
pending_futures.clear()
logger_debug(
"RemoteMpiCommSessionServer all pending futures completed\n",
"grey")
message: Optional[RemoteTask] = self.queue.get()
if message is None:
logger_debug(
f"RemoteMpiCommSessionServer [rank{global_mpi_rank()}] received shutdown signal\n",
"green")
self.session.shutdown_abort()
break
else:
logger_debug(
f"RemoteMpiCommSessionServer [rank{global_mpi_rank()}] received task [{message.task}] from {self.addr}\n",
"green")
futures = self.session.submit(
RemoteMpiCommSessionServer.task_wrapper, message.task,
*message.args, **message.kwargs)
self.num_results = self.session.n_workers
assert len(futures) == self.num_results == mpi_world_size()
# Store futures to wait for them before the next task
pending_futures = list(futures)
if message.sync:
for future in futures:
future.add_done_callback(self.mpi_future_callback)
def mpi_future_callback(self, future):
logger_debug(f"rank{global_mpi_rank()} got future: {future}\n", "red")
if future.exception() is not None:
logger_debug(
f"mpi_future got exception: {future.exception()}, quitting\n",
"red")
self.queue.put(future.exception())
return
result = future.result()
self.results.append(result)
logger_debug(
f"RemoteMpiCommSessionServer working status: {len(self.results)}/{self.num_results}\n",
"grey")
if len(self.results) == self.num_results:
logger_debug(
f"RemoteMpiCommSessionServer received all results, sending to client\n",
"green")
try:
self.queue.put_noblock(self.results, retry=2)
except zmq.ZMQError as e:
# The client could be shutdown first.
if e.errno == zmq.EAGAIN:
pass
else:
raise e
logger_debug(f"RemoteMpiCommSessionServer sent results to client\n",
"green")
self.results.clear()
def find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_mpi_world_size() -> int:
# avoid cyclic import
from ..executor.utils import get_spawn_proxy_process_env
# If the proxy process is spawned, the MPI-related env will be cleaned in the proxy process, thus we made another env for the mpi_world_size
if get_spawn_proxy_process_env():
return int(os.getenv("tllm_mpi_size") or 1)
else:
return mpi_world_size()
def split_mpi_env(mpi_env_keys: List[str] | None = None) -> Tuple[dict, dict]:
'''
Splits the environment variables into MPI-related and non-MPI-related dictionaries.
Args:
mpi_env_keys: Additional environment variables to be considered as MPI-related.
Returns:
Tuple[dict, dict]: (non_mpi_env, mpi_env)
- non_mpi_env: Environment dictionary without MPI-related variables
- mpi_env: Environment dictionary containing only MPI-related variables
'''
current_env = os.environ.copy()
# Identify MPI-related variables
mpi_vars = set(
itertools.chain([
var for var in current_env if var.startswith((
'MPI_',
'OMPI_',
'PMIX_',
'PMI_',
'OMPI_',
'PMIX_',
'PMI_',
'SLURM_',
'MPI_',
'UCX_',
'I_MPI_',
'HYDRA_',
'KMP_',
'MPICH_',
'MV2_',
'CRAY_',
))
], mpi_env_keys or []))
# Split into two dictionaries
non_mpi_env = {k: v for k, v in current_env.items() if k not in mpi_vars}
mpi_env = {k: v for k, v in current_env.items() if k in mpi_vars}
return non_mpi_env, mpi_env