Visual Generation (Prototype)#
Note
This feature is in prototype stage. APIs, supported models, and optimization options are actively evolving and may change in future releases.
Background#
Visual generation models based on diffusion transformers (DiT) have become the standard for high-quality image and video synthesis. These models iteratively denoise latent representations through a learned transformer backbone, then decode the final latents with a VAE to produce pixels.
TensorRT-LLM VisualGen provides a unified inference stack for diffusion models, with a pipeline architecture separate from the LLM inference path. Key capabilities include:
A shared pipeline abstraction covering the denoising loop, guidance strategies, and component loading.
Pluggable attention backends (PyTorch SDPA and TRT-LLM optimized kernels).
Quantization support (dynamic and static) using the ModelOpt configuration format.
Multi-GPU parallelism (CFG parallel, Ulysses sequence parallel).
TeaCache — a runtime caching optimization that skips transformer steps when timestep embeddings change slowly.
trtllm-serveintegration with OpenAI-compatible API endpoints for image and video generation.
Supported Models#
HuggingFace Model ID |
Tasks |
|---|---|
|
Text-to-Image |
|
Text-to-Image |
|
Text-to-Video |
|
Text-to-Video |
|
Image-to-Video |
|
Image-to-Video |
|
Text-to-Video |
|
Image-to-Video |
Models are auto-detected from the model_index.json file in the checkpoint directory. The AutoPipeline registry selects the appropriate pipeline class automatically.
Feature Matrix#
Quick Start#
Here is a simple example to generate a video with Wan 2.1:
1#! /usr/bin/env python
2# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4
5from tensorrt_llm import VisualGen, VisualGenParams
6from tensorrt_llm.serve.media_storage import MediaStorage
7
8
9def main():
10 visual_gen = VisualGen(model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
11 params = VisualGenParams(
12 height=480,
13 width=832,
14 num_frames=81,
15 guidance_scale=5.0,
16 num_inference_steps=50,
17 seed=42,
18 )
19 output = visual_gen.generate(
20 inputs="A cat sitting on a windowsill",
21 params=params,
22 )
23 MediaStorage.save_video(output.video, "output.avi", frame_rate=params.frame_rate)
24
25
26if __name__ == "__main__":
27 main()
To learn more about VisualGen, see examples/visual_gen/ for more examples including text-to-image, image-to-video, and batch generation.
Usage with trtllm-serve#
The trtllm-serve command automatically detects diffusion models (by the presence of model_index.json) and launches an OpenAI-compatible visual generation server with image and video generation endpoints.
See examples/visual_gen/serve/ for server launch instructions, example configurations, and API usage.
Serving Endpoints#
When served via trtllm-serve, the following OpenAI-compatible endpoints are available:
Endpoint |
Method |
Purpose |
|---|---|---|
|
POST |
Synchronous image generation |
|
POST |
Image editing |
|
POST |
Asynchronous video generation |
|
POST |
Synchronous video generation |
|
GET |
Video status / metadata |
|
GET |
Download generated video |
|
DELETE |
Delete generated video |
|
GET |
List all videos |
Optimizations#
Quantization#
VisualGen supports both dynamic quantization (on-the-fly at weight-loading time from BF16 checkpoints) and static quantization (loading pre-quantized checkpoints with embedded scales). Both modes use the ModelOpt quantization_config format.
Dynamic quantization via --linear_type:
python visual_gen_wan_t2v.py \
--model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--prompt "A cute cat playing piano" \
--linear_type trtllm-fp8-per-tensor \
--output_path output_fp8.mp4
Supported --linear_type values: default (BF16/FP16), trtllm-fp8-per-tensor, trtllm-fp8-blockwise, trtllm-nvfp4.
Programmatic usage via VisualGenArgs.quant_config:
from tensorrt_llm import VisualGenArgs
args = VisualGenArgs(
checkpoint_path="/path/to/model",
quant_config={"quant_algo": "FP8", "dynamic": True},
)
TeaCache#
TeaCache caches transformer outputs when timestep embeddings change slowly between denoising steps, skipping redundant computation. Enable with teacache.enable_teacache: true (YAML config). The teacache_thresh parameter controls the similarity threshold.
Multi-GPU Parallelism#
Two parallelism modes can be combined:
CFG Parallelism (
--cfg_size 2): Splits positive/negative guidance prompts across GPUs.Ulysses Parallelism (
--ulysses_size N): Splits the sequence dimension across GPUs for longer sequences.
Total GPU count = cfg_size * ulysses_size.
Developer Guide#
Architecture Overview#
The VisualGen module lives under tensorrt_llm._torch.visual_gen. At a high level, the inference flow is:
Config — User-facing
VisualGenArgs(CLI / YAML) is merged with checkpoint metadata intoDiffusionModelConfig.Pipeline creation & loading —
AutoPipelinedetects the model type frommodel_index.json, instantiates the matchingBasePipelinesubclass, and loads weights (with optional dynamic quantization) and standard components (VAE, text encoder, tokenizer, scheduler).Execution —
DiffusionExecutorcoordinates multi-GPU inference via worker processes communicating over ZeroMQ IPC.
Key components:
Component |
Location |
Role |
|---|---|---|
|
|
High-level API: manages workers, |
|
|
Worker process: loads pipeline, processes requests via ZeroMQ |
|
|
Base class: denoising loop, CFG handling, TeaCache, CUDA graph |
|
|
Factory: auto-detects model type, selects pipeline class |
|
|
Resolves checkpoint, loads config/weights, creates pipeline |
|
|
Runtime caching for transformer outputs |
|
|
Loads transformer weights from safetensors/bin |
VisualGen is a parallel inference subsystem within TensorRT-LLM. It shares low-level primitives (Mapping, QuantConfig, Linear, RMSNorm, ZeroMqQueue, TrtllmAttention) but has its own executor, scheduler (diffusers-based), request types, and pipeline architecture separate from the LLM autoregressive decode path.
Implementing a New Diffusion Model#
Adding a new model (e.g., a hypothetical “MyDiT”) requires four steps. The framework handles weight loading, parallelism, quantization, and serving automatically once the pipeline is registered.
1. Create the Transformer Module#
Create the DiT backbone in tensorrt_llm/_torch/visual_gen/models/mydit/transformer_mydit.py. It should be an nn.Module that:
Uses existing modules (e.g.,
Attentionwith configurable attention backend,Linearfor builtin linear ops) wherever possible.Implements
load_weights(weights: Dict[str, torch.Tensor])to map checkpoint weight names to module parameters.
2. Create the Pipeline Class#
Create a pipeline class extending BasePipeline in tensorrt_llm/_torch/visual_gen/models/mydit/. Override methods for transformer initialization, component loading, and inference. BasePipeline provides the denoising loop, CFG handling, and TeaCache integration — your pipeline only needs to implement model-specific logic. See WanPipeline for a reference implementation.
3. Register the Pipeline#
Use the @register_pipeline("MyDiTPipeline") decorator on your pipeline class to register it in the global PIPELINE_REGISTRY. Make sure to export it from models/__init__.py.
4. Update AutoPipeline Detection#
In pipeline_registry.py, add detection logic for your model’s _class_name in model_index.json.
After these steps, the framework automatically handles:
Weight loading with optional dynamic quantization via
PipelineLoaderMulti-GPU execution via
DiffusionExecutorTeaCache integration (if you call
self._setup_teacache()inpost_load_weights())Serving via
trtllm-servewith the full endpoint set