Pipelines and runners

FlashDreams model integrations are built from two public layers:

  • Pipelines (StreamInferencePipelineConfig) that define model behavior.

  • Runners (RunnerConfig + Runner) that define CLI-facing I/O.

Most actively developed model implementations now live under integrations/* as plugin-style standalone packages. This page keeps documenting the in-tree pipeline modules that are still exposed from flashdreams.recipes.

Note

Pipeline modules import the heavy GPU stack (transformer-engine, CUDA ops) at import time, so this page shows them by automodule with :no-undoc-members: to keep the rendered API focused on the names that these in-tree modules actually expose. The unified flashdreams-run CLI shows end-to-end usage; see Models for model launch examples.

Integration structure (current)

For new model work, follow integrations/<name>/:

  • config.py: pipeline + runner config literals (slugged entries).

  • runner.py: runtime I/O, cache init, generate/finalize loop, persistence.

  • pipeline.py and transformer/*: model compute path.

  • pyproject.toml: plugin packaging + entry-point registration.

This makes each integration effectively a standalone repository while still plugging into the same flashdreams-run registry.

Reference integration folders

NVIDIA OmniDreams

OmniDreams now ships as a plugin under integrations/omnidreams; it registers its runners via the flashdreams.runner_configs entry-point group and is no longer part of the in-tree flashdreams.recipes API surface. See integrations/omnidreams/README.md for the plugin entry point and flashdreams-run omnidreams-* for the user-facing CLI.

Wan

Public Wan integration surface for integration plugins.

class Wan21Transformer(config: Wan21TransformerConfig)[source]

Bases: Transformer[Wan21TransformerCache]

Wan 2.1 DiT adapted to the infra Transformer interface.

property latent_shape: tuple[int, ...]

Per-rank post-patchify latent shape [*batch_shape, L/cp, D].

Wan flattens THW into one token axis and shards across the THW CP group. D = network.out_dim * prod(patch_size) is the noise channel count; the mask / image-latent channels added by concat_image_mask_to_latent come from input in predict_flow, not from the noise tensor.

Per-rollout (height, width) is populated by initialize_autoregressive_cache(); reading earlier asserts.

initialize_autoregressive_cache(*, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, negative_text_embeddings: Tensor | None = None, **_unused: Any) Wan21TransformerCache[source]

Build a seeded transformer cache for a new rollout.

I2V state is not baked into the cache; the latent + injection mask are passed per AR step as the input argument to predict_flow / postprocess_clean_latent.

Parameters:
  • height – Pre-patchify latent height (post-VAE).

  • width – Pre-patchify latent width (post-VAE).

  • text_embeddings – Conditional UMT5 embeddings [..., text_len, text_dim].

  • image_embeddings – Conditional CLIP image embeddings (only used by networks with cross_attn_enable_img=True). Shared with the uncond branch.

  • negative_text_embeddings – Negative-prompt embeddings. Required iff config.guidance_scale > 1.0; must be None otherwise.

Returns:

Populated cache. network_cache_uncond is None iff CFG is disabled.

predict_flow(noisy_latent: Tensor, timestep: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None, network_extra_kwargs: dict[str, Any] | None = None) Tensor[source]

Predict the flow for one denoising step.

timestep may be a scalar / per-batch tensor (standard Wan 2.1 / 14B path) or a per-token tensor with the same trailing token axis as noisy_latent (Wan 2.2 TI2V 5B first-frame seeding at AR step 0). The per-token layout flows through WanDiTNetwork.forward(), which dispatches the sinusoidal embedding + AdaLN modulation on the native shape.

CUDA-graph capture is shape-sensitive: the captured replay region only sees AR step >= self._cuda_graph_capture_ar_idx (steady state). TI2V 5B is configured with len_t == window_size_t so the threshold lands at AR 1, putting the per-token AR-0 step inside the eager .drain branch where shape changes are safe. After AR 0 the pipeline switches back to scalar timesteps, so the captured branch sees a single stable shape across all AR steps it owns.

postprocess_clean_latent(clean_latent: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None) Tensor[source]

Re-stamp x0 masked positions with the image latent (mask-inject I2V only).

T2V and the channel-concat I2V mode fall through unchanged.

patchify_and_maybe_split_cp(x: Tensor) Tensor[source]
patchify_and_maybe_split_cp(x: I2VCtrl) I2VCtrl

Patchify and CP-split a noisy latent or an I2V control payload.

Tensors delegate to the network helper; I2V payloads patchify the latent and mask fields independently so the per-field channel layouts are preserved for the mask-inject blend downstream.

unpatchify_and_maybe_gather_cp(x: Tensor) Tensor[source]

Inverse of patchify_and_maybe_split_cp for the network output.

class Wan21TransformerConfig(*, _target: type[~flashdreams.recipes.wan.transformer.wan21.Wan21Transformer] = <factory>, network: ~flashdreams.recipes.wan.transformer.impl.network.WanDiTNetworkConfig = <factory>, dtype: ~torch.dtype = torch.bfloat16, checkpoint_path: str | None = None, state_dict_transform: ~collections.abc.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None, batch_shape: tuple[int, ...] = (1,), len_t: int = 21, guidance_scale: float = 1.0, window_size_t: int = 21, sink_size_t: int = 0, h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, compile_network: bool = True, use_cuda_graph: bool = True, cuda_graph_warmup_iters: int = 2, stamp_image_latent: bool = False, concat_image_mask_to_latent: bool = False, ti2v_first_frame_per_token_timestep: bool = False, first_frame_timestep_value: float = 0.0)[source]

Bases: TransformerConfig

Config for the Wan 2.1 transformer.

Bakes in the temporal layout (len_t, window_size_t, optional sink_size_t) and the CFG / compile knobs. Per-rollout spatial layout (height, width) is supplied to Wan21Transformer.initialize_autoregressive_cache() so one instance can serve multiple resolutions. Wan flattens T*H*W into one token axis and shards it across the THW CP group; the CP size is auto-detected from torch.distributed.get_world_size() at construction time, so the launcher (torchrun --nproc_per_node=N) is the single source of truth.

The two I2V flags are independent and composable:

  • stamp_image_latent: overwrite the noisy latent with the clean image latent at masked positions every denoising step, and re-stamp the predicted x0 the same way. network.in_dim unchanged. (flashdreams mask-inject integration; used by the out-of-tree causal_forcing plugin.)

  • concat_image_mask_to_latent: append the 4-channel mask and 16-channel image latent along the channel dim. Builders that set this flag must also set network.in_dim = 16 + 4 + 16 to match the official Wan 2.1 14B I2V layout.

With both enabled, the stamp runs first and the result is then concatenated with the mask + image latent.

state_dict_transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None

Pre-load state-dict remap (e.g. Self-Forcing’s generator_ema.model.… layout).

batch_shape: tuple[int, ...] = (1,)

Batch dims of the latent (excluding the L, D dims).

len_t: int = 21

Latent frames per AR chunk (post-VAE).

guidance_scale: float = 1.0

flow = uncond + s * (cond - uncond). 1.0 disables CFG; > 1.0 requires negative-text embeddings at cache build time.

Type:

CFG scale s

window_size_t: int = 21

Self-attention sliding-window size (pre-patchify T frames).

sink_size_t: int = 0

Prefix sink size (pre-patchify T frames) for self-attention KV cache.

compile_network: bool = True

torch.compile the network on init.

use_cuda_graph: bool = True

Wrap the network in CUDAGraphWrapper for steady-state replay. Caller must keep non-staged inputs at stable storage addresses across calls. predict_flow dispatches to wrapper.drain while the KV cache is still filling and to wrapper once it reaches steady state.

cuda_graph_warmup_iters: int = 2

Eager calls before capture (>= 2 to drain Inductor autotune).

stamp_image_latent: bool = False

See class docstring (mask-inject I2V integration).

concat_image_mask_to_latent: bool = False

See class docstring (channel-concat I2V layout).

ti2v_first_frame_per_token_timestep: bool = False

Wan 2.2 TI2V 5B first-frame conditioning. When True and an I2VCtrl input is provided at AR step 0, predict_flow rewrites the scheduler’s scalar timestep into a per-token tensor: t = first_frame_timestep_value at positions marked by the I2V mask (i.e. the first-frame latent), and the scheduler’s t elsewhere. AR steps >= 1 continue to use the scalar timestep, which keeps the CUDA-graph-captured replay branch on a single stable input shape.

Composes with stamp_image_latent: together they implement the upstream Wan 2.2 5B “VAE-seeded first-frame + per-token t=0” TI2V recipe – the latent is stamped clean every denoising step while the network sees t=0 for those tokens. The standard mask-inject I2V recipe leaves this flag off and relies on the classifier-free stamp alone.

first_frame_timestep_value: float = 0.0

Per-token timestep assigned to first-frame conditioning tokens when ti2v_first_frame_per_token_timestep is True.

Defaults to 0.0 (Wan 2.2 TI2V 5B’s base recipe — treats the first frame as fully clean by AdaLN). HY-WorldPlay’s distilled WAN-5B raises it to 14.0 (vendor’s stabilization_level - 1) so the AdaLN table sees a small nonzero sigma at the first frame.

Unused when ti2v_first_frame_per_token_timestep is False.

class Wan22TI2V5BVAEDecoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers/resolve/main/vae/diffusion_pytorch_model.safetensors', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 160, decoder_base_dim: int | None = 256, z_dim: int = 48, patch_size: int = 2, is_residual: bool = True, latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667), latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function wan22_ti2v_5b_vae_state_dict_transform>)[source]

Bases: WanVAEDecoderConfig

Pre-rolled config for the Wan 2.2 TI2V 5B decoder.

Mirrors Wan22TI2V5BVAEEncoderConfig but with the asymmetric decoder_base_dim=256.

decoder_base_dim: int | None = 256

Decoder base channel count. None mirrors base_dim (Wan 2.1). Wan 2.2 TI2V 5B uses an asymmetric 256.

state_dict_transform() Dict[str, Tensor]

Remap a diffusers AutoencoderKLWan state-dict to WanVAE keys.

Applied automatically when Wan22TI2V5BVAEEncoderConfig / Wan22TI2V5BVAEDecoderConfig load the upstream Wan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensors checkpoint. The mapping is purely structural – no tensors are copied or reshaped.

Note

Patterns are applied in iteration order via flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as a load_state_dict unexpected_keys warning so missing remap entries are easy to spot.

class Wan22TI2V5BVAEEncoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEEncoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers/resolve/main/vae/diffusion_pytorch_model.safetensors', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 160, z_dim: int = 48, patch_size: int = 2, is_residual: bool = True, latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667), latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function wan22_ti2v_5b_vae_state_dict_transform>)[source]

Bases: WanVAEEncoderConfig

Pre-rolled config for the Wan 2.2 TI2V 5B encoder.

Pins the diffusers upstream checkpoint, the 16x-spatial / 48ch / residual / patchify architecture knobs, and the matching diffusers -> flashdreams key remap. Equivalent to the Wan 2.1 encoder config plus the 5B-specific knobs flipped on.

base_dim: int = 160

Encoder base channel count (WanVAE dim). 96 for Wan 2.1, 160 for Wan 2.2 TI2V 5B.

z_dim: int = 48

Latent channels. 16 for Wan 2.1, 48 for Wan 2.2 TI2V 5B.

patch_size: int = 2

Outer spatial pixel-shuffle factor (1 = no patchify; 2 for Wan 2.2 TI2V 5B).

is_residual: bool = True

Use ResidualDownBlock (Wan 2.2) instead of the legacy ResidualBlock + AttentionBlock down-stage (Wan 2.1).

latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667)

Per-channel latent mean used for normalisation; must match z_dim entries.

latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744)

Per-channel latent std used for normalisation.

state_dict_transform() Dict[str, Tensor]

Remap a diffusers AutoencoderKLWan state-dict to WanVAE keys.

Applied automatically when Wan22TI2V5BVAEEncoderConfig / Wan22TI2V5BVAEDecoderConfig load the upstream Wan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensors checkpoint. The mapping is purely structural – no tensors are copied or reshaped.

Note

Patterns are applied in iteration order via flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as a load_state_dict unexpected_keys warning so missing remap entries are easy to spot.

class Wan22Transformer(config: Wan22TransformerConfig)[source]

Bases: Transformer[Wan22TransformerCache]

Wan 2.2 dual-network DiT.

predict_flow dispatches to the branch selected by the timestep; finalize_kv_cache re-runs both branches once at the context noise so neither KV cache lags behind.

property latent_shape: tuple[int, ...]

Per-rank latent shape (both branches share this, asserted by config).

initialize_autoregressive_cache(*, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, **_unused: Any) Wan22TransformerCache[source]

Build a seeded transformer cache for a new rollout.

Both branches see the same text/image conditioning and the same per-rollout spatial layout, matching upstream Wan 2.2. CFG is not supported here.

predict_flow(noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None) Tensor[source]

Predict the flow at timestep.

Parameters:
  • noisy_latent – Patchified noisy latent for this denoising step.

  • timestep – Scalar timestep tensor.

  • cache – Per-rollout AR cache.

  • input – Patchified encoder output for this AR step, or None when the pipeline has no encoder. Subclasses should narrow the type to their encoder’s output type.

Returns:

Predicted flow tensor with the same shape as noisy_latent.

finalize_kv_cache(noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None) None[source]

Refresh both networks’ KV caches at the context-noise step.

Each Wan 2.2 denoising step only touches one branch, so the other lags by the end of the loop. Re-running both at context_noise keeps them in lock-step; flow outputs are discarded.

patchify_and_maybe_split_cp(x: Tensor) Tensor[source]

Patchify and (optionally) CP-split a noisy latent or encoder payload.

Tensors patchify and split. Structured payloads (e.g. an image-control struct with latent + mask) patchify each tensor field and return the same struct type. Implement as identity when neither token packing nor CP sharding applies. Output preserves the input Python type — only shapes change.

unpatchify_and_maybe_gather_cp(x: Tensor) Tensor[source]

Inverse of patchify_and_maybe_split_cp for the network output.

class Wan22TransformerConfig(*, _target: type[Wan22Transformer] = <factory>, transformer_high_noise: Wan21TransformerConfig = <factory>, transformer_low_noise: Wan21TransformerConfig = <factory>, boundary_ratio: float = 0.875, num_train_timesteps: int = 1000)[source]

Bases: TransformerConfig

Config for the Wan 2.2 dual-network transformer.

Wan 2.2 dispatches to one of two Wan 2.1 networks based on whether the timestep is above or below boundary_ratio * num_train_timesteps. Both branches must agree on patch_size / in_dim / dim / batch_shape / len_t / guidance_scale (asserted in __post_init__). The per-rollout spatial layout (height, width) is supplied to Wan22Transformer.initialize_autoregressive_cache() and forwarded to both branches. The CP size is auto-detected from torch.distributed.get_world_size(), same as Wan 2.1. Wan 2.2 has no CFG or I2V support here.

transformer_high_noise: Wan21TransformerConfig

Sub-config for the high-noise branch (timestep > boundary).

transformer_low_noise: Wan21TransformerConfig

Sub-config for the low-noise branch (timestep <= boundary).

boundary_ratio: float = 0.875

Fraction of num_train_timesteps separating the two branches. Default 0.875 matches upstream Wan 2.2.

property boundary_timestep: float

Absolute timestep boundary.

class WanDiTNetwork(config: WanDiTNetworkConfig)[source]

Bases: Module

WAN diffusion backbone for text-to-video and image-to-video.

set_context_parallel_group(cp_group: ProcessGroup | None = None) None[source]

Set context-parallel process group for all blocks.

This must be called before initialize_cache when CP is used.

patchify_and_maybe_split_cp(x: Tensor, process_groups: list[ProcessGroup | None] | None = None, cp_dims: list[int | None] | None = None) Tensor[source]

Patchify and optionally CP-split the input video tensor.

The patchify pattern is ... (t kt) c (h kh) (w kw) -> ... (t h w) (c kt kh kw).

Returns:

Patched tensor with shape [..., L, D] where L = T * H * W / (kt * kh * kw).

unpatchify_and_maybe_gather_cp(pH: int, pW: int, x: Tensor, process_groups: list[ProcessGroup | None] | None = None, cp_dims: list[int | None] | None = None) Tensor[source]

Unpatchify and optionally CP-gather the tensor back to video shape.

The unpatchify pattern is ... (t h w) (c kt kh kw) -> ... (t kt) c (h kh) (w kw).

Returns:

Unpatched tensor with shape [..., T, C, H, W].

initialize_cache(chunk_size: int, window_size: int, sink_size: int, text_embeddings: Tensor, img_embeddings: Tensor | None = None) WanDiTNetworkCache[source]

Initialize block caches from text/image context embeddings.

Parameters:
  • chunk_size – Number of tokens appended per self-attention update.

  • window_size – Rolling-window size in tokens for self-attention cache.

  • sink_size – Sink-token capacity preserved across updates.

  • text_embeddings – Text embeddings. UMT5 has shape […, 512, 4096].

  • img_embeddings – Optional image embeddings for I2V. CLIP has shape […, 256, 1280].

Returns:

WanDiTNetworkCache containing per-block caches.

update_parameters_after_loading_checkpoint() None[source]

Fuse load-time-known ops into weights; call once after loading the checkpoint.

forward(x: Tensor, timesteps: Tensor, cache: WanDiTNetworkCache, rope_freqs: Tensor, current_chunk_idx: int = 0, eager_mode: bool = True, block_extra_kwargs: dict[str, Any] = {}) Tensor[source]

Run one denoising forward pass.

Parameters:
  • x – Input tokens after patchify + CP, shape [..., L, D_in]; layout "... (t h w) (c kt kh kw)".

  • timesteps

    Diffusion timesteps. Two layouts are supported:

    • Scalar (per-batch). Shape broadcastable to [...] (i.e., to x.shape[:-2]). The same timestep is shared across every token, matching the standard Wan 2.1 / Wan 2.2 14B chunked-denoise path.

    • Per-token. Shape [..., L] matching x’s post- patchify token axis. Used by Wan 2.2 TI2V 5B at AR step 0 to stamp t=0 at the first-frame conditioning tokens while the rest of the chunk denoises at the current scheduler step. See Wan21Transformer.predict_flow for the higher-level entry point.

  • cache – Network KV caches.

  • rope_freqs – Full-width RoPE frequencies after CP. Standard mode uses current-chunk frequencies with shape [L, 1, 1, d]; KV-cache-relative mode uses frequencies relative to the KV cache.

  • current_chunk_idx – Current chunk index for streaming cache update.

  • eager_mode – If True, run cache before/after update hooks.

  • block_extra_kwargs – Extra kwargs forwarded to each block.

Returns:

Network output, shape [..., L, prod(patch_size) * out_dim].

class WanDiTNetwork1pt3BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 1536, ffn_dim: int = 8960, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 12, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]

Bases: WanDiTNetworkConfig

Configuration for the 1.3B Wan DiT network.

dim: int = 1536

Transformer hidden size (width).

ffn_dim: int = 8960

Feed-forward hidden dimension.

num_heads: int = 12

Number of attention heads.

num_layers: int = 30

Number of transformer blocks.

class WanDiTNetwork14BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 5120, ffn_dim: int = 13824, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 40, num_layers: int = 40, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]

Bases: WanDiTNetworkConfig

Configuration for the 14B Wan DiT network.

dim: int = 5120

Transformer hidden size (width).

ffn_dim: int = 13824

Feed-forward hidden dimension.

num_heads: int = 40

Number of attention heads.

num_layers: int = 40

Number of transformer blocks.

class WanDiTNetworkConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 1536, ffn_dim: int = 8960, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 12, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]

Bases: InstantiateConfig

Configuration for the Wan DiT network.

patch_size: tuple[int, int, int] = (1, 2, 2)

Patch size for the input tensor.

text_len: int = 512

Maximum text token length.

in_dim: int = 16

Number of input latent channels before patch embedding.

dim: int = 1536

Transformer hidden size (width).

ffn_dim: int = 8960

Feed-forward hidden dimension.

freq_dim: int = 256

Sinusoidal timestep embedding dimension.

text_dim: int = 4096

Text encoder output dimension.

out_dim: int = 16

Output latent channels after the head.

num_heads: int = 12

Number of attention heads.

num_layers: int = 30

Number of transformer blocks.

cross_attn_norm: bool = True

If True, apply LayerNorm before cross-attention.

cross_attn_enable_img: bool = False

If True, build image cross-attention and CLIP image projection (I2V).

eps: float = 1e-06

Epsilon for normalization layers.

concat_padding_mask: bool = False

If True, concatenate one mask channel into the input channels.

patch_embedding_type: Literal['linear', 'conv3d'] = 'conv3d'

"linear" (flattened patch MLP) or "conv3d" (strided conv).

Type:

Type of patch embedding

apply_rope_before_kvcache: bool = True

If True, apply RoPE to keys before storing them in the KV cache.

cp_method: Literal['ring', 'ulysses'] = 'ring'

Context-parallel attention method for transformer attention ops.

class WanDiTNetworkTI2V5BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 48, dim: int = 3072, ffn_dim: int = 14336, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 48, num_heads: int = 24, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]

Bases: WanDiTNetworkConfig

Configuration for the Wan 2.2 TI2V 5B DiT network.

Mirrors the official Wan-AI/Wan2.2-TI2V-5B-Diffusers/transformer config: 24 heads * 128 head_dim = 3072 inner dim, 30 layers, ffn_dim 14336, and 48-channel latent in/out (the matching 16x VAE in vae.py outputs 48 channels). Unlike Wan 2.1 14B I2V, TI2V 5B has no CLIP cross-attention branch (cross_attn_enable_img=False): the first frame is conditioned via a clean VAE-latent seed plus a per-token t=0 timestep on the AR-step-0 first-frame tokens, not via CLIP image features.

in_dim: int = 48

Number of input latent channels before patch embedding.

out_dim: int = 48

Output latent channels after the head.

dim: int = 3072

Transformer hidden size (width).

ffn_dim: int = 14336

Feed-forward hidden dimension.

num_heads: int = 24

Number of attention heads.

num_layers: int = 30

Number of transformer blocks.

cross_attn_enable_img: bool = False

If True, build image cross-attention and CLIP image projection (I2V).

class WanI2VCtrlEncoderConfig(*, _target: type[I2VCtrlEncoder] = <factory>, encoder: WanVAEEncoderConfig = <factory>)[source]

Bases: EncoderConfig

Config for the I2V control encoder.

encoder: WanVAEEncoderConfig

Streaming Wan VAE encoder. Pin its checkpoint to the decoder’s so the encoded latent matches the network’s input distribution.

class WanInferencePipeline(config: WanInferencePipelineConfig)[source]

Bases: StreamInferencePipeline[WanVAECache, Wan21TransformerCache | Wan22TransformerCache, WanVAECache]

Wan 2.1 / 2.2 inference pipeline, T2V and I2V.

T2V and I2V share the same rollout loop; the difference is whether you pass an image to initialize_cache. The pipeline config’s encoder slot must agree (None for T2V, an I2V config for I2V).

Example:

pipeline: WanInferencePipeline = ...

# T2V: pass latent ``height`` and ``width``.
cache = pipeline.initialize_cache(
    text=["A cat surfing."], height=60, width=104
)
_chunk = pipeline.generate(0, cache)
pipeline.finalize(0, cache)

# I2V: pass the first frame and let sizes derive from it.
_i2v_cache = pipeline.initialize_cache(
    text=["A cat surfing."], image=first_frame
)
initialize_cache(text: list[str], image: Tensor | None = None, *, height: int | None = None, width: int | None = None, release_oneshot_encoders: bool = True) WanInferencePipelineCache[source]

Initialize the per-rollout cache for a batch of prompts.

Parameters:
  • text – One prompt per batch element. Length must match the transformer’s batch_shape.

  • image – First-frame pixels of shape [*batch_shape, 1, 3, H, W] in [-1, 1]. Required for I2V (self.encoder is set), forbidden for T2V. H and W must equal height * decoder.spatial_compression_ratio and width * decoder.spatial_compression_ratio, respectively.

  • height – Pre-patchify latent height (post-VAE). Optional for I2V — derived from image when omitted; required for T2V.

  • width – Pre-patchify latent width (post-VAE). Same rules as height.

  • release_oneshot_encoders – Free the text and image encoders after the cache is initialized. Later calls reload them from self.config before encoding new prompts/images.

Returns:

Cache to thread through generate / finalize.

release_oneshot_encoders() None[source]

Free the per-rollout text and first-frame image encoders.

Idempotent. A later initialize_cache() call can reload the encoders from self.config before encoding raw prompts/images.

generate(autoregressive_index: int, cache: WanInferencePipelineCache) Tensor[source]

Generate one decoded video chunk.

Parameters:
  • autoregressive_index – AR step index, starting at 0.

  • cache – Per-rollout cache from initialize_cache.

Returns:

Decoded video of shape [*batch_shape, T, C, H, W] in [-1, 1].

get_num_input_frames(autoregressive_index: int) int[source]

Number of input video frames the model expects at this AR step.

get_num_output_frames(autoregressive_index: int) int[source]

Number of decoded video frames produced at this AR step.

class WanInferencePipelineCache(*, transformer_cache: TransformerCacheT, encoder_cache: StreamingEncoderCacheT | None = None, decoder_cache: StreamingDecoderCacheT | None = None, final_state: DiffusionModel.FinalState[TransformerCacheT] | None = None, autoregressive_index: int | None = None, event_profiler: EventProfiler | None = None, image: Tensor | None = None)[source]

Bases: StreamInferencePipelineCache[WanVAECache, Wan21TransformerCache | Wan22TransformerCache, WanVAECache]

Per-rollout state for the Wan pipeline.

Adds the I2V first-frame pixels on top of the inherited caches. Pixel-to- latent encoding happens per AR step inside the encoder, not here.

image: Tensor | None = None

First-frame pixels [*batch_shape, 1, 3, H, W] in [-1, 1]; None for T2V.

class WanInferencePipelineConfig(*, _target: type['WanInferencePipeline'] = <factory>, name: str, diffusion_model: DiffusionModelConfig, decoder: DecoderConfig | None = None, encoder: EncoderConfig | None = None, enable_sync_and_profile: bool = False, text_encoder: UMT5TextEncoderConfig | None = <factory>, image_encoder: CLIPImageEncoderConfig | None = None)[source]

Bases: StreamInferencePipelineConfig

Config for the Wan inference pipeline.

T2V vs I2V is selected by the inherited encoder slot: None for T2V, an I2V control-encoder config for I2V.

text_encoder: UMT5TextEncoderConfig | None

UMT5 text encoder run once per rollout.

image_encoder: CLIPImageEncoderConfig | None = None

CLIP image encoder for I2V variants trained with cross_attn_enable_img=True (Wan 2.1 14B I2V). None skips CLIP cross-attention conditioning.

class WanVAEDecoder(config: WanVAEDecoderConfig)[source]

Bases: StreamingVideoDecoder[WanVAECache]

Wan VAE decoder.

Forward input is a latent [..., Tl, Cl, Hl, Wl]; output is a video tensor [..., T, C, H, W] in [-1, 1]. The cache is advanced in-place across AR decode steps; passing cache=None allocates a fresh single-shot cache.

initialize_autoregressive_cache() WanVAECache[source]

Build a fresh per-rollout cache.

Override to return the decoder’s concrete cache type.

forward(input: Tensor, autoregressive_index: int = 0, cache: WanVAECache | None = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property temporal_compression_ratio: int

Pixel frames ÷ latent frames in steady state (AR ≥ 1).

AR 0 typically yields fewer pixel frames per latent frame because of causal first-frame padding; that asymmetry lives inside get_output_temporal_size() / get_input_temporal_size().

property spatial_compression_ratio: int

Pixel side ÷ latent side. Constant across AR steps.

get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]

Causal: AR 0 first latent frame decodes to a single pixel frame.

get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]

Latent frame count needed to produce output_temporal_size pixels.

Inverse of get_output_temporal_size(). Implementations should assert output_temporal_size is achievable at this AR step (i.e. divisible by the right ratio after subtracting any causal padding).

Parameters:
  • autoregressive_index – AR step index (0-based).

  • output_temporal_size – Desired number of pixel frames.

Returns:

Number of latent frames needed at this step.

class WanVAEDecoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/Wan2.1_VAE.pth', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 96, decoder_base_dim: int | None = None, z_dim: int = 16, patch_size: int = 1, is_residual: bool = False, latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921), latent_std: tuple[float, ...] = (2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None)[source]

Bases: DecoderConfig

Config for the Wan VAE decoder.

Defaults reproduce the Wan 2.1 / 14B 8x-spatial 16-channel streaming VAE. Override the architecture knobs (or use Wan22TI2V5BVAEDecoderConfig) to load the Wan 2.2 TI2V 5B decoder, which has asymmetric base_dim=160 / decoder_base_dim =256 and the residual up-stage with DupUp3D shortcut.

use_cuda_graph: bool = True

Wrap the decoder forward in a CUDA graph for replay.

use_compile: bool = False

torch.compile(mode="max-autotune-no-cudagraphs"). See WanVAEEncoderConfig.use_compile for the VRAM caveat.

decoder_base_dim: int | None = None

Decoder base channel count. None mirrors base_dim (Wan 2.1). Wan 2.2 TI2V 5B uses an asymmetric 256.

state_dict_transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None

Optional pre-load_state_dict key remap. See wan22_ti2v_5b_vae_state_dict_transform() for the Wan 2.2 TI2V 5B diffusers -> flashdreams remap.

class WanVAEEncoder(config: WanVAEEncoderConfig)[source]

Bases: StreamingVideoEncoder[WanVAECache]

Wan VAE encoder.

Forward input is a video tensor [..., T, C, H, W] in [-1, 1]; output is the latent [..., Tl, Cl, Hl, Wl]. The cache is advanced in-place across AR encode steps; passing cache=None allocates a fresh single-shot cache.

initialize_autoregressive_cache() WanVAECache[source]

Build a fresh per-rollout cache.

Override to return the encoder’s concrete cache type.

forward(input: Tensor, autoregressive_index: int = 0, cache: WanVAECache | None = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property temporal_compression_ratio: int

Pixel frames ÷ latent frames in steady state (AR ≥ 1).

AR 0 typically takes one extra (un-grouped) pixel frame to produce its first latent frame because of causal first-frame padding; that asymmetry lives inside get_output_temporal_size() / get_input_temporal_size().

property spatial_compression_ratio: int

Pixel side ÷ latent side. Constant across AR steps.

get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]

Causal: AR 0 needs an extra (un-grouped) pixel frame for the first latent.

get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]

Pixel frame count needed to produce output_temporal_size latents.

Inverse of get_output_temporal_size(). Implementations should assert output_temporal_size is achievable at this AR step (i.e. the corresponding pixel count comes out as a positive integer).

Parameters:
  • autoregressive_index – AR step index (0-based).

  • output_temporal_size – Desired number of latent frames.

Returns:

Number of pixel frames needed at this step.

class WanVAEEncoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEEncoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/Wan2.1_VAE.pth', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 96, z_dim: int = 16, patch_size: int = 1, is_residual: bool = False, latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921), latent_std: tuple[float, ...] = (2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None)[source]

Bases: EncoderConfig

Config for the Wan VAE encoder.

Defaults reproduce the Wan 2.1 / 14B 8x-spatial 16-channel streaming VAE. Override base_dim / z_dim / patch_size / is_residual (and the latent mean/std) to load the Wan 2.2 TI2V 5B 16x-spatial 48-channel residual VAE; see Wan22TI2V5BVAEEncoderConfig for the pre-rolled set.

use_cuda_graph: bool = True

Wrap the encoder forward in a CUDA graph for replay.

use_compile: bool = False

torch.compile(mode="max-autotune-no-cudagraphs"). Off by default: Inductor autotune workspaces can add several GiB of transient VRAM per unique input shape, surfacing as ‘illegal memory access’ on smaller GPUs with the full-channel vae checkpoint.

base_dim: int = 96

Encoder base channel count (WanVAE dim). 96 for Wan 2.1, 160 for Wan 2.2 TI2V 5B.

z_dim: int = 16

Latent channels. 16 for Wan 2.1, 48 for Wan 2.2 TI2V 5B.

patch_size: int = 1

Outer spatial pixel-shuffle factor (1 = no patchify; 2 for Wan 2.2 TI2V 5B).

is_residual: bool = False

Use ResidualDownBlock (Wan 2.2) instead of the legacy ResidualBlock + AttentionBlock down-stage (Wan 2.1).

latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921)

Per-channel latent mean used for normalisation; must match z_dim entries.

latent_std: tuple[float, ...] = (2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916)

Per-channel latent std used for normalisation.

state_dict_transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None

Optional pre-load_state_dict key remap (e.g. diffusers -> flashdreams layout). See wan22_ti2v_5b_vae_state_dict_transform() for the Wan 2.2 TI2V 5B remap.

wan22_ti2v_5b_vae_state_dict_transform(state_dict: Dict[str, Tensor]) Dict[str, Tensor][source]

Remap a diffusers AutoencoderKLWan state-dict to WanVAE keys.

Applied automatically when Wan22TI2V5BVAEEncoderConfig / Wan22TI2V5BVAEDecoderConfig load the upstream Wan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensors checkpoint. The mapping is purely structural – no tensors are copied or reshaped.

Note

Patterns are applied in iteration order via flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as a load_state_dict unexpected_keys warning so missing remap entries are easy to spot.

Unified Wan inference pipeline (Wan 2.1 / Wan 2.2, T2V and I2V).

class WanInferencePipelineCache(*, transformer_cache: TransformerCacheT, encoder_cache: StreamingEncoderCacheT | None = None, decoder_cache: StreamingDecoderCacheT | None = None, final_state: DiffusionModel.FinalState[TransformerCacheT] | None = None, autoregressive_index: int | None = None, event_profiler: EventProfiler | None = None, image: Tensor | None = None)[source]

Bases: StreamInferencePipelineCache[WanVAECache, Wan21TransformerCache | Wan22TransformerCache, WanVAECache]

Per-rollout state for the Wan pipeline.

Adds the I2V first-frame pixels on top of the inherited caches. Pixel-to- latent encoding happens per AR step inside the encoder, not here.

image: Tensor | None = None

First-frame pixels [*batch_shape, 1, 3, H, W] in [-1, 1]; None for T2V.

class WanInferencePipelineConfig(*, _target: type['WanInferencePipeline'] = <factory>, name: str, diffusion_model: DiffusionModelConfig, decoder: DecoderConfig | None = None, encoder: EncoderConfig | None = None, enable_sync_and_profile: bool = False, text_encoder: UMT5TextEncoderConfig | None = <factory>, image_encoder: CLIPImageEncoderConfig | None = None)[source]

Bases: StreamInferencePipelineConfig

Config for the Wan inference pipeline.

T2V vs I2V is selected by the inherited encoder slot: None for T2V, an I2V control-encoder config for I2V.

text_encoder: UMT5TextEncoderConfig | None

UMT5 text encoder run once per rollout.

image_encoder: CLIPImageEncoderConfig | None = None

CLIP image encoder for I2V variants trained with cross_attn_enable_img=True (Wan 2.1 14B I2V). None skips CLIP cross-attention conditioning.

class WanInferencePipeline(config: WanInferencePipelineConfig)[source]

Bases: StreamInferencePipeline[WanVAECache, Wan21TransformerCache | Wan22TransformerCache, WanVAECache]

Wan 2.1 / 2.2 inference pipeline, T2V and I2V.

T2V and I2V share the same rollout loop; the difference is whether you pass an image to initialize_cache. The pipeline config’s encoder slot must agree (None for T2V, an I2V config for I2V).

Example:

pipeline: WanInferencePipeline = ...

# T2V: pass latent ``height`` and ``width``.
cache = pipeline.initialize_cache(
    text=["A cat surfing."], height=60, width=104
)
_chunk = pipeline.generate(0, cache)
pipeline.finalize(0, cache)

# I2V: pass the first frame and let sizes derive from it.
_i2v_cache = pipeline.initialize_cache(
    text=["A cat surfing."], image=first_frame
)
initialize_cache(text: list[str], image: Tensor | None = None, *, height: int | None = None, width: int | None = None, release_oneshot_encoders: bool = True) WanInferencePipelineCache[source]

Initialize the per-rollout cache for a batch of prompts.

Parameters:
  • text – One prompt per batch element. Length must match the transformer’s batch_shape.

  • image – First-frame pixels of shape [*batch_shape, 1, 3, H, W] in [-1, 1]. Required for I2V (self.encoder is set), forbidden for T2V. H and W must equal height * decoder.spatial_compression_ratio and width * decoder.spatial_compression_ratio, respectively.

  • height – Pre-patchify latent height (post-VAE). Optional for I2V — derived from image when omitted; required for T2V.

  • width – Pre-patchify latent width (post-VAE). Same rules as height.

  • release_oneshot_encoders – Free the text and image encoders after the cache is initialized. Later calls reload them from self.config before encoding new prompts/images.

Returns:

Cache to thread through generate / finalize.

release_oneshot_encoders() None[source]

Free the per-rollout text and first-frame image encoders.

Idempotent. A later initialize_cache() call can reload the encoders from self.config before encoding raw prompts/images.

generate(autoregressive_index: int, cache: WanInferencePipelineCache) Tensor[source]

Generate one decoded video chunk.

Parameters:
  • autoregressive_index – AR step index, starting at 0.

  • cache – Per-rollout cache from initialize_cache.

Returns:

Decoded video of shape [*batch_shape, T, C, H, W] in [-1, 1].

get_num_input_frames(autoregressive_index: int) int[source]

Number of input video frames the model expects at this AR step.

get_num_output_frames(autoregressive_index: int) int[source]

Number of decoded video frames produced at this AR step.

TAEHV

TAEHV video decoder.

AVAILABLE_TAEHV_CHECKPOINT_PATHS = {'lighttae': 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth'}

Checkpoint paths for the TAEHV decoder.

lighttae_state_dict_transform(sd: Mapping[str, Tensor]) dict[str, Tensor]

Per-checkpoint remap for the lighttae weights. Rewrites the flat decoder.<i>.* keys to the current decoder.blocks.<i>.* layout and clips the stride=2 TGrow weights at idx 7 down to the stride=1 slice the live model expects.

class TeahvVAEDecoderConfig(*, _target: type[~flashdreams.recipes.taehv.TeahvVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth', state_dict_transform: ~collections.abc.Callable[[~collections.abc.Mapping[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function compose.<locals>.composed>, dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = True)[source]

Bases: DecoderConfig

Config for the TAEHV decoder.

checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth'

Path to a pretrained TAEHV checkpoint. Defaults to the lighttae weights.

state_dict_transform() dict[str, Tensor]

Pre-load state-dict remap. Defaults to lighttae_state_dict_transform; None falls through to the bare TAEHV default (see load_from_checkpoint()).

dtype: dtype = torch.bfloat16

Network parameter / activation dtype.

use_cuda_graph: bool = True

Wrap the decoder forward in a CUDA graph for replay.

use_compile: bool = True

torch.compile(mode="max-autotune-no-cudagraphs").

class TeahvVAEDecoder(config: TeahvVAEDecoderConfig)[source]

Bases: StreamingVideoDecoder[TAEHVCache]

TAEHV (Tiny AutoEncoder for Hunyuan Video) decoder.

Forward input is a latent [..., Tl, Cl, Hl, Wl]; output is a video tensor [..., T, C, H, W] in [-1, 1].

Set torch.backends.cudnn.benchmark = True at process start for ~5% extra on the eager seed/tail chunks.

mean: Tensor

Per-channel latent mean buffer; registered only when need_scaled.

std: Tensor

Per-channel latent standard deviation buffer; registered only when need_scaled.

initialize_autoregressive_cache() TAEHVCache[source]

Return an empty streaming decoder cache.

forward(input: Tensor, autoregressive_index: int = 0, cache: TAEHVCache | None = None) Tensor[source]

Decode a latent chunk to a video tensor in [-1, 1].

Parameters:
  • input – Latent of shape [..., Tl, Cl, Hl, Wl].

  • autoregressive_index – Unused by TAEHV; kept for the StreamingDecoder interface.

  • cache – Streaming decoder cache; created on the fly when None.

Returns:

Video tensor of shape [..., T, C, H, W] in [-1, 1].

property temporal_compression_ratio: int

Pixel frames / latent frames in steady state (AR >= 1).

property spatial_compression_ratio: int

Pixel side / latent side.

get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]

Return pixel frame count from input_temporal_size latent frames.

AR 0 applies causal padding: the first latent frame yields one pixel frame, remaining frames yield temporal_compression_ratio each. AR >= 1 is a plain multiply.

get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]

Return latent frame count needed to produce output_temporal_size pixel frames.

Inverse of get_output_temporal_size().