Sampling Techniques Showcase#

Source NVIDIA/TensorRT-LLM.

  1"""
  2This example demonstrates various sampling techniques available in TensorRT-LLM.
  3It showcases different sampling parameters and their effects on text generation.
  4"""
  5
  6from typing import Optional
  7
  8import click
  9
 10from tensorrt_llm import LLM, SamplingParams
 11
 12# Example prompts to demonstrate different sampling techniques
 13prompts = [
 14    "What is the future of artificial intelligence?",
 15    "Describe a beautiful sunset over the ocean.",
 16    "Write a short story about a robot discovering emotions.",
 17]
 18
 19
 20def demonstrate_greedy_decoding(prompt: str):
 21    """Demonstrates greedy decoding with temperature=0."""
 22    print("\n🎯 === GREEDY DECODING ===")
 23    print("Using temperature=0 for deterministic, focused output")
 24
 25    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 26
 27    sampling_params = SamplingParams(
 28        max_tokens=50,
 29        temperature=0.0,  # Greedy decoding
 30    )
 31
 32    response = llm.generate(prompt, sampling_params)
 33    print(f"Prompt: {prompt}")
 34    print(f"Response: {response.outputs[0].text}")
 35
 36
 37def demonstrate_temperature_sampling(prompt: str):
 38    """Demonstrates temperature sampling with different temperature values."""
 39    print("\n🌡️ === TEMPERATURE SAMPLING ===")
 40    print(
 41        "Higher temperature = more creative/random, Lower temperature = more focused"
 42    )
 43
 44    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 45
 46    temperatures = [0.3, 0.7, 1.0, 1.5]
 47    for temp in temperatures:
 48
 49        sampling_params = SamplingParams(
 50            max_tokens=50,
 51            temperature=temp,
 52        )
 53
 54        response = llm.generate(prompt, sampling_params)
 55        print(f"Temperature {temp}: {response.outputs[0].text}")
 56
 57
 58def demonstrate_top_k_sampling(prompt: str):
 59    """Demonstrates top-k sampling with different k values."""
 60    print("\n🔝 === TOP-K SAMPLING ===")
 61    print("Only consider the top-k most likely tokens at each step")
 62
 63    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 64
 65    top_k_values = [1, 5, 20, 50]
 66
 67    for k in top_k_values:
 68        sampling_params = SamplingParams(
 69            max_tokens=50,
 70            temperature=0.8,  # Use moderate temperature
 71            top_k=k,
 72        )
 73
 74        response = llm.generate(prompt, sampling_params)
 75        print(f"Top-k {k}: {response.outputs[0].text}")
 76
 77
 78def demonstrate_top_p_sampling(prompt: str):
 79    """Demonstrates top-p (nucleus) sampling with different p values."""
 80    print("\n🎯 === TOP-P (NUCLEUS) SAMPLING ===")
 81    print("Only consider tokens whose cumulative probability is within top-p")
 82
 83    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
 84
 85    top_p_values = [0.1, 0.5, 0.9, 0.95]
 86
 87    for p in top_p_values:
 88        sampling_params = SamplingParams(
 89            max_tokens=50,
 90            temperature=0.8,  # Use moderate temperature
 91            top_p=p,
 92        )
 93
 94        response = llm.generate(prompt, sampling_params)
 95        print(f"Top-p {p}: {response.outputs[0].text}")
 96
 97
 98def demonstrate_combined_sampling(prompt: str):
 99    """Demonstrates combined top-k and top-p sampling."""
100    print("\n🔄 === COMBINED TOP-K + TOP-P SAMPLING ===")
101    print("Using both top-k and top-p together for balanced control")
102
103    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
104
105    sampling_params = SamplingParams(
106        max_tokens=50,
107        temperature=0.8,
108        top_k=40,  # Consider top 40 tokens
109        top_p=0.9,  # Within 90% cumulative probability
110    )
111
112    response = llm.generate(prompt, sampling_params)
113    print(f"Combined (k=40, p=0.9): {response.outputs[0].text}")
114
115
116def demonstrate_multiple_sequences(prompt: str):
117    """Demonstrates generating multiple sequences with different sampling."""
118    print("\n📚 === MULTIPLE SEQUENCES ===")
119    print("Generate multiple different responses for the same prompt")
120
121    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
122
123    sampling_params = SamplingParams(
124        max_tokens=40,
125        temperature=0.8,
126        top_k=50,
127        top_p=0.95,
128        n=3,  # Generate 3 different sequences
129    )
130
131    response = llm.generate(prompt, sampling_params)
132    print(f"Prompt: {prompt}")
133    for i, output in enumerate(response.outputs):
134        print(f"Sequence {i+1}: {output.text}")
135
136
137def demonstrate_with_logprobs(prompt: str):
138    """Demonstrates generation with log probabilities."""
139    print("\n📊 === GENERATION WITH LOG PROBABILITIES ===")
140    print("Get probability information for generated tokens")
141
142    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
143
144    sampling_params = SamplingParams(
145        max_tokens=20,
146        temperature=0.7,
147        top_k=50,
148        logprobs=True,  # Return log probabilities
149    )
150
151    response = llm.generate(prompt, sampling_params)
152    output = response.outputs[0]
153
154    print(f"Prompt: {prompt}")
155    print(f"Generated: {output.text}")
156    print(f"Logprobs: {output.logprobs}")
157
158
159def run_all_demonstrations(model_path: Optional[str] = None):
160    """Run all sampling demonstrations."""
161    print("🚀 TensorRT-LLM Sampling Techniques Showcase")
162    print("=" * 50)
163
164    # Use the first prompt for most demonstrations
165    demo_prompt = prompts[0]
166
167    # Run all demonstrations
168    demonstrate_greedy_decoding(demo_prompt)
169    demonstrate_temperature_sampling(demo_prompt)
170    demonstrate_top_k_sampling(demo_prompt)
171    demonstrate_top_p_sampling(demo_prompt)
172    demonstrate_combined_sampling(demo_prompt)
173    # TODO[Superjomn]: enable them once pytorch backend supports
174    # demonstrate_multiple_sequences(llm, demo_prompt)
175    # demonstrate_beam_search(demo_prompt)
176    demonstrate_with_logprobs(demo_prompt)
177
178    print("\n🎉 All sampling demonstrations completed!")
179
180
181@click.command()
182@click.option("--model",
183              type=str,
184              default=None,
185              help="Path to the model or model name")
186@click.option("--demo",
187              type=click.Choice([
188                  "greedy", "temperature", "top_k", "top_p", "combined",
189                  "multiple", "beam", "logprobs", "creative", "all"
190              ]),
191              default="all",
192              help="Which demonstration to run")
193@click.option("--prompt", type=str, default=None, help="Custom prompt to use")
194def main(model: Optional[str], demo: str, prompt: Optional[str]):
195    """
196    Showcase various sampling techniques in TensorRT-LLM.
197
198    Examples:
199        python llm_sampling.py --demo all
200        python llm_sampling.py --demo temperature --prompt "Tell me a joke"
201        python llm_sampling.py --demo beam --model path/to/your/model
202    """
203
204    demo_prompt = prompt or prompts[0]
205
206    # Run specific demonstration
207    if demo == "greedy":
208        demonstrate_greedy_decoding(demo_prompt)
209    elif demo == "temperature":
210        demonstrate_temperature_sampling(demo_prompt)
211    elif demo == "top_k":
212        demonstrate_top_k_sampling(demo_prompt)
213    elif demo == "top_p":
214        demonstrate_top_p_sampling(demo_prompt)
215    elif demo == "combined":
216        demonstrate_combined_sampling(demo_prompt)
217    elif demo == "multiple":
218        demonstrate_multiple_sequences(demo_prompt)
219    elif demo == "logprobs":
220        demonstrate_with_logprobs(demo_prompt)
221    elif demo == "all":
222        run_all_demonstrations(model)
223
224
225if __name__ == "__main__":
226    main()