Source code for flashdreams.infra.decoder.base
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder interfaces.
Two flavours:
- :class:`StreamingDecoder` is stateful. ``forward(self, input,
autoregressive_index, cache)`` plus a per-rollout
:class:`StreamingDecoderCache`. Use for chunk-by-chunk streaming
decoders (e.g. the WAN VAE that maintains a temporal cache across AR
steps); also fine for stateless decoders — just return an empty
:class:`StreamingDecoderCache` and ignore ``autoregressive_index`` /
``cache`` in ``forward``.
- :class:`StreamingVideoDecoder` extends :class:`StreamingDecoder` with
the contracts a streaming pixel-video decoder always needs: spatial
and temporal compression ratios, plus AR-step-aware temporal size
mappers between latent and pixel space.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Generic
import torch.nn as nn
from typing_extensions import TypeVar
from flashdreams.infra.config import InstantiateConfig
[docs]
@dataclass(kw_only=True)
class StreamingDecoderCache:
"""Per-rollout cache for :class:`StreamingDecoder`.
Empty by default; subclass to add fields (e.g. temporal feature
buffers carried across AR steps).
"""
StreamingDecoderCacheT = TypeVar(
"StreamingDecoderCacheT",
bound=StreamingDecoderCache,
default=StreamingDecoderCache,
)
[docs]
class StreamingDecoder(ABC, nn.Module, Generic[StreamingDecoderCacheT]):
"""Streaming decoder, generic over the per-rollout cache type.
``forward`` is not pinned by the base. Streaming decoders called by
:class:`StreamInferencePipeline` must match its call shape:
``forward(self, input, autoregressive_index=0, cache=None)``.
"""
def __init__(self, config: "DecoderConfig") -> None:
super().__init__()
self.config = config
[docs]
@abstractmethod
def initialize_autoregressive_cache(self, **context: Any) -> StreamingDecoderCacheT:
"""Build a fresh per-rollout cache.
Override to return the decoder's concrete cache type.
"""
[docs]
class StreamingVideoDecoder(StreamingDecoder[StreamingDecoderCacheT]):
"""Streaming pixel-video decoder.
Pins down the contracts that every streaming latent→pixel video
decoder satisfies in addition to :class:`StreamingDecoder`:
- Spatial and temporal compression ratios between the latent and
pixel grids (constants of the architecture).
- AR-step-aware temporal size mappers, so a pipeline can size its
inputs and outputs without knowing the decoder's concrete
temporal cache topology (causal first-frame padding, sliding
windows, etc.).
Spatial scaling is trivially ``side * spatial_compression_ratio``
in either direction; the AR-step-asymmetric piece is the temporal
size, which gets its own mapper. Typically AR 0 produces fewer
pixel frames per latent frame than AR ≥ 1 because of causal
first-frame padding.
"""
@property
@abstractmethod
def spatial_compression_ratio(self) -> int:
"""Pixel side ÷ latent side. Constant across AR steps."""
@property
@abstractmethod
def temporal_compression_ratio(self) -> int:
"""Pixel frames ÷ latent frames in steady state (AR ≥ 1).
AR 0 typically yields fewer pixel frames per latent frame
because of causal first-frame padding; that asymmetry lives
inside :meth:`get_output_temporal_size` /
:meth:`get_input_temporal_size`.
"""
[docs]
@abstractmethod
def get_output_temporal_size(
self,
autoregressive_index: int,
input_temporal_size: int,
) -> int:
"""Pixel frame count produced by ``input_temporal_size`` latent frames.
Args:
autoregressive_index: AR step index (0-based).
input_temporal_size: Number of latent frames fed at this step.
Returns:
Number of pixel frames emitted at this step.
"""
@dataclass(kw_only=True)
class DecoderConfig(InstantiateConfig):
"""Category base for every decoder config."""
_target: type["StreamingDecoder"] = field(default_factory=lambda: StreamingDecoder)