import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from typing import List, NamedTuple, Optional, Tuple, Union
import torch
from pydantic import BaseModel
from tensorrt_llm.bindings import executor as tllme
[docs]
@dataclass(slots=True, kw_only=True)
class GuidedDecodingParams:
"""
Guided decoding parameters for text generation. Only one of the fields could be effective.
Args:
json (str, pydantic.main.BaseModel, dict, optional): The generated text is amenable to json format with additional user-specified restrictions, namely schema. Defaults to None.
regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None.
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
"""
json: Optional[Union[str, BaseModel, dict]] = None
regex: Optional[str] = None
grammar: Optional[str] = None
json_object: bool = False
def _validate(self):
num_guides = 0
for field in fields(self):
num_guides += bool(getattr(self, field.name))
if num_guides > 1:
raise ValueError(
f"Only one guide can be used for a request, but got {num_guides}."
)
class LogprobParams(NamedTuple):
prompt_logprobs: Optional[int] = None
logprobs: Optional[int] = None
# Drop the logits once the logprobs are computed
drop_context_logits: bool = False
# Drop the geneation_logits once the logprobs are computed
drop_generation_logits: bool = False
class LogitsProcessor(ABC):
"""Base class for logits processor.
The recommended way to create a customized logits processor:
* Subclass this class and implement the processing logics in the __call__ method.
* Create an instance and pass to SamplingParams.
Alternatively, you can create any callable with the same signature with the __call__ method.
"""
@abstractmethod
def __call__(self, req_id: int, logits: torch.Tensor,
token_ids: List[List[int]], stream_ptr: Optional[int],
client_id: Optional[int]) -> None:
"""Logits processing callback. The callback is expected to inplace modify the logits.
Args:
req_id (int): Request id.
logits (torch.Tensor): Logits tensor to be modified.
token_ids (List[List[int]]): Token ids produced by the request so far. The shape is beam_width * sequence_length.
stream_ptr (int, optional): The operation stream used by the logits tensor. Not required for PyTorch backend.
client_id (int, optional): An optional client id.
"""
pass # noqa
class BatchedLogitsProcessor(ABC):
"""Base class for batched logits processor.
The recommended way to create a customized batched logits processor:
* Subclass this class and implement the processing logics in the __call__ method.
* Create an instance and pass to LLM.
Alternatively, you can create any callable with the same signature with the __call__ method.
"""
@abstractmethod
def __call__(self, req_ids: List[int], logits: List[torch.Tensor],
token_ids: List[List[List[int]]], stream_ptr: int,
client_ids: List[Optional[int]]) -> None:
"""Batched logits processing callback. The callback is expected to inplace modify the logits.
Args:
req_ids (List[int]): A batch of request ids.
logits (List[torch.Tensor]): A batch of the logits tensors.
token_ids (List[List[List[int]]]): A batch of the token ids produced by the requests so far. The shape is batch * beam_width * sequence_length.
stream_ptr (int): The operation stream used by the logits tensors.
client_ids (List[Optional[int]]): A batch of optional client ids.
"""
pass # noqa
@dataclass(slots=True, kw_only=True)
class AdditionalModelOutput:
"""
An additional output to gather from the model.
Args:
name (str): The name of the additional output to gather from the model.
gather_context (bool): A value indicating whether or not to gather the additional output from the context too. Defaults to False.
"""
name: str
gather_context: bool
[docs]
@dataclass(slots=True, kw_only=True)
class SamplingParams:
"""
Sampling parameters for text generation.
Usage Examples:
use_beam_search is False:
- best_of is None: (top-p/top-k) sampling n responses and return n generations
- best_of is not None: (top-p/top-k) sampling best_of responses and return n generations (best_of >= n must hold)
use_beam_search is True:
- best_of is None: beam search with beam width of n, return n generations
- best_of is not None: beam search with beam width of best_of, return n generations (best_of >= n must hold)
Args:
end_id (int, optional): The end token id. Defaults to None.
pad_id (int, optional): The pad token id. Defaults to None.
max_tokens (int): The maximum number of tokens to generate. Defaults to 32.
bad (str, List[str], optional): A string or a list of strings that redirect the generation when they are generated, so that the bad strings are excluded from the returned output. Defaults to None.
bad_token_ids (List[int], optional): A list of token ids that redirect the generation when they are generated, so that the bad ids are excluded from the returned output. Defaults to None.
stop (str, List[str], optional): A string or a list of strings that stop the generation when they are generated. The returned output will not contain the stop strings unless include_stop_str_in_output is True. Defaults to None.
stop_token_ids (List[int], optional): A list of token ids that stop the generation when they are generated. Defaults to None.
include_stop_str_in_output (bool): Whether to include the stop strings in output text. Defaults to False.
embedding_bias (torch.Tensor, optional): The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size]. Defaults to None.
logits_processor (tensorrt_llm.sampling_params.LogitsProcessor, List[tensorrt_llm.sampling_params.LogitsProcessor], optional): The logits postprocessor callback(s). Defaults to None.
If a list, each processor is applied in order during generation (supported in PyTorch backend only).
apply_batched_logits_processor (bool): Whether to apply batched logits postprocessor callback. Defaults to False.
The BatchedLogitsProcessor class is recommended for callback creation. The callback must be provided when initializing LLM.
n (int): Number of sequences to generate. Defaults to 1.
best_of (int, optional): Number of sequences to consider for best output. Defaults to None.
use_beam_search (bool): Whether to use beam search. Defaults to False.
top_k (int, optional): Controls number of logits to sample from. None means using C++ runtime default 0, i.e., all logits. Defaults to None.
top_p (float, optional): Controls the top-P probability to sample from. None means using C++ runtime default 0.f. Defaults to None.
top_p_min (float, optional): Controls decay in the top-P algorithm. topPMin is lower-bound. None means using C++ runtime default 1.e-6. Defaults to None.
top_p_reset_ids (int, optional): Controls decay in the top-P algorithm. Indicates where to reset the decay. None means using C++ runtime default 1. Defaults to None.
top_p_decay (float, optional): Controls decay in the top-P algorithm. The decay value. None means using C++ runtime default 1.f. Defaults to None.
seed (int, optional): Controls the random seed used by the random number generator in sampling. None means using C++ runtime default 0. Defaults to None.
temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values > 0.f. None means using C++ runtime default 1.0f. Defaults to None.
min_tokens (int, optional): Lower bound on the number of tokens to generate. Values < 1 have no effect. None means using C++ runtime default 1. Defaults to None.
beam_search_diversity_rate (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
presence_penalty (float, optional): Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
frequency_penalty (float, optional): Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None.
length_penalty (float, optional): Controls how to penalize longer sequences in beam search. None means using C++ runtime default 0.f. Defaults to None.
early_stopping (int, optional): Controls whether the generation process finishes once beamWidth sentences are generated (ends with end_token). None means using C++ runtime default 1. Defaults to None.
no_repeat_ngram_size (int, optional): Controls how many repeat ngram size are acceptable. None means using C++ runtime default 1 << 30. Defaults to None.
min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None.
beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None.
logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
exclude_input_from_output (bool): Controls if output tokens in Result should include the input tokens. Defaults to True.
return_encoder_output (bool): Controls if Result should contain encoder output hidden states (for encoder-only and encoder-decoder models). Defaults to False.
return_perf_metrics (bool): Controls if Result should contain the performance metrics for this request. Defaults to False.
additional_model_outputs (List[tensorrt_llm.sampling_params.AdditionalModelOutput], optional): The additional outputs to gather from the model. Defaults to None.
lookahead_config (tensorrt_llm.bindings.executor.LookaheadDecodingConfig , optional): Lookahead decoding config. Defaults to None.
guided_decoding (tensorrt_llm.sampling_params.GuidedDecodingParams, optional): Guided decoding params. Defaults to None.
ignore_eos (bool): Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. Defaults to False.
detokenize (bool): Whether to detokenize the output. Defaults to True.
add_special_tokens (bool): Whether to add special tokens to the prompt. Defaults to True.
truncate_prompt_tokens (int, optional): If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None.
skip_special_tokens (bool): Whether to skip special tokens in the output. Defaults to True.
spaces_between_special_tokens (bool): Whether to add spaces between special tokens in the output. Defaults to True.
"""
# [TO DEVELOPER] This class provides an interface to LLMAPI users.
# Internally, it manages and dispatches fields to Python bindings of C++ objects, currently including:
# (1) all fields of tllme.SamplingConfig;
# (2) all fields of tllme.OutputConfig;
# (3) some fields of tllme.Request.
# If you changed the implementation of C++ objects and corresponding Python bindings, please update:
# (1) the fields and corresponding docstring of this class, and
# (2) the expected_fields defined in _get_xxx_config methods.
end_id: Optional[int] = None
pad_id: Optional[int] = None
max_tokens: int = 32
bad: Optional[Union[str, List[str]]] = None
bad_token_ids: Optional[List[int]] = None
_bad_word_ids: Optional[List[List[int]]] = field(default=None,
init=False,
repr=False)
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
include_stop_str_in_output: bool = False
_stop_word_ids: Optional[List[List[int]]] = field(default=None,
init=False,
repr=False)
embedding_bias: Optional[torch.Tensor] = None
logits_processor: Optional[Union[LogitsProcessor,
List[LogitsProcessor]]] = None
apply_batched_logits_processor: bool = False
n: int = 1
best_of: Optional[int] = None
use_beam_search: bool = False
# Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table.
top_k: Optional[int] = None
top_p: Optional[float] = None
top_p_min: Optional[float] = None
top_p_reset_ids: Optional[int] = None
top_p_decay: Optional[float] = None
seed: Optional[int] = None
temperature: Optional[float] = None
min_tokens: Optional[int] = None
beam_search_diversity_rate: Optional[float] = None
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
length_penalty: Optional[float] = None
early_stopping: Optional[int] = None
no_repeat_ngram_size: Optional[int] = None
min_p: Optional[float] = None
beam_width_array: Optional[List[int]] = None
# Keep the below fields in sync with tllme.OutputConfig
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
return_context_logits: bool = False
return_generation_logits: bool = False
exclude_input_from_output: bool = True
return_encoder_output: bool = False
return_perf_metrics: bool = False
additional_model_outputs: Optional[List[AdditionalModelOutput]] = None
# Used in logprobs calculation in TRT flow to drop logits early if user did not explicitly request them.
# Can be deprecated after migration to PyTorch backend.
_context_logits_auto_enabled: bool = False
_generation_logits_auto_enabled: bool = False
# TODO: deprecate this after trtllm-serve migrate to use TopK logprobs
_return_log_probs: bool = False
# Lookahead decoding config
lookahead_config: Optional[tllme.LookaheadDecodingConfig] = None
# Guided decoding params
guided_decoding: Optional[GuidedDecodingParams] = None
# Tokenizer-related configs
ignore_eos: bool = False
detokenize: bool = True
add_special_tokens: bool = True
truncate_prompt_tokens: Optional[int] = None
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
def __post_init__(self):
if self.pad_id is None:
self.pad_id = self.end_id
self.best_of = self.best_of or self.n
self._validate()
def _validate(self):
''' Verify the sampling parameters.
This function verifies the sampling parameters in the LLM API, which
may have stricter requirements than the Executor class of C++ runtime.
For instance, while the greedy decoding with n > 1 is capable in the
Executor class of C++ runtime, the LLM API disallows such combination.
'''
if self.best_of is not None:
if self.best_of > 1 and self.best_of < self.n:
raise ValueError(
f'In beam search, best_of ({self.best_of}) must be '
f'greater than or equal to n ({self.n}).')
if (self.best_of > 1 and self._greedy_decoding and
not os.environ.get('TLLM_ALLOW_N_GREEDY_DECODING', None)):
raise ValueError(
f'Greedy decoding in the LLM API does not allow multiple '
f'returns. Please set to best_of=1, got best_of={self.best_of}. '
f'Please set to best_of=1 or set an environment variable '
f'TLLM_ALLOW_N_GREEDY_DECODING=1 to allow best_of > 1 '
f'under the greedy decoding.')
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
raise ValueError(
f"truncate_prompt_tokens must be >= 1, got {self.truncate_prompt_tokens}"
)
if self.guided_decoding is not None:
self.guided_decoding._validate()
@property
def _greedy_decoding(self) -> bool:
return (not self.use_beam_search
and (self.top_k is None or self.top_k == 1)
and (self.top_p is None or self.top_p == 0.0))
@property
def _need_return_context_logits(self) -> bool:
return self.return_context_logits and not self._context_logits_auto_enabled
@property
def _need_return_generation_logits(self) -> bool:
return self.return_generation_logits and not self._generation_logits_auto_enabled
def _setup(self,
tokenizer,
add_special_tokens: bool = False) -> 'SamplingParams':
if self.end_id is None:
self.end_id = tokenizer.eos_token_id
self.pad_id = tokenizer.pad_token_id
if self.pad_id is None:
self.pad_id = self.end_id
if self.bad is not None:
strs = [self.bad] if isinstance(self.bad, str) else self.bad
self._bad_word_ids = [
tokenizer.encode(s, add_special_tokens=add_special_tokens)
for s in strs
]
if self.stop is not None:
strs = [self.stop] if isinstance(self.stop, str) else self.stop
self._stop_word_ids = [
tokenizer.encode(s, add_special_tokens=add_special_tokens)
for s in strs
]
return self
def _get_bad_words(self) -> List[List[int]]:
words = []
if self.bad_token_ids:
words = [[i] for i in self.bad_token_ids]
if self.bad is None:
return words
else:
if self._bad_word_ids is None:
raise RuntimeError(
f"{self.__class__.__name__}.bad ({self.bad}) is not processed by tokenizer, "
"please call the setup method.")
return words + self._bad_word_ids
def _get_stop_words(self) -> List[List[int]]:
words = []
if self.stop_token_ids:
words = [[i] for i in self.stop_token_ids]
if self.stop is None:
return words
else:
if self._stop_word_ids is None:
raise RuntimeError(
f"{self.__class__.__name__}.stop ({self.stop}) is not processed by tokenizer, "
"please call the setup method.")
return words + self._stop_word_ids
def _get_stop_reasons_and_words(
self) -> List[Tuple[Union[str, int], List[List[int]]]]:
stop_reasons = []
if self.stop_token_ids is not None:
stop_reasons.extend(self.stop_token_ids)
if self.stop is not None:
if isinstance(self.stop, str):
stop_reasons.append(self.stop)
else:
stop_reasons.extend(self.stop)
stop_words = self._get_stop_words()
return list(zip(stop_reasons, stop_words))
def _get_sampling_config(self) -> tllme.SamplingConfig:
# A map from the SamplingConfig fields of the LLM API to their
# corresponding field names of the Executor of TRT-LLM C++ runtime.
# In sampling, there is no parameter that directly matches 'best_of',
# so outputs must be trimmed during postprocessing.
# | LLM API | TRT-LLM Executor |
# --------------|-----------------|------------------------|
# | Beam search | use_beam_search | beam_width > 1 |
# | Beam search | n | num_return_sequences |
# | Beam search | best_of | beam_width |
# |-------------|-----------------|------------------------|
# | Sampling | use_beam_search | beam_width == 1 |
# | Sampling | n | num_return_sequences |
# | Sampling | best_of | no corresponding param |
fields = {
f
for f in dir(tllme.SamplingConfig) if not f.startswith('__')
}
unmatched_params = [
'num_return_sequences',
'beam_width',
'n',
'best_of',
'use_beam_search',
]
llmapi_to_rt_param_map = {
f: getattr(self, f)
for f in fields if f not in unmatched_params
}
if self.use_beam_search:
llmapi_to_rt_param_map['num_return_sequences'] = self.n
llmapi_to_rt_param_map['beam_width'] = self.best_of
else:
llmapi_to_rt_param_map['num_return_sequences'] = self.best_of
llmapi_to_rt_param_map['beam_width'] = 1
return tllme.SamplingConfig(**llmapi_to_rt_param_map)
def _get_output_config(self,
is_pytorch_backend: bool = False
) -> tllme.OutputConfig:
sampling_param_fields = set(dir(SamplingParams))
fields = [
f for f in dir(tllme.OutputConfig)
if not f.startswith('__') and f in sampling_param_fields
]
config_kwargs = {f: getattr(self, f) for f in fields}
if is_pytorch_backend:
config_kwargs["return_log_probs"] = bool(self.logprobs)
else:
config_kwargs["return_log_probs"] = self._return_log_probs
return tllme.OutputConfig(**config_kwargs)
def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
if self.guided_decoding is None:
return None
if self.guided_decoding.json_object:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json, BaseModel):
json_schema = json_schema.model_json_schema()
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema)
elif self.guided_decoding.regex is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.REGEX,
self.guided_decoding.regex)
elif self.guided_decoding.grammar is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
self.guided_decoding.grammar)
else:
return None