Generate Text Using Medusa Decoding#
Source NVIDIA/TensorRT-LLM.
1### Generate Text Using Medusa Decoding
2import argparse
3from pathlib import Path
4
5from tensorrt_llm import LLM, SamplingParams
6from tensorrt_llm.llmapi import (LLM, BuildConfig, KvCacheConfig,
7 MedusaDecodingConfig, SamplingParams)
8
9
10def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
11 # Sample prompts.
12 prompts = [
13 "Hello, my name is",
14 "The president of the United States is",
15 "The capital of France is",
16 "The future of AI is",
17 ]
18 # The end user can customize the sampling configuration with the SamplingParams class
19 sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
20
21 # The end user can customize the build configuration with the BuildConfig class
22 build_config = BuildConfig(
23 max_batch_size=1,
24 max_seq_len=1024,
25 )
26
27 # The end user can customize the kv cache configuration with the KVCache class
28 kv_cache_config = KvCacheConfig(enable_block_reuse=True)
29
30 llm_kwargs = {}
31
32 if use_modelopt_ckpt:
33 # This is a Llama-3.1-8B combined with Medusa heads provided by TensorRT Model Optimizer.
34 # Both the base model (except lm_head) and Medusa heads have been quantized in FP8.
35 model = model_dir or "nvidia/Llama-3.1-8B-Medusa-FP8"
36
37 # ModelOpt ckpt uses 3 Medusa heads
38 speculative_config = MedusaDecodingConfig(
39 max_draft_len=63,
40 num_medusa_heads=3,
41 medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
42 [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
43 [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
44 [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
45 [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
46 )
47 else:
48 # In this path, base model and Medusa heads are stored and loaded separately.
49 model = "lmsys/vicuna-7b-v1.3"
50
51 # The end user can customize the medusa decoding configuration by specifying the
52 # speculative_model, max_draft_len, medusa heads num and medusa choices
53 # with the MedusaDecodingConfig class
54 speculative_config = MedusaDecodingConfig(
55 speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
56 max_draft_len=63,
57 num_medusa_heads=4,
58 medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
59 [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
60 [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
61 [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
62 [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
63 [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
64 )
65
66 # Add 'tensor_parallel_size=2' if using ckpt for
67 # a larger model like nvidia/Llama-3.1-70B-Medusa.
68 llm = LLM(model=model,
69 build_config=build_config,
70 kv_cache_config=kv_cache_config,
71 speculative_config=speculative_config,
72 **llm_kwargs)
73
74 outputs = llm.generate(prompts, sampling_params)
75
76 # Print the outputs.
77 for output in outputs:
78 prompt = output.prompt
79 generated_text = output.outputs[0].text
80 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
81
82
83if __name__ == '__main__':
84 parser = argparse.ArgumentParser(
85 description="Generate text using Medusa decoding.")
86 parser.add_argument(
87 '--use_modelopt_ckpt',
88 action='store_true',
89 help="Use FP8-quantized checkpoint from TensorRT Model Optimizer.")
90 # TODO: remove this arg after ModelOpt ckpt is public on HF
91 parser.add_argument('--model_dir', type=Path, default=None)
92 args = parser.parse_args()
93
94 run_medusa_decoding(args.use_modelopt_ckpt, args.model_dir)