Generate Text in Streaming#
Source NVIDIA/TensorRT-LLM.
1### Generate Text in Streaming
2import asyncio
3
4from tensorrt_llm import SamplingParams
5from tensorrt_llm._tensorrt_engine import LLM
6
7
8def main():
9
10 # model could accept HF model name or a path to local HF model.
11 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
12
13 # Sample prompts.
14 prompts = [
15 "Hello, my name is",
16 "The president of the United States is",
17 "The capital of France is",
18 "The future of AI is",
19 ]
20
21 # Create a sampling params.
22 sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
23
24 # Async based on Python coroutines
25 async def task(id: int, prompt: str):
26
27 # streaming=True is used to enable streaming generation.
28 async for output in llm.generate_async(prompt,
29 sampling_params,
30 streaming=True):
31 print(f"Generation for prompt-{id}: {output.outputs[0].text!r}")
32
33 async def main():
34 tasks = [task(id, prompt) for id, prompt in enumerate(prompts)]
35 await asyncio.gather(*tasks)
36
37 asyncio.run(main())
38
39 # Got output like follows:
40 # Generation for prompt-0: '\n'
41 # Generation for prompt-3: 'an'
42 # Generation for prompt-2: 'Paris'
43 # Generation for prompt-1: 'likely'
44 # Generation for prompt-0: '\n\n'
45 # Generation for prompt-3: 'an exc'
46 # Generation for prompt-2: 'Paris.'
47 # Generation for prompt-1: 'likely to'
48 # Generation for prompt-0: '\n\nJ'
49 # Generation for prompt-3: 'an exciting'
50 # Generation for prompt-2: 'Paris.'
51 # Generation for prompt-1: 'likely to nomin'
52 # Generation for prompt-0: '\n\nJane'
53 # Generation for prompt-3: 'an exciting time'
54 # Generation for prompt-1: 'likely to nominate'
55 # Generation for prompt-0: '\n\nJane Smith'
56 # Generation for prompt-3: 'an exciting time for'
57 # Generation for prompt-1: 'likely to nominate a'
58 # Generation for prompt-0: '\n\nJane Smith.'
59 # Generation for prompt-3: 'an exciting time for us'
60 # Generation for prompt-1: 'likely to nominate a new'
61
62
63if __name__ == '__main__':
64 main()