Source code for tensorrt_llm.disaggregated_params

from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np

# isort: off
# needed before trying to import bindings to load tensorrt_libs
import tensorrt as trt  # noqa
# isort: on

from tensorrt_llm.bindings import executor as tllme


[docs] @dataclass(slots=True, kw_only=True) class DisaggregatedParams: """Disaggregated serving parameters. Args: request_type (str): The type of request ("context_only" | "generation_only" | "context_and_generation") first_gen_tokens (List[int]): The first tokens of the generation request ctx_request_id (int): The context request id opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances draft_tokens (List[int]): The draft tokens of the generation request multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT. multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request. """ request_type: Optional[str] = None # P-D Disaggregated Params first_gen_tokens: Optional[List[int]] = None ctx_request_id: Optional[int] = None opaque_state: Optional[bytes] = None draft_tokens: Optional[List[int]] = None # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding ) multimodal_hashes: Optional[List[List[int]]] = ( None # user provided mm hashes should be a list of 8 integers )
[docs] def get_context_phase_params(self) -> tllme.ContextPhaseParams: return tllme.ContextPhaseParams( self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens )
[docs] def get_request_type(self) -> tllme.RequestType: if self.request_type == "context_only": return tllme.RequestType.REQUEST_TYPE_CONTEXT_ONLY elif self.request_type == "generation_only": return tllme.RequestType.REQUEST_TYPE_GENERATION_ONLY elif self.request_type == "context_and_generation": return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION else: raise ValueError( f"Unknown request type: {self.request_type}. Must be context_only, generation_only or " "context_and_generation" )
def __post_init__(self): if self.request_type is not None: self.request_type = self.request_type.lower() if self.request_type not in [ "context_only", "generation_only", "context_and_generation", ]: raise ValueError( f"Unknown request type: {self.request_type}. Must be context_only, generation_only or " "context_and_generation" ) if self.multimodal_embedding_handles is not None: if self.multimodal_hashes is not None: # if mm hashes are provided, kvcache reuse can be enabled assert len(self.multimodal_embedding_handles) == len(self.multimodal_hashes), ( "multimodal_embedding_handles and multimodal_hashes must have the same length" ) for mm_hash in self.multimodal_hashes: assert isinstance(mm_hash, list), "mm_hash must be a list" assert len(mm_hash) == 8, "mm_hash must be a list of 8 integers" assert all(isinstance(x, int) for x in mm_hash), "mm_hash must contain integers" else: # if user did not provide mm embedding handles, kvcache reuse will be disabled assert len(self.multimodal_embedding_handles) > 0, ( "multimodal_embedding_handles must be provided" ) vals = np.random.randint( np.iinfo(np.int32).min, np.iinfo(np.int32).max, size=8, dtype=np.int32 ).tolist() self.multimodal_hashes = [vals] * len(self.multimodal_embedding_handles)