PyTorch Backend

Note

Note: This feature is currently experimental, and the related API is subjected to change in future versions.

To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch.

The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing tensorrt_llm._torch.

Quick Start

Here is a simple example to show how to use tensorrt_llm._torch.LLM API with Llama model.

 1import argparse
 2
 3from tensorrt_llm import SamplingParams
 4from tensorrt_llm._torch import LLM
 5from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
 6
 7
 8def parse_arguments():
 9    parser = argparse.ArgumentParser()
10    parser.add_argument('--model_dir',
11                        type=str,
12                        default='meta-llama/Llama-3.1-8B-Instruct')
13    parser.add_argument('--tp_size', type=int, default=1)
14    parser.add_argument('--enable_overlap_scheduler',
15                        default=False,
16                        action='store_true')
17    parser.add_argument('--enable_chunked_prefill',
18                        default=False,
19                        action='store_true')
20    parser.add_argument('--kv_cache_dtype', type=str, default='auto')
21    args = parser.parse_args()
22    return args
23
24
25def main():
26    args = parse_arguments()
27
28    pytorch_config = PyTorchConfig(
29        enable_overlap_scheduler=args.enable_overlap_scheduler,
30        kv_cache_dtype=args.kv_cache_dtype)
31    llm = LLM(model=args.model_dir,
32              tensor_parallel_size=args.tp_size,
33              enable_chunked_prefill=args.enable_chunked_prefill,
34              pytorch_backend_config=pytorch_config)
35
36    prompts = [
37        "Hello, my name is",
38        "The president of the United States is",
39        "The capital of France is",
40        "The future of AI is",
41    ]
42    sampling_params = SamplingParams(max_tokens=32)
43
44    outputs = llm.generate(prompts, sampling_params)
45    # Print the outputs.
46    for output in outputs:
47        prompt = output.prompt
48        generated_text = output.outputs[0].text
49        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
50
51
52if __name__ == '__main__':
53    main()

Quantization

The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized models in HF model hub, which are generated by TensorRT Model Optimizer.

from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
llm.generate("Hello, my name is")

Or you can try the following commands to get a quantized model by yourself:

git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
cd TensorRT-Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf

Developer Guide

Key Components