Speculative Decoding#

Source NVIDIA/TensorRT-LLM.

 1from typing import Optional
 2
 3import click
 4
 5from tensorrt_llm import LLM, SamplingParams
 6from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
 7                                 MTPDecodingConfig, NGramDecodingConfig)
 8
 9prompts = [
10    "What is the capital of France?",
11    "What is the future of AI?",
12]
13
14
15def run_MTP(model: Optional[str] = None):
16    spec_config = MTPDecodingConfig(num_nextn_predict_layers=1,
17                                    use_relaxed_acceptance_for_thinking=True,
18                                    relaxed_topk=10,
19                                    relaxed_delta=0.01)
20
21    llm = LLM(
22        # You can change this to a local model path if you have the model downloaded
23        model=model or "nvidia/DeepSeek-R1-FP4",
24        speculative_config=spec_config,
25    )
26
27    for prompt in prompts:
28        response = llm.generate(prompt, SamplingParams(max_tokens=10))
29        print(response.outputs[0].text)
30
31
32def run_Eagle3():
33    spec_config = EagleDecodingConfig(
34        max_draft_len=3,
35        speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
36        eagle3_one_model=True)
37
38    kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
39
40    llm = LLM(
41        model="meta-llama/Llama-3.1-8B-Instruct",
42        speculative_config=spec_config,
43        kv_cache_config=kv_cache_config,
44    )
45
46    for prompt in prompts:
47        response = llm.generate(prompt, SamplingParams(max_tokens=10))
48        print(response.outputs[0].text)
49
50
51def run_ngram():
52    spec_config = NGramDecodingConfig(
53        max_draft_len=3,
54        max_matching_ngram_size=3,
55        is_keep_all=True,
56        is_use_oldest=True,
57        is_public_pool=True,
58    )
59
60    llm = LLM(
61        model="meta-llama/Llama-3.1-8B-Instruct",
62        speculative_config=spec_config,
63        # ngram doesn't work with overlap_scheduler
64        disable_overlap_scheduler=True,
65    )
66
67    for prompt in prompts:
68        response = llm.generate(prompt, SamplingParams(max_tokens=10))
69        print(response.outputs[0].text)
70
71
72@click.command()
73@click.argument("algo",
74                type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
75@click.option("--model",
76              type=str,
77              default=None,
78              help="Path to the model or model name.")
79def main(algo: str, model: Optional[str] = None):
80    algo = algo.upper()
81    if algo == "MTP":
82        run_MTP(model)
83    elif algo == "EAGLE3":
84        run_Eagle3()
85    elif algo == "NGRAM":
86        run_ngram()
87    else:
88        raise ValueError(f"Invalid algorithm: {algo}")
89
90
91if __name__ == "__main__":
92    main()