Runtime Configuration Examples#

Source NVIDIA/TensorRT-LLM.

 1
 2import argparse
 3
 4from tensorrt_llm import LLM, SamplingParams
 5from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
 6
 7
 8def example_cuda_graph_config():
 9    """
10    Example demonstrating CUDA graph configuration for performance optimization.
11
12    CUDA graphs help with:
13    - Reduced kernel launch overhead
14    - Better GPU utilization
15    - Improved throughput for repeated operations
16    """
17    print("\n=== CUDA Graph Configuration Example ===")
18
19    cuda_graph_config = CudaGraphConfig(
20        batch_sizes=[1, 2, 4],
21        enable_padding=True,
22    )
23
24    llm = LLM(
25        model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26        cuda_graph_config=cuda_graph_config,  # Enable CUDA graphs
27        max_batch_size=4,
28        max_seq_len=512,
29        kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.8,
30                                      enable_block_reuse=True))
31
32    prompts = [
33        "Hello, my name is",
34        "The capital of France is",
35        "The future of AI is",
36    ]
37
38    sampling_params = SamplingParams(max_tokens=50, temperature=0.8, top_p=0.95)
39
40    # This should benefit from CUDA graphs
41    outputs = llm.generate(prompts, sampling_params)
42    for output in outputs:
43        print(f"Prompt: {output.prompt}")
44        print(f"Generated: {output.outputs[0].text}")
45        print()
46
47
48def example_kv_cache_config():
49    print("\n=== KV Cache Configuration Example ===")
50    print("\n1. KV Cache Configuration:")
51
52    llm_advanced = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
53                       max_batch_size=8,
54                       max_seq_len=1024,
55                       kv_cache_config=KvCacheConfig(
56                           free_gpu_memory_fraction=0.85,
57                           enable_block_reuse=True))
58
59    prompts = [
60        "Hello, my name is",
61        "The capital of France is",
62        "The future of AI is",
63    ]
64
65    outputs = llm_advanced.generate(prompts)
66    for i, output in enumerate(outputs):
67        print(f"Query {i+1}: {output.prompt}")
68        print(f"Answer: {output.outputs[0].text[:100]}...")
69        print()
70
71
72def main():
73    """
74    Main function to run all runtime configuration examples.
75    """
76    parser = argparse.ArgumentParser(
77        description="Runtime Configuration Examples")
78    parser.add_argument("--example",
79                        type=str,
80                        choices=["kv_cache", "cuda_graph", "all"],
81                        default="all",
82                        help="Which example to run")
83
84    args = parser.parse_args()
85
86    if args.example == "kv_cache" or args.example == "all":
87        example_kv_cache_config()
88
89    if args.example == "cuda_graph" or args.example == "all":
90        example_cuda_graph_config()
91
92
93if __name__ == "__main__":
94    main()