Speculative Decoding#
Source NVIDIA/TensorRT-LLM.
1from typing import Optional
2
3import click
4
5from tensorrt_llm import LLM, SamplingParams
6from tensorrt_llm.llmapi import (EagleDecodingConfig, MTPDecodingConfig,
7 NGramDecodingConfig)
8
9prompts = [
10 "What is the capital of France?",
11 "What is the future of AI?",
12]
13
14
15def run_MTP(model: Optional[str] = None):
16 spec_config = MTPDecodingConfig(num_nextn_predict_layers=1,
17 use_relaxed_acceptance_for_thinking=True,
18 relaxed_topk=10,
19 relaxed_delta=0.01)
20
21 llm = LLM(
22 # You can change this to a local model path if you have the model downloaded
23 model=model or "nvidia/DeepSeek-R1-FP4",
24 speculative_config=spec_config,
25 )
26
27 for prompt in prompts:
28 response = llm.generate(prompt, SamplingParams(max_tokens=10))
29 print(response.outputs[0].text)
30
31
32def run_Eagle3():
33 spec_config = EagleDecodingConfig(
34 max_draft_len=3,
35 speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
36 eagle3_one_model=True)
37
38 llm = LLM(
39 model="meta-llama/Llama-3.1-8B-Instruct",
40 speculative_config=spec_config,
41 )
42
43 for prompt in prompts:
44 response = llm.generate(prompt, SamplingParams(max_tokens=10))
45 print(response.outputs[0].text)
46
47
48def run_ngram():
49 spec_config = NGramDecodingConfig(
50 max_draft_len=3,
51 max_matching_ngram_size=3,
52 is_keep_all=True,
53 is_use_oldest=True,
54 is_public_pool=True,
55 )
56
57 llm = LLM(
58 model="meta-llama/Llama-3.1-8B-Instruct",
59 speculative_config=spec_config,
60 # ngram doesn't work with overlap_scheduler
61 disable_overlap_scheduler=True,
62 )
63
64 for prompt in prompts:
65 response = llm.generate(prompt, SamplingParams(max_tokens=10))
66 print(response.outputs[0].text)
67
68
69@click.command()
70@click.argument("algo",
71 type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
72@click.option("--model",
73 type=str,
74 default=None,
75 help="Path to the model or model name.")
76def main(algo: str, model: Optional[str] = None):
77 algo = algo.upper()
78 if algo == "MTP":
79 run_MTP(model)
80 elif algo == "EAGLE3":
81 run_Eagle3()
82 elif algo == "NGRAM":
83 run_ngram()
84 else:
85 raise ValueError(f"Invalid algorithm: {algo}")
86
87
88if __name__ == "__main__":
89 main()