Speculative Decoding#

Source NVIDIA/TensorRT-LLM.

 1from typing import Optional
 2
 3import click
 4
 5from tensorrt_llm import LLM, SamplingParams
 6from tensorrt_llm.llmapi import (Eagle3DecodingConfig, KvCacheConfig,
 7                                 MTPDecodingConfig, NGramDecodingConfig)
 8
 9prompts = [
10    "What is the capital of France?",
11    "What is the future of AI?",
12]
13
14
15def run_MTP(model: Optional[str] = None):
16    spec_config = MTPDecodingConfig(use_relaxed_acceptance_for_thinking=True,
17                                    relaxed_topk=10,
18                                    relaxed_delta=0.01)
19
20    llm = LLM(
21        # You can change this to a local model path if you have the model downloaded
22        model=model or "nvidia/DeepSeek-R1-FP4",
23        speculative_config=spec_config,
24    )
25
26    for prompt in prompts:
27        response = llm.generate(prompt, SamplingParams(max_tokens=10))
28        print(response.outputs[0].text)
29
30
31def run_Eagle3():
32    spec_config = Eagle3DecodingConfig(
33        max_draft_len=3,
34        speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
35        eagle3_one_model=True)
36
37    kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
38
39    llm = LLM(
40        model="meta-llama/Llama-3.1-8B-Instruct",
41        speculative_config=spec_config,
42        kv_cache_config=kv_cache_config,
43    )
44
45    for prompt in prompts:
46        response = llm.generate(prompt, SamplingParams(max_tokens=10))
47        print(response.outputs[0].text)
48
49
50def run_ngram():
51    spec_config = NGramDecodingConfig(
52        max_draft_len=3,
53        max_matching_ngram_size=3,
54        is_keep_all=True,
55        is_use_oldest=True,
56        is_public_pool=True,
57    )
58
59    llm = LLM(
60        model="meta-llama/Llama-3.1-8B-Instruct",
61        speculative_config=spec_config,
62        # ngram doesn't work with overlap_scheduler
63        disable_overlap_scheduler=True,
64    )
65
66    for prompt in prompts:
67        response = llm.generate(prompt, SamplingParams(max_tokens=10))
68        print(response.outputs[0].text)
69
70
71@click.command()
72@click.argument("algo",
73                type=click.Choice(["MTP", "EAGLE3", "DRAFT_TARGET", "NGRAM"]))
74@click.option("--model",
75              type=str,
76              default=None,
77              help="Path to the model or model name.")
78def main(algo: str, model: Optional[str] = None):
79    algo = algo.upper()
80    if algo == "MTP":
81        run_MTP(model)
82    elif algo == "EAGLE3":
83        run_Eagle3()
84    elif algo == "NGRAM":
85        run_ngram()
86    else:
87        raise ValueError(f"Invalid algorithm: {algo}")
88
89
90if __name__ == "__main__":
91    main()