# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TAEHV video decoder."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
import torch
from torch import Tensor
from flashdreams.infra.decoder import DecoderConfig, StreamingVideoDecoder
from flashdreams.recipes.taehv.checkpoint import (
StateDictTransform,
compose,
legacy_to_blocks_keys,
truncate_oversize_tgrow_weights,
)
from flashdreams.recipes.taehv.impl import TAEHV, TAEHVCache
AVAILABLE_TAEHV_CHECKPOINT_PATHS = {
"lighttae": "https://huggingface.co/lightx2v/Autoencoders/resolve/main/lighttaew2_1.pth",
}
"""Checkpoint paths for the TAEHV decoder."""
_LIGHTTAE_CHANNELS: tuple[int, int, int, int] = (256, 128, 64, 64)
"""TAEHV ``Decoder`` block widths the ``lighttae`` weights were trained
for. Mirrors :attr:`TAEHV.channels` for the default config; kept here so
:data:`lighttae_state_dict_transform` can pre-compute the per-``TGrow``
expected output channel widths without an instantiated model."""
lighttae_state_dict_transform: StateDictTransform = compose(
legacy_to_blocks_keys,
truncate_oversize_tgrow_weights(channels=_LIGHTTAE_CHANNELS),
)
"""Per-checkpoint remap for the ``lighttae`` weights. Rewrites the flat
``decoder.<i>.*`` keys to the current ``decoder.blocks.<i>.*`` layout
and clips the stride=2 ``TGrow`` weights at idx 7 down to the stride=1
slice the live model expects."""
[docs]
@dataclass(kw_only=True)
class TeahvVAEDecoderConfig(DecoderConfig):
"""Config for the TAEHV decoder."""
_target: type["TeahvVAEDecoder"] = field(default_factory=lambda: TeahvVAEDecoder)
checkpoint_path: str = AVAILABLE_TAEHV_CHECKPOINT_PATHS["lighttae"]
"""Path to a pretrained TAEHV checkpoint. Defaults to the ``lighttae`` weights."""
state_dict_transform: StateDictTransform | None = lighttae_state_dict_transform
"""Pre-load state-dict remap. Defaults to
:data:`lighttae_state_dict_transform`; ``None`` falls through to the
bare :class:`TAEHV` default (see
:meth:`~flashdreams.recipes.taehv.impl.TAEHV.load_from_checkpoint`)."""
dtype: torch.dtype = torch.bfloat16
"""Network parameter / activation dtype."""
use_cuda_graph: bool = True
"""Wrap the decoder forward in a CUDA graph for replay."""
use_compile: bool = True
"""``torch.compile(mode="max-autotune-no-cudagraphs")``."""
[docs]
class TeahvVAEDecoder(StreamingVideoDecoder[TAEHVCache]):
"""TAEHV (Tiny AutoEncoder for Hunyuan Video) decoder.
Forward input is a latent ``[..., Tl, Cl, Hl, Wl]``; output is a video
tensor ``[..., T, C, H, W]`` in ``[-1, 1]``.
Set ``torch.backends.cudnn.benchmark = True`` at process start for ~5%
extra on the eager seed/tail chunks.
"""
TEMPORAL_COMPRESSION_RATIO = TAEHV.TEMPORAL_COMPRESSION_RATIO
SPATIAL_COMPRESSION_RATIO = TAEHV.SPATIAL_COMPRESSION_RATIO
mean: Tensor
"""Per-channel latent mean buffer; registered only when ``need_scaled``."""
std: Tensor
"""Per-channel latent standard deviation buffer; registered only when ``need_scaled``."""
_LIGHTTAE_MEAN: tuple[float, ...] = (
-0.7571, -0.7089, -0.9113, 0.1075,
-0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632,
-0.1922, -0.9497, 0.2503, -0.2921,
) # fmt: skip
"""Per-channel mean for the ``lighttae`` checkpoint's latent scaling."""
_LIGHTTAE_STD: tuple[float, ...] = (
2.8184, 1.4541, 2.3275, 2.6558,
1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579,
1.6382, 1.1253, 2.8251, 1.9160,
) # fmt: skip
"""Per-channel standard deviation for the ``lighttae`` checkpoint's latent scaling."""
def __init__(self, config: TeahvVAEDecoderConfig) -> None:
super().__init__(config)
self.config: TeahvVAEDecoderConfig = config
self.need_scaled = "lighttae" in config.checkpoint_path
self.taehv = TAEHV(
checkpoint_path=config.checkpoint_path,
use_cuda_graph=config.use_cuda_graph,
use_compile=config.use_compile,
state_dict_transform=config.state_dict_transform,
).to(dtype=config.dtype)
if self.need_scaled:
self.register_buffer(
"mean",
torch.tensor(self._LIGHTTAE_MEAN, dtype=config.dtype).view(
1, 1, -1, 1, 1
),
persistent=False,
)
self.register_buffer(
"std",
torch.tensor(self._LIGHTTAE_STD, dtype=config.dtype).view(
1, 1, -1, 1, 1
),
persistent=False,
)
[docs]
def initialize_autoregressive_cache(self) -> TAEHVCache:
"""Return an empty streaming decoder cache."""
return self.taehv.prepare_cache()
[docs]
@torch.inference_mode()
def forward(
self,
input: Tensor,
autoregressive_index: int = 0,
cache: TAEHVCache | None = None,
) -> Tensor:
"""Decode a latent chunk to a video tensor in ``[-1, 1]``.
Args:
input: Latent of shape ``[..., Tl, Cl, Hl, Wl]``.
autoregressive_index: Unused by TAEHV; kept for the
:class:`~flashdreams.infra.decoder.StreamingDecoder` interface.
cache: Streaming decoder cache; created on the fly when ``None``.
Returns:
Video tensor of shape ``[..., T, C, H, W]`` in ``[-1, 1]``.
"""
if cache is None:
cache = self.initialize_autoregressive_cache()
assert input.ndim >= 4, (
f"Expected input to have shape [..., T, C, H, W] (ndim>=4), "
f"got ndim={input.ndim}"
)
*batch_shape, T, C, H, W = input.shape
batch_size = math.prod(batch_shape)
z = input.reshape(batch_size, T, C, H, W)
if self.need_scaled:
z = z * self.std
z = z + self.mean
x = self.taehv.decode(z, cache=cache).mul_(2).sub_(1)
return x.reshape(*batch_shape, *x.shape[1:])
@property
def temporal_compression_ratio(self) -> int:
"""Pixel frames / latent frames in steady state (AR >= 1)."""
return self.TEMPORAL_COMPRESSION_RATIO
@property
def spatial_compression_ratio(self) -> int:
"""Pixel side / latent side."""
return self.SPATIAL_COMPRESSION_RATIO
[docs]
def get_output_temporal_size(
self, autoregressive_index: int, input_temporal_size: int
) -> int:
"""Return pixel frame count from ``input_temporal_size`` latent frames.
AR 0 applies causal padding: the first latent frame yields one pixel
frame, remaining frames yield ``temporal_compression_ratio`` each.
AR >= 1 is a plain multiply."""
r = self.temporal_compression_ratio
if autoregressive_index == 0:
return 1 + (input_temporal_size - 1) * r
return input_temporal_size * r
if __name__ == "__main__":
import tyro
config = tyro.cli(TeahvVAEDecoderConfig)
model = config.setup()
print(model)