Generate Text Using Eagle2 Decoding#

Source NVIDIA/TensorRT-LLM.

 1### Generate Text Using Eagle2 Decoding
 2
 3from tensorrt_llm import LLM, SamplingParams
 4from tensorrt_llm.llmapi import (LLM, EagleDecodingConfig, KvCacheConfig,
 5                                 SamplingParams)
 6
 7
 8def main():
 9    # Sample prompts.
10    prompts = [
11        "Hello, my name is",
12        "The president of the United States is",
13        "The capital of France is",
14        "The future of AI is",
15    ]
16    # The end user can customize the sampling configuration with the SamplingParams class
17    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
18
19    # The end user can customize the kv cache configuration with the KVCache class
20    kv_cache_config = KvCacheConfig(enable_block_reuse=True)
21
22    llm_kwargs = {}
23
24    model = "lmsys/vicuna-7b-v1.3"
25
26    # The end user can customize the eagle decoding configuration by specifying the
27    # speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
28    # greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
29    # with the EagleDecodingConfig class
30
31    speculative_config = EagleDecodingConfig(
32        speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
33        max_draft_len=63,
34        num_eagle_layers=4,
35        max_non_leaves_per_layer=10,
36        use_dynamic_tree=True,
37        dynamic_tree_max_topK=10)
38
39    llm = LLM(model=model,
40              kv_cache_config=kv_cache_config,
41              speculative_config=speculative_config,
42              max_batch_size=1,
43              max_seq_len=1024,
44              **llm_kwargs)
45
46    outputs = llm.generate(prompts, sampling_params)
47
48    # Print the outputs.
49    for output in outputs:
50        prompt = output.prompt
51        generated_text = output.outputs[0].text
52        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53
54
55if __name__ == '__main__':
56    main()