Generate text with multiple LoRA adapters#
Source NVIDIA/TensorRT-LLM.
1
2import argparse
3from typing import Optional
4
5from huggingface_hub import snapshot_download
6
7from tensorrt_llm import LLM
8from tensorrt_llm.executor import LoRARequest
9from tensorrt_llm.lora_helper import LoraConfig
10
11
12def main(chatbot_lora_dir: Optional[str], mental_health_lora_dir: Optional[str],
13 tarot_lora_dir: Optional[str]):
14
15 # Download the LoRA adapters from huggingface hub, if not provided via command line args.
16 if chatbot_lora_dir is None:
17 chatbot_lora_dir = snapshot_download(
18 repo_id="snshrivas10/sft-tiny-chatbot")
19 if mental_health_lora_dir is None:
20 mental_health_lora_dir = snapshot_download(
21 repo_id=
22 "givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
23 if tarot_lora_dir is None:
24 tarot_lora_dir = snapshot_download(
25 repo_id="barissglc/tinyllama-tarot-v1")
26
27 # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
28 # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
29 lora_config = LoraConfig(lora_dir=[chatbot_lora_dir],
30 max_lora_rank=64,
31 max_loras=3,
32 max_cpu_loras=3)
33 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
34 lora_config=lora_config)
35
36 # Sample prompts
37 prompts = [
38 "Hello, tell me a story: ",
39 "Hello, tell me a story: ",
40 "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",
41 "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",
42 "In this reading, the Justice card represents a situation where",
43 "In this reading, the Justice card represents a situation where",
44 ]
45
46 # At runtime, multiple LoRA adapters can be specified via lora_request; None means no LoRA used.
47 for output in llm.generate(prompts,
48 lora_request=[
49 None,
50 LoRARequest("chatbot", 1, chatbot_lora_dir),
51 None,
52 LoRARequest("mental-health", 2,
53 mental_health_lora_dir), None,
54 LoRARequest("tarot", 3, tarot_lora_dir)
55 ]):
56 prompt = output.prompt
57 generated_text = output.outputs[0].text
58 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
59
60 # Got output like
61 # Prompt: 'Hello, tell me a story: ', Generated text: '1. Start with a question: "What\'s your favorite color?" 2. Ask a question that leads to a story: "What\'s your'
62 # Prompt: 'Hello, tell me a story: ', Generated text: '1. A person is walking down the street. 2. A person is sitting on a bench. 3. A person is reading a book.'
63 # Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (smiling) No, I'm just feeling a bit overwhelmed lately. I've been trying to"
64 # Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (sighs) Yeah, I've been struggling with some personal issues. I've been feeling like I'm"
65 # Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. The card suggests that you should take the time to consider all the options'
66 # Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. It is important to take the time to consider all the options and make'
67
68
69if __name__ == '__main__':
70 parser = argparse.ArgumentParser(
71 description="Generate text with multiple LoRA adapters")
72 parser.add_argument('--chatbot_lora_dir',
73 type=str,
74 default=None,
75 help='Path to the chatbot LoRA directory')
76 parser.add_argument('--mental_health_lora_dir',
77 type=str,
78 default=None,
79 help='Path to the mental health LoRA directory')
80 parser.add_argument('--tarot_lora_dir',
81 type=str,
82 default=None,
83 help='Path to the tarot LoRA directory')
84 args = parser.parse_args()
85 main(args.chatbot_lora_dir, args.mental_health_lora_dir,
86 args.tarot_lora_dir)