Speculative Decoding#
Speculative decoding is a technique for accelerating LLM inference at low batch sizes. A lightweight drafting mechanism proposes candidate tokens, and the target model verifies them in a single forward pass. Tokens that match are accepted, reducing the number of sequential forward passes needed.
Quick Start#
For all speculation algorithms, when speculation is enabled, a single sequence of draft tokens with length max_draft_len is created for every request. There is currently no way to dynamically disable speculation, thus speed ups are only observable at low batch sizes.
Draft/Target#
Draft/target is the simplest form of speculative decoding. In this approach, an arbitrary draft model is used to produce draft tokens. It is important to make sure that the draft and target models were trained with the same tokenizer, else the acceptance rate is extremely low and performance is regressed.
from tensorrt_llm.llmapi import DraftTargetDecodingConfig
# Option 1: Use a HuggingFace Hub model ID (auto-downloaded)
speculative_config = DraftTargetDecodingConfig(
max_draft_len=3, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")
# Option 2: Use a local path
# speculative_config = DraftTargetDecodingConfig(
# max_draft_len=3, speculative_model="/path/to/draft_model")
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
EAGLE 3#
The EAGLE 3 algorithm is described in the paper EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test.
TRT-LLM supports a modified version of the algorithm presented in the paper: tree structures for draft sequences are not supported. Instead, each request uses a single sequence of draft tokens with length max_draft_len.
The following draft model checkpoints can be used for EAGLE 3:
Llama 3 variants: use the checkpoints from the authors of the original EAGLE 3 paper.
Llama 4 Maverick: use the checkpoint from the NVIDIA HuggingFace repository.
Other models, including
gpt-oss-120bandQwen3: check out the Speculative Decoding Modules collection from NVIDIA.
from tensorrt_llm.llmapi import Eagle3DecodingConfig
model = "meta-llama/Llama-3.1-8B-Instruct"
speculative_model = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = Eagle3DecodingConfig(
max_draft_len=3,
speculative_model=speculative_model)
llm = LLM(model, speculative_config=speculative_config)
EAGLE 3 can be combined with the Suffix Automaton enhancement for improved acceptance rates on repetitive content. See the SA section below for details.
NGram#
The NGram method is an implementation of this Prompt Lookup Decoding algorithm.
When the NGram algorithm is used, TRT-LLM will maintain a map from token prefixes to candidate draft sequences. For example, the 3-gram [“The “, “ future “, “ is”] could map to the draft sequence [” bright”, “ because”]. The prefixes are token sequences that are extracted from the prompt and the tokens generated by the target model. The NGram pool and matching procedure can be tuned with the following options:
max_draft_len: Maximum draft candidate length.max_matching_ngram_size: Maximum prompt suffix length to match with keys in the pool.is_public_pool: If true, a single ngram pool is shared for all requests. Otherwise, each request has its own ngram pool.is_keep_all: If true, draft candidates will be retained in the pool forever. Otherwise, only the largest draft candidate is retained.is_use_oldest: If true, the oldest draft candidate is always proposed for a given match. Otherwise, the newest draft candidate is used. Only applicable ifis_keep_all == Truebecauseis_keep_all == Falsemeans we’ll only ever have a single value for each key.
from tensorrt_llm.llmapi import NGramDecodingConfig
speculative_config = NGramDecodingConfig(
max_draft_len=3, max_matching_ngram_size=4, is_public_pool=True)
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)
MTP#
MTP is currently only supported by Deepseek. MTP can be tuned with the following configuration options:
max_draft_len: Maximum draft candidate length.num_nextn_predict_layers: Number of MTP modules to use. Currently must matchmax_draft_len.use_relaxed_acceptance_for_thinking: If true, use relaxed decoding for reasoning models in the thinking phase. In this mode, speculation requirements are relaxed for the thinking phase - a draft token may be accepted if it appears in a candidate set constructed withrelaxed_topkandrelaxed_delta.relaxed_topk: The top K tokens are sampled from the target model’s logits to create the initial candidate set for relaxed decoding.relaxed_delta: Used to further filter the top K candidate set for relaxed decoding. We remove tokenstfor whichlog(P(top 1 token)) - log(P(t)) > relaxed_delta.
from tensorrt_llm.llmapi import MTPDecodingConfig
speculative_config = MTPDecodingConfig(
max_draft_len=3, num_nextn_predict_layers=3)
llm = LLM("/path/to/deepseek_model", speculative_config=speculative_config)
MTP can be combined with the Suffix Automaton enhancement for improved acceptance rates on repetitive content. See the SA section below for details.
PARD#
PARD (PARallel Draft) is a target-independent speculative decoding method that predicts all draft tokens in a single forward pass using mask tokens. Unlike MTP or EAGLE 3 which generate drafts one token at a time, PARD produces K draft tokens in parallel.
Reference: PARD: Parallel Drafting for Speculative Decoding
max_draft_len: Maximum draft candidate length.speculative_model: Path or HuggingFace model ID for the PARD draft model.mask_token_id: Token ID used as the mask token for parallel prediction. If not set, it is read from the draft model config.
from tensorrt_llm.llmapi import PARDDecodingConfig
speculative_config = PARDDecodingConfig(
max_draft_len=4, speculative_model="/path/to/pard_model")
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
PARD can be combined with the Suffix Automaton enhancement for improved acceptance rates on repetitive content. See the SA section below for details.
User-provided drafting#
A completely user-defined drafting method can be supplied with a UserProvidedDecodingConfig that includes
max_draft_len: Maximum draft candidate length.drafter: An object of typeDrafterthat implements theprepare_draft_tokensmethod (see Developer Guide 7.)resource_manager: An optionalResourceManagerobject (see Developer Guide 4.)
from tensorrt_llm.llmapi import UserProvidedDecodingConfig
speculative_config = UserProvidedDecodingConfig(
max_draft_len=3, drafter=MyDrafter())
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
Suffix Automaton (SA) Enhancement#
The Suffix Automaton (SA) is a model-free, GPU-based pattern-matching draft enhancer. It finds suffix matches in previously generated tokens and proposes draft tokens when the match is long enough. SA is very accurate when it matches (exact pattern repetition), while neural methods are better for novel content — combining them gives the best of both worlds.
SA can be combined with the following speculative decoding techniques:
MTP (
MTPDecodingConfig)EAGLE 3 (
Eagle3DecodingConfig)PARD (
PARDDecodingConfig)
To enable SA combination, set use_sa_spec=True on the speculative config. The sa_spec_threshold parameter controls the minimum suffix match length required to override the neural draft (default: 4).
from tensorrt_llm.llmapi import Eagle3DecodingConfig
speculative_config = Eagle3DecodingConfig(
max_draft_len=4,
speculative_model="/path/to/eagle3_model",
use_sa_spec=True,
sa_spec_threshold=4)
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
SA can also be used as a standalone speculative decoding technique via SADecodingConfig:
from tensorrt_llm.llmapi import SADecodingConfig
speculative_config = SADecodingConfig(max_draft_len=4)
llm = LLM("/path/to/target_model", speculative_config=speculative_config)
Usage with trtllm-bench and trtllm-serve#
Note
Non-breaking: --config <file.yaml> is the preferred flag for passing a YAML configuration file.
Existing workflows using --extra_llm_api_options <file.yaml> continue to work; it is an equivalent alias.
Speculative decoding options must be specified via --config config.yaml for both trtllm-bench and trtllm-serve. All speculative decoding options can be specified in this YAML file. An additional decoding_type option is used to specify the type of speculation to use. The available options are:
MTPEagle3NGramDraftTargetPARDSA
Note: The PyTorch backend supports only
Eagle3.decoding_type: Eagleis accepted as a backward-compatible alias forEagle3, but EAGLE (v1/v2) draft checkpoints are incompatible.
The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:
# Using a HuggingFace Hub model ID (auto-downloaded)
speculative_config:
decoding_type: Eagle3
max_draft_len: 4
speculative_model: yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# Or using a local path
speculative_config:
decoding_type: Eagle3
max_draft_len: 4
speculative_model: /path/to/draft/model
# SA combination: enable Suffix Automaton enhancement with any supported technique
speculative_config:
decoding_type: Eagle3
max_draft_len: 4
speculative_model: /path/to/draft/model
use_sa_spec: true
sa_spec_threshold: 4
Note
The field name speculative_model_dir can also be used as an alias for speculative_config.speculative_model. For example:
speculative_config:
decoding_type: Eagle3
max_draft_len: 4
speculative_model_dir: /path/to/draft/model