Generate Text Using Medusa Decoding

Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_medusa_decoding.py.

 1### Generate Text Using Medusa Decoding
 2import argparse
 3from pathlib import Path
 4
 5from tensorrt_llm import LLM, SamplingParams
 6from tensorrt_llm.llmapi import (LLM, BuildConfig, KvCacheConfig,
 7                                 MedusaDecodingConfig, SamplingParams)
 8from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
 9
10
11def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
12    # Sample prompts.
13    prompts = [
14        "Hello, my name is",
15        "The president of the United States is",
16        "The capital of France is",
17        "The future of AI is",
18    ]
19    # The end user can customize the sampling configuration with the SamplingParams class
20    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
21
22    # The end user can customize the build configuration with the BuildConfig class
23    build_config = BuildConfig(
24        max_batch_size=1,
25        max_seq_len=1024,
26        max_draft_len=63,
27        speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)
28
29    # The end user can customize the kv cache configuration with the KVCache class
30    kv_cache_config = KvCacheConfig(enable_block_reuse=True)
31
32    llm_kwargs = {}
33
34    if use_modelopt_ckpt:
35        # This is a Llama-3.1-8B combined with Medusa heads provided by TensorRT Model Optimizer.
36        # Both the base model (except lm_head) and Medusa heads have been quantized in FP8.
37        model = model_dir or "nvidia/Llama-3.1-8B-Medusa-FP8"
38
39        # ModelOpt ckpt uses 3 Medusa heads
40        speculative_config = MedusaDecodingConfig(num_medusa_heads=3,
41                            medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
42                                [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
43                                    [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
44                                        [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
45                                            [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
46        )
47    else:
48        # In this path, base model and Medusa heads are stored and loaded separately.
49        model = "lmsys/vicuna-7b-v1.3"
50        speculative_model = "FasterDecoding/medusa-vicuna-7b-v1.3"
51
52        # The end user can customize the medusa decoding configuration by specifying the
53        # medusa heads num and medusa choices with the MedusaDecodingConfig class
54        speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
55                                medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
56                                                [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
57                                                [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
58                                                [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
59                                                [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
60                                                [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
61        )
62
63        llm_kwargs = {"speculative_model": speculative_model}
64
65    llm = LLM(model=model,
66              build_config=build_config,
67              kv_cache_config=kv_cache_config,
68              speculative_config=speculative_config,
69              **llm_kwargs)
70
71    outputs = llm.generate(prompts, sampling_params)
72
73    # Print the outputs.
74    for output in outputs:
75        prompt = output.prompt
76        generated_text = output.outputs[0].text
77        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
78
79
80if __name__ == '__main__':
81    parser = argparse.ArgumentParser(
82        description="Generate text using Medusa decoding.")
83    parser.add_argument(
84        '--use_modelopt_ckpt',
85        action='store_true',
86        help="Use FP8-quantized checkpoint from TensorRT Model Optimizer.")
87    # TODO: remove this arg after ModelOpt ckpt is public on HF
88    parser.add_argument('--model_dir', type=Path, default=None)
89    args = parser.parse_args()
90
91    run_medusa_decoding(args.use_modelopt_ckpt, args.model_dir)