API walkthrough#
Source NVIDIA/TensorRT-LLM.
1from tensorrt_llm import VisualGen, VisualGenArgs
2from tensorrt_llm.visual_gen.args import CompilationConfig
3
4
5def main():
6 # 1. List supported models registered with the pipeline registry.
7 print("\n=== Supported models ===")
8 for hf_id in VisualGen.supported_models():
9 print(f" - {hf_id}")
10
11 # 2. Inspect default pipeline_config knobs for the chosen model. These
12 # are per-architecture runtime knobs (e.g. Lightricks/LTX-2's
13 # ``text_encoder_path``); Wan-AI/Wan2.1-T2V-1.3B-Diffusers registers
14 # none, so the dict is empty.
15 pipeline_defaults = VisualGen.pipeline_config("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
16 print("\n=== Pipeline config defaults for Wan-AI/Wan2.1-T2V-1.3B-Diffusers ===")
17 print(f" {pipeline_defaults or '(none)'}")
18
19 # 3. Build VisualGenArgs. ``pipeline_config`` carries the per-architecture
20 # knobs from step 2 (here we just forward the registered defaults;
21 # real callers would override entries like ``text_encoder_path``).
22 # ``compilation_config.skip_warmup`` skips the post-load warmup pass.
23 visual_gen = VisualGen(
24 model="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
25 args=VisualGenArgs(
26 pipeline_config=pipeline_defaults,
27 compilation_config=CompilationConfig(skip_warmup=True),
28 ),
29 )
30
31 # 4. Discover model-specific ``extra_params`` accepted by the loaded
32 # pipeline. Wan-AI/Wan2.1-T2V-1.3B-Diffusers declares none;
33 # Wan-AI/Wan2.2-T2V-A14B-Diffusers surfaces ``guidance_scale_2`` and
34 # ``boundary_ratio`` here.
35 specs = visual_gen.extra_param_specs
36 print("\n=== Extra param specs (extra_params keys) ===")
37 for name, spec in specs.items():
38 print(f" - {name}: {spec}")
39 if not specs:
40 print(" (none for this model)")
41
42 # 5. Take the pipeline's resolved defaults (height/width/steps/etc.)
43 # and override fields. ``default_params`` already pre-populates
44 # ``params.extra_params`` with each declared spec's default, so the
45 # override below shows how a caller would set a model-specific knob
46 # -- no-op on Wan-AI/Wan2.1-T2V-1.3B-Diffusers, but the wiring is
47 # the same on Wan-AI/Wan2.2-T2V-A14B-Diffusers where
48 # ``extra_params["guidance_scale_2"]`` is honored.
49 params = visual_gen.default_params
50 # Wan requires num_frames of the form 4k+1; 1.25x the model default (81)
51 # is 101.25, so we round to the nearest valid value, 101 (= 4*25 + 1).
52 params.num_frames = 101
53 for name, spec in specs.items():
54 params.extra_params[name] = spec.default
55
56 print("\n=== Request params ===")
57 print(params.model_dump_json(indent=2))
58
59 output = visual_gen.generate(inputs="A cute cat playing piano in a sunny room", params=params)
60
61 # 6. Persist to disk. ``save`` infers the container from the file
62 # extension (.avi/.mp4) and uses the frame_rate carried on the
63 # output.
64 saved = output.save("api_walkthrough_output.avi")
65 print(f"\nSaved: {saved}")
66
67
68if __name__ == "__main__":
69 main()