Speculative Decoding#

There are two flavors of speculative decoding currently supported in the PyTorch backend:

  • The “one model” implementation – a variant which inserts a drafter directly into the model code as a submodule.

  • The “two model” implementation – a variant which produces draft tokens in the PyExecutor. The draft tokens are attached to requests before they are passed into the target model’s ModelEngine.

In general, the one model implementation is faster. It’s able to achieve better performance in extreme low latency scenarios because it can launch the entire drafting loop as a single CUDA graph. The trade off is flexibility. The one model implementation does not support dynamic draft lengths. Additionally, only a subset of models/speculative decoding algorithms support the one model implementation. The table below enumerates all of the algorithm/model combinations that are supported.

Speculative Decoding Algorithm

Model

EAGLE 3

Llama 4 Maverick

MTP

Deepseek V3/R1

EAGLE-style MTP

Deepseek V3/R1

The two model implementation supports the following speculative decoding algorithms:

Speculative Decoding Algorithm

Model

EAGLE 3

Llama 4 Maverick, Llama 3.1 8B, Llama 3.3 70B

Draft/target

All models

NGram

All models

User-provided

All models

Quick Start#

For all speculation algorithms, when speculation is enabled, a single sequence of draft tokens with length max_draft_len is created for every request. There is currently no way to dynamically disable speculation, thus speed ups are only observable at low batch sizes.

Draft/Target#

Draft/target is the simplest form of speculative decoding. In this approach, an arbitrary draft model is used to produce draft tokens. It is important to make sure that the draft and target models were trained with the same tokenizer, else the acceptance rate is extremely low and performance is regressed.

from tensorrt_llm.llmapi import DraftTargetDecodingConfig

speculative_config = DraftTargetDecodingConfig(
    max_draft_len=3, speculative_model="/path/to/draft_model")

llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)

EAGLE 3#

The EAGLE 3 algorithm is described in the paper EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test. TRT-LLM supports a modified version of the algorithm presented in the paper: tree structures for draft sequences are not supported. Instead, each request uses a single sequence of draft tokens with length max_draft_len.

The following draft model checkpoints can be used for EAGLE 3:

from tensorrt_llm.llmapi import EagleDecodingConfig

# Enable to use the faster one-model implementation for Llama 4.
eagle3_one_model = False

speculative_config = EagleDecodingConfig(
    max_draft_len=3, speculative_model="/path/to/draft_model", eagle3_one_model=eagle3_one_model)

# Only need to disable overlap scheduler if eagle3_one_model is False.
llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)

NGram#

The NGram method is an implementation of this Prompt Lookup Decoding algorithm.

When the NGram algorithm is used, TRT-LLM will maintain a map from token prefixes to candidate draft sequences. For example, the 3-gram [“The “, “ future “, “ is”] could map to the draft sequence [” bright”, “ because”]. The prefixes are token sequences that are extracted from the prompt and the tokens generated by the target model. The NGram pool and matching procedure can be tuned with the following options:

  • max_draft_len: Maximum draft candidate length.

  • max_matching_ngram_size: Maximum prompt suffix length to match with keys in the pool.

  • is_public_pool: If true, a single ngram pool is shared for all requests. Otherwise, each request has its own ngram pool.

  • is_keep_all: If true, draft candidates will be retained in the pool forever. Otherwise, only the largest draft candidate is retained.

  • is_use_oldest: If true, the oldest draft candidate is always proposed for a given match. Otherwise, the newest draft candidate is used. Only applicable if is_keep_all == True because is_keep_all == False means we’ll only ever have a single value for each key.

from tensorrt_llm.llmapi import NGramDecodingConfig

speculative_config = NGramDecodingConfig(
    max_draft_len=3, max_matching_ngram_size=4, is_public_pool=True)

llm = LLM("/path/to/target_model", speculative_config=speculative_config, disable_overlap_scheduler=True)

MTP#

MTP is currently only supported by Deepseek. MTP can be tuned with the following configuration options:

  • max_draft_len: Maximum draft candidate length.

  • num_nextn_predict_layers: Number of MTP modules to use. Currently must match max_draft_len.

  • use_relaxed_acceptance_for_thinking: If true, use relaxed decoding for reasoning models in the thinking phase. In this mode, speculation requirements are relaxed for the thinking phase - a draft token may be accepted if it appears in a candidate set constructed with relaxed_topk and relaxed_delta.

  • relaxed_topk: The top K tokens are sampled from the target model’s logits to create the initial candidate set for relaxed decoding.

  • relaxed_delta: Used to further filter the top K candidate set for relaxed decoding. We remove tokens t for which log(P(top 1 token)) - log(P(t)) > relaxed_delta.

from tensorrt_llm.llmapi import MTPDecodingConfig

speculative_config = MTPDecodingConfig(
    max_draft_len=3, num_nextn_predict_layers=3)

llm = LLM("/path/to/deepseek_model", speculative_config=speculative_config)

User-provided drafting#

A completely user-defined drafting method can be supplied with a UserProvidedDecodingConfig that includes

  • max_draft_len: Maximum draft candidate length.

  • drafter: An object of type Drafter that implements the prepare_draft_tokens method (see Developer Guide 7.)

  • resource_manager: An optional ResourceManager object (see Developer Guide 4.)

from tensorrt_llm.llmapi import UserProvidedDecodingConfig

speculative_config = UserProvidedDecodingConfig(
    max_draft_len=3, drafter=MyDrafter())

llm = LLM("/path/to/target_model", speculative_config=speculative_config)

Usage with trtllm-bench and trtllm-serve#

Speculative decoding options must be specified via --extra_llm_api_options config.yaml for both trtllm-bench and trtllm-serve. All speculative decoding options can be specified in this YAML file. An additional decoding_type option is used to specify the type of speculation to use. The available options are:

  • MTP

  • Eagle (for EAGLE 3)

  • NGram

  • DraftTarget

The rest of the argument names/valid values are the same as in their corresponding configuration class described in the Quick Start section. For example, a YAML configuration could look like this:

disable_overlap_scheduler: true
speculative_config:
  decoding_type: Eagle
  max_draft_len: 4
  speculative_model: /path/to/draft/model

Developer Guide#

This section describes the components of a speculative decoding algorithm. All of the interfaces are defined in _torch/speculative/interface.py.

  1. SpeculativeDecodingMode: this is a simple IntEnum, one for each supported algorithm. There are a few nontrivial methods, however.

  • needs_kv_cache_rewind. See “KV Cache Rewind” below. In general, this is true for all two model speculative decoding algorithms.

  • extend_ctx: If true, the speculative decoding dispatches requests with py_draft_tokens attached to them to the prefill version of the attention kernels. This usually needs to be true. The exception is when you’re on Blackwell using the TensorRT LLM attention backend. In that case, use the generation kernels for better performance. This optimized kernel has one limitation; all draft lengths must be the same (or padding must be used) in this case.

These may be refactored in the future to reduce the difficulty of adding a new speculative decoding algorithm. extend_ctx in particular is problematic. Ideally, we would move all of the kernel dispatching logic to a lower level of abstraction.

  1. SpecMetadata: Defines all metadata that should be passed to the model during the forward pass to facilitate speculative decoding. Each speculative decoding algorithm defines a subclass of SpecMetadata. Similar to AttentionMetadata, each CUDAGraphRunner owns its own SpecMetadata, and CUDA-graph compatible SpecMetadata objects may be created by invoking create_cuda_graph_metadata(batch_size). SpecMetadata has many fields. Many of them are exclusively used by the one model implementation. For the two model implementation, the main purpose of SpecMetadata is to facilitate the capture of hidden states. In EAGLE 3, we need to capture hidden states from the target model to use as draft model inputs. The SpecMetadata stores a list of layers to capture and the model calls maybe_capture_hidden_states(layer_id, hidden_states, residual) during its forward pass. If the layer ID is in the list of layers to capture, the hidden states are saved. For CUDA graph compatibility, these may be saved in pre-allocated buffers.

SpecMetadata is derived from a SpecConfig object in _torch/speculative/utils.py. There are a few other optional components created in this file too:

  1. ResourceManager: Create a custom resource manager to prepare and free resources before and after target forward passes; see the section on ResourceManager in arch.md. This is used by the n-gram method to manage its pool. The one model implementation also uses ResourceManagers to manage hidden states.

  2. Sampler: Each speculative decoding algorithm can optionally create its own sampler. This is mostly used by the one model implementation. The default TorchSampler is used as a fallback if no custom sampler is provided. EAGLE 3 two model also has a simple custom decoder to handle differences in the draft/target model vocab sizes.

  3. Worker: This is exclusive to the one-model implementation. The Worker is the object that gets injected into the target model as a submodule.

  4. Drafter: All of the logic required to actually produce draft tokens should be implemented in a Drafter subclass. There is a single abstract method, prepare_draft_tokens. It takes a set of requests (a ScheduledRequests object) and returns nothing. The PyExecutor expects draft tokens to be attached to the py_draft_tokens field of request that speculation is to be done for.

Two Model Speculative Decoding Architecture#

Two-model based speculation implementations do not support overlap scheduler. It will be disabled automatically.

In this approach, there are two new steps to the PyExecutor’s _executor_loop.

  • _prepare_draft_requests

  • _prepare_draft_tokens

_prepare_draft_requests#

This stage occurs for all speculative decoding algorithms before scheduling. The purpose of this stage is to make the KV cache and scheduler aware of the fact that speculative decoding will occur. Draft tokens take up extra KV cache pages and count towards the executor’s max_num_tokens limit. Thus, we need a way to tell the scheduler that drafting will occur before we do the scheduling.

To achieve this, we simply attach the maximum number of draft tokens to each request. The scheduler and KV cache manager will automatically account for tokens attached to the py_draft_tokens attribute.

for req in self.active_requests:
    req.py_draft_tokens = [0] * max_draft_len

_prepare_draft_tokens#

This stage occurs after scheduling and KV cache allocation. The purpose of this stage is to attach draft tokens to the py_draft_tokens attribute. This occurs by calling self.drafter.prepare_draft_tokens; each speculative decoding algorithm should have a concrete instance of the Drafter class associated with it that defines the drafting logic.

In addition to producing all “real” draft tokens, _prepare_draft_tokens currently must also pad all py_draft_tokens to the maximum draft length. This is a CUDA graph limitation - the target model captures its CUDA graphs using the maximum number of draft tokens on each request.

Verification and Sampling#

Once the draft tokens are obtained, the target model runs a forward pass through the usual flow. Everything is the same, except that the logits for all the draft tokens are returned and passed to the sampler.

Currently, only greedy sampling is supported for speculative decoding. A draft token is accepted if matches the previously decoded token exactly. For example, suppose there is a generation request [t, d1, d2, d3], where d1, d2, and d3 are drat tokens. Suppose the token after t is d1 (determined with the argmax of the logits). d1 is then accepted. If the token after d1 is d2, then d2 can be accepted. And so on until draft tokens cannot be accepted anymore.

KV Cache Rewind#

KV cache space allocated to rejected tokens is freed before the next iteration. This is achieved by setting the request.py_rewind_len attribute to num_draft_tokens_allocated - num_accepted_tokens. The pages are freed as part of the resource_manager.free_resources routine.

The purpose of KV cache rewind is to avoid complicated page reuse logic in the KV cache manager’s prepare_resources function. In practice, this is very cheap since the blocks are just marked as available; no memory is actually freed.