Pipelines and runners¶
FlashDreams model integrations are built from two public layers:
Pipelines (
StreamInferencePipelineConfig) that define model behavior.Runners (
RunnerConfig+Runner) that define CLI-facing I/O.
Most actively developed model implementations now live under integrations/*
as plugin-style standalone packages. This page keeps documenting the in-tree
pipeline modules that are still exposed from flashdreams.recipes.
Note
Pipeline modules import the heavy GPU stack (transformer-engine, CUDA
ops) at import time, so this page shows them by automodule with
:no-undoc-members: to keep the rendered API focused on the names
that these in-tree modules actually expose. The unified flashdreams-run
CLI shows end-to-end usage; see Models for model launch
examples.
Integration structure (current)¶
For new model work, follow integrations/<name>/:
config.py: pipeline + runner config literals (slugged entries).runner.py: runtime I/O, cache init, generate/finalize loop, persistence.pipeline.pyandtransformer/*: model compute path.pyproject.toml: plugin packaging + entry-point registration.
This makes each integration effectively a standalone repository while still
plugging into the same flashdreams-run registry.
Reference integration folders¶
NVIDIA OmniDreams¶
OmniDreams now ships as a plugin under integrations/omnidreams; it
registers its runners via the flashdreams.runner_configs entry-point
group and is no longer part of the in-tree flashdreams.recipes API
surface. See integrations/omnidreams/README.md for the plugin entry
point and flashdreams-run omnidreams-* for the user-facing CLI.
Wan¶
Public Wan integration surface for integration plugins.
- class Wan21Transformer(config: Wan21TransformerConfig)[source]¶
Bases:
Transformer[Wan21TransformerCache]Wan 2.1 DiT adapted to the infra Transformer interface.
- property latent_shape: tuple[int, ...]¶
Per-rank post-patchify latent shape
[*batch_shape, L/cp, D].Wan flattens THW into one token axis and shards across the THW CP group.
D = network.out_dim * prod(patch_size)is the noise channel count; the mask / image-latent channels added byconcat_image_mask_to_latentcome frominputinpredict_flow, not from the noise tensor.Per-rollout
(height, width)is populated byinitialize_autoregressive_cache(); reading earlier asserts.
- initialize_autoregressive_cache(*, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, negative_text_embeddings: Tensor | None = None, **_unused: Any) Wan21TransformerCache[source]¶
Build a seeded transformer cache for a new rollout.
I2V state is not baked into the cache; the latent + injection mask are passed per AR step as the
inputargument topredict_flow/postprocess_clean_latent.- Parameters:
height – Pre-patchify latent height (post-VAE).
width – Pre-patchify latent width (post-VAE).
text_embeddings – Conditional UMT5 embeddings
[..., text_len, text_dim].image_embeddings – Conditional CLIP image embeddings (only used by networks with
cross_attn_enable_img=True). Shared with the uncond branch.negative_text_embeddings – Negative-prompt embeddings. Required iff
config.guidance_scale > 1.0; must beNoneotherwise.
- Returns:
Populated cache.
network_cache_uncondisNoneiff CFG is disabled.
- predict_flow(noisy_latent: Tensor, timestep: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None, network_extra_kwargs: dict[str, Any] | None = None) Tensor[source]¶
Predict the flow for one denoising step.
timestepmay be a scalar / per-batch tensor (standard Wan 2.1 / 14B path) or a per-token tensor with the same trailing token axis asnoisy_latent(Wan 2.2 TI2V 5B first-frame seeding at AR step 0). The per-token layout flows throughWanDiTNetwork.forward(), which dispatches the sinusoidal embedding + AdaLN modulation on the native shape.CUDA-graph capture is shape-sensitive: the captured replay region only sees AR step
>= self._cuda_graph_capture_ar_idx(steady state). TI2V 5B is configured withlen_t == window_size_tso the threshold lands at AR 1, putting the per-token AR-0 step inside the eager.drainbranch where shape changes are safe. After AR 0 the pipeline switches back to scalar timesteps, so the captured branch sees a single stable shape across all AR steps it owns.
- postprocess_clean_latent(clean_latent: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None) Tensor[source]¶
Re-stamp
x0masked positions with the image latent (mask-inject I2V only).T2V and the channel-concat I2V mode fall through unchanged.
- patchify_and_maybe_split_cp(x: Tensor) Tensor[source]¶
- patchify_and_maybe_split_cp(x: I2VCtrl) I2VCtrl
Patchify and CP-split a noisy latent or an I2V control payload.
Tensors delegate to the network helper; I2V payloads patchify the
latentandmaskfields independently so the per-field channel layouts are preserved for the mask-inject blend downstream.
- class Wan21TransformerConfig(*, _target: type[~flashdreams.recipes.wan.transformer.wan21.Wan21Transformer] = <factory>, network: ~flashdreams.recipes.wan.transformer.impl.network.WanDiTNetworkConfig = <factory>, dtype: ~torch.dtype = torch.bfloat16, checkpoint_path: str | None = None, state_dict_transform: ~collections.abc.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None, batch_shape: tuple[int, ...] = (1,), len_t: int = 21, guidance_scale: float = 1.0, window_size_t: int = 21, sink_size_t: int = 0, h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, compile_network: bool = True, use_cuda_graph: bool = True, cuda_graph_warmup_iters: int = 2, stamp_image_latent: bool = False, concat_image_mask_to_latent: bool = False, ti2v_first_frame_per_token_timestep: bool = False, first_frame_timestep_value: float = 0.0)[source]¶
Bases:
TransformerConfigConfig for the Wan 2.1 transformer.
Bakes in the temporal layout (
len_t,window_size_t, optionalsink_size_t) and the CFG / compile knobs. Per-rollout spatial layout (height,width) is supplied toWan21Transformer.initialize_autoregressive_cache()so one instance can serve multiple resolutions. Wan flattensT*H*Winto one token axis and shards it across the THW CP group; the CP size is auto-detected fromtorch.distributed.get_world_size()at construction time, so the launcher (torchrun --nproc_per_node=N) is the single source of truth.The two I2V flags are independent and composable:
stamp_image_latent: overwrite the noisy latent with the clean image latent at masked positions every denoising step, and re-stamp the predictedx0the same way.network.in_dimunchanged. (flashdreams mask-inject integration; used by the out-of-treecausal_forcingplugin.)concat_image_mask_to_latent: append the 4-channel mask and 16-channel image latent along the channel dim. Builders that set this flag must also setnetwork.in_dim = 16 + 4 + 16to match the official Wan 2.1 14B I2V layout.
With both enabled, the stamp runs first and the result is then concatenated with the mask + image latent.
- state_dict_transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None¶
Pre-load state-dict remap (e.g. Self-Forcing’s
generator_ema.model.…layout).
- guidance_scale: float = 1.0¶
flow = uncond + s * (cond - uncond).1.0disables CFG;> 1.0requires negative-text embeddings at cache build time.- Type:
CFG scale
s
- use_cuda_graph: bool = True¶
Wrap the network in
CUDAGraphWrapperfor steady-state replay. Caller must keep non-staged inputs at stable storage addresses across calls.predict_flowdispatches towrapper.drainwhile the KV cache is still filling and towrapperonce it reaches steady state.
- ti2v_first_frame_per_token_timestep: bool = False¶
Wan 2.2 TI2V 5B first-frame conditioning. When
Trueand anI2VCtrlinput is provided at AR step 0,predict_flowrewrites the scheduler’s scalar timestep into a per-token tensor:t = first_frame_timestep_valueat positions marked by the I2V mask (i.e. the first-frame latent), and the scheduler’stelsewhere. AR steps>= 1continue to use the scalar timestep, which keeps the CUDA-graph-captured replay branch on a single stable input shape.Composes with
stamp_image_latent: together they implement the upstream Wan 2.2 5B “VAE-seeded first-frame + per-tokent=0” TI2V recipe – the latent is stamped clean every denoising step while the network seest=0for those tokens. The standard mask-inject I2V recipe leaves this flag off and relies on the classifier-free stamp alone.
- first_frame_timestep_value: float = 0.0¶
Per-token timestep assigned to first-frame conditioning tokens when
ti2v_first_frame_per_token_timestepisTrue.Defaults to
0.0(Wan 2.2 TI2V 5B’s base recipe — treats the first frame as fully clean by AdaLN). HY-WorldPlay’s distilled WAN-5B raises it to14.0(vendor’sstabilization_level - 1) so the AdaLN table sees a small nonzero sigma at the first frame.Unused when
ti2v_first_frame_per_token_timestepisFalse.
- class Wan22TI2V5BVAEDecoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers/resolve/main/vae/diffusion_pytorch_model.safetensors', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 160, decoder_base_dim: int | None = 256, z_dim: int = 48, patch_size: int = 2, is_residual: bool = True, latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667), latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function wan22_ti2v_5b_vae_state_dict_transform>)[source]¶
Bases:
WanVAEDecoderConfigPre-rolled config for the Wan 2.2 TI2V 5B decoder.
Mirrors
Wan22TI2V5BVAEEncoderConfigbut with the asymmetricdecoder_base_dim=256.- decoder_base_dim: int | None = 256¶
Decoder base channel count.
Nonemirrorsbase_dim(Wan 2.1). Wan 2.2 TI2V 5B uses an asymmetric 256.
- state_dict_transform() Dict[str, Tensor]¶
Remap a diffusers
AutoencoderKLWanstate-dict toWanVAEkeys.Applied automatically when
Wan22TI2V5BVAEEncoderConfig/Wan22TI2V5BVAEDecoderConfigload the upstreamWan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensorscheckpoint. The mapping is purely structural – no tensors are copied or reshaped.Note
Patterns are applied in iteration order via
flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as aload_state_dictunexpected_keyswarning so missing remap entries are easy to spot.
- class Wan22TI2V5BVAEEncoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEEncoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers/resolve/main/vae/diffusion_pytorch_model.safetensors', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 160, z_dim: int = 48, patch_size: int = 2, is_residual: bool = True, latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667), latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function wan22_ti2v_5b_vae_state_dict_transform>)[source]¶
Bases:
WanVAEEncoderConfigPre-rolled config for the Wan 2.2 TI2V 5B encoder.
Pins the diffusers upstream checkpoint, the 16x-spatial / 48ch / residual / patchify architecture knobs, and the matching diffusers -> flashdreams key remap. Equivalent to the Wan 2.1 encoder config plus the 5B-specific knobs flipped on.
- base_dim: int = 160¶
Encoder base channel count (
WanVAEdim). 96 for Wan 2.1, 160 for Wan 2.2 TI2V 5B.
- is_residual: bool = True¶
Use
ResidualDownBlock(Wan 2.2) instead of the legacyResidualBlock + AttentionBlockdown-stage (Wan 2.1).
- latent_mean: tuple[float, ...] = (-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.157, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.123, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.052, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667)¶
Per-channel latent mean used for normalisation; must match
z_dimentries.
- latent_std: tuple[float, ...] = (0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.499, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.06, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744)¶
Per-channel latent std used for normalisation.
- state_dict_transform() Dict[str, Tensor]¶
Remap a diffusers
AutoencoderKLWanstate-dict toWanVAEkeys.Applied automatically when
Wan22TI2V5BVAEEncoderConfig/Wan22TI2V5BVAEDecoderConfigload the upstreamWan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensorscheckpoint. The mapping is purely structural – no tensors are copied or reshaped.Note
Patterns are applied in iteration order via
flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as aload_state_dictunexpected_keyswarning so missing remap entries are easy to spot.
- class Wan22Transformer(config: Wan22TransformerConfig)[source]¶
Bases:
Transformer[Wan22TransformerCache]Wan 2.2 dual-network DiT.
predict_flowdispatches to the branch selected by the timestep;finalize_kv_cachere-runs both branches once at the context noise so neither KV cache lags behind.- property latent_shape: tuple[int, ...]¶
Per-rank latent shape (both branches share this, asserted by config).
- initialize_autoregressive_cache(*, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, **_unused: Any) Wan22TransformerCache[source]¶
Build a seeded transformer cache for a new rollout.
Both branches see the same text/image conditioning and the same per-rollout spatial layout, matching upstream Wan 2.2. CFG is not supported here.
- predict_flow(noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None) Tensor[source]¶
Predict the flow at
timestep.- Parameters:
noisy_latent – Patchified noisy latent for this denoising step.
timestep – Scalar timestep tensor.
cache – Per-rollout AR cache.
input – Patchified encoder output for this AR step, or
Nonewhen the pipeline has no encoder. Subclasses should narrow the type to their encoder’s output type.
- Returns:
Predicted flow tensor with the same shape as
noisy_latent.
- finalize_kv_cache(noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None) None[source]¶
Refresh both networks’ KV caches at the context-noise step.
Each Wan 2.2 denoising step only touches one branch, so the other lags by the end of the loop. Re-running both at
context_noisekeeps them in lock-step; flow outputs are discarded.
- patchify_and_maybe_split_cp(x: Tensor) Tensor[source]¶
Patchify and (optionally) CP-split a noisy latent or encoder payload.
Tensors patchify and split. Structured payloads (e.g. an image-control struct with
latent+mask) patchify each tensor field and return the same struct type. Implement as identity when neither token packing nor CP sharding applies. Output preserves the input Python type — only shapes change.
- class Wan22TransformerConfig(*, _target: type[Wan22Transformer] = <factory>, transformer_high_noise: Wan21TransformerConfig = <factory>, transformer_low_noise: Wan21TransformerConfig = <factory>, boundary_ratio: float = 0.875, num_train_timesteps: int = 1000)[source]¶
Bases:
TransformerConfigConfig for the Wan 2.2 dual-network transformer.
Wan 2.2 dispatches to one of two Wan 2.1 networks based on whether the timestep is above or below
boundary_ratio * num_train_timesteps. Both branches must agree on patch_size / in_dim / dim / batch_shape / len_t / guidance_scale (asserted in__post_init__). The per-rollout spatial layout (height,width) is supplied toWan22Transformer.initialize_autoregressive_cache()and forwarded to both branches. The CP size is auto-detected fromtorch.distributed.get_world_size(), same as Wan 2.1. Wan 2.2 has no CFG or I2V support here.- transformer_high_noise: Wan21TransformerConfig¶
Sub-config for the high-noise branch (timestep > boundary).
- transformer_low_noise: Wan21TransformerConfig¶
Sub-config for the low-noise branch (timestep <= boundary).
- class WanDiTNetwork(config: WanDiTNetworkConfig)[source]¶
Bases:
ModuleWAN diffusion backbone for text-to-video and image-to-video.
- set_context_parallel_group(cp_group: ProcessGroup | None = None) None[source]¶
Set context-parallel process group for all blocks.
This must be called before
initialize_cachewhen CP is used.
- patchify_and_maybe_split_cp(x: Tensor, process_groups: list[ProcessGroup | None] | None = None, cp_dims: list[int | None] | None = None) Tensor[source]¶
Patchify and optionally CP-split the input video tensor.
The patchify pattern is
... (t kt) c (h kh) (w kw) -> ... (t h w) (c kt kh kw).- Returns:
Patched tensor with shape
[..., L, D]whereL = T * H * W / (kt * kh * kw).
- unpatchify_and_maybe_gather_cp(pH: int, pW: int, x: Tensor, process_groups: list[ProcessGroup | None] | None = None, cp_dims: list[int | None] | None = None) Tensor[source]¶
Unpatchify and optionally CP-gather the tensor back to video shape.
The unpatchify pattern is
... (t h w) (c kt kh kw) -> ... (t kt) c (h kh) (w kw).- Returns:
Unpatched tensor with shape
[..., T, C, H, W].
- initialize_cache(chunk_size: int, window_size: int, sink_size: int, text_embeddings: Tensor, img_embeddings: Tensor | None = None) WanDiTNetworkCache[source]¶
Initialize block caches from text/image context embeddings.
- Parameters:
chunk_size – Number of tokens appended per self-attention update.
window_size – Rolling-window size in tokens for self-attention cache.
sink_size – Sink-token capacity preserved across updates.
text_embeddings – Text embeddings. UMT5 has shape […, 512, 4096].
img_embeddings – Optional image embeddings for I2V. CLIP has shape […, 256, 1280].
- Returns:
WanDiTNetworkCachecontaining per-block caches.
- update_parameters_after_loading_checkpoint() None[source]¶
Fuse load-time-known ops into weights; call once after loading the checkpoint.
- forward(x: Tensor, timesteps: Tensor, cache: WanDiTNetworkCache, rope_freqs: Tensor, current_chunk_idx: int = 0, eager_mode: bool = True, block_extra_kwargs: dict[str, Any] = {}) Tensor[source]¶
Run one denoising forward pass.
- Parameters:
x – Input tokens after patchify + CP, shape
[..., L, D_in]; layout"... (t h w) (c kt kh kw)".timesteps –
Diffusion timesteps. Two layouts are supported:
Scalar (per-batch). Shape broadcastable to
[...](i.e., tox.shape[:-2]). The same timestep is shared across every token, matching the standard Wan 2.1 / Wan 2.2 14B chunked-denoise path.Per-token. Shape
[..., L]matchingx’s post- patchify token axis. Used by Wan 2.2 TI2V 5B at AR step 0 to stampt=0at the first-frame conditioning tokens while the rest of the chunk denoises at the current scheduler step. SeeWan21Transformer.predict_flowfor the higher-level entry point.
cache – Network KV caches.
rope_freqs – Full-width RoPE frequencies after CP. Standard mode uses current-chunk frequencies with shape
[L, 1, 1, d]; KV-cache-relative mode uses frequencies relative to the KV cache.current_chunk_idx – Current chunk index for streaming cache update.
eager_mode – If
True, run cache before/after update hooks.block_extra_kwargs – Extra kwargs forwarded to each block.
- Returns:
Network output, shape
[..., L, prod(patch_size) * out_dim].
- class WanDiTNetwork1pt3BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 1536, ffn_dim: int = 8960, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 12, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]¶
Bases:
WanDiTNetworkConfigConfiguration for the 1.3B Wan DiT network.
- class WanDiTNetwork14BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 5120, ffn_dim: int = 13824, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 40, num_layers: int = 40, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]¶
Bases:
WanDiTNetworkConfigConfiguration for the 14B Wan DiT network.
- class WanDiTNetworkConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 16, dim: int = 1536, ffn_dim: int = 8960, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 16, num_heads: int = 12, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]¶
Bases:
InstantiateConfigConfiguration for the Wan DiT network.
- cross_attn_enable_img: bool = False¶
If True, build image cross-attention and CLIP image projection (I2V).
- patch_embedding_type: Literal['linear', 'conv3d'] = 'conv3d'¶
"linear"(flattened patch MLP) or"conv3d"(strided conv).- Type:
Type of patch embedding
- class WanDiTNetworkTI2V5BConfig(_target: type[WanDiTNetwork] = <factory>, patch_size: tuple[int, int, int]=(1, 2, 2), text_len: int = 512, in_dim: int = 48, dim: int = 3072, ffn_dim: int = 14336, freq_dim: int = 256, text_dim: int = 4096, out_dim: int = 48, num_heads: int = 24, num_layers: int = 30, cross_attn_norm: bool = True, cross_attn_enable_img: bool = False, eps: float = 1e-06, concat_padding_mask: bool = False, patch_embedding_type: Literal['linear', 'conv3d']='conv3d', apply_rope_before_kvcache: bool = True, cp_method: Literal['ring', 'ulysses']='ring')[source]¶
Bases:
WanDiTNetworkConfigConfiguration for the Wan 2.2 TI2V 5B DiT network.
Mirrors the official
Wan-AI/Wan2.2-TI2V-5B-Diffusers/transformerconfig: 24 heads * 128 head_dim = 3072 inner dim, 30 layers, ffn_dim 14336, and 48-channel latent in/out (the matching 16x VAE invae.pyoutputs 48 channels). Unlike Wan 2.1 14B I2V, TI2V 5B has no CLIP cross-attention branch (cross_attn_enable_img=False): the first frame is conditioned via a clean VAE-latent seed plus a per-tokent=0timestep on the AR-step-0 first-frame tokens, not via CLIP image features.
- class WanI2VCtrlEncoderConfig(*, _target: type[I2VCtrlEncoder] = <factory>, encoder: WanVAEEncoderConfig = <factory>)[source]¶
Bases:
EncoderConfigConfig for the I2V control encoder.
- encoder: WanVAEEncoderConfig¶
Streaming Wan VAE encoder. Pin its checkpoint to the decoder’s so the encoded latent matches the network’s input distribution.
- class WanInferencePipeline(config: WanInferencePipelineConfig)[source]¶
Bases:
StreamInferencePipeline[WanVAECache,Wan21TransformerCache|Wan22TransformerCache,WanVAECache]Wan 2.1 / 2.2 inference pipeline, T2V and I2V.
T2V and I2V share the same rollout loop; the difference is whether you pass an
imagetoinitialize_cache. The pipeline config’sencoderslot must agree (Nonefor T2V, an I2V config for I2V).Example:
pipeline: WanInferencePipeline = ... # T2V: pass latent ``height`` and ``width``. cache = pipeline.initialize_cache( text=["A cat surfing."], height=60, width=104 ) _chunk = pipeline.generate(0, cache) pipeline.finalize(0, cache) # I2V: pass the first frame and let sizes derive from it. _i2v_cache = pipeline.initialize_cache( text=["A cat surfing."], image=first_frame )
- initialize_cache(text: list[str], image: Tensor | None = None, *, height: int | None = None, width: int | None = None, release_oneshot_encoders: bool = True) WanInferencePipelineCache[source]¶
Initialize the per-rollout cache for a batch of prompts.
- Parameters:
text – One prompt per batch element. Length must match the transformer’s
batch_shape.image – First-frame pixels of shape
[*batch_shape, 1, 3, H, W]in[-1, 1]. Required for I2V (self.encoderis set), forbidden for T2V.HandWmust equalheight * decoder.spatial_compression_ratioandwidth * decoder.spatial_compression_ratio, respectively.height – Pre-patchify latent height (post-VAE). Optional for I2V — derived from
imagewhen omitted; required for T2V.width – Pre-patchify latent width (post-VAE). Same rules as
height.release_oneshot_encoders – Free the text and image encoders after the cache is initialized. Later calls reload them from
self.configbefore encoding new prompts/images.
- Returns:
Cache to thread through
generate/finalize.
- release_oneshot_encoders() None[source]¶
Free the per-rollout text and first-frame image encoders.
Idempotent. A later
initialize_cache()call can reload the encoders fromself.configbefore encoding raw prompts/images.
- generate(autoregressive_index: int, cache: WanInferencePipelineCache) Tensor[source]¶
Generate one decoded video chunk.
- Parameters:
autoregressive_index – AR step index, starting at 0.
cache – Per-rollout cache from
initialize_cache.
- Returns:
Decoded video of shape
[*batch_shape, T, C, H, W]in[-1, 1].
- class WanInferencePipelineCache(*, transformer_cache: TransformerCacheT, encoder_cache: StreamingEncoderCacheT | None = None, decoder_cache: StreamingDecoderCacheT | None = None, final_state: DiffusionModel.FinalState[TransformerCacheT] | None = None, autoregressive_index: int | None = None, event_profiler: EventProfiler | None = None, image: Tensor | None = None)[source]¶
Bases:
StreamInferencePipelineCache[WanVAECache,Wan21TransformerCache|Wan22TransformerCache,WanVAECache]Per-rollout state for the Wan pipeline.
Adds the I2V first-frame pixels on top of the inherited caches. Pixel-to- latent encoding happens per AR step inside the encoder, not here.
- class WanInferencePipelineConfig(*, _target: type['WanInferencePipeline'] = <factory>, name: str, diffusion_model: DiffusionModelConfig, decoder: DecoderConfig | None = None, encoder: EncoderConfig | None = None, enable_sync_and_profile: bool = False, text_encoder: UMT5TextEncoderConfig | None = <factory>, image_encoder: CLIPImageEncoderConfig | None = None)[source]¶
Bases:
StreamInferencePipelineConfigConfig for the Wan inference pipeline.
T2V vs I2V is selected by the inherited
encoderslot:Nonefor T2V, an I2V control-encoder config for I2V.
- class WanVAEDecoder(config: WanVAEDecoderConfig)[source]¶
Bases:
StreamingVideoDecoder[WanVAECache]Wan VAE decoder.
Forward input is a latent
[..., Tl, Cl, Hl, Wl]; output is a video tensor[..., T, C, H, W]in[-1, 1]. The cache is advanced in-place across AR decode steps; passingcache=Noneallocates a fresh single-shot cache.- initialize_autoregressive_cache() WanVAECache[source]¶
Build a fresh per-rollout cache.
Override to return the decoder’s concrete cache type.
- forward(input: Tensor, autoregressive_index: int = 0, cache: WanVAECache | None = None) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property temporal_compression_ratio: int¶
Pixel frames ÷ latent frames in steady state (AR ≥ 1).
AR 0 typically yields fewer pixel frames per latent frame because of causal first-frame padding; that asymmetry lives inside
get_output_temporal_size()/get_input_temporal_size().
- get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]¶
Causal: AR 0 first latent frame decodes to a single pixel frame.
- get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]¶
Latent frame count needed to produce
output_temporal_sizepixels.Inverse of
get_output_temporal_size(). Implementations should assertoutput_temporal_sizeis achievable at this AR step (i.e. divisible by the right ratio after subtracting any causal padding).- Parameters:
autoregressive_index – AR step index (0-based).
output_temporal_size – Desired number of pixel frames.
- Returns:
Number of latent frames needed at this step.
- class WanVAEDecoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/Wan2.1_VAE.pth', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 96, decoder_base_dim: int | None = None, z_dim: int = 16, patch_size: int = 1, is_residual: bool = False, latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921), latent_std: tuple[float, ...] = (2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None)[source]¶
Bases:
DecoderConfigConfig for the Wan VAE decoder.
Defaults reproduce the Wan 2.1 / 14B 8x-spatial 16-channel streaming VAE. Override the architecture knobs (or use
Wan22TI2V5BVAEDecoderConfig) to load the Wan 2.2 TI2V 5B decoder, which has asymmetricbase_dim=160/decoder_base_dim =256and the residual up-stage withDupUp3Dshortcut.- use_compile: bool = False¶
torch.compile(mode="max-autotune-no-cudagraphs"). SeeWanVAEEncoderConfig.use_compilefor the VRAM caveat.
- class WanVAEEncoder(config: WanVAEEncoderConfig)[source]¶
Bases:
StreamingVideoEncoder[WanVAECache]Wan VAE encoder.
Forward input is a video tensor
[..., T, C, H, W]in[-1, 1]; output is the latent[..., Tl, Cl, Hl, Wl]. The cache is advanced in-place across AR encode steps; passingcache=Noneallocates a fresh single-shot cache.- initialize_autoregressive_cache() WanVAECache[source]¶
Build a fresh per-rollout cache.
Override to return the encoder’s concrete cache type.
- forward(input: Tensor, autoregressive_index: int = 0, cache: WanVAECache | None = None) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property temporal_compression_ratio: int¶
Pixel frames ÷ latent frames in steady state (AR ≥ 1).
AR 0 typically takes one extra (un-grouped) pixel frame to produce its first latent frame because of causal first-frame padding; that asymmetry lives inside
get_output_temporal_size()/get_input_temporal_size().
- get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]¶
Causal: AR 0 needs an extra (un-grouped) pixel frame for the first latent.
- get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]¶
Pixel frame count needed to produce
output_temporal_sizelatents.Inverse of
get_output_temporal_size(). Implementations should assertoutput_temporal_sizeis achievable at this AR step (i.e. the corresponding pixel count comes out as a positive integer).- Parameters:
autoregressive_index – AR step index (0-based).
output_temporal_size – Desired number of latent frames.
- Returns:
Number of pixel frames needed at this step.
- class WanVAEEncoderConfig(*, _target: type[~flashdreams.recipes.wan.autoencoder.vae.WanVAEEncoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/Wan2.1_VAE.pth', dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = False, base_dim: int = 96, z_dim: int = 16, patch_size: int = 1, is_residual: bool = False, latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921), latent_std: tuple[float, ...] = (2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916), state_dict_transform: ~typing.Callable[[dict[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = None)[source]¶
Bases:
EncoderConfigConfig for the Wan VAE encoder.
Defaults reproduce the Wan 2.1 / 14B 8x-spatial 16-channel streaming VAE. Override
base_dim/z_dim/patch_size/is_residual(and the latent mean/std) to load the Wan 2.2 TI2V 5B 16x-spatial 48-channel residual VAE; seeWan22TI2V5BVAEEncoderConfigfor the pre-rolled set.- use_compile: bool = False¶
torch.compile(mode="max-autotune-no-cudagraphs"). Off by default: Inductor autotune workspaces can add several GiB of transient VRAM per unique input shape, surfacing as ‘illegal memory access’ on smaller GPUs with the full-channelvaecheckpoint.
- base_dim: int = 96¶
Encoder base channel count (
WanVAEdim). 96 for Wan 2.1, 160 for Wan 2.2 TI2V 5B.
- is_residual: bool = False¶
Use
ResidualDownBlock(Wan 2.2) instead of the legacyResidualBlock + AttentionBlockdown-stage (Wan 2.1).
- latent_mean: tuple[float, ...] = (-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921)¶
Per-channel latent mean used for normalisation; must match
z_dimentries.
- wan22_ti2v_5b_vae_state_dict_transform(state_dict: Dict[str, Tensor]) Dict[str, Tensor][source]¶
Remap a diffusers
AutoencoderKLWanstate-dict toWanVAEkeys.Applied automatically when
Wan22TI2V5BVAEEncoderConfig/Wan22TI2V5BVAEDecoderConfigload the upstreamWan-AI/Wan2.2-TI2V-5B-Diffusers/vae/diffusion_pytorch_model.safetensorscheckpoint. The mapping is purely structural – no tensors are copied or reshaped.Note
Patterns are applied in iteration order via
flashdreams.core.checkpoint.remap.remap_checkpoint_keys(). Any key without a matching pattern passes through unchanged, which surfaces as aload_state_dictunexpected_keyswarning so missing remap entries are easy to spot.
Unified Wan inference pipeline (Wan 2.1 / Wan 2.2, T2V and I2V).
- class WanInferencePipelineCache(*, transformer_cache: TransformerCacheT, encoder_cache: StreamingEncoderCacheT | None = None, decoder_cache: StreamingDecoderCacheT | None = None, final_state: DiffusionModel.FinalState[TransformerCacheT] | None = None, autoregressive_index: int | None = None, event_profiler: EventProfiler | None = None, image: Tensor | None = None)[source]¶
Bases:
StreamInferencePipelineCache[WanVAECache,Wan21TransformerCache|Wan22TransformerCache,WanVAECache]Per-rollout state for the Wan pipeline.
Adds the I2V first-frame pixels on top of the inherited caches. Pixel-to- latent encoding happens per AR step inside the encoder, not here.
- class WanInferencePipelineConfig(*, _target: type['WanInferencePipeline'] = <factory>, name: str, diffusion_model: DiffusionModelConfig, decoder: DecoderConfig | None = None, encoder: EncoderConfig | None = None, enable_sync_and_profile: bool = False, text_encoder: UMT5TextEncoderConfig | None = <factory>, image_encoder: CLIPImageEncoderConfig | None = None)[source]¶
Bases:
StreamInferencePipelineConfigConfig for the Wan inference pipeline.
T2V vs I2V is selected by the inherited
encoderslot:Nonefor T2V, an I2V control-encoder config for I2V.
- class WanInferencePipeline(config: WanInferencePipelineConfig)[source]¶
Bases:
StreamInferencePipeline[WanVAECache,Wan21TransformerCache|Wan22TransformerCache,WanVAECache]Wan 2.1 / 2.2 inference pipeline, T2V and I2V.
T2V and I2V share the same rollout loop; the difference is whether you pass an
imagetoinitialize_cache. The pipeline config’sencoderslot must agree (Nonefor T2V, an I2V config for I2V).Example:
pipeline: WanInferencePipeline = ... # T2V: pass latent ``height`` and ``width``. cache = pipeline.initialize_cache( text=["A cat surfing."], height=60, width=104 ) _chunk = pipeline.generate(0, cache) pipeline.finalize(0, cache) # I2V: pass the first frame and let sizes derive from it. _i2v_cache = pipeline.initialize_cache( text=["A cat surfing."], image=first_frame )
- initialize_cache(text: list[str], image: Tensor | None = None, *, height: int | None = None, width: int | None = None, release_oneshot_encoders: bool = True) WanInferencePipelineCache[source]¶
Initialize the per-rollout cache for a batch of prompts.
- Parameters:
text – One prompt per batch element. Length must match the transformer’s
batch_shape.image – First-frame pixels of shape
[*batch_shape, 1, 3, H, W]in[-1, 1]. Required for I2V (self.encoderis set), forbidden for T2V.HandWmust equalheight * decoder.spatial_compression_ratioandwidth * decoder.spatial_compression_ratio, respectively.height – Pre-patchify latent height (post-VAE). Optional for I2V — derived from
imagewhen omitted; required for T2V.width – Pre-patchify latent width (post-VAE). Same rules as
height.release_oneshot_encoders – Free the text and image encoders after the cache is initialized. Later calls reload them from
self.configbefore encoding new prompts/images.
- Returns:
Cache to thread through
generate/finalize.
- release_oneshot_encoders() None[source]¶
Free the per-rollout text and first-frame image encoders.
Idempotent. A later
initialize_cache()call can reload the encoders fromself.configbefore encoding raw prompts/images.
- generate(autoregressive_index: int, cache: WanInferencePipelineCache) Tensor[source]¶
Generate one decoded video chunk.
- Parameters:
autoregressive_index – AR step index, starting at 0.
cache – Per-rollout cache from
initialize_cache.
- Returns:
Decoded video of shape
[*batch_shape, T, C, H, W]in[-1, 1].
TAEHV¶
TAEHV video decoder.
- AVAILABLE_TAEHV_CHECKPOINT_PATHS = {'lighttae': 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth'}¶
Checkpoint paths for the TAEHV decoder.
- lighttae_state_dict_transform(sd: Mapping[str, Tensor]) dict[str, Tensor]¶
Per-checkpoint remap for the
lighttaeweights. Rewrites the flatdecoder.<i>.*keys to the currentdecoder.blocks.<i>.*layout and clips the stride=2TGrowweights at idx 7 down to the stride=1 slice the live model expects.
- class TeahvVAEDecoderConfig(*, _target: type[~flashdreams.recipes.taehv.TeahvVAEDecoder] = <factory>, checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth', state_dict_transform: ~collections.abc.Callable[[~collections.abc.Mapping[str, ~torch.Tensor]], dict[str, ~torch.Tensor]] | None = <function compose.<locals>.composed>, dtype: ~torch.dtype = torch.bfloat16, use_cuda_graph: bool = True, use_compile: bool = True)[source]¶
Bases:
DecoderConfigConfig for the TAEHV decoder.
- checkpoint_path: str = 'https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth'¶
Path to a pretrained TAEHV checkpoint. Defaults to the
lighttaeweights.
- state_dict_transform() dict[str, Tensor]¶
Pre-load state-dict remap. Defaults to
lighttae_state_dict_transform;Nonefalls through to the bareTAEHVdefault (seeload_from_checkpoint()).
- class TeahvVAEDecoder(config: TeahvVAEDecoderConfig)[source]¶
Bases:
StreamingVideoDecoder[TAEHVCache]TAEHV (Tiny AutoEncoder for Hunyuan Video) decoder.
Forward input is a latent
[..., Tl, Cl, Hl, Wl]; output is a video tensor[..., T, C, H, W]in[-1, 1].Set
torch.backends.cudnn.benchmark = Trueat process start for ~5% extra on the eager seed/tail chunks.- forward(input: Tensor, autoregressive_index: int = 0, cache: TAEHVCache | None = None) Tensor[source]¶
Decode a latent chunk to a video tensor in
[-1, 1].- Parameters:
input – Latent of shape
[..., Tl, Cl, Hl, Wl].autoregressive_index – Unused by TAEHV; kept for the
StreamingDecoderinterface.cache – Streaming decoder cache; created on the fly when
None.
- Returns:
Video tensor of shape
[..., T, C, H, W]in[-1, 1].
- get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]¶
Return pixel frame count from
input_temporal_sizelatent frames.AR 0 applies causal padding: the first latent frame yields one pixel frame, remaining frames yield
temporal_compression_ratioeach. AR >= 1 is a plain multiply.