Speculative Decoding#
Source NVIDIA/TensorRT-LLM.
1from typing import Optional
2
3import click
4
5from tensorrt_llm import LLM, SamplingParams
6from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig,
7 MTPDecodingConfig, NGramDecodingConfig)
8
9prompts = [
10 "What is the capital of France?",
11 "What is the future of AI?",
12]
13
14
15def run_MTP(model: Optional[str] = None):
16 spec_config = MTPDecodingConfig(use_relaxed_acceptance_for_thinking=True,
17 relaxed_topk=10,
18 relaxed_delta=0.01)
19
20 llm = LLM(
21 # You can change this to a local model path if you have the model downloaded
22 model=model or "nvidia/DeepSeek-R1-FP4",
23 speculative_config=spec_config,
24 )
25
26 for prompt in prompts:
27 response = llm.generate(prompt, SamplingParams(max_tokens=10))
28 print(response.outputs[0].text)
29
30
31def run_Eagle3():
32 spec_config = Eagle3DecodingConfig(
33 max_draft_len=3,
34 speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
35 eagle3_one_model=True)
36
37 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
38
39 llm = LLM(
40 model="meta-llama/Llama-3.1-8B-Instruct",
41 speculative_config=spec_config,
42 kv_cache_config=kv_cache_config,
43 )
44
45 for prompt in prompts:
46 response = llm.generate(prompt, SamplingParams(max_tokens=10))
47 print(response.outputs[0].text)
48
49
50def run_ngram():
51 spec_config = NGramDecodingConfig(
52 max_draft_len=3,
53 max_matching_ngram_size=3,
54 is_keep_all=True,
55 is_use_oldest=True,
56 is_public_pool=True,
57 )
58
59 llm = LLM(
60 model="meta-llama/Llama-3.1-8B-Instruct",
61 speculative_config=spec_config,
62 # ngram doesn't work with overlap_scheduler
63 disable_overlap_scheduler=True,
64 )
65
66 for prompt in prompts:
67 response = llm.generate(prompt, SamplingParams(max_tokens=10))
68 print(response.outputs[0].text)
69
70
71@click.command()
72@click.argument("algo",
73 type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
74@click.option("--model",
75 type=str,
76 default=None,
77 help="Path to the model or model name.")
78def main(algo: str, model: Optional[str] = None):
79 algo = algo.upper()
80 if algo == "MTP":
81 run_MTP(model)
82 elif algo == "EAGLE3":
83 run_Eagle3()
84 elif algo == "NGRAM":
85 run_ngram()
86 else:
87 raise ValueError(f"Invalid algorithm: {algo}")
88
89
90if __name__ == "__main__":
91 main()