Guided Decoding#
Guided decoding (or interchangeably constrained decoding, structured generation) guarantees that the LLM outputs are amenable to a user-specified grammar (e.g., JSON schema, regular expression or EBNF grammar).
TensorRT LLM supports two grammar backends:
XGrammar: Supports JSON schema, regular expression, EBNF and structural tag.
LLGuidance: Supports JSON schema, regular expression, EBNF.
Online API: trtllm-serve#
If you are using trtllm-serve, enable guided decoding by specifying guided_decoding_backend with xgrammar or llguidance in the YAML configuration file, and pass it to --extra_llm_api_options. For example,
cat > extra_llm_api_options.yaml <<EOF
guided_decoding_backend: xgrammar
EOF
trtllm-serve nvidia/Llama-3.1-8B-Instruct-FP8 --extra_llm_api_options extra_llm_api_options.yaml
You should see a log like the following, which indicates the grammar backend is successfully enabled.
......
[TRT-LLM] [I] Guided decoder initialized with backend: GuidedDecodingBackend.XGRAMMAR
......
JSON Schema#
Define a JSON schema and pass it to response_format when creating the OpenAI chat completion request. Alternatively, the JSON schema can be created using pydantic.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"pattern": "^[\\w]+$"
},
"population": {
"type": "integer"
},
},
"required": ["name", "population"],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "json",
"schema": json_schema
},
)
message = chat_completion.choices[0].message
print(message.content)
The output would look like:
{
"name": "Paris",
"population": 2145200
}
Regular expression#
Define a regular expression and pass it to response_format when creating the OpenAI chat completion request.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "What is the capital of France?",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "regex",
"regex": "(Paris|London)"
},
)
message = chat_completion.choices[0].message
print(message.content)
The output would look like:
Paris
EBNF grammar#
Define an EBNF grammar and pass it to response_format when creating the OpenAI chat completion request.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
ebnf_grammar = """root ::= description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
messages = [
{
"role": "system",
"content": "You are a helpful geography bot."
},
{
"role": "user",
"content": "Give me the information of the capital of France.",
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "ebnf",
"ebnf": ebnf_grammar
},
)
message = chat_completion.choices[0].message
print(message.content)
The output would look like:
Paris is the capital of France
Structural tag#
Define a structural tag and pass it to response_format when creating the OpenAI chat completion request.
Structural tag is supported by xgrammar backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see structural tag usage for a comprehensive tutorial. Below is an example of function calling with customized function call format for Llama-3.1-8B-Instruct.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
tool_get_current_date = {
"type": "function",
"function": {
"name": "get_current_date",
"description": "Get the current date and time for a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
}
},
"required": ["timezone"],
},
},
}
system_prompt = f"""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant."""
user_prompt = "You are in New York. Please get the current date and time, and the weather."
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
chat_completion = client.chat.completions.create(
model="nvidia/Llama-3.1-8B-Instruct-FP8",
messages=messages,
max_completion_tokens=256,
response_format={
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"triggers": ["<function="],
"tags": [
{
"begin": "<function=get_current_weather>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_weather["function"]["parameters"]
},
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_date["function"]["parameters"]
},
"end": "</function>",
},
],
},
},
)
message = chat_completion.choices[0].message
print(message.content)
The output would look like:
<function=get_current_date>{"timezone": "America/New_York"}</function>
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>
Offline API: LLM API#
If you are using LLM API, enable guided decoding by specifying guided_decoding_backend with xgrammar or llguidance when creating the LLM instance. For example,
from tensorrt_llm import LLM
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
JSON Schema#
Create a GuidedDecodingParams with the json field specified with a JSON schema, use it to create SamplingParams, and then pass to llm.generate or llm.generate_async. Alternatively, the JSON schema can be created using pydantic.
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string",
"pattern": "^[\\w]+$"
},
"population": {
"type": "integer"
},
},
"required": ["name", "population"],
}
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(json=json_schema)),
)
print(output.outputs[0].text)
The output would look like:
{
"name": "Paris",
"population": 2145206
}
Regular expression#
Create a GuidedDecodingParams with the regex field specified with a regular expression, use it to create SamplingParams, and then pass to llm.generate or llm.generate_async.
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "What is the capital of France?",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(regex="(Paris|London)")),
)
print(output.outputs[0].text)
The output would look like:
Paris
EBNF grammar#
Create a GuidedDecodingParams with the grammar field specified with an EBNF grammar, use it to create SamplingParams, and then pass to llm.generate or llm.generate_async.
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
ebnf_grammar = """root ::= description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
messages = [
{
"role": "system",
"content": "You are a helpful geography bot."
},
{
"role": "user",
"content": "Give me the information of the capital of France.",
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(grammar=ebnf_grammar)),
)
print(output.outputs[0].text)
The output would look like:
Paris is the capital of France
Structural tag#
Create a GuidedDecodingParams with the structural_tag field specified with a structural tag string, use it to create SamplingParams, and then pass to llm.generate or llm.generate_async.
Structural tag is supported by xgrammar backend only. It is a powerful and flexible tool to represent the LLM output constraints. Please see structural tag usage for a comprehensive tutorial. Below is an example of function calling with customized function call format for Llama-3.1-8B-Instruct.
import json
from tensorrt_llm import LLM
from tensorrt_llm.sampling_params import SamplingParams, GuidedDecodingParams
if __name__ == "__main__":
llm = LLM("nvidia/Llama-3.1-8B-Instruct-FP8", guided_decoding_backend="xgrammar")
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
tool_get_current_date = {
"type": "function",
"function": {
"name": "get_current_date",
"description": "Get the current date and time for a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
}
},
"required": ["timezone"],
},
},
}
system_prompt = f"""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant."""
user_prompt = "You are in New York. Please get the current date and time, and the weather."
structural_tag = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"triggers": ["<function="],
"tags": [
{
"begin": "<function=get_current_weather>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_weather["function"]["parameters"]
},
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"content": {
"type": "json_schema",
"json_schema": tool_get_current_date["function"]["parameters"]
},
"end": "</function>",
},
],
},
}
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
prompt = llm.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = llm.generate(
prompt,
sampling_params=SamplingParams(max_tokens=256, guided_decoding=GuidedDecodingParams(structural_tag=json.dumps(structural_tag))),
)
print(output.outputs[0].text)
The output would look like:
<function=get_current_date>{"timezone": "America/New_York"}</function>
<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>