Generate Text Using Medusa Decoding

Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_medusa_decoding.py.

 1### Generate Text Using Medusa Decoding
 2
 3from tensorrt_llm import LLM, SamplingParams
 4from tensorrt_llm.llmapi import (LLM, BuildConfig, KvCacheConfig,
 5                                 MedusaDecodingConfig, SamplingParams)
 6from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
 7
 8
 9def main():
10    # Sample prompts.
11    prompts = [
12        "Hello, my name is",
13        "The president of the United States is",
14        "The capital of France is",
15        "The future of AI is",
16    ]
17    # The end user can customize the sampling configuration with the SamplingParams class
18    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
19
20    # The end user can customize the build configuration with the BuildConfig class
21    build_config = BuildConfig(
22        max_batch_size=1,
23        max_seq_len=1024,
24        max_draft_len=63,
25        speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)
26
27    # The end user can customize the kv cache configuration with the KVCache class
28    kv_cache_config = KvCacheConfig(enable_block_reuse=True)
29
30    # The end user can customize the medusa decoding configuration by specifying the
31    # medusa heads num and medusa choices with the MedusaDecodingConfig class
32    speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
33                            medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
34                                            [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
35                                            [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
36                                            [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
37                                            [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
38                                             [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
39      )
40    llm = LLM(model="lmsys/vicuna-7b-v1.3",
41              speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
42              build_config=build_config,
43              kv_cache_config=kv_cache_config,
44              speculative_config=speculative_config)
45
46    outputs = llm.generate(prompts, sampling_params)
47
48    # Print the outputs.
49    for output in outputs:
50        prompt = output.prompt
51        generated_text = output.outputs[0].text
52        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53
54
55if __name__ == '__main__':
56    main()