API Reference

class tensorrt_llm.hlapi.LLM(model: str, tokenizer: str | Path | PreTrainedTokenizerBase | TokenizerBase | None = None, skip_tokenizer_init: bool = False, tensor_parallel_size: int = 1, dtype: str = 'auto', revision: str | None = None, tokenizer_revision: str | None = None, **kwargs: Any)[source]

Bases: object

LLM class is the main class for running a LLM model.

Parameters:
  • model (str) – The model name or a local path to the model directory. It could be a HuggingFace(HF) model name, a local path to the HF model, or a local path to the TRT-LLM engine or checkpoint.

  • tokenizer (Optional[Union[str, Path, TokenizerBase, PreTrainedTokenizerBase]]) – The tokenizer name or a local path to the tokenizer directory.

  • skip_tokenizer_init – If true, skip initialization of tokenizer and detokenizer. generate and generate_async will accept prompt token ids as input only.

  • tensor_parallel_size (int) – The number of processes for tensor parallelism.

  • dtype (str) – The data type for the model weights and activations.

  • revision (Optional[str]) – The revision of the model.

  • tokenzier_revision (Optional[str]) – The revision of the tokenizer.

  • build_config (BuildConfig, default=BuildConfig()) – The build configuration for the model. Default is an empty BuildConfig instance.

  • quant_config (QuantConfig, default=QuantConfig()) – The quantization configuration for the model. Default is an empty QuantConfig instance.

  • embedding_parallel_mode (str, default="SHARDING_ALONG_VOCAB") – The parallel mode for embeddings.

  • share_embedding_table (bool, default=False) – Whether to share the embedding table.

  • kv_cache_config (KvCacheConfig, optional) – The key-value cache configuration for the model. Default is None.

  • peft_cache_config (PeftCacheConfig, optional) – The PEFT cache configuration for the model. Default is None.

  • decoding_config (DecodingConfig, optional) – The decoding configuration for the model. Default is None.

  • logits_post_processor_map (Dict[str, Callable], optional) – A map of logit post-processing functions. Default is None.

  • scheduler_config (SchedulerConfig, default=SchedulerConfig()) – The scheduler configuration for the model. Default is an empty SchedulerConfig instance.

  • normalize_log_probs (bool, default=False) – Whether to normalize log probabilities for the model.

  • iter_stats_max_iterations (int, optional) – The maximum number of iterations for iteration statistics. Default is None.

  • request_stats_max_iterations (int, optional) – The maximum number of iterations for request statistics. Default is None.

  • batching_type (BatchingType, optional) – The batching type for the model. Default is None.

  • enable_build_cache (bool or BuildCacheConfig, optional) – Whether to enable build caching for the model. Default is None.

  • enable_tqdm (bool, default=False) – Whether to display a progress bar during model building.

__init__(model: str, tokenizer: str | Path | PreTrainedTokenizerBase | TokenizerBase | None = None, skip_tokenizer_init: bool = False, tensor_parallel_size: int = 1, dtype: str = 'auto', revision: str | None = None, tokenizer_revision: str | None = None, **kwargs: Any)[source]
generate(inputs: str | List[int] | Sequence[str | List[int]], sampling_params: SamplingParams | List[SamplingParams] | None = None, use_tqdm: bool = True) RequestOutput | List[RequestOutput][source]

Generate output for the given prompts in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts.

Parameters:
  • inputs (Union[str, Iterable[str], List[int], Iterable[List[int]]]) – The prompt text or token ids. Note, it must be single prompt or batched prompts.

  • sampling_params (Optional[Union[SamplingParams, List[SamplingParams]]]) – The sampling params for the generation, a default one will be used if not provided.

  • use_tqdm – Whether to use tqdm to display the progress bar.

Returns:

The output data of the completion request to the LLM.

Return type:

Union[RequestOutput, List[RequestOutput]]

generate_async(inputs: str | List[int], sampling_params: SamplingParams | None = None, streaming: bool = False) RequestOutput[source]

Generate output for the given prompt in the asynchronous mode. Asynchronous generation accepts single prompt only.

Parameters:
  • inputs (Union[str, List[int]]) – The prompt text or token ids; must be single prompt.

  • sampling_params (Optional[SamplingParams]) – The sampling params for the generation, a default one will be used if not provided.

  • streaming (bool) – Whether to use the streaming mode for the generation.

Returns:

The output data of the completion request to the LLM.

Return type:

RequestOutput

save(engine_dir: str)[source]

Save the built engine to the given path.

Parameters:

engine_dir (str) – The path to save the engine.

Returns:

None

property tokenizer: TokenizerBase | None
property workspace: Path
class tensorrt_llm.hlapi.RequestOutput(generation_result: GenerationResult, prompt: str | None = None, tokenizer: TokenizerBase | None = None)[source]

Bases: GenerationResult

The output data of a completion request to the LLM.

Fields:

request_id (int): The unique ID of the request. prompt (str): The prompt string of the request. prompt_token_ids (List[int]): The token ids of the prompt. outputs (List[CompletionOutput]): The output sequences of the request. context_logits (torch.Tensor): The logits on the prompt token ids. finished (bool): Whether the whole request is finished.

__init__(generation_result: GenerationResult, prompt: str | None = None, tokenizer: TokenizerBase | None = None) None[source]
handle_generation_msg(tensors: tuple, error: str)[source]
class tensorrt_llm.hlapi.SamplingParams(end_id: int | None = None, pad_id: int | None = None, max_new_tokens: int = 32, bad: List[str] | str | None = None, bad_token_ids: List[int] | None = None, stop: List[str] | str | None = None, stop_token_ids: List[int] | None = None, include_stop_str_in_output: bool = False, embedding_bias: Tensor | None = None, external_draft_tokens_config: ExternalDraftTokensConfig | None = None, prompt_tuning_config: PromptTuningConfig | None = None, lora_config: LoraConfig | None = None, logits_post_processor_name: str | None = None, beam_width: int = 1, top_k: int | None = None, top_p: float | None = None, top_p_min: float | None = None, top_p_reset_ids: int | None = None, top_p_decay: float | None = None, random_seed: int | None = None, temperature: float | None = None, min_length: int | None = None, beam_search_diversity_rate: float | None = None, repetition_penalty: float | None = None, presence_penalty: float | None = None, frequency_penalty: float | None = None, length_penalty: float | None = None, early_stopping: int | None = None, no_repeat_ngram_size: int | None = None, return_log_probs: bool = False, return_context_logits: bool = False, return_generation_logits: bool = False, exclude_input_from_output: bool = True, return_encoder_output: bool = False)[source]

Bases: object

Sampling parameters for text generation.

Parameters:
  • end_id (int) – The end token id.

  • pad_id (int) – The pad token id.

  • max_new_tokens (int) – The maximum number of tokens to generate.

  • bad (Union[str, List[str]]) – A string or a list of strings that redirect the generation when they are generated, so that the bad strings are excluded from the returned output.

  • bad_token_ids (List[int]) – A list of token ids that redirect the generation when they are generated, so that the bad ids are excluded from the returned output.

  • stop (Union[str, List[str]]) – A string or a list of strings that stop the generation when they are generated. The returned output will not contain the stop strings unless include_stop_str_in_output is True.

  • stop_token_ids (List[int]) – A list of token ids that stop the generation when they are generated.

  • include_stop_str_in_output (bool) – Whether to include the stop strings in output text. Defaults to False.

  • embedding_bias (torch.Tensor) – The embedding bias tensor. Expected type is kFP32 and shape is [vocab_size].

  • external_draft_tokens_config (ExternalDraftTokensConfig) – The speculative decoding configuration.

  • prompt_tuning_config (PromptTuningConfig) – The prompt tuning configuration.

  • lora_config (LoraConfig) – The LoRA configuration.

  • logits_post_processor_name (str) – The logits postprocessor name. Must correspond to one of the logits postprocessor name provided to the ExecutorConfig.

  • beam_width (int) – The beam width. Default is 1 which disables beam search.

  • top_k (int) – Controls number of logits to sample from. Default is 0 (all logits).

  • top_p (float) – Controls the top-P probability to sample from. Default is 0.f

  • top_p_min (float) – Controls decay in the top-P algorithm. topPMin is lower-bound. Default is 1.e-6.

  • top_p_reset_ids (int) – Controls decay in the top-P algorithm. Indicates where to reset the decay. Default is 1.

  • top_p_decay (float) – Controls decay in the top-P algorithm. The decay value. Default is 1.f

  • random_seed (int) – Controls the random seed used by the random number generator in sampling

  • temperature (float) – Controls the modulation of logits when sampling new tokens. It can have values > 0.f. Default is 1.0f

  • min_length (int) – Lower bound on the number of tokens to generate. Values < 1 have no effect. Default is 1.

  • beam_search_diversity_rate (float) – Controls the diversity in beam search.

  • repetition_penalty (float) – Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. Default is 1.f

  • presence_penalty (float) – Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f

  • frequency_penalty (float) – Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f

  • length_penalty (float) – Controls how to penalize longer sequences in beam search. Default is 0.f

  • early_stopping (int) – Controls whether the generation process finishes once beamWidth sentences are generated (ends with end_token)

  • no_repeat_ngram_size (int) – Controls how many repeat ngram size are acceptable. Default is 1 << 30.

  • return_log_probs (bool) – Controls if Result should contain log probabilities. Default is false.

  • return_context_logits (bool) – Controls if Result should contain the context logits. Default is false.

  • return_generation_logits (bool) – Controls if Result should contain the generation logits. Default is false.

  • exclude_input_from_output (bool) – Controls if output tokens in Result should include the input tokens. Default is true.

  • return_encoder_output (bool) – Controls if Result should contain encoder output hidden states (for encoder-only and encoder-decoder models). Default is false.

__init__(end_id: int | None = None, pad_id: int | None = None, max_new_tokens: int = 32, bad: List[str] | str | None = None, bad_token_ids: List[int] | None = None, stop: List[str] | str | None = None, stop_token_ids: List[int] | None = None, include_stop_str_in_output: bool = False, embedding_bias: Tensor | None = None, external_draft_tokens_config: ExternalDraftTokensConfig | None = None, prompt_tuning_config: PromptTuningConfig | None = None, lora_config: LoraConfig | None = None, logits_post_processor_name: str | None = None, beam_width: int = 1, top_k: int | None = None, top_p: float | None = None, top_p_min: float | None = None, top_p_reset_ids: int | None = None, top_p_decay: float | None = None, random_seed: int | None = None, temperature: float | None = None, min_length: int | None = None, beam_search_diversity_rate: float | None = None, repetition_penalty: float | None = None, presence_penalty: float | None = None, frequency_penalty: float | None = None, length_penalty: float | None = None, early_stopping: int | None = None, no_repeat_ngram_size: int | None = None, return_log_probs: bool = False, return_context_logits: bool = False, return_generation_logits: bool = False, exclude_input_from_output: bool = True, return_encoder_output: bool = False) None
bad: List[str] | str | None
bad_token_ids: List[int] | None
beam_search_diversity_rate: float | None
beam_width: int
early_stopping: int | None
embedding_bias: Tensor | None
end_id: int | None
exclude_input_from_output: bool
external_draft_tokens_config: ExternalDraftTokensConfig | None
frequency_penalty: float | None
include_stop_str_in_output: bool
length_penalty: float | None
logits_post_processor_name: str | None
lora_config: LoraConfig | None
max_new_tokens: int
min_length: int | None
no_repeat_ngram_size: int | None
pad_id: int | None
presence_penalty: float | None
prompt_tuning_config: PromptTuningConfig | None
random_seed: int | None
repetition_penalty: float | None
return_context_logits: bool
return_encoder_output: bool
return_generation_logits: bool
return_log_probs: bool
setup(tokenizer, add_special_tokens: bool = False) SamplingParams[source]
stop: List[str] | str | None
stop_token_ids: List[int] | None
temperature: float | None
top_k: int | None
top_p: float | None
top_p_decay: float | None
top_p_min: float | None
top_p_reset_ids: int | None
class tensorrt_llm.hlapi.KvCacheConfig

Bases: pybind11_object

__init__(self: tensorrt_llm.bindings.executor.KvCacheConfig, enable_block_reuse: bool = False, max_tokens: int | None = None, max_attention_window: list[int] | None = None, sink_token_length: int | None = None, free_gpu_memory_fraction: float | None = None, host_cache_size: int | None = None, onboard_blocks: bool = True) None
property enable_block_reuse
property free_gpu_memory_fraction
property host_cache_size
property max_attention_window
property max_tokens
property onboard_blocks
property sink_token_length
class tensorrt_llm.hlapi.SchedulerConfig

Bases: pybind11_object

__init__(*args, **kwargs)

Overloaded function.

  1. __init__(self: tensorrt_llm.bindings.executor.SchedulerConfig, capacity_scheduler_policy: tensorrt_llm.bindings.executor.CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) -> None

  2. __init__(self: tensorrt_llm.bindings.executor.SchedulerConfig, capacity_scheduler_policy: tensorrt_llm.bindings.executor.CapacitySchedulerPolicy, context_chunking_policy: Optional[tensorrt_llm.bindings.executor.ContextChunkingPolicy]) -> None

property capacity_scheduler_policy
property context_chunking_policy
class tensorrt_llm.hlapi.CapacitySchedulerPolicy

Bases: pybind11_object

Members:

MAX_UTILIZATION

GUARANTEED_NO_EVICT

GUARANTEED_NO_EVICT = <CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: 1>
MAX_UTILIZATION = <CapacitySchedulerPolicy.MAX_UTILIZATION: 0>
__init__(self: tensorrt_llm.bindings.executor.CapacitySchedulerPolicy, value: int) None
property name
property value
class tensorrt_llm.hlapi.BuildConfig(max_input_len: int = 256, max_seq_len: int = 512, opt_batch_size: int = 8, max_batch_size: int = 8, max_beam_width: int = 1, max_num_tokens: Optional[int] = None, opt_num_tokens: Optional[int] = None, max_prompt_embedding_table_size: int = 0, kv_cache_type: tensorrt_llm.bindings.KVCacheType = None, gather_context_logits: int = False, gather_generation_logits: int = False, strongly_typed: bool = True, builder_opt: Optional[int] = None, force_num_profiles: Optional[int] = None, profiling_verbosity: str = 'layer_names_only', enable_debug_output: bool = False, max_draft_len: int = 0, speculative_decoding_mode: tensorrt_llm.models.modeling_utils.SpeculativeDecodingMode = <SpeculativeDecodingMode.NONE: 1>, use_refit: bool = False, input_timing_cache: str = None, output_timing_cache: str = None, lora_config: tensorrt_llm.lora_manager.LoraConfig = <factory>, auto_parallel_config: tensorrt_llm.auto_parallel.config.AutoParallelConfig = <factory>, weight_sparsity: bool = False, weight_streaming: bool = False, plugin_config: tensorrt_llm.plugin.plugin.PluginConfig = <factory>, use_strip_plan: bool = False, max_encoder_input_len: int = 1, use_fused_mlp: bool = False, dry_run: bool = False, visualize_network: bool = False)[source]

Bases: object

__init__(max_input_len: int = 256, max_seq_len: int = 512, opt_batch_size: int = 8, max_batch_size: int = 8, max_beam_width: int = 1, max_num_tokens: int | None = None, opt_num_tokens: int | None = None, max_prompt_embedding_table_size: int = 0, kv_cache_type: ~tensorrt_llm.bindings.KVCacheType | None = None, gather_context_logits: int = False, gather_generation_logits: int = False, strongly_typed: bool = True, builder_opt: int | None = None, force_num_profiles: int | None = None, profiling_verbosity: str = 'layer_names_only', enable_debug_output: bool = False, max_draft_len: int = 0, speculative_decoding_mode: ~tensorrt_llm.models.modeling_utils.SpeculativeDecodingMode = SpeculativeDecodingMode.NONE, use_refit: bool = False, input_timing_cache: str | None = None, output_timing_cache: str | None = None, lora_config: ~tensorrt_llm.lora_manager.LoraConfig = <factory>, auto_parallel_config: ~tensorrt_llm.auto_parallel.config.AutoParallelConfig = <factory>, weight_sparsity: bool = False, weight_streaming: bool = False, plugin_config: ~tensorrt_llm.plugin.plugin.PluginConfig = <factory>, use_strip_plan: bool = False, max_encoder_input_len: int = 1, use_fused_mlp: bool = False, dry_run: bool = False, visualize_network: bool = False) None
auto_parallel_config: AutoParallelConfig
builder_opt: int | None = None
dry_run: bool = False
enable_debug_output: bool = False
force_num_profiles: int | None = None
classmethod from_dict(config, plugin_config=None)[source]
classmethod from_json_file(config_file, plugin_config=None)[source]
gather_context_logits: int = False
gather_generation_logits: int = False
input_timing_cache: str = None
kv_cache_type: KVCacheType = None
lora_config: LoraConfig
max_batch_size: int = 8
max_beam_width: int = 1
max_draft_len: int = 0
max_encoder_input_len: int = 1
max_input_len: int = 256
max_num_tokens: int | None = None
max_prompt_embedding_table_size: int = 0
max_seq_len: int = 512
opt_batch_size: int = 8
opt_num_tokens: int | None = None
output_timing_cache: str = None
plugin_config: PluginConfig
profiling_verbosity: str = 'layer_names_only'
speculative_decoding_mode: SpeculativeDecodingMode = 1
strongly_typed: bool = True
to_dict()[source]
update(**kwargs)[source]
update_from_dict(config: dict)[source]
update_kv_cache_type(model_architecture: str)[source]
use_fused_mlp: bool = False
use_refit: bool = False
use_strip_plan: bool = False
visualize_network: bool = False
weight_sparsity: bool = False
weight_streaming: bool = False
class tensorrt_llm.hlapi.QuantConfig(quant_algo: QuantAlgo | None = None, kv_cache_quant_algo: QuantAlgo | None = None, group_size: int | None = 128, smoothquant_val: float = 0.5, clamp_val: List[float] | None = None, has_zero_point: bool | None = False, pre_quant_scale: bool | None = False, exclude_modules: List[str] | None = None)[source]

Bases: object

Serializable quantization configuration class, part of the PretrainedConfig

__init__(quant_algo: QuantAlgo | None = None, kv_cache_quant_algo: QuantAlgo | None = None, group_size: int | None = 128, smoothquant_val: float = 0.5, clamp_val: List[float] | None = None, has_zero_point: bool | None = False, pre_quant_scale: bool | None = False, exclude_modules: List[str] | None = None) None
clamp_val: List[float] | None = None
exclude_modules: List[str] | None = None
classmethod from_dict(config: dict)[source]
get_modelopt_kv_cache_dtype()[source]
get_modelopt_qformat()[source]
group_size: int | None = 128
has_zero_point: bool | None = False
kv_cache_quant_algo: QuantAlgo | None = None
pre_quant_scale: bool | None = False
quant_algo: QuantAlgo | None = None
property quant_mode: QuantMode
property requires_calibration
property requires_modelopt_quantization
smoothquant_val: float = 0.5
to_dict()[source]
property use_plugin_sq
class tensorrt_llm.hlapi.QuantAlgo(value)[source]

Bases: StrEnum

An enumeration.

FP8 = 'FP8'
FP8_PER_CHANNEL_PER_TOKEN = 'FP8_PER_CHANNEL_PER_TOKEN'
INT8 = 'INT8'
W4A16 = 'W4A16'
W4A16_AWQ = 'W4A16_AWQ'
W4A16_GPTQ = 'W4A16_GPTQ'
W4A8_AWQ = 'W4A8_AWQ'
W8A16 = 'W8A16'
W8A8_SQ_PER_CHANNEL = 'W8A8_SQ_PER_CHANNEL'
W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN'
W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN'
W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN'
W8A8_SQ_PER_TENSOR_PLUGIN = 'W8A8_SQ_PER_TENSOR_PLUGIN'