Generate Text Using Eagle Decoding#

Source NVIDIA/TensorRT-LLM.

 1### Generate Text Using Eagle Decoding
 2
 3from tensorrt_llm import LLM, SamplingParams
 4from tensorrt_llm.llmapi import (LLM, BuildConfig, EagleDecodingConfig,
 5                                 KvCacheConfig, SamplingParams)
 6
 7
 8def main():
 9    # Sample prompts.
10    prompts = [
11        "Hello, my name is",
12        "The president of the United States is",
13        "The capital of France is",
14        "The future of AI is",
15    ]
16    # The end user can customize the sampling configuration with the SamplingParams class
17    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
18
19    # The end user can customize the build configuration with the BuildConfig class
20    build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
21
22    # The end user can customize the kv cache configuration with the KVCache class
23    kv_cache_config = KvCacheConfig(enable_block_reuse=True)
24
25    llm_kwargs = {}
26
27    model = "lmsys/vicuna-7b-v1.3"
28
29    # The end user can customize the eagle decoding configuration by specifying the
30    # speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
31    # greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
32    # with the EagleDecodingConfig class
33
34    speculative_config = EagleDecodingConfig(
35        speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
36        max_draft_len=63,
37        num_eagle_layers=4,
38        max_non_leaves_per_layer=10,
39                            eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
40                                            [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
41                                            [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
42                                            [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
43                                            [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
44                                            [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
45    )
46
47    llm = LLM(model=model,
48              build_config=build_config,
49              kv_cache_config=kv_cache_config,
50              speculative_config=speculative_config,
51              **llm_kwargs)
52
53    outputs = llm.generate(prompts, sampling_params)
54
55    # Print the outputs.
56    for output in outputs:
57        prompt = output.prompt
58        generated_text = output.outputs[0].text
59        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
60
61
62if __name__ == '__main__':
63    main()