Sampling#
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, beam search, stop words, bad words, penalty, context and generation logits, log probability, guided decoding and logits processors
General usage#
To use the feature:
Enable the
enable_trtllm_sampler
option in theLLM
classPass a
SamplingParams
object with the desired options to thegenerate()
function
The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
enable_trtllm_sampler=True)
sampling_params = SamplingParams(
temperature=1.0,
top_k=8,
top_p=0.5,
)
llm.generate(["Hello, my name is",
"Hello, my name is"], sampling_params)
Note: The enable_trtllm_sampler
option is not currently supported when using speculative decoders, such as MTP or Eagle-3, so there is a smaller subset of sampling options available.
Beam search#
Beam search is a decoding strategy that maintains multiple candidate sequences (beams) during text generation, exploring different possible continuations to find higher quality outputs. Unlike greedy decoding or sampling, beam search considers multiple hypotheses simultaneously.
To enable beam search, you must:
Enable the
use_beam_search
option in theSamplingParams
objectSet the
max_beam_width
parameter in theLLM
class to match thebest_of
parameter inSamplingParams
Disable overlap scheduling using the
disable_overlap_scheduler
parameter of theLLM
classDisable the usage of CUDA Graphs by passing
None
to thecuda_graph_config
parameter of theLLM
class
Parameter Configuration:
best_of
: Controls the number of beams processed during generation (beam width)n
: Controls the number of output sequences returned (can be less thanbest_of
)If
best_of
is omitted, the number of beams processed defaults ton
max_beam_width
in theLLM
class must equalbest_of
inSamplingParams
The following example demonstrates beam search with a beam width of 4, returning the top 3 sequences:
from tensorrt_llm import LLM, SamplingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
enable_trtllm_sampler=True,
max_beam_width=4, # must equal SamplingParams.best_of
disable_overlap_scheduler=True,
cuda_graph_config=None)
sampling_params = SamplingParams(
best_of=4, # must equal LLM.max_beam_width
use_beam_search=True,
n=3, # return top 3 sequences
)
llm.generate(["Hello, my name is",
"Hello, my name is"], sampling_params)
Guided decoding#
Guided decoding controls the generation outputs to conform to pre-defined structured formats, ensuring outputs follow specific schemas or patterns.
The PyTorch backend supports guided decoding with the XGrammar and Low-level Guidance (llguidance) backends and the following formats:
JSON schema
JSON object
Regular expressions
Extended Backus-Naur form (EBNF) grammar
Structural tags
To enable guided decoding, you must:
Set the
guided_decoding_backend
parameter to'xgrammar'
or'llguidance'
in theLLM
classCreate a
GuidedDecodingParams
object with the desired format specificationNote: Depending on the type of format, a different parameter needs to be chosen to construct the object (
json
,regex
,grammar
,structural_tag
).
Pass the
GuidedDecodingParams
object to theguided_decoding
parameter of theSamplingParams
object
The following example demonstrates guided decoding with a JSON schema:
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import GuidedDecodingParams
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
guided_decoding_backend='xgrammar')
structure = '{"title": "Example JSON", "type": "object", "properties": {...}}'
guided_decoding_params = GuidedDecodingParams(json=structure)
sampling_params = SamplingParams(
guided_decoding=guided_decoding_params,
)
llm.generate("Generate a JSON response", sampling_params)
You can find a more detailed example on guided decoding here.
Logits processor#
Logits processors allow you to modify the logits produced by the network before sampling, enabling custom generation behavior and constraints.
To use a custom logits processor:
Create a custom class that inherits from
LogitsProcessor
and implements the__call__
methodPass an instance of this class to the
logits_processor
parameter ofSamplingParams
The following example demonstrates logits processing:
import torch
from typing import List, Optional
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.sampling_params import LogitsProcessor
class MyCustomLogitsProcessor(LogitsProcessor):
def __call__(self,
req_id: int,
logits: torch.Tensor,
token_ids: List[List[int]],
stream_ptr: Optional[int],
client_id: Optional[int]
) -> None:
# Implement your custom inplace logits processing logic
logits *= logits
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
sampling_params = SamplingParams(
logits_processor=MyCustomLogitsProcessor()
)
llm.generate(["Hello, my name is"], sampling_params)
You can find a more detailed example on logits processors here.