Source code for flashdreams.recipes.wan.transformer.wan22

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wan 2.2 MoE DiT (two Wan 2.1 networks + timestep-based dispatch)."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Literal

import torch
from torch import Tensor

from flashdreams.infra.diffusion.transformer import (
    Transformer,
    TransformerAutoregressiveCache,
    TransformerConfig,
)
from flashdreams.recipes.wan.transformer.impl.network import WanDiTNetwork14BConfig
from flashdreams.recipes.wan.transformer.wan21 import (
    Wan21Transformer,
    Wan21TransformerCache,
    Wan21TransformerConfig,
)


@dataclass(kw_only=True)
class Wan22TransformerCache(TransformerAutoregressiveCache):
    """Per-rollout AR cache for the Wan 2.2 transformer.

    Wraps two independent Wan 2.1 caches — one per branch — because the
    residual stream diverges between high- and low-noise stacks. ``start``
    / ``finalize`` advance both in lock-step.
    """

    transformer_high_noise: Wan21TransformerCache
    """AR cache for the high-noise branch."""

    transformer_low_noise: Wan21TransformerCache
    """AR cache for the low-noise branch."""

    def start(self, autoregressive_index: int) -> None:
        self.transformer_high_noise.start(autoregressive_index)
        self.transformer_low_noise.start(autoregressive_index)

    def finalize(self, autoregressive_index: int) -> None:
        self.transformer_high_noise.finalize(autoregressive_index)
        self.transformer_low_noise.finalize(autoregressive_index)


## Transformer


[docs] @dataclass(kw_only=True) class Wan22TransformerConfig(TransformerConfig): """Config for the Wan 2.2 dual-network transformer. Wan 2.2 dispatches to one of two Wan 2.1 networks based on whether the timestep is above or below ``boundary_ratio * num_train_timesteps``. Both branches must agree on patch_size / in_dim / dim / batch_shape / len_t / guidance_scale (asserted in ``__post_init__``). The per-rollout spatial layout (``height``, ``width``) is supplied to :meth:`Wan22Transformer.initialize_autoregressive_cache` and forwarded to both branches. The CP size is auto-detected from ``torch.distributed.get_world_size()``, same as Wan 2.1. Wan 2.2 has no CFG or I2V support here. """ _target: type["Wan22Transformer"] = field(default_factory=lambda: Wan22Transformer) transformer_high_noise: Wan21TransformerConfig = field( default_factory=Wan21TransformerConfig ) """Sub-config for the high-noise branch (timestep > boundary).""" transformer_low_noise: Wan21TransformerConfig = field( default_factory=Wan21TransformerConfig ) """Sub-config for the low-noise branch (timestep <= boundary).""" boundary_ratio: float = 0.875 """Fraction of ``num_train_timesteps`` separating the two branches. Default ``0.875`` matches upstream Wan 2.2.""" num_train_timesteps: int = 1000 def __post_init__(self) -> None: hi, lo = self.transformer_high_noise, self.transformer_low_noise # Per-branch network must agree on the token layout. assert hi.network.patch_size == lo.network.patch_size, ( "high/low noise networks must share patch_size; got " f"{hi.network.patch_size} vs {lo.network.patch_size}" ) assert hi.network.in_dim == lo.network.in_dim, ( "high/low noise networks must share in_dim; got " f"{hi.network.in_dim} vs {lo.network.in_dim}" ) assert hi.network.dim == lo.network.dim, ( "high/low noise networks must share dim (head sizing); got " f"{hi.network.dim} vs {lo.network.dim}" ) # guidance_scale is part of this list because the unified pipeline # reads it off this config to decide whether to build an uncond text # branch — the two sub-configs can't disagree. for key in ( "batch_shape", "len_t", "guidance_scale", ): assert getattr(hi, key) == getattr(lo, key), ( f"high/low noise sub-configs must share {key}; got " f"{getattr(hi, key)} vs {getattr(lo, key)}" ) @property def boundary_timestep(self) -> float: """Absolute timestep boundary.""" return self.boundary_ratio * self.num_train_timesteps ## Shared-field aliases (mirror the fields both branches must agree on) @property def batch_shape(self) -> tuple[int, ...]: return self.transformer_high_noise.batch_shape @property def len_t(self) -> int: return self.transformer_high_noise.len_t @property def dtype(self) -> torch.dtype: return self.transformer_high_noise.dtype @property def guidance_scale(self) -> float: return self.transformer_high_noise.guidance_scale
_NetworkChoice = Literal["high_noise", "low_noise"]
[docs] class Wan22Transformer(Transformer[Wan22TransformerCache]): """Wan 2.2 dual-network DiT. ``predict_flow`` dispatches to the branch selected by the timestep; ``finalize_kv_cache`` re-runs both branches once at the context noise so neither KV cache lags behind. """ transformer_high_noise: Wan21Transformer transformer_low_noise: Wan21Transformer def __init__(self, config: Wan22TransformerConfig) -> None: super().__init__(config) self.config: Wan22TransformerConfig = config self.transformer_high_noise = Wan21Transformer(config.transformer_high_noise) self.transformer_low_noise = Wan21Transformer(config.transformer_low_noise) @property def latent_shape(self) -> tuple[int, ...]: """Per-rank latent shape (both branches share this, asserted by config).""" return self.transformer_high_noise.latent_shape
[docs] @torch.no_grad() def initialize_autoregressive_cache( self, *, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, **_unused: Any, ) -> Wan22TransformerCache: """Build a seeded transformer cache for a new rollout. Both branches see the same text/image conditioning and the same per-rollout spatial layout, matching upstream Wan 2.2. CFG is not supported here. """ return Wan22TransformerCache( transformer_high_noise=self.transformer_high_noise.initialize_autoregressive_cache( height=height, width=width, text_embeddings=text_embeddings, image_embeddings=image_embeddings, ), transformer_low_noise=self.transformer_low_noise.initialize_autoregressive_cache( height=height, width=width, text_embeddings=text_embeddings, image_embeddings=image_embeddings, ), )
def _choose_network(self, timestep: Tensor) -> _NetworkChoice: """High-noise branch above the boundary, low-noise at or below.""" scalar = timestep.flatten()[0] if timestep.numel() > 0 else timestep return "high_noise" if scalar > self.config.boundary_timestep else "low_noise"
[docs] def predict_flow( self, noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None, ) -> Tensor: # Wan 2.2 (FastVideo T2V) has no per-AR-step encoder input; accept # and ignore ``input`` to satisfy the DiffusionModel.generate contract. if self._choose_network(timestep) == "high_noise": return self.transformer_high_noise.predict_flow( noisy_latent, timestep, cache.transformer_high_noise ) else: return self.transformer_low_noise.predict_flow( noisy_latent, timestep, cache.transformer_low_noise )
[docs] def finalize_kv_cache( self, noisy_latent: Tensor, timestep: Tensor, cache: Wan22TransformerCache, input: Any = None, ) -> None: """Refresh both networks' KV caches at the context-noise step. Each Wan 2.2 denoising step only touches one branch, so the other lags by the end of the loop. Re-running both at ``context_noise`` keeps them in lock-step; flow outputs are discarded. """ _ = self.transformer_high_noise.predict_flow( noisy_latent, timestep, cache.transformer_high_noise ) _ = self.transformer_low_noise.predict_flow( noisy_latent, timestep, cache.transformer_low_noise )
[docs] def patchify_and_maybe_split_cp(self, x: Tensor) -> Tensor: return self.transformer_high_noise.patchify_and_maybe_split_cp(x)
[docs] def unpatchify_and_maybe_gather_cp(self, x: Tensor) -> Tensor: return self.transformer_high_noise.unpatchify_and_maybe_gather_cp(x)