PyTorch Backend
Note
Note: This feature is currently experimental, and the related API is subjected to change in future versions.
To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch.
The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing tensorrt_llm._torch
.
Quick Start
Here is a simple example to show how to use tensorrt_llm._torch.LLM
API with Llama model.
1import argparse
2
3from tensorrt_llm import SamplingParams
4from tensorrt_llm._torch import LLM
5from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
6
7
8def parse_arguments():
9 parser = argparse.ArgumentParser()
10 parser.add_argument('--model_dir',
11 type=str,
12 default='meta-llama/Llama-3.1-8B-Instruct')
13 parser.add_argument('--tp_size', type=int, default=1)
14 parser.add_argument('--enable_overlap_scheduler',
15 default=False,
16 action='store_true')
17 parser.add_argument('--enable_chunked_prefill',
18 default=False,
19 action='store_true')
20 parser.add_argument('--kv_cache_dtype', type=str, default='auto')
21 args = parser.parse_args()
22 return args
23
24
25def main():
26 args = parse_arguments()
27
28 pytorch_config = PyTorchConfig(
29 enable_overlap_scheduler=args.enable_overlap_scheduler,
30 kv_cache_dtype=args.kv_cache_dtype)
31 llm = LLM(model=args.model_dir,
32 tensor_parallel_size=args.tp_size,
33 enable_chunked_prefill=args.enable_chunked_prefill,
34 pytorch_backend_config=pytorch_config)
35
36 prompts = [
37 "Hello, my name is",
38 "The president of the United States is",
39 "The capital of France is",
40 "The future of AI is",
41 ]
42 sampling_params = SamplingParams(max_tokens=32)
43
44 outputs = llm.generate(prompts, sampling_params)
45 # Print the outputs.
46 for output in outputs:
47 prompt = output.prompt
48 generated_text = output.outputs[0].text
49 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
50
51
52if __name__ == '__main__':
53 main()
Quantization
The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized models in HF model hub, which are generated by TensorRT Model Optimizer.
from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
llm.generate("Hello, my name is")
Or you can try the following commands to get a quantized model by yourself:
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
cd TensorRT-Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf