Generate text with guided decoding
Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_guided_decoding.py.
1### Generate text with guided decoding
2from tensorrt_llm import LLM, SamplingParams
3from tensorrt_llm.llmapi import BuildConfig, GuidedDecodingParams
4
5
6def main():
7
8 # TODO(jiahanc): Clean up build_config when use_paged_context_fmha is by default enabled
9 build_config = BuildConfig()
10 build_config.plugin_config.use_paged_context_fmha = True
11
12 # Specify the guided decoding backend; xgrammar is supported currently.
13 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
14 build_config=build_config,
15 guided_decoding_backend='xgrammar')
16
17 # An example from json-mode-eval
18 schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}'
19
20 prompt = [{
21 'role':
22 'system',
23 'content':
24 "You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n{'title': 'WirelessAccessPoint', 'type': 'object', 'properties': {'ssid': {'title': 'SSID', 'type': 'string'}, 'securityProtocol': {'title': 'SecurityProtocol', 'type': 'string'}, 'bandwidth': {'title': 'Bandwidth', 'type': 'string'}}, 'required': ['ssid', 'securityProtocol', 'bandwidth']}\n</schema>\n"
25 }, {
26 'role':
27 'user',
28 'content':
29 "I'm currently configuring a wireless access point for our office network and I need to generate a JSON object that accurately represents its settings. The access point's SSID should be 'OfficeNetSecure', it uses WPA2-Enterprise as its security protocol, and it's capable of a bandwidth of up to 1300 Mbps on the 5 GHz band. This JSON object will be used to document our network configurations and to automate the setup process for additional access points in the future. Please provide a JSON object that includes these details."
30 }]
31 prompt = llm.tokenizer.apply_chat_template(prompt, tokenize=False)
32 print(f"Prompt: {prompt!r}")
33
34 output = llm.generate(prompt, sampling_params=SamplingParams(max_tokens=50))
35 print(f"Generated text (unguided): {output.outputs[0].text!r}")
36
37 output = llm.generate(
38 prompt,
39 sampling_params=SamplingParams(
40 max_tokens=50, guided_decoding=GuidedDecodingParams(json=schema)))
41 print(f"Generated text (guided): {output.outputs[0].text!r}")
42
43 # Got output like
44 # Prompt: "<|system|>\nYou are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n{'title': 'WirelessAccessPoint', 'type': 'object', 'properties': {'ssid': {'title': 'SSID', 'type': 'string'}, 'securityProtocol': {'title': 'SecurityProtocol', 'type': 'string'}, 'bandwidth': {'title': 'Bandwidth', 'type': 'string'}}, 'required': ['ssid', 'securityProtocol', 'bandwidth']}\n</schema>\n</s>\n<|user|>\nI'm currently configuring a wireless access point for our office network and I need to generate a JSON object that accurately represents its settings. The access point's SSID should be 'OfficeNetSecure', it uses WPA2-Enterprise as its security protocol, and it's capable of a bandwidth of up to 1300 Mbps on the 5 GHz band. This JSON object will be used to document our network configurations and to automate the setup process for additional access points in the future. Please provide a JSON object that includes these details.</s>\n"
45 # Generated text (unguided): '<|assistant|>\nHere\'s a JSON object that accurately represents the settings of a wireless access point for our office network:\n\n```json\n{\n "title": "WirelessAccessPoint",\n "'
46 # Generated text (guided): '{"ssid": "OfficeNetSecure", "securityProtocol": "WPA2-Enterprise", "bandwidth": "1300 Mbps"}'
47
48
49if __name__ == '__main__':
50 main()