Source code for tensorrt_llm.executor.result

import asyncio
import json
import weakref
from dataclasses import dataclass, field
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union
from weakref import WeakMethod

import torch

from .._utils import nvtx_range_debug
from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue
from ..sampling_params import SamplingParams
from .utils import ErrorResponse, has_event_loop

if TYPE_CHECKING:
    from .executor import GenerationExecutor
    from .postproc_worker import PostprocParams, PostprocWorker
    from .request import GenerationRequest

__all__ = [
    "CompletionOutput",
    "GenerationResultBase",
    "DetokenizedGenerationResultBase",
    "GenerationResult",
    "IterationResult",
]


[docs] @dataclass(slots=True) class CompletionOutput: """The output data of one completion output of a request. Args: index (int): The index of the output in the request. text (str): The generated output text. Defaults to "". token_ids (List[int], optional): The token ids of the generated output text. Defaults to None. cumulative_logprob (float, optional): The cumulative log probability of the generated output text. Defaults to None. logprobs (List[float], optional): The log probabilities of the top probability words at each position if the logprobs are requested. Defaults to None. finish_reason (Literal['stop', 'length', 'timeout', 'cancelled'], optional): The reason why the sequence is finished. Defaults to None. stop_reason (int, str, optional): The stop string or token id that caused the completion to stop, None if the completion finished for some other reason. Defaults to None. generation_logits (torch.Tensor, optional): The logits on the generated output token ids. Defaults to None. disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Parameters needed for disaggregated serving. Includes the type of request, the first generated tokens, the context request id and the any additional state needing to be transferred from context and generation instances. Defaults to None. Attributes: length (int): The number of generated tokens. token_ids_diff (List[int]): Newly generated token ids. logprobs_diff (List[float]): Logprobs of newly generated tokens. text_diff (str): Newly generated tokens. """ index: int text: str = "" token_ids: Optional[List[int]] = None cumulative_logprob: Optional[float] = None logprobs: Optional[List[float]] = None finish_reason: Optional[Literal['stop', 'length', 'timeout', 'cancelled']] = None stop_reason: Optional[Union[int, str]] = None generation_logits: Optional[torch.Tensor] = None disaggregated_params: Optional[DisaggregatedParams] = None # hidden fields for tracking the diffs _last_text_len: int = field(default=0, init=False, repr=False) _last_token_ids_len: int = field(default=0, init=False, repr=False) _last_logprobs_len: int = field(default=0, init=False, repr=False) _incremental_states: Optional[dict] = field(default=None, init=False, repr=False) # the result of result_handler passed to postprocess workers _postprocess_result: Any = None def __post_init__(self): if self.token_ids is None: self.token_ids = [] if self.logprobs is None: self.logprobs = [] @property def length(self) -> int: return len(self.token_ids) @property def text_diff(self) -> str: return self.text[self._last_text_len:] @property def token_ids_diff(self) -> List[int]: return self.token_ids[self._last_token_ids_len:] @property def logprobs_diff(self) -> List[float]: return self.logprobs[self._last_logprobs_len:]
class GenerationResultBase: ''' This holds the core logic of the GenerationResult class. ''' def __init__(self, id: int, sampling_params: SamplingParams, background_error_handler: Optional[Callable] = None, postproc_params: "Optional[PostprocParams]" = None): self.id = id self.sampling_params = sampling_params self.postproc_params = postproc_params self.disaggregated_params = None self.decoding_iter = 0 self._done = False if has_event_loop(): self.aqueue = AsyncQueue() self.queue = self.aqueue.sync_q else: self.queue = Queue() self.aqueue = None # In Sampling mode, the Executor runtime will return best_of sequences # in total, which the LLM API will select the n-best sequences among # them based on their cumulative log probabilities. self._outputs: List[CompletionOutput] = [ CompletionOutput(i) for i in range(self.sampling_params.best_of) ] self._context_logits: Optional[torch.Tensor] = None self._background_error_handler = None if background_error_handler is not None: if not isinstance(background_error_handler, WeakMethod): self._background_error_handler = WeakMethod( background_error_handler) else: self._background_error_handler = background_error_handler # This is used for avoid duplicate transmission the sampling_params for a # request. SamplingParams is necessary for creating dummy # GenerationResultBase instances on postprocess worker processes. self._params_transmitted = False @property def outputs(self) -> List[CompletionOutput]: sampling_param = self.sampling_params if (sampling_param.use_beam_search or sampling_param.n == sampling_param.best_of): return self._outputs[:sampling_param.n] # Pick the top-n outputs, sorted by cumulative log probs. sorted_outputs = sorted( self._outputs, key=lambda x: (x.cumulative_logprob if x.cumulative_logprob is not None else float('-inf')), reverse=True) # Reindex the sequence. for i, sorted_out in enumerate(sorted_outputs): sorted_out.index = i return sorted_outputs[:sampling_param.n] @property def context_logits(self) -> Optional[torch.Tensor]: return self._context_logits def _handle_sequence(self, finish_reasons, response_tensors, sequence_index): """ Handle a single sequence in the response. """ seq_idx = sequence_index src_idx = sequence_index if self.sampling_params.use_beam_search else 0 output = self._outputs[seq_idx] output.disaggregated_params = self.disaggregated_params output._last_token_ids_len = len(output.token_ids) if self.sampling_params.use_beam_search: # Beam search enforces returning all generated tokens output.token_ids = response_tensors.output_token_ids[src_idx] else: output.token_ids.extend(response_tensors.output_token_ids[src_idx]) # In PD, the first token should be ignored in streaming mode, since it's already been returned by the context server if self.disaggregated_params is not None and self.disaggregated_params.request_type == "generation_only" and self._streaming and self.decoding_iter == 2: output._last_token_ids_len = 1 if response_tensors.cum_log_probs is not None: output.cumulative_logprob = response_tensors.cum_log_probs[src_idx] if response_tensors.log_probs is not None: output._last_logprobs_len = len(output.logprobs) output.logprobs = response_tensors.log_probs[src_idx] # overcome some WAR in the cpp executor if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED: assert len(output.logprobs) == output.length if response_tensors.generation_logits is not None: output.generation_logits = response_tensors.generation_logits[ src_idx, :output.length] # when sampling_params.n > 1 and is cancelled, make sure all the outputs # be marked as cancelled. if finish_reasons and finish_reasons[ src_idx] == tllm.FinishReason.CANCELLED: output.finish_reason = 'cancelled' if self._done: if finish_reasons[src_idx] == tllm.FinishReason.END_ID: output.finish_reason = 'stop' elif finish_reasons[src_idx] == tllm.FinishReason.STOP_WORDS: output.finish_reason = 'stop' for stop_reason, stop_ids in self.sampling_params._get_stop_reasons_and_words( ): if output.token_ids[-len(stop_ids):] == stop_ids: output.stop_reason = stop_reason if not self.sampling_params.include_stop_str_in_output: output.token_ids = output.token_ids[:-len(stop_ids)] break elif finish_reasons[src_idx] == tllm.FinishReason.LENGTH: output.finish_reason = 'length' elif finish_reasons[src_idx] == tllm.FinishReason.TIMED_OUT: output.finish_reason = 'timeout' elif finish_reasons[src_idx] == tllm.FinishReason.CANCELLED: pass else: raise ValueError( f"Unknown finish reason: {finish_reasons[src_idx]}") @nvtx_range_debug("handle_response", color="red", category="GenerationResultBase") def _handle_response(self, response: Union["PostprocWorker.Output", tllm.Response, ErrorResponse]): if isinstance(response, PostprocWorker.Output): self._done = response.is_final if isinstance(response.res, CompletionOutput): # in streaming mode self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res if response.error: if self._background_error_handler is not None and ( handler := self._background_error_handler()): handler(response.error) elif isinstance(response, tllm.Response): if response.has_error(): if self._background_error_handler is not None and ( handler := self._background_error_handler()): handler(response.error_msg) response_result = response.result self._done = response_result.is_final context_phase_params = response_result.context_phase_params self.decoding_iter = response_result.decoding_iter if context_phase_params is not None: self.disaggregated_params = DisaggregatedParams( request_type="context_only", first_gen_tokens=context_phase_params.first_gen_tokens, ctx_request_id=context_phase_params.req_id, opaque_state=context_phase_params.opaque_state, draft_tokens=context_phase_params.draft_tokens) finish_reasons = response_result.finish_reasons # output_token_ids = (beams, tokens) if self.sampling_params.use_beam_search: for beam_idx, _ in enumerate(response_result.output_token_ids): self._handle_sequence(finish_reasons, response_result, beam_idx) else: self._handle_sequence(finish_reasons, response_result, response_result.sequence_index) if response_result.context_logits is not None: self._context_logits = response_result.context_logits # Processing background errors here ASAF during generation. if self._background_error_handler and ( handler := self._background_error_handler()): handler() elif isinstance(response, ErrorResponse): if self._background_error_handler is not None and ( handler := self._background_error_handler()): handler(response.error_msg) else: raise ValueError(f"Unknown response type: {response}") class DetokenizedGenerationResultBase(GenerationResultBase): ''' The base class for the generation result with detokenization support. ''' # import once and avoid cyclic import from .postproc_worker import PostprocWorker def __init__(self, id: int, sampling_params: SamplingParams, tokenizer: Optional[Callable] = None, streaming: bool = False, background_error_handler: Optional[Callable] = None, postproc_params: Optional["PostprocParams"] = None): super().__init__( id, sampling_params, background_error_handler=background_error_handler, postproc_params=postproc_params, ) self.tokenizer = tokenizer self._streaming = streaming @nvtx_range_debug("handle_response", color="red", category="DetokenizedGenerationResultBase") def _handle_response(self, response: "GenerationExecutor.Response"): GenerationResultBase._handle_response(self, response) # The postprocess has been performed, return directly if isinstance(response, PostprocWorker.Output): return kwargs = { 'skip_special_tokens': self.sampling_params.skip_special_tokens, 'spaces_between_special_tokens': self.sampling_params.spaces_between_special_tokens } if self.sampling_params.detokenize and self.tokenizer is not None: for beam_output in self.outputs: beam_output._last_text_len = len(beam_output.text) if hasattr(self.tokenizer, 'decode_incrementally'): if self._streaming and not self.sampling_params.use_beam_search: beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally( beam_output.token_ids_diff, prev_text=beam_output.text, states=beam_output._incremental_states, flush=self._done, **kwargs) else: beam_output.text, _ = self.tokenizer.decode_incrementally( beam_output.token_ids, flush=self._done, **kwargs) else: beam_output.text = self.tokenizer.decode( beam_output.token_ids, **kwargs) # alias PostprocWorker = DetokenizedGenerationResultBase.PostprocWorker class GenerationResult(GenerationResultBase): ''' The result of a generation request. It can be used to wait for the completion of the request. Args: generation_request (GenerationRequest): The generation request object. background_error_handler (Callable, optional): The error handler to process the errors from the background threads/processes. Defaults to None. executor (GenerationExecutor, optional): The executor that created this result. Defaults to None. ''' def __init__( self, generation_request: "GenerationRequest", background_error_handler: Optional[Callable] = None, executor: Optional["GenerationExecutor"] = None, disaggregated_params: Optional[DisaggregatedParams] = None) -> None: super().__init__( generation_request.id, generation_request.sampling_params, background_error_handler, postproc_params=generation_request.postproc_params, ) self._generation_request = generation_request self._streaming = generation_request.streaming self.disaggregated_params = disaggregated_params # for aborting the request self._executor: Optional[weakref.ReferenceType[ "GenerationExecutor"]] = weakref.ref(executor) if executor else None self._aborted = False @property def request_id(self) -> int: return self._generation_request.id @property def prompt_token_ids(self) -> List[int]: return self._generation_request.prompt_token_ids def abort(self) -> None: """Abort the generation request. """ assert self._executor is not None, "The executor is not set for this result." self._executor().abort_request(self.request_id) self._aborted = True def aborted(self) -> bool: """Return whether the generation request is aborted. Returns: bool: whether the generation request is aborted. """ return self._aborted @property def finished(self) -> bool: return self._done def _result_step(self, timeout: Optional[float] = None): response = self.queue.get(timeout=timeout) self._handle_response(response) async def _aresult_step(self): assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available." response = await self.aqueue.get() global_tracer().log_instant("result_step.get") self._handle_response(response) def result(self, timeout: Optional[float] = None) -> "GenerationResult": """Wait for the completion of the request, and return the result. Args: timeout (float, optional): Timeout. Defaults to None. Returns: tensorrt_llm.executor.result.GenerationResult: generation result. """ while not self._done: self._result_step(timeout) return self async def aresult(self) -> "GenerationResult": """Wait for the completion of the request, and return the result. Returns: tensorrt_llm.executor.result.GenerationResult: generation result. """ while not self._done: await self._aresult_step() return self def __await__(self): return self.aresult().__await__() def __iter__(self): return self def __next__(self): if self._done: raise StopIteration self._result_step() return self def __aiter__(self): return self async def __anext__(self): if self._done: raise StopAsyncIteration await self._aresult_step() return self def _exception(self, timeout: Optional[float] = None): try: self.result(timeout) except RuntimeError as e: return e def _repr_fields(self): return [ 'request_id', 'prompt_token_ids', 'outputs', 'finished', "context_logits" ] def __repr__(self) -> str: repr = [] for field in self._repr_fields(): value = getattr(self, field) if isinstance(value, str): repr.append(f"{field}={value!r}") else: repr.append(f"{field}={value}") repr = ", ".join(repr) repr = f"{self.__class__.__name__}({repr})" return repr def __hash__(self): return hash(self.request_id) class IterationResult: """ Runtime results for all available iterations. """ def __init__(self): self._done = False self._timeout = 2 if has_event_loop(): self.aqueue = AsyncQueue() self.queue = self.aqueue.sync_q else: self.queue = Queue() self.aqueue = None def set_timeout(self, timeout: float): self._timeout = timeout def mark_undone(self): # should be called when new prompts are submitted self._done = False def get_results(self) -> List[dict]: """ Return all runtime results in the queue. """ results = [] while not self._done: try: data = self.queue.get(timeout=self._timeout) results.append(json.loads(data)) except Empty: self._done = True return results def __aiter__(self): return self async def __anext__(self): if self._done: raise StopAsyncIteration assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available." try: data = await self.aqueue.get(timeout=self._timeout) return json.loads(data) except asyncio.TimeoutError: self._done = True raise StopAsyncIteration