Control generated text using logits processor#
Source NVIDIA/TensorRT-LLM.
1from typing import List, Optional
2
3import torch
4from transformers import PreTrainedTokenizer
5
6from tensorrt_llm import LLM
7from tensorrt_llm.sampling_params import LogitsProcessor, SamplingParams
8
9
10def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool):
11 tokens = tokenizer.encode(text, add_special_tokens=False)
12
13 max_token_count = 1
14 bos_token_added = getattr(tokenizer, 'bos_token', None) and getattr(
15 tokenizer, 'bos_token_id', None) in tokens
16 prefix_token_added = getattr(tokenizer, 'add_prefix_space',
17 None) is not False
18 if bos_token_added or prefix_token_added:
19 max_token_count = 2
20
21 if not last and len(tokens) > max_token_count:
22 raise Exception(
23 f"Can't convert {text} to token. It has {len(tokens)} tokens.")
24
25 return tokens[-1]
26
27
28# The recommended way to create a customized logits processor:
29# * Subclass LogitsProcessor and implement the processing logics in the __call__ method.
30# * Create an instance and pass to SamplingParams.
31# More LogitsProcessors references can be found at https://github.com/NVIDIA/logits-processor-zoo.
32class GenLengthLogitsProcessor(LogitsProcessor):
33 """
34 A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token
35 based on the length of the generated sequence, encouraging or discouraging shorter answers.
36 WARNING: Create a new object before every model.generate call since token_count is accumulated.
37
38 Parameters
39 ----------
40 tokenizer: The tokenizer used by the LLM.
41 boost_factor (float): A factor to boost the likelihood of the EOS token as the sequence length increases.
42 Suggested value range is [-1.0, 1.0]. Negative values are used for the opposite effect.
43 p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2.
44 complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop
45 or a new line. Default is False.
46
47 """
48
49 def __init__(self,
50 tokenizer,
51 boost_factor: float,
52 p: int = 2,
53 complete_sentences: bool = False):
54 self.eos_token = tokenizer.eos_token_id
55 self.boost_factor = boost_factor
56 self.p = p
57 self.token_count = 0
58 self.full_stop_token = text_to_token(tokenizer,
59 "It is a sentence.",
60 last=True)
61 self.new_line_token = text_to_token(tokenizer,
62 "It is a new line\n",
63 last=True)
64 self.complete_sentences = complete_sentences
65
66 def __call__(self, req_ids: int, logits: torch.Tensor, ids: List[List[int]],
67 stream_ptr, client_id: Optional[int]):
68 boost_val = self.boost_factor * (self.token_count**self.p) / (10**
69 self.p)
70
71 stream = None if stream_ptr is None else torch.cuda.ExternalStream(
72 stream_ptr)
73
74 with torch.cuda.stream(stream):
75 ids = torch.LongTensor(ids).to(logits.device, non_blocking=True)
76
77 if self.complete_sentences:
78 enabled = (ids[:, -1] == self.full_stop_token) | (
79 ids[:, -1] == self.new_line_token)
80 logits[:, :, self.eos_token] += enabled * boost_val
81 else:
82 logits[:, :, self.eos_token] += boost_val
83
84 self.token_count += 1
85
86
87def main():
88
89 llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
90
91 # Sample prompts
92 prompts = [
93 "The future of AI is",
94 "The future of AI is",
95 ]
96
97 # Generate text
98 for prompt_id, prompt in enumerate(prompts):
99 if prompt_id % 2 == 0:
100 # Without logit processor
101 sampling_params = SamplingParams(top_p=1, max_tokens=200)
102 else:
103 # Each prompt can be specified with a logits processor at runtime
104 sampling_params = SamplingParams(
105 temperature=0.8,
106 top_p=0.95,
107 logits_processor=GenLengthLogitsProcessor(
108 llm.tokenizer, boost_factor=1, complete_sentences=True))
109
110 output = llm.generate(prompt, sampling_params)
111 print(
112 f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
113 )
114
115 # Got output like:
116 # Prompt (original): "bright, and it's not just for big companies. Small businesses can also benefit from AI technology. Here are some ways:\n\n1. Improved customer service: AI can help businesses provide better customer service by analyzing customer data and providing personalized recommendations.
117 # This can help businesses improve their customer experience and increase customer loyalty.\n\n2. Increased productivity: AI can help businesses automate repetitive tasks, freeing up employees to focus on more complex tasks. This can
118 # help businesses increase productivity and reduce costs.\n\n3. Enhanced marketing: AI can help businesses create more personalized marketing campaigns by analyzing customer data and targeting specific audiences. This can help businesses
119 # increase their marketing ROI and drive more sales.\n\n4. Improved supply chain management: AI can help businesses optimize their supply chain by analyzing data on demand,"'
120 #
121 # Prompt (with GenLenthLogitsProcesor): "bright, and it's not just for big companies. Small businesses can also benefit from AI technology."
122
123
124if __name__ == '__main__':
125 main()