Visual Generation (Diffusion Models) [Beta]#

Background and Motivation#

Visual generation models based on diffusion transformers (DiT) have become the standard for high-quality image and video synthesis. These models iteratively denoise latent representations through a learned transformer backbone, then decode the final latents with a VAE to produce pixels. As model sizes and output resolutions grow, efficient inference becomes critical — demanding multi-GPU parallelism, weight quantization, and runtime caching to achieve practical throughput and latency.

TensorRT-LLM VisualGen module provides a unified inference stack for diffusion models. Key capabilities include (subject to change as the feature matures):

  • A shared pipeline abstraction for diffusion model families, covering the denoising loop, guidance strategies, and component loading.

  • Pluggable attention backends.

  • Quantization support (dynamic and static) using the ModelOpt configuration format.

  • Multi-GPU parallelism strategies.

  • TeaCache — a runtime caching optimization for the transformer backbone.

  • trtllm-serve integration with OpenAI-compatible API endpoints.

Note: This is the initial release of TensorRT-LLM VisualGen. APIs, supported models, and optimization options are actively evolving and may change in future releases.

Quick Start#

Prerequisites#

pip install -r requirements-dev.txt
pip install git+https://github.com/huggingface/diffusers.git
pip install av

Python API#

The example scripts under examples/visual_gen/ demonstrate direct Python usage. For Wan2.1 text-to-video generation:

cd examples/visual_gen

python visual_gen_wan_t2v.py \
    --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
    --prompt "A cute cat playing piano" \
    --height 480 --width 832 --num_frames 33 \
    --output_path output.mp4

Run python visual_gen_wan_t2v.py --help for the full list of arguments. Key options control resolution, denoising steps, quantization mode, attention backend, parallelism, and TeaCache settings.

Usage with trtllm-serve#

The trtllm-serve command automatically detects diffusion models (by the presence of model_index.json) and launches an OpenAI-compatible visual generation server.

1. Create a YAML configuration file:

# wan_config.yml
linear:
  type: default
teacache:
  enable_teacache: true
  teacache_thresh: 0.2
parallel:
  dit_cfg_size: 1
  dit_ulysses_size: 1

2. Launch the server:

trtllm-serve Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
    --extra_visual_gen_options wan_config.yml

3. Send requests using curl or any OpenAI-compatible client:

Synchronous video generation:

curl -X POST "http://localhost:8000/v1/videos/generations" \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "A cool cat on a motorcycle in the night",
    "seconds": 4.0,
    "fps": 24,
    "size": "480x832"
  }' -o output.mp4

Asynchronous video generation:

# Submit the job
curl -X POST "http://localhost:8000/v1/videos" \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "A cool cat on a motorcycle in the night",
    "seconds": 4.0,
    "fps": 24,
    "size": "480x832"
  }'
# Returns: {"id": "<video_id>", "status": "processing", ...}

# Poll for status
curl -X GET "http://localhost:8000/v1/videos/<video_id>"

# Download when complete
curl -X GET "http://localhost:8000/v1/videos/<video_id>/content" -o output.mp4

The server exposes OpenAI-compatible endpoints for image generation (/v1/images/generations), video generation (/v1/videos, /v1/videos/generations), video management, and standard health/model info endpoints.

The --extra_visual_gen_options YAML file configures quantization (linear), TeaCache (teacache), and parallelism (parallel). See examples/visual_gen/serve/configs/ for reference configurations.

Quantization#

TensorRT-LLM VisualGen supports both dynamic quantization (on-the-fly at weight-loading time from BF16 checkpoints) and static quantization (loading pre-quantized checkpoints with embedded scales). Both modes use the same ModelOpt quantization_config format.

Quick start — dynamic quantization via --linear_type:

python visual_gen_wan_t2v.py \
    --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
    --prompt "A cute cat playing piano" \
    --linear_type trtllm-fp8-per-tensor \
    --output_path output_fp8.mp4

Supported --linear_type values: default (BF16/FP16), trtllm-fp8-per-tensor, trtllm-fp8-blockwise, svd-nvfp4.

ModelOpt quantization_config format:

Both dynamic and static quantization use the ModelOpt quantization_config format — the same format found in a model’s config.json under the quantization_config field. This config can be passed as a dict to DiffusionArgs.quant_config when constructing the pipeline programmatically:

from tensorrt_llm._torch.visual_gen.config import DiffusionArgs

args = DiffusionArgs(
    checkpoint_path="/path/to/model",
    quant_config={"quant_algo": "FP8", "dynamic": True},  # dynamic FP8
)

The --linear_type CLI flag is a convenience shorthand that maps to these configs internally (e.g., trtllm-fp8-per-tensor{"quant_algo": "FP8", "dynamic": True}).

Key fields: "dynamic" controls load-time quantization (true) vs pre-quantized checkpoint (false); "ignore" excludes specific modules from quantization.

Developer Guide#

This section describes the TensorRT-LLM VisualGen module architecture and guides developers on how to add support for new diffusion model families.

Architecture Overview#

The VisualGen module lives under tensorrt_llm._torch.visual_gen. At a high level, the flow is:

  1. Config — User-facing DiffusionArgs (CLI / YAML) is merged with checkpoint metadata into DiffusionModelConfig.

  2. Pipeline creation & loadingAutoPipeline detects the model type from model_index.json, instantiates the matching BasePipeline subclass, and loads weights (with optional dynamic quantization) and standard components (VAE, text encoder, tokenizer, scheduler).

  3. ExecutionDiffusionExecutor coordinates multi-GPU inference via worker processes.

Note: Internal module structure is subject to change. Refer to inline docstrings in tensorrt_llm/_torch/visual_gen/ for the latest details.

Implementing a New Diffusion Model#

Adding a new model (e.g., a hypothetical “MyDiT”) requires four steps. The framework handles weight loading, parallelism, quantization, and serving automatically once the pipeline is registered.

1. Create the Transformer Module#

Create the DiT backbone in tensorrt_llm/_torch/visual_gen/models/mydit/transformer_mydit.py. It should be an nn.Module that:

  • Uses existing modules (e.g., Attention with configurable attention backend, Linear for builtin linear ops) wherever possible.

  • Implements load_weights(weights: Dict[str, torch.Tensor]) to map checkpoint weight names to module parameters.

2. Create the Pipeline Class#

Create a pipeline class extending BasePipeline in tensorrt_llm/_torch/visual_gen/models/mydit/. Override methods for transformer initialization, component loading, and inference. BasePipeline provides the denoising loop, CFG handling, and TeaCache integration — your pipeline only needs to implement model-specific logic. See WanPipeline for a reference implementation.

3. Register the Pipeline#

Use the @register_pipeline("MyDiTPipeline") decorator on your pipeline class to register it in the global PIPELINE_REGISTRY. Make sure to export it from models/__init__.py.

4. Update AutoPipeline Detection#

In pipeline_registry.py, add detection logic for your model’s _class_name in model_index.json.

After these steps, the framework automatically handles:

  • Weight loading with optional dynamic quantization via PipelineLoader

  • Multi-GPU execution via DiffusionExecutor

  • TeaCache integration (if you call self._setup_teacache() in post_load_weights())

  • Serving via trtllm-serve with the full endpoint set

Summary and Future Work#

Current Status#

Supported models: Wan2.1 and Wan2.2 families (text-to-video, image-to-video; 1.3B and 14B variants).

Supported features:

Feature

Status

Multi-GPU Parallelism

CFG parallel, Ulysses sequence parallel (more strategies planned)

TeaCache

Caches transformer outputs when timestep embeddings change slowly

Quantization

Dynamic (on-the-fly from BF16) and static (pre-quantized checkpoints), both via ModelOpt quantization_config format

Attention Backends

Vanilla (torch SDPA) and TRT-LLM optimized fused kernels

trtllm-serve

OpenAI-compatible endpoints for image/video generation (sync + async)

Future Work#

  • Additional model support: Extend to more diffusion model families.

  • More attention backends: Support for additional attention backends.

  • Advanced parallelism: Additional parallelism strategies for larger models and higher resolutions.

  • Serving enhancements: Improved throughput and user experience for production serving workloads.