Source code for tensorrt_llm.disaggregated_params
from dataclasses import dataclass
from typing import List, Optional
from tensorrt_llm.bindings import executor as tllme
[docs]
@dataclass(slots=True, kw_only=True)
class DisaggregatedParams:
"""
Disaggregated seving parameters
Args:
request_type (str): The type of request ("context_only" or "generation_only")
first_gen_tokens (List[int]): The first tokens of the generation request
ctx_request_id (int): The context request id
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
"""
request_type: Optional[str] = None
first_gen_tokens: Optional[List[int]] = None
ctx_request_id: Optional[int] = None
opaque_state: Optional[bytes] = None
draft_tokens: Optional[List[int]] = None
[docs]
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
return tllme.ContextPhaseParams(self.first_gen_tokens,
self.ctx_request_id, self.opaque_state,
self.draft_tokens)
[docs]
def get_request_type(self) -> tllme.RequestType:
if self.request_type == "context_only":
return tllme.RequestType.REQUEST_TYPE_CONTEXT_ONLY
elif self.request_type == "generation_only":
return tllme.RequestType.REQUEST_TYPE_GENERATION_ONLY
elif self.request_type == "context_and_generation":
return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
else:
raise ValueError(
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or context_and_generation"
)