Control generated text using logits post processor

Source https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llm-api/llm_logits_processor.py.

 1### Control generated text using logits post processor
 2import typing as tp
 3
 4import torch
 5
 6from tensorrt_llm import LLM, SamplingParams
 7
 8
 9# Define the logits post-processor callback. This simple callback will output
10# a specific token at each step irrespective of prompt.
11# Refer to ../bindings/executor/example_logits_processor.py for a more
12# sophisticated callback that generates JSON structured output.
13def logits_post_processor(req_id: int, logits: torch.Tensor,
14                          ids: tp.List[tp.List[int]], stream_ptr: int,
15                          client_id: tp.Optional[int]):
16    target_token_id = 42
17    with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
18        logits[:] = float("-inf")
19        logits[..., target_token_id] = 0
20
21
22def main():
23
24    # Several callbacks can be specified when initializing LLM
25    llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26              logits_post_processor_map={"my_logits_pp": logits_post_processor})
27
28    # Sample prompts
29    prompts = [
30        "Hello, my name is",
31        "The president of the United States is",
32    ]
33
34    # Generate text
35    for prompt_id, prompt in enumerate(prompts):
36        # We will use logits post processor callback only for odd-numbered prompts
37        if prompt_id % 2 == 0:
38            sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
39        else:
40            # Each prompt can use one callback from the choices that were provided to LLM
41            sampling_params = SamplingParams(
42                temperature=0.8,
43                top_p=0.95,
44                logits_post_processor_name='my_logits_pp')
45
46        for output in llm.generate([prompt], sampling_params):
47            print(
48                f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
49            )
50
51    # Got output like
52    # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
53    # Prompt: 'The president of the United States is', Generated text: "''''''''''''''''''''''''''''''''"
54
55
56if __name__ == '__main__':
57    main()