N-Gram Speculative Decoding in TensorRT‑LLM#

N-Gram speculative decoding leverages the natural repetition in many LLM workloads. It splits previously seen text into configurable (key, value) n‑gram pairs and, during generation, swiftly proposes draft tokens by matching the current key against n-gram pools in memory.

In this blog, we introduce design choices in TensorRT‑LLM’s N-Gram speculative decoding algorithm, share our experimental results of performance gains, and explain N-Gram’s low barrier to adoption by deriving a simple heuristic to enable it.

Highlights#

  • Fast & lightweight. N‑Gram algorithm runs on the host with low overhead.

  • Real speed‑ups at low concurrency. N-Gram achieves accepted length of 1.37 and more on average running on the Magpie-Align/Magpie-Llama-3.1-Pro-MT-300K-Filtered dataset (link) with the first round of conversation. Results in 10-60% E2E runtime speed-up.

  • Works even better with multi-turn conversations. With the cache built up during the first round of conversation, the second round achieved a higher accepted length of 1.66 and a 30–90% E2E runtime speed-up.

  • Excels on tasks with natural repetition like translation. With the translation dataset, the accepted length can exceed 4.0. New requests can benefit from cache generated by previous requests with similar tasks and reduce latency by up to 70%.

  • Heuristic “just works”. Set spec_decode_algo=AUTO to enable N‑Gram by default.

    • This policy adds less than 15% overhead to iteration latency yet offers nets double‑digit end‑to‑end speed‑ups.


Table of Contents#


Background & Motivation#

Speculative decoding drafts several tokens, verifies them on the model, and keeps the accepted prefix at each iteration of the generation loop. An N‑Gram proposer can generate drafts without an extra LLM or model heads, making it a low-cost way to improve serving latency. Average accepted length (AL) is ~1.3 in generic chat (MT‑Bench, Magpie with the first round of conversation) and can exceed 4.0 on highly repetitive data like a translation task.


Algorithm & Complexity#

NGramDecodingConfig in TensorRT-LLM:

spec_config = NGramDecodingConfig(
    max_draft_len = v ,             # max length of draft tokens
    max_matching_ngram_size  = k ,  # max length for keys
    is_keep_all   = True,           # Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
    is_use_oldest = True,           # Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
    is_public_pool= True,           # Whether to use a common pool for all requests, or the pool is private for each request if False.
)
  • Processing New Request ‑ scan input sequence once to create N-Gram key-value pairs for the new sequence.

    With max_matching_ngram_size = 3, max_draft_len = 5, input_sequence_len=8, Figure 1 shows the 18 new key-value pairs added to the cache pool.

    The number of cache pairs grows proportionally to the product of the maximum key length and the input sequence length.

Figure 1. Request initial scan

  • Per‑token update ‑ slide window and update cache pool

    We now have a new token in the sequence. Figure 2 shows how the cache pool is updated accordingly. For existing key-value pairs whose value length is less than the max_draft_len, the new token can be appended. The new token can be the value to new keys as well, which are marked as new pairs in the graph.

    The number of cache update and addition is approximately the product of max_draft_len and max_matching_ngram_size, which is a constant for fixed parameters.

Figure 2. Per-token update

  • Lookup ‑ construct the last k tokens as the key and propose draft tokens as its value.

    If is_public_pool= True, a global pool is shared by all the requests. If is_public_pool= False, each request will have its own cache pool.

    The lookup time is amortized constant time, but extra latency can be observed once the dictionary outgrows the CPU’s fastest cache.

  • Verification ‑ Verify proposed draft tokens.

    Run the target model with verification_batch =  original_batch × (v+1); There will always be at least one new token from verification even if no draft token is correct. In this case, the accepted length (AL) will be 1. In addition, if w out of the v draft tokens are accepted, the accepted length (AL) will be w+1.

    The iteration latency grows as the verification batch becomes larger than the original batch. As we increase max_draft_len (v), the overhead grows even more. Therefore, speculative decoding tends to work best with small batch sizes and low concurrency.


Performance Study#

Experimental Setup#

  • Hardware: 8 × B200 GPUs (Blackwell)

  • Model: Llama‑4‑Scout‑17B‑16E, FP8 weights

  • Tensor Parallel: 8


Case 1 with Conversation Dataset#

In this experiment, we used Magpie-Align/Magpie-Llama-3.1-Pro-MT-300K-Filtered dataset (link) which is a conversational dataset with two turns. The user question on the second turn is related to the previous question and answer.

The first turn only data represents a general conversation with no context. The repetition comes from the conversational structure and correlation between the question and answers.

On the second turn, the global cache already has the knowledge of the previous conversation. The additional repetitions come from the correlation between the second answer and previous conversation.

Speed-up for the First Turn#

For batch size of 1, 4 and 32, we configure the max_batch_size of the model accordingly. We will run 20 * batch_size number of requests with the model and compare the E2E runtime with and without N-Gram speculative decoding.

Figure 3. First Turn Speed-up

We can see that N-Gram can provide speed-ups for batch sizes up to 32 and works best with a single batch. The main overhead with larger batch sizes is the verification cost. With batch size being 1 and 4, k = 3, v = 5 is the best N-Gram configuration. With batch size = 32, k = 5, v = 3 is the best configuration since the verification batch size is smaller and the overhead is less.

Effect of Multi-turn conversation#

The table below shows the accepted length (AL) derived from 3000 sampled conversations using different N-Gram configurations.

k

v

AL Turn1

AL Turn2

3

5

1.37

1.66

5

5

1.40

1.77

5

3

1.37

1.66

Figure 4 shows the distribution of accepted length (AL) with k=3, v=5. When AL=1, it means none of the draft tokens are accepted. AL=6 means all the drafts are accepted.

Figure 4. Accepted draft token length distribution

In Figure 5, for each iteration, we plot the average of accepted length (AL) for each request. Transparency is calculated according to the number of requests scheduled on that iteration and normalized by the max capacity among all iterations. If fewer requests are scheduled, the dot is more transparent.

Figure 5. AL over iteration

Figure 6 shows the speed-up with N-Gram speculative decoding for the second turn of conversation only. N-Gram with k = 3, v = 5 delivers 96.13% of speed-up with single batch and 63.99% of speed-up with batch size 4. With batch size 32 and N-Gram k = 5, v = 3, the speed up is 33.06%.

Figure 6. Second Turn Speed-up

We can draw the conclusion that:

N-Gram speculative decoding improves the runtime of conversational workloads, especially when the conversation has multiple rounds.


Case 2 with Translation Dataset#

From the conversational dataset, we learned that N-Gram takes advantage of structural repetition. In the second case study, we unleash the potential of N-Gram by testing it with a translation dataset that exhibits natural repetition in both context and language. The dataset has a single turn, with prompts in English asking for translations into other languages.

The table below shows the accepted length (AL) measured with 4000 requests. AL grows with increasing max_draft_len (v) and the trend extends beyond max_draft_len (v) = 23 in our measurements.

1

2

3

4

5

6

7

8

9

10

11

12

13

14

k

3

5

3

5

3

5

3

5

3

5

5

5

5

5

v

7

7

9

9

11

11

13

13

15

15

17

19

21

23

AL

3.44

3.62

3.708

3.925

3.878

4.092

4.079

4.214

4.198

4.36

4.43

4.55

4.59

4.73

Figure 7 shows properties of accepted length with N-Gram configured with k = 5, v = 7.

From the pie chart on the left, among the seven draft tokens proposed by N-Gram, roughly one-third of the cases accept none of the drafts, which correspond to AL=1, while another one-third accept all of them, which correspond to AL=8. Compared with the similar pie chart in Case 1 Figure 4, the ratio is very high. The graph on the right plots the accepted length at each iteration with five random requests.

Figure 7. Accepted Tokens from Drafts

Auto‑Enablement with Heuristic#

A big part of N-Gram’s appeal is the simplicity of deployment. It does not need a carefully selected draft model or additional training of model heads to benefit from speculative decoding. It can be enabled by the serving software to take advantage of the strong performance of the N-Gram speculative decoding algorithm.

From our experiments, we propose a simple batch-aware policy that keeps iteration overhead under control and yields ~15 % end-to-end speed-up at low to mid concurrency. Give it a try by setting spec_decode_algo=AUTO!