Sampling Techniques Showcase#
Source NVIDIA/TensorRT-LLM.
1"""
2This example demonstrates various sampling techniques available in TensorRT-LLM.
3It showcases different sampling parameters and their effects on text generation.
4"""
5
6from typing import Optional
7
8import click
9
10from tensorrt_llm import LLM, SamplingParams
11
12# Example prompts to demonstrate different sampling techniques
13prompts = [
14 "What is the future of artificial intelligence?",
15 "Describe a beautiful sunset over the ocean.",
16 "Write a short story about a robot discovering emotions.",
17]
18
19
20def demonstrate_greedy_decoding(prompt: str):
21 """Demonstrates greedy decoding with temperature=0."""
22 print("\n🎯 === GREEDY DECODING ===")
23 print("Using temperature=0 for deterministic, focused output")
24
25 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
26
27 sampling_params = SamplingParams(
28 max_tokens=50,
29 temperature=0.0, # Greedy decoding
30 )
31
32 response = llm.generate(prompt, sampling_params)
33 print(f"Prompt: {prompt}")
34 print(f"Response: {response.outputs[0].text}")
35
36
37def demonstrate_temperature_sampling(prompt: str):
38 """Demonstrates temperature sampling with different temperature values."""
39 print("\n🌡️ === TEMPERATURE SAMPLING ===")
40 print(
41 "Higher temperature = more creative/random, Lower temperature = more focused"
42 )
43
44 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
45
46 temperatures = [0.3, 0.7, 1.0, 1.5]
47 for temp in temperatures:
48
49 sampling_params = SamplingParams(
50 max_tokens=50,
51 temperature=temp,
52 )
53
54 response = llm.generate(prompt, sampling_params)
55 print(f"Temperature {temp}: {response.outputs[0].text}")
56
57
58def demonstrate_top_k_sampling(prompt: str):
59 """Demonstrates top-k sampling with different k values."""
60 print("\n🔝 === TOP-K SAMPLING ===")
61 print("Only consider the top-k most likely tokens at each step")
62
63 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
64
65 top_k_values = [1, 5, 20, 50]
66
67 for k in top_k_values:
68 sampling_params = SamplingParams(
69 max_tokens=50,
70 temperature=0.8, # Use moderate temperature
71 top_k=k,
72 )
73
74 response = llm.generate(prompt, sampling_params)
75 print(f"Top-k {k}: {response.outputs[0].text}")
76
77
78def demonstrate_top_p_sampling(prompt: str):
79 """Demonstrates top-p (nucleus) sampling with different p values."""
80 print("\n🎯 === TOP-P (NUCLEUS) SAMPLING ===")
81 print("Only consider tokens whose cumulative probability is within top-p")
82
83 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
84
85 top_p_values = [0.1, 0.5, 0.9, 0.95]
86
87 for p in top_p_values:
88 sampling_params = SamplingParams(
89 max_tokens=50,
90 temperature=0.8, # Use moderate temperature
91 top_p=p,
92 )
93
94 response = llm.generate(prompt, sampling_params)
95 print(f"Top-p {p}: {response.outputs[0].text}")
96
97
98def demonstrate_combined_sampling(prompt: str):
99 """Demonstrates combined top-k and top-p sampling."""
100 print("\n🔄 === COMBINED TOP-K + TOP-P SAMPLING ===")
101 print("Using both top-k and top-p together for balanced control")
102
103 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
104
105 sampling_params = SamplingParams(
106 max_tokens=50,
107 temperature=0.8,
108 top_k=40, # Consider top 40 tokens
109 top_p=0.9, # Within 90% cumulative probability
110 )
111
112 response = llm.generate(prompt, sampling_params)
113 print(f"Combined (k=40, p=0.9): {response.outputs[0].text}")
114
115
116def demonstrate_multiple_sequences(prompt: str):
117 """Demonstrates generating multiple sequences with different sampling."""
118 print("\n📚 === MULTIPLE SEQUENCES ===")
119 print("Generate multiple different responses for the same prompt")
120
121 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
122
123 sampling_params = SamplingParams(
124 max_tokens=40,
125 temperature=0.8,
126 top_k=50,
127 top_p=0.95,
128 n=3, # Generate 3 different sequences
129 )
130
131 response = llm.generate(prompt, sampling_params)
132 print(f"Prompt: {prompt}")
133 for i, output in enumerate(response.outputs):
134 print(f"Sequence {i+1}: {output.text}")
135
136
137def demonstrate_with_logprobs(prompt: str):
138 """Demonstrates generation with log probabilities."""
139 print("\n📊 === GENERATION WITH LOG PROBABILITIES ===")
140 print("Get probability information for generated tokens")
141
142 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
143
144 sampling_params = SamplingParams(
145 max_tokens=20,
146 temperature=0.7,
147 top_k=50,
148 logprobs=True, # Return log probabilities
149 )
150
151 response = llm.generate(prompt, sampling_params)
152 output = response.outputs[0]
153
154 print(f"Prompt: {prompt}")
155 print(f"Generated: {output.text}")
156 print(f"Logprobs: {output.logprobs}")
157
158
159def run_all_demonstrations(model_path: Optional[str] = None):
160 """Run all sampling demonstrations."""
161 print("🚀 TensorRT-LLM Sampling Techniques Showcase")
162 print("=" * 50)
163
164 # Use the first prompt for most demonstrations
165 demo_prompt = prompts[0]
166
167 # Run all demonstrations
168 demonstrate_greedy_decoding(demo_prompt)
169 demonstrate_temperature_sampling(demo_prompt)
170 demonstrate_top_k_sampling(demo_prompt)
171 demonstrate_top_p_sampling(demo_prompt)
172 demonstrate_combined_sampling(demo_prompt)
173 # TODO[Superjomn]: enable them once pytorch backend supports
174 # demonstrate_multiple_sequences(llm, demo_prompt)
175 # demonstrate_beam_search(demo_prompt)
176 demonstrate_with_logprobs(demo_prompt)
177
178 print("\n🎉 All sampling demonstrations completed!")
179
180
181@click.command()
182@click.option("--model",
183 type=str,
184 default=None,
185 help="Path to the model or model name")
186@click.option("--demo",
187 type=click.Choice([
188 "greedy", "temperature", "top_k", "top_p", "combined",
189 "multiple", "beam", "logprobs", "creative", "all"
190 ]),
191 default="all",
192 help="Which demonstration to run")
193@click.option("--prompt", type=str, default=None, help="Custom prompt to use")
194def main(model: Optional[str], demo: str, prompt: Optional[str]):
195 """
196 Showcase various sampling techniques in TensorRT-LLM.
197
198 Examples:
199 python llm_sampling.py --demo all
200 python llm_sampling.py --demo temperature --prompt "Tell me a joke"
201 python llm_sampling.py --demo beam --model path/to/your/model
202 """
203
204 demo_prompt = prompt or prompts[0]
205
206 # Run specific demonstration
207 if demo == "greedy":
208 demonstrate_greedy_decoding(demo_prompt)
209 elif demo == "temperature":
210 demonstrate_temperature_sampling(demo_prompt)
211 elif demo == "top_k":
212 demonstrate_top_k_sampling(demo_prompt)
213 elif demo == "top_p":
214 demonstrate_top_p_sampling(demo_prompt)
215 elif demo == "combined":
216 demonstrate_combined_sampling(demo_prompt)
217 elif demo == "multiple":
218 demonstrate_multiple_sequences(demo_prompt)
219 elif demo == "logprobs":
220 demonstrate_with_logprobs(demo_prompt)
221 elif demo == "all":
222 run_all_demonstrations(model)
223
224
225if __name__ == "__main__":
226 main()