Generate Text Using Lookahead Decoding
Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_lookahead_decoding.py.
1### Generate Text Using Lookahead Decoding
2from tensorrt_llm import LLM, SamplingParams
3from tensorrt_llm.llmapi import (LLM, BuildConfig, KvCacheConfig,
4 LookaheadDecodingConfig, SamplingParams)
5
6
7def main():
8
9 # The end user can customize the build configuration with the build_config class
10 build_config = BuildConfig()
11 build_config.max_batch_size = 32
12
13 # The configuration for lookahead decoding
14 lookahead_config = LookaheadDecodingConfig(max_window_size=4,
15 max_ngram_size=4,
16 max_verification_set_size=4)
17
18 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
19 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
20 kv_cache_config=kv_cache_config,
21 build_config=build_config,
22 speculative_config=lookahead_config)
23
24 prompt = "NVIDIA is a great company because"
25 print(f"Prompt: {prompt!r}")
26
27 sampling_params = SamplingParams(lookahead_config=lookahead_config)
28
29 output = llm.generate(prompt, sampling_params=sampling_params)
30 print(output)
31
32 #Output should be similar to:
33 # Prompt: 'NVIDIA is a great company because'
34 #RequestOutput(request_id=2, prompt='NVIDIA is a great company because', prompt_token_ids=[1, 405, 13044, 10764, 338, 263, 2107, 5001, 1363], outputs=[CompletionOutput(index=0, text='they are always pushing the envelope. They are always trying to make the best graphics cards and the best processors. They are always trying to make the best', token_ids=[896, 526, 2337, 27556, 278, 427, 21367, 29889, 2688, 526, 2337, 1811, 304, 1207, 278, 1900, 18533, 15889, 322, 278, 1900, 1889, 943, 29889, 2688, 526, 2337, 1811, 304, 1207, 278, 1900], cumulative_logprob=None, logprobs=[], finish_reason='length', stop_reason=None, generation_logits=None)], finished=True)
35
36
37if __name__ == '__main__':
38 main()