Generate text with multiple LoRA adapters#

Source NVIDIA/TensorRT-LLM.

 1from huggingface_hub import snapshot_download
 2
 3from tensorrt_llm import LLM
 4from tensorrt_llm.executor import LoRARequest
 5from tensorrt_llm.lora_manager import LoraConfig
 6
 7
 8def main():
 9
10    # Download the LoRA adapters from huggingface hub.
11    lora_dir1 = snapshot_download(repo_id="snshrivas10/sft-tiny-chatbot")
12    lora_dir2 = snapshot_download(
13        repo_id="givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
14    lora_dir3 = snapshot_download(repo_id="barissglc/tinyllama-tarot-v1")
15
16    # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
17    # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
18    lora_config = LoraConfig(lora_dir=[lora_dir1],
19                             max_lora_rank=64,
20                             max_loras=3,
21                             max_cpu_loras=3)
22    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
23              lora_config=lora_config)
24
25    # Sample prompts
26    prompts = [
27        "Hello, tell me a story: ",
28        "Hello, tell me a story: ",
29        "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",
30        "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?",
31        "In this reading, the Justice card represents a situation where",
32        "In this reading, the Justice card represents a situation where",
33    ]
34
35    # At runtime, multiple LoRA adapters can be specified via lora_request; None means no LoRA used.
36    for output in llm.generate(prompts,
37                               lora_request=[
38                                   None,
39                                   LoRARequest("chatbot", 1, lora_dir1), None,
40                                   LoRARequest("mental-health", 2, lora_dir2),
41                                   None,
42                                   LoRARequest("tarot", 3, lora_dir3)
43                               ]):
44        prompt = output.prompt
45        generated_text = output.outputs[0].text
46        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
47
48    # Got output like
49    # Prompt: 'Hello, tell me a story: ', Generated text: '1. Start with a question: "What\'s your favorite color?" 2. Ask a question that leads to a story: "What\'s your'
50    # Prompt: 'Hello, tell me a story: ', Generated text: '1. A person is walking down the street. 2. A person is sitting on a bench. 3. A person is reading a book.'
51    # Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (smiling) No, I'm just feeling a bit overwhelmed lately. I've been trying to"
52    # Prompt: "I've noticed you seem a bit down lately. Is there anything you'd like to talk about?", Generated text: "\n\nJASON: (sighs) Yeah, I've been struggling with some personal issues. I've been feeling like I'm"
53    # Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. The card suggests that you should take the time to consider all the options'
54    # Prompt: 'In this reading, the Justice card represents a situation where', Generated text: 'you are being asked to make a decision that will have a significant impact on your life. It is important to take the time to consider all the options and make'
55
56
57if __name__ == '__main__':
58    main()