Speculative Decoding#
Source NVIDIA/TensorRT-LLM.
1from typing import Optional
2
3import click
4
5from tensorrt_llm import LLM, SamplingParams
6from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
7 MTPDecodingConfig, NGramDecodingConfig)
8
9prompts = [
10 "What is the capital of France?",
11 "What is the future of AI?",
12]
13
14
15def run_MTP(model: Optional[str] = None):
16 spec_config = MTPDecodingConfig(num_nextn_predict_layers=1,
17 use_relaxed_acceptance_for_thinking=True,
18 relaxed_topk=10,
19 relaxed_delta=0.01)
20
21 llm = LLM(
22 # You can change this to a local model path if you have the model downloaded
23 model=model or "nvidia/DeepSeek-R1-FP4",
24 speculative_config=spec_config,
25 )
26
27 for prompt in prompts:
28 response = llm.generate(prompt, SamplingParams(max_tokens=10))
29 print(response.outputs[0].text)
30
31
32def run_Eagle3():
33 spec_config = EagleDecodingConfig(
34 max_draft_len=3,
35 speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
36 eagle3_one_model=True)
37
38 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
39
40 llm = LLM(
41 model="meta-llama/Llama-3.1-8B-Instruct",
42 speculative_config=spec_config,
43 kv_cache_config=kv_cache_config,
44 )
45
46 for prompt in prompts:
47 response = llm.generate(prompt, SamplingParams(max_tokens=10))
48 print(response.outputs[0].text)
49
50
51def run_ngram():
52 spec_config = NGramDecodingConfig(
53 max_draft_len=3,
54 max_matching_ngram_size=3,
55 is_keep_all=True,
56 is_use_oldest=True,
57 is_public_pool=True,
58 )
59
60 llm = LLM(
61 model="meta-llama/Llama-3.1-8B-Instruct",
62 speculative_config=spec_config,
63 # ngram doesn't work with overlap_scheduler
64 disable_overlap_scheduler=True,
65 )
66
67 for prompt in prompts:
68 response = llm.generate(prompt, SamplingParams(max_tokens=10))
69 print(response.outputs[0].text)
70
71
72@click.command()
73@click.argument("algo",
74 type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
75@click.option("--model",
76 type=str,
77 default=None,
78 help="Path to the model or model name.")
79def main(algo: str, model: Optional[str] = None):
80 algo = algo.upper()
81 if algo == "MTP":
82 run_MTP(model)
83 elif algo == "EAGLE3":
84 run_Eagle3()
85 elif algo == "NGRAM":
86 run_ngram()
87 else:
88 raise ValueError(f"Invalid algorithm: {algo}")
89
90
91if __name__ == "__main__":
92 main()