Source code for flashdreams.recipes.wan.transformer.wan21

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wan 2.1 DiT."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, overload

import torch
from torch import Tensor

from flashdreams.core.attention.rope import (
    KVCacheRelativeRotaryPositionEmbedding3D,
    RotaryPositionEmbedding3D,
)
from flashdreams.core.checkpoint.load import load_checkpoint
from flashdreams.infra.compile import compile_module
from flashdreams.infra.cuda_graph import CUDAGraphWrapper
from flashdreams.infra.diffusion.transformer import (
    Transformer,
    TransformerAutoregressiveCache,
    TransformerConfig,
)
from flashdreams.recipes.wan.autoencoder.i2v import I2VCtrl
from flashdreams.recipes.wan.transformer.impl.network import (
    WanDiTNetwork,
    WanDiTNetwork1pt3BConfig,
    WanDiTNetworkCache,
    WanDiTNetworkConfig,
)

## Autoregressive cache (per-rollout, mutated across AR steps)


@dataclass(kw_only=True)
class Wan21TransformerCache(TransformerAutoregressiveCache):
    """Per-rollout AR cache for the Wan 2.1 transformer.

    Holds an always-present conditional network cache and an optional
    unconditional one for classifier-free guidance (``None`` disables CFG).
    Both branches own independent per-block self-attention KV buffers since
    the residual stream diverges after the first cross-attention layer.
    """

    network_cache: WanDiTNetworkCache
    """Conditional per-block KV / cross-attention caches."""

    network_cache_uncond: WanDiTNetworkCache | None = None
    """Unconditional caches; ``None`` disables CFG."""

    rope_adapter: RotaryPositionEmbedding3D | KVCacheRelativeRotaryPositionEmbedding3D
    """3D RoPE adapter for self-attention position frequencies."""

    rope_freqs: Tensor | None = None
    """Self-attention RoPE frequencies for the current AR step.
    Standard mode stores K after applying current-chunk RoPE.
    KV-cache-relative mode stores unrotated K and applies cache-slot RoPE on cache read.
    Shape ``[L, 1, 1, head_dim]`` after CP in standard mode. Recomputed once per
    AR step in :meth:`start` and reused across cond and uncond branches
    (and across all scheduler steps within the AR step)."""

    autoregressive_index: int = -1
    """Current AR step index, set by ``start``."""

    def start(self, autoregressive_index: int) -> None:
        # Hoist per-block KV pre-update out of the (graph-captured) network
        # forward; predict_flow runs with eager_mode=False so the network
        # itself does not call before_update. Same for shift_t: tying the
        # AR index into the captured graph as a Python int would re-trigger
        # cat/repeat on every cond/uncond pass.
        self.rope_freqs = self.rope_adapter.shift_t(autoregressive_index)

        self.autoregressive_index = autoregressive_index
        self.network_cache.before_update(autoregressive_index)
        if self.network_cache_uncond is not None:
            self.network_cache_uncond.before_update(autoregressive_index)

    def finalize(self, autoregressive_index: int) -> None:
        self.network_cache.after_update(autoregressive_index)
        if self.network_cache_uncond is not None:
            self.network_cache_uncond.after_update(autoregressive_index)


## Transformer


[docs] @dataclass(kw_only=True) class Wan21TransformerConfig(TransformerConfig): """Config for the Wan 2.1 transformer. Bakes in the temporal layout (``len_t``, ``window_size_t``, optional ``sink_size_t``) and the CFG / compile knobs. Per-rollout spatial layout (``height``, ``width``) is supplied to :meth:`Wan21Transformer.initialize_autoregressive_cache` so one instance can serve multiple resolutions. Wan flattens ``T*H*W`` into one token axis and shards it across the THW CP group; the CP size is auto-detected from ``torch.distributed.get_world_size()`` at construction time, so the launcher (``torchrun --nproc_per_node=N``) is the single source of truth. The two I2V flags are independent and composable: - ``stamp_image_latent``: overwrite the noisy latent with the clean image latent at masked positions every denoising step, and re-stamp the predicted ``x0`` the same way. ``network.in_dim`` unchanged. (flashdreams mask-inject integration; used by the out-of-tree ``causal_forcing`` plugin.) - ``concat_image_mask_to_latent``: append the 4-channel mask and 16-channel image latent along the channel dim. Builders that set this flag must also set ``network.in_dim = 16 + 4 + 16`` to match the official Wan 2.1 14B I2V layout. With both enabled, the stamp runs first and the result is then concatenated with the mask + image latent. """ _target: type["Wan21Transformer"] = field(default_factory=lambda: Wan21Transformer) network: WanDiTNetworkConfig = field(default_factory=WanDiTNetwork1pt3BConfig) dtype: torch.dtype = torch.bfloat16 checkpoint_path: str | None = None state_dict_transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None """Pre-load state-dict remap (e.g. Self-Forcing's ``generator_ema.model.…`` layout).""" batch_shape: tuple[int, ...] = (1,) """Batch dims of the latent (excluding the L, D dims).""" len_t: int = 21 """Latent frames per AR chunk (post-VAE).""" guidance_scale: float = 1.0 """CFG scale ``s``: ``flow = uncond + s * (cond - uncond)``. ``1.0`` disables CFG; ``> 1.0`` requires negative-text embeddings at cache build time.""" window_size_t: int = 21 """Self-attention sliding-window size (pre-patchify T frames).""" sink_size_t: int = 0 """Prefix sink size (pre-patchify T frames) for self-attention KV cache.""" h_extrapolation_ratio: float = 1.0 w_extrapolation_ratio: float = 1.0 compile_network: bool = True """``torch.compile`` the network on init.""" use_cuda_graph: bool = True """Wrap the network in ``CUDAGraphWrapper`` for steady-state replay. Caller must keep non-staged inputs at stable storage addresses across calls. ``predict_flow`` dispatches to ``wrapper.drain`` while the KV cache is still filling and to ``wrapper`` once it reaches steady state.""" cuda_graph_warmup_iters: int = 2 """Eager calls before capture (>= 2 to drain Inductor autotune).""" stamp_image_latent: bool = False """See class docstring (mask-inject I2V integration).""" concat_image_mask_to_latent: bool = False """See class docstring (channel-concat I2V layout).""" ti2v_first_frame_per_token_timestep: bool = False """Wan 2.2 TI2V 5B first-frame conditioning. When ``True`` and an :class:`I2VCtrl` input is provided at AR step 0, ``predict_flow`` rewrites the scheduler's scalar timestep into a per-token tensor: ``t = first_frame_timestep_value`` at positions marked by the I2V mask (i.e. the first-frame latent), and the scheduler's ``t`` elsewhere. AR steps ``>= 1`` continue to use the scalar timestep, which keeps the CUDA-graph-captured replay branch on a single stable input shape. Composes with ``stamp_image_latent``: together they implement the upstream Wan 2.2 5B "VAE-seeded first-frame + per-token ``t=0``" TI2V recipe -- the latent is stamped clean every denoising step while the network sees ``t=0`` for those tokens. The standard mask-inject I2V recipe leaves this flag off and relies on the classifier-free stamp alone.""" first_frame_timestep_value: float = 0.0 """Per-token timestep assigned to first-frame conditioning tokens when :attr:`ti2v_first_frame_per_token_timestep` is ``True``. Defaults to ``0.0`` (Wan 2.2 TI2V 5B's base recipe — treats the first frame as fully clean by AdaLN). HY-WorldPlay's distilled WAN-5B raises it to ``14.0`` (vendor's ``stabilization_level - 1``) so the AdaLN table sees a small nonzero sigma at the first frame. Unused when :attr:`ti2v_first_frame_per_token_timestep` is ``False``. """
[docs] class Wan21Transformer(Transformer[Wan21TransformerCache]): """Wan 2.1 DiT adapted to the infra Transformer interface.""" config: Wan21TransformerConfig network: WanDiTNetwork def __init__(self, config: Wan21TransformerConfig) -> None: super().__init__(config) self.config = config # Auto-detect CP size from the launcher (``torchrun # --nproc_per_node=N``) — the single source of truth. Wan flattens # T*H*W into one token axis and shards it across the WORLD group. if torch.distributed.is_initialized(): self._cp_size = torch.distributed.get_world_size() self._cp_group = ( torch.distributed.group.WORLD if self._cp_size > 1 else None ) else: self._cp_size = 1 self._cp_group = None # Pre-patchify temporal divisibility check; per-rollout # (height, width) is populated by initialize_autoregressive_cache. kt, _, _ = config.network.patch_size assert config.len_t % kt == 0, ( f"len_t ({config.len_t}) must be divisible by patch_size[0] ({kt})." ) assert config.window_size_t % kt == 0, ( f"window_size_t ({config.window_size_t}) must be divisible by " f"patch_size[0] ({kt})." ) assert config.sink_size_t % kt == 0, ( f"sink_size_t ({config.sink_size_t}) must be divisible by " f"patch_size[0] ({kt})" ) len_t = config.len_t // kt window_size_t = config.window_size_t // kt sink_size_t = config.sink_size_t // kt assert (sink_size_t + window_size_t) % len_t == 0, ( f"sink_size_t + window_size_t ({sink_size_t + window_size_t}) must be " f"divisible by post-patch len_t ({len_t}) so the BlockKVCache can " f"fit a whole number of AR chunks." ) self._output_height: int | None = None self._output_width: int | None = None self.network = config.network.setup() self.network = self.network.to(dtype=config.dtype) self.network.eval() self.network.set_context_parallel_group(cp_group=self._cp_group) if config.checkpoint_path is not None: state_dict = load_checkpoint(config.checkpoint_path) if config.state_dict_transform is not None: state_dict = config.state_dict_transform(state_dict) self.network.load_state_dict(state_dict) self.network.update_parameters_after_loading_checkpoint() if config.compile_network: self.network = compile_module(self.network) # Per-rollout dispatch when use_cuda_graph=True: # filling phase -> wrapper.drain (eager, drains Inductor autotune); # steady-state -> wrapper.__call__ (warmup + capture + replay). # Cond and CFG-uncond branches each get their own wrapper since each # mutates an independent rolling KV cache. The dispatch threshold # matches the KV cache's filling -> steady transition so the captured # region only sees steady-state paths. self._use_cuda_graph = config.use_cuda_graph chunks_total = sink_size_t + window_size_t self._cuda_graph_capture_ar_idx: int = chunks_total // len_t self._network_call: CUDAGraphWrapper | WanDiTNetwork = ( CUDAGraphWrapper(self.network, warmup_iters=config.cuda_graph_warmup_iters) if config.use_cuda_graph else self.network ) self._network_call_uncond: CUDAGraphWrapper | WanDiTNetwork = ( CUDAGraphWrapper(self.network, warmup_iters=config.cuda_graph_warmup_iters) if config.use_cuda_graph else self.network ) @property def latent_shape(self) -> tuple[int, ...]: """Per-rank post-patchify latent shape ``[*batch_shape, L/cp, D]``. Wan flattens THW into one token axis and shards across the THW CP group. ``D = network.out_dim * prod(patch_size)`` is the noise channel count; the mask / image-latent channels added by ``concat_image_mask_to_latent`` come from ``input`` in ``predict_flow``, not from the noise tensor. Per-rollout ``(height, width)`` is populated by :meth:`initialize_autoregressive_cache`; reading earlier asserts. """ assert self._output_height is not None and self._output_width is not None, ( "latent_shape requires an initialized rollout; call " "initialize_autoregressive_cache(..., height=..., width=...) first." ) cfg = self.config kt, kh, kw = cfg.network.patch_size L = (cfg.len_t // kt) * (self._output_height // kh) * (self._output_width // kw) return ( *cfg.batch_shape, L // self._cp_size, cfg.network.out_dim * kt * kh * kw, ) @torch.no_grad() def _build_network_cache( self, *, text_embeddings: Tensor, image_embeddings: Tensor | None = None, ) -> WanDiTNetworkCache: """Build one network cache (cond or uncond branch). Caller must have populated ``self._output_height/_output_width`` (done by :meth:`initialize_autoregressive_cache`) before invoking this. """ assert self._output_height is not None and self._output_width is not None, ( "_build_network_cache called before height/width were stashed." ) cfg = self.config kt, kh, kw = cfg.network.patch_size pHW = (self._output_height // kh) * (self._output_width // kw) cp_size = self._cp_size chunk_size = self.latent_shape[-2] # already CP-divided window_size_t = cfg.window_size_t // kt sink_size_t = cfg.sink_size_t // kt assert (window_size_t * pHW) % cp_size == 0, ( f"window_size_t * frame_token_count ({window_size_t * pHW}) must be " f"divisible by cp_size ({cp_size})" ) assert (sink_size_t * pHW) % cp_size == 0, ( f"sink_size_t * frame_token_count ({sink_size_t * pHW}) must be " f"divisible by cp_size ({cp_size})" ) window_size = (window_size_t * pHW) // cp_size sink_size = (sink_size_t * pHW) // cp_size return self.network.initialize_cache( chunk_size=chunk_size, window_size=window_size, sink_size=sink_size, text_embeddings=text_embeddings, img_embeddings=image_embeddings, )
[docs] @torch.no_grad() def initialize_autoregressive_cache( self, *, height: int, width: int, text_embeddings: Tensor, image_embeddings: Tensor | None = None, negative_text_embeddings: Tensor | None = None, **_unused: Any, ) -> Wan21TransformerCache: """Build a seeded transformer cache for a new rollout. I2V state is *not* baked into the cache; the latent + injection mask are passed per AR step as the ``input`` argument to ``predict_flow`` / ``postprocess_clean_latent``. Args: height: Pre-patchify latent height (post-VAE). width: Pre-patchify latent width (post-VAE). text_embeddings: Conditional UMT5 embeddings ``[..., text_len, text_dim]``. image_embeddings: Conditional CLIP image embeddings (only used by networks with ``cross_attn_enable_img=True``). Shared with the uncond branch. negative_text_embeddings: Negative-prompt embeddings. Required iff ``config.guidance_scale > 1.0``; must be ``None`` otherwise. Returns: Populated cache. ``network_cache_uncond`` is ``None`` iff CFG is disabled. """ # Stash the per-rollout spatial layout. ``_build_network_cache`` # below and ``latent_shape`` / ``unpatchify_and_maybe_gather_cp`` # at AR-step time read these. cfg = self.config kt, kh, kw = cfg.network.patch_size assert height % kh == 0 and width % kw == 0, ( f"(height, width) = ({height}, {width}) must be divisible by " f"patch_size={cfg.network.patch_size[1:]}." ) self._output_height = height self._output_width = width total_tokens = (cfg.len_t // kt) * (height // kh) * (width // kw) assert total_tokens % self._cp_size == 0, ( f"Wan token length ({total_tokens} from len_t={cfg.len_t}, " f"height={height}, width={width}, " f"patch_size={cfg.network.patch_size}) must be divisible by " f"cp_size={self._cp_size}" ) network_cache = self._build_network_cache( text_embeddings=text_embeddings, image_embeddings=image_embeddings, ) network_cache_uncond: WanDiTNetworkCache | None = None if self.config.guidance_scale > 1.0: assert negative_text_embeddings is not None, ( f"WanTransformerConfig.guidance_scale=" f"{self.config.guidance_scale} > 1.0 requires " f"negative_text_embeddings." ) network_cache_uncond = self._build_network_cache( text_embeddings=negative_text_embeddings, image_embeddings=image_embeddings, ) head_dim = self.config.network.dim // self.config.network.num_heads rope_kwargs: dict[str, Any] = { "len_t": cfg.len_t // kt, "len_h": height // kh, "len_w": width // kw, "head_dim": head_dim, "h_extrapolation_ratio": self.config.h_extrapolation_ratio, "w_extrapolation_ratio": self.config.w_extrapolation_ratio, "interleaved": True, "device": self.device, } if cfg.network.apply_rope_before_kvcache: rope_adapter = RotaryPositionEmbedding3D(**rope_kwargs) else: rope_kwargs["sink_size_t"] = cfg.sink_size_t // kt rope_kwargs["window_size_t"] = cfg.window_size_t // kt rope_adapter = KVCacheRelativeRotaryPositionEmbedding3D(**rope_kwargs) rope_adapter.set_context_parallel_group(cp_group=self._cp_group) # Reset any prior CUDA graph: it refers to slot pointers from the # previous cache, which the new cache invalidates. if self._use_cuda_graph: assert isinstance(self._network_call, CUDAGraphWrapper) self._network_call.reset() assert isinstance(self._network_call_uncond, CUDAGraphWrapper) self._network_call_uncond.reset() return Wan21TransformerCache( network_cache=network_cache, network_cache_uncond=network_cache_uncond, rope_adapter=rope_adapter, )
def _maybe_build_per_token_timestep( self, timestep: Tensor, input: I2VCtrl | None, autoregressive_index: int, ) -> Tensor: """Optionally rewrite ``timestep`` into a per-token tensor for TI2V. Off-path for everything except Wan 2.2 TI2V 5B AR-step 0 with a non-``None`` :class:`I2VCtrl`. When on-path, the scalar scheduler timestep is broadcast to ``[..., L]`` then zeroed at positions marked by the I2V mask, so the first-frame conditioning tokens see ``t=0`` while the rest of the chunk denoises at the current scheduler step. The post-patchify mask is constant across the patchified channel axis (the encoder fills a per-pixel binary mask, and patchify concatenates channel * kt * kh * kw entries that all share the same value), so ``mask[..., 0]`` recovers a per-token boolean without an ``any`` reduction. """ if not self.config.ti2v_first_frame_per_token_timestep: return timestep if autoregressive_index != 0: # CUDA-graph capture starts at AR ``_cuda_graph_capture_ar_idx``; # AR>=1 must keep the scalar shape stable across the captured # replay branch. return timestep if input is None: return timestep assert isinstance(input, I2VCtrl), ( "ti2v_first_frame_per_token_timestep requires the I2V control " f"payload to be an I2VCtrl (got {type(input).__name__})" ) per_token_mask = input.mask[..., 0] # [..., L] # Broadcast scalar / per-batch ``timestep`` to ``[..., L]`` and # blend with ``first_frame_timestep_value`` at masked positions. # Multiplying preserves the scheduler dtype so downstream # sinusoidal embedding stays bit-identical to the scalar path # on non-masked tokens. timestep = timestep.to(per_token_mask.device) mask = per_token_mask.to(timestep.dtype) first_frame_value = timestep.new_tensor(self.config.first_frame_timestep_value) return timestep.unsqueeze(-1) * (1.0 - mask) + first_frame_value * mask def _stamp_image_latent( self, latent: Tensor, control: I2VCtrl, ) -> Tensor: """Overwrite ``latent`` with the image latent at masked positions. All three tensors share the same patchified + CP-split shape, so this is a plain per-token blend ``(1 - m) * latent + m * control.latent``. """ return latent * (1.0 - control.mask) + control.latent * control.mask def _select_network(self, autoregressive_index: int, *, uncond: bool) -> Any: # Filling phase: eager ``.drain`` (drains Inductor autotune and # exercises the KV cache's slice-returning filling path). # Steady phase: ``wrapper.__call__`` (warmup + capture + replay). # Cond and CFG-uncond branches both mutate their rolling KV cache, # so neither branch can be graph-captured until the cache is steady. if not self._use_cuda_graph: return self.network network_call = self._network_call_uncond if uncond else self._network_call assert isinstance(network_call, CUDAGraphWrapper) return ( network_call.drain if autoregressive_index < self._cuda_graph_capture_ar_idx else network_call ) def _build_network_input( self, noisy_latent: Tensor, input: I2VCtrl | None, ) -> Tensor: """Apply the (optional) I2V stamp / channel-concat to the noisy latent. See :class:`Wan21TransformerConfig` for the two composable I2V modes. T2V (``input is None``) takes neither path. """ network_input = noisy_latent if self.config.stamp_image_latent: assert isinstance(input, I2VCtrl), ( "stamp_image_latent requires input to be an " f"I2VCtrl (got {type(input).__name__})" ) network_input = self._stamp_image_latent(network_input, input) if self.config.concat_image_mask_to_latent: assert isinstance(input, I2VCtrl), ( "concat_image_mask_to_latent requires input to be " f"an I2VCtrl (got {type(input).__name__})" ) # The patchified mask carries the encoder's 16-channel uniform # tag. Slicing the leading 16 entries recovers the 4-channel mask # the official 14B I2V network expects (4 ch * K=4 patch entries). mask = input.mask[..., :16] network_input = torch.cat([network_input, mask, input.latent], dim=-1) return network_input def _predict_flow( self, network_input: Tensor, timestep: Tensor, cache: Wan21TransformerCache, autoregressive_index: int, network_extra_kwargs: dict[str, Any], *, uncond: bool, ) -> Tensor: network_cache = cache.network_cache_uncond if uncond else cache.network_cache assert network_cache is not None, ( "uncond=True requires cache.network_cache_uncond, but it is None " "(CFG was not enabled at cache build time)." ) assert cache.rope_freqs is not None, ( "Wan21TransformerCache.start() must populate rope_freqs before predict_flow" ) return self._select_network(autoregressive_index, uncond=uncond)( x=network_input, timesteps=timestep, cache=network_cache, rope_freqs=cache.rope_freqs, current_chunk_idx=autoregressive_index, eager_mode=False, **network_extra_kwargs, )
[docs] def predict_flow( self, noisy_latent: Tensor, timestep: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None, network_extra_kwargs: dict[str, Any] | None = None, ) -> Tensor: """Predict the flow for one denoising step. ``timestep`` may be a scalar / per-batch tensor (standard Wan 2.1 / 14B path) or a per-token tensor with the same trailing token axis as ``noisy_latent`` (Wan 2.2 TI2V 5B first-frame seeding at AR step 0). The per-token layout flows through :meth:`WanDiTNetwork.forward`, which dispatches the sinusoidal embedding + AdaLN modulation on the native shape. CUDA-graph capture is shape-sensitive: the captured replay region only sees AR step ``>= self._cuda_graph_capture_ar_idx`` (steady state). TI2V 5B is configured with ``len_t == window_size_t`` so the threshold lands at AR 1, putting the per-token AR-0 step inside the eager ``.drain`` branch where shape changes are safe. After AR 0 the pipeline switches back to scalar timesteps, so the captured branch sees a single stable shape across all AR steps it owns. """ ar_idx = cache.autoregressive_index assert ar_idx >= 0, ( "Wan21TransformerCache.start(autoregressive_index) must be called " "before predict_flow (DiffusionModel.generate handles this)." ) network_extra_kwargs = network_extra_kwargs or {} network_input = self._build_network_input(noisy_latent, input) timestep = self._maybe_build_per_token_timestep( timestep=timestep, input=input, autoregressive_index=ar_idx ) flow_cond = self._predict_flow( network_input=network_input, timestep=timestep, cache=cache, autoregressive_index=ar_idx, network_extra_kwargs=network_extra_kwargs, uncond=False, ) if cache.network_cache_uncond is None: return flow_cond flow_uncond = self._predict_flow( network_input=network_input, timestep=timestep, cache=cache, autoregressive_index=ar_idx, network_extra_kwargs=network_extra_kwargs, uncond=True, ) return flow_uncond + self.config.guidance_scale * (flow_cond - flow_uncond)
[docs] def postprocess_clean_latent( self, clean_latent: Tensor, cache: Wan21TransformerCache, input: I2VCtrl | None = None, ) -> Tensor: """Re-stamp ``x0`` masked positions with the image latent (mask-inject I2V only). T2V and the channel-concat I2V mode fall through unchanged. """ if input is None or not self.config.stamp_image_latent: return clean_latent return self._stamp_image_latent(clean_latent, input)
@overload def patchify_and_maybe_split_cp(self, x: Tensor) -> Tensor: ... @overload def patchify_and_maybe_split_cp(self, x: I2VCtrl) -> I2VCtrl: ...
[docs] def patchify_and_maybe_split_cp(self, x: Tensor | I2VCtrl) -> Tensor | I2VCtrl: """Patchify and CP-split a noisy latent or an I2V control payload. Tensors delegate to the network helper; I2V payloads patchify the ``latent`` and ``mask`` fields independently so the per-field channel layouts are preserved for the mask-inject blend downstream. """ if isinstance(x, I2VCtrl): if x._is_patchified: return x return I2VCtrl( latent=self.patchify_and_maybe_split_cp(x.latent), mask=self.patchify_and_maybe_split_cp(x.mask), _is_patchified=True, ) return self.network.patchify_and_maybe_split_cp( x, process_groups=[self._cp_group], cp_dims=[-2], )
[docs] def unpatchify_and_maybe_gather_cp(self, x: Tensor) -> Tensor: assert self._output_height is not None and self._output_width is not None, ( "unpatchify_and_maybe_gather_cp requires an initialized rollout; " "call initialize_autoregressive_cache(..., height=..., width=...) first." ) _, kh, kw = self.config.network.patch_size return self.network.unpatchify_and_maybe_gather_cp( pH=self._output_height // kh, pW=self._output_width // kw, x=x, process_groups=[self._cp_group], cp_dims=[-2], )