Generate Text Using Lookahead Decoding

Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_lookahead_decoding.py.

 1### Generate Text Using Lookahead Decoding
 2from tensorrt_llm import LLM, SamplingParams
 3from tensorrt_llm.llmapi import (LLM, BuildConfig, KvCacheConfig,
 4                                 LookaheadDecodingConfig, SamplingParams)
 5
 6
 7def main():
 8
 9    # The end user can customize the build configuration with the build_config class
10    build_config = BuildConfig()
11    build_config.max_batch_size = 32
12
13    # The configuration for lookahead decoding
14    lookahead_config = LookaheadDecodingConfig(max_window_size=4,
15                                               max_ngram_size=4,
16                                               max_verification_set_size=4)
17
18    kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
19    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
20              kv_cache_config=kv_cache_config,
21              build_config=build_config,
22              speculative_config=lookahead_config)
23
24    prompt = "NVIDIA is a great company because"
25    print(f"Prompt: {prompt!r}")
26
27    sampling_params = SamplingParams(lookahead_config=lookahead_config)
28
29    output = llm.generate(prompt, sampling_params=sampling_params)
30    print(output)
31
32    #Output should be similar to:
33    # Prompt: 'NVIDIA is a great company because'
34    #RequestOutput(request_id=2, prompt='NVIDIA is a great company because', prompt_token_ids=[1, 405, 13044, 10764, 338, 263, 2107, 5001, 1363], outputs=[CompletionOutput(index=0, text='they are always pushing the envelope. They are always trying to make the best graphics cards and the best processors. They are always trying to make the best', token_ids=[896, 526, 2337, 27556, 278, 427, 21367, 29889, 2688, 526, 2337, 1811, 304, 1207, 278, 1900, 18533, 15889, 322, 278, 1900, 1889, 943, 29889, 2688, 526, 2337, 1811, 304, 1207, 278, 1900], cumulative_logprob=None, logprobs=[], finish_reason='length', stop_reason=None, generation_logits=None)], finished=True)
35
36
37if __name__ == '__main__':
38    main()