API walkthrough#

Source NVIDIA/TensorRT-LLM.

 1from tensorrt_llm import VisualGen, VisualGenArgs
 2from tensorrt_llm.visual_gen.args import CompilationConfig
 3
 4
 5def main():
 6    # 1. List supported models registered with the pipeline registry.
 7    print("\n=== Supported models ===")
 8    for hf_id in VisualGen.supported_models():
 9        print(f"  - {hf_id}")
10
11    # 2. Inspect default pipeline_config knobs for the chosen model. These
12    #    are per-architecture runtime knobs (e.g. Lightricks/LTX-2's
13    #    ``text_encoder_path``); Wan-AI/Wan2.1-T2V-1.3B-Diffusers registers
14    #    none, so the dict is empty.
15    pipeline_defaults = VisualGen.pipeline_config("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
16    print("\n=== Pipeline config defaults for Wan-AI/Wan2.1-T2V-1.3B-Diffusers ===")
17    print(f"  {pipeline_defaults or '(none)'}")
18
19    # 3. Build VisualGenArgs. ``pipeline_config`` carries the per-architecture
20    #    knobs from step 2 (here we just forward the registered defaults;
21    #    real callers would override entries like ``text_encoder_path``).
22    #    ``compilation_config.skip_warmup`` skips the post-load warmup pass.
23    visual_gen = VisualGen(
24        model="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
25        args=VisualGenArgs(
26            pipeline_config=pipeline_defaults,
27            compilation_config=CompilationConfig(skip_warmup=True),
28        ),
29    )
30
31    # 4. Discover model-specific ``extra_params`` accepted by the loaded
32    #    pipeline. Wan-AI/Wan2.1-T2V-1.3B-Diffusers declares none;
33    #    Wan-AI/Wan2.2-T2V-A14B-Diffusers surfaces ``guidance_scale_2`` and
34    #    ``boundary_ratio`` here.
35    specs = visual_gen.extra_param_specs
36    print("\n=== Extra param specs (extra_params keys) ===")
37    for name, spec in specs.items():
38        print(f"  - {name}: {spec}")
39    if not specs:
40        print("  (none for this model)")
41
42    # 5. Take the pipeline's resolved defaults (height/width/steps/etc.)
43    #    and override fields. ``default_params`` already pre-populates
44    #    ``params.extra_params`` with each declared spec's default, so the
45    #    override below shows how a caller would set a model-specific knob
46    #    -- no-op on Wan-AI/Wan2.1-T2V-1.3B-Diffusers, but the wiring is
47    #    the same on Wan-AI/Wan2.2-T2V-A14B-Diffusers where
48    #    ``extra_params["guidance_scale_2"]`` is honored.
49    params = visual_gen.default_params
50    # Wan requires num_frames of the form 4k+1; 1.25x the model default (81)
51    # is 101.25, so we round to the nearest valid value, 101 (= 4*25 + 1).
52    params.num_frames = 101
53    for name, spec in specs.items():
54        params.extra_params[name] = spec.default
55
56    print("\n=== Request params ===")
57    print(params.model_dump_json(indent=2))
58
59    output = visual_gen.generate(inputs="A cute cat playing piano in a sunny room", params=params)
60
61    # 6. Persist to disk. ``save`` infers the container from the file
62    #    extension (.avi/.mp4) and uses the frame_rate carried on the
63    #    output.
64    saved = output.save("api_walkthrough_output.avi")
65    print(f"\nSaved: {saved}")
66
67
68if __name__ == "__main__":
69    main()