PyTorch Backend#
Note
Note: This feature is currently experimental, and the related API is subjected to change in future versions.
To enhance the usability of the system and improve developer efficiency, TensorRT-LLM launches a new experimental backend based on PyTorch.
The PyTorch backend of TensorRT-LLM is available in version 0.17 and later. You can try it via importing tensorrt_llm._torch
.
Quick Start#
Here is a simple example to show how to use tensorrt_llm._torch.LLM
API with Llama model.
1from tensorrt_llm import SamplingParams
2from tensorrt_llm._torch import LLM
3
4
5def main():
6 prompts = [
7 "Hello, my name is",
8 "The president of the United States is",
9 "The capital of France is",
10 "The future of AI is",
11 ]
12 sampling_params = SamplingParams(max_tokens=32)
13
14 llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0')
15 outputs = llm.generate(prompts, sampling_params)
16
17 for i, output in enumerate(outputs):
18 prompt = output.prompt
19 generated_text = output.outputs[0].text
20 print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
21
22
23if __name__ == '__main__':
24 main()
Quantization#
The PyTorch backend supports FP8 and NVFP4 quantization. You can pass quantized models in HF model hub, which are generated by TensorRT Model Optimizer.
from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8')
llm.generate("Hello, my name is")
Or you can try the following commands to get a quantized model by yourself:
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git
cd TensorRT-Model-Optimizer/examples/llm_ptq
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf
Sampling#
The PyTorch backend supports most of the sampling features that are supported on the C++ backend, such as temperature, top-k and top-p sampling, stop words, bad words, penalty, context and generation logits, and log probs.
In order to use this feature, it is necessary to enable option enable_trtllm_sampler
in the LLM
class, and pass a SamplingParams
object with the desired options as well. The following example prepares two identical prompts which will give different results due to the sampling parameters chosen:
from tensorrt_llm._torch import LLM
llm = LLM(model='nvidia/Llama-3.1-8B-Instruct-FP8',
enable_trtllm_sampler=True)
sampling_params = SamplingParams(
temperature=1.0,
top_k=8,
top_p=0.5,
)
llm.generate(["Hello, my name is",
"Hello, my name is"], sampling_params)
When using speculative decoders such as MTP or Eagle-3, the enable_trtllm_sampler
option is not yet supported and therefore the subset of sampling options available is more restricted.
Developer Guide#
Key Components#
Known Issues#
The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the PyTorch NGC Container for optimal support on SBSA platforms.