# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Autoregressive inference pipeline: encode → diffuse → decode."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Generic
import torch
import torch.nn as nn
from loguru import logger
from torch import Tensor
from flashdreams.infra.config import InstantiateConfig
from flashdreams.infra.decoder import (
DecoderConfig,
StreamingDecoder,
StreamingDecoderCacheT,
)
from flashdreams.infra.diffusion.model import (
DiffusionModel,
DiffusionModelConfig,
)
from flashdreams.infra.diffusion.transformer import (
TransformerCacheT,
)
from flashdreams.infra.encoder import (
EncoderConfig,
StreamingEncoder,
StreamingEncoderCacheT,
)
from flashdreams.infra.profiler import EventProfiler
[docs]
@dataclass(kw_only=True)
class StreamInferencePipelineConfig(InstantiateConfig):
"""Config for the streaming inference pipeline.
Set ``encoder=None`` when the pipeline has no per-AR-step control input
(pure T2V). Set ``decoder=None`` to return the clean latent directly
(training, latent-space evaluation, or pipelines that own decoding).
"""
_target: type["StreamInferencePipeline"] = field(
default_factory=lambda: StreamInferencePipeline
)
name: str
"""Stable slug for this pipeline variant; the primary key of
``<NAME>_CONFIGS``. Runners mirror it as ``runner_name`` so
``flashdreams-run <slug>`` resolves to this pipeline."""
diffusion_model: DiffusionModelConfig
"""Transformer + scheduler config."""
decoder: DecoderConfig | None = None
"""Optional output :class:`StreamingDecoder` with a per-rollout cache,
called as ``decoder(input, autoregressive_index, cache)``. Use
``None`` to return the clean latent unchanged."""
encoder: EncoderConfig | None = None
"""Optional per-AR-step input encoder. Must be a
:class:`StreamingEncoder`; one-shot encoders go on
``transformer.context_encoder`` instead."""
enable_sync_and_profile: bool = False
"""Record per-stage CUDA events and log timing per AR step. Calls
``torch.cuda.synchronize()`` once per step, which hurts throughput."""
[docs]
@dataclass(kw_only=True)
class StreamInferencePipelineCache(
Generic[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT]
):
"""Per-rollout cache held by the pipeline."""
transformer_cache: TransformerCacheT
"""Long-lived transformer AR cache (always present)."""
encoder_cache: StreamingEncoderCacheT | None = None
"""Encoder AR cache; ``None`` iff the pipeline has no encoder."""
decoder_cache: StreamingDecoderCacheT | None = None
"""Decoder AR cache; ``None`` iff the pipeline has no decoder."""
final_state: "DiffusionModel.FinalState[TransformerCacheT] | None" = None
"""Diffusion-model state from the most recent ``generate``, consumed
by ``finalize``."""
autoregressive_index: int | None = None
"""AR step index of the most recent ``generate``."""
event_profiler: EventProfiler | None = None
"""Per-step profiler, populated only when profiling is on."""
[docs]
class StreamInferencePipeline(
nn.Module,
Generic[
StreamingEncoderCacheT,
TransformerCacheT,
StreamingDecoderCacheT,
],
):
"""End-to-end streaming inference pipeline.
Generic over the encoder, transformer, and decoder cache types. The
encoder's input/output types are forwarded as ``Any`` so the
transformer's ``predict_flow`` / ``postprocess_clean_latent`` overrides
own the typing on the ``input`` argument they receive.
Examples:
cache = pipeline.initialize_cache(transformer_context={...})
output = pipeline.generate(0, cache, input=...)
pipeline.finalize(0, cache)
output = pipeline.generate(1, cache, input=...)
pipeline.finalize(1, cache) # optional for the last rollout
"""
encoder: StreamingEncoder[StreamingEncoderCacheT] | None
decoder: StreamingDecoder[StreamingDecoderCacheT] | None
diffusion_model: DiffusionModel[TransformerCacheT]
def __init__(self, config: StreamInferencePipelineConfig) -> None:
super().__init__()
self.config = config
self.encoder = config.encoder.setup() if config.encoder is not None else None
self.decoder = config.decoder.setup() if config.decoder is not None else None
self.diffusion_model = config.diffusion_model.setup()
@property
def device(self) -> torch.device:
return self.diffusion_model.device
[docs]
def initialize_cache(
self,
transformer_context: dict[str, Any] | None = None,
encoder_context: dict[str, Any] | None = None,
decoder_context: dict[str, Any] | None = None,
) -> StreamInferencePipelineCache[
StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT
]:
"""Build a fresh per-rollout cache.
Each ``*_context`` dict is forwarded as keyword arguments to the
corresponding component's ``initialize_autoregressive_cache``.
Args:
transformer_context: Per-rollout state for the transformer
(e.g. ``{"text_embeddings": ..., "image_embeddings": ...}``).
encoder_context: Per-rollout state for the encoder. Ignored
when there is no encoder.
decoder_context: Per-rollout state for the decoder. Ignored
when there is no decoder.
Returns:
A fresh cache to thread through ``generate`` / ``finalize``.
"""
transformer_context = transformer_context or {}
encoder_context = encoder_context or {}
decoder_context = decoder_context or {}
return StreamInferencePipelineCache(
encoder_cache=(
self.encoder.initialize_autoregressive_cache(**encoder_context)
if self.encoder is not None
else None
),
decoder_cache=(
self.decoder.initialize_autoregressive_cache(**decoder_context)
if self.decoder is not None
else None
),
transformer_cache=self.diffusion_model.transformer.initialize_autoregressive_cache(
**transformer_context
),
)
[docs]
@torch.no_grad()
def generate(
self,
autoregressive_index: int,
cache: StreamInferencePipelineCache[
StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT
],
input: Any = None,
) -> Tensor:
"""Generate one chunk for this AR step.
Args:
autoregressive_index: Must be ``cache.autoregressive_index + 1``,
or ``0`` for the first call after ``initialize_cache``.
cache: Per-rollout cache from ``initialize_cache``.
input: Raw input fed to the encoder. Required when an encoder
is configured, must be ``None`` otherwise. Use
``NullEncoderConfig`` to pass an already-encoded tensor
straight through.
Returns:
Decoded tensor (e.g. RGB video) when a decoder is configured;
otherwise the unpatchified clean latent from the diffusion model.
"""
prev = cache.autoregressive_index
expected = (prev + 1) if prev is not None else 0
assert autoregressive_index == expected, (
f"AR step out of order: previous step was {prev}, expected next "
f"{expected}, got {autoregressive_index}"
)
cache.autoregressive_index = autoregressive_index
events: EventProfiler | None = None
if self.config.enable_sync_and_profile:
events = EventProfiler()
cache.event_profiler = events
if input is not None:
assert self.encoder is not None, (
"input was provided but the pipeline has no encoder. "
"Configure StreamInferencePipelineConfig.encoder (e.g. with "
"NullEncoderConfig() for an identity passthrough)."
)
assert cache.encoder_cache is not None # invariant: paired with encoder
input = self.encoder(
input=input,
autoregressive_index=autoregressive_index,
cache=cache.encoder_cache,
)
if events is not None:
events.record("encode")
clean_latent, final_state = self.diffusion_model.generate(
autoregressive_index=autoregressive_index,
cache=cache.transformer_cache,
input=input,
)
cache.final_state = final_state
if events is not None:
events.record("diffuse")
if self.decoder is not None:
assert cache.decoder_cache is not None # invariant: paired with decoder
output = self.decoder(
input=clean_latent,
autoregressive_index=autoregressive_index,
cache=cache.decoder_cache,
)
else:
output = clean_latent
if events is not None:
events.record("decode")
return output
[docs]
@torch.no_grad()
def finalize(
self,
autoregressive_index: int,
cache: StreamInferencePipelineCache[
StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT
],
) -> dict[str, float] | None:
"""Advance the diffusion AR cache for the next AR step.
Args:
autoregressive_index: Must match the index passed to the most
recent ``generate`` (asserted).
cache: Same cache used by ``generate``. Consumes
``cache.final_state``.
Returns:
``None`` when profiling is disabled. Otherwise a snapshot of this
AR step's per-stage timings (ms) and GPU memory (GiB):
``{<stage>_ms, total_ms, total_ms_wo_finalize, mem_alloc_gib,
mem_reserved_gib, mem_peak_gib}``. The same numbers are also
logged via ``logger.info``.
"""
assert cache.autoregressive_index == autoregressive_index, (
f"autoregressive_index mismatch: generate() ran with "
f"{cache.autoregressive_index} but finalize() was called with "
f"{autoregressive_index}."
)
assert cache.final_state is not None, (
"finalize() called before generate() — no FinalState on the cache."
)
self.diffusion_model.finalize(final_state=cache.final_state)
if not self.config.enable_sync_and_profile:
return None
assert cache.event_profiler is not None, (
"finalize() called before any generate() — no EventProfiler on the cache."
)
cache.event_profiler.record("finalize")
stats_ms = cache.event_profiler.sync_and_summarize()
total_ms = sum(stats_ms.values())
total_ms_wo_finalize = total_ms - stats_ms.get("finalize", 0.0)
stages_str = " ".join(f"{stage} {ms:.3f} ms" for stage, ms in stats_ms.items())
stats: dict[str, float] = {f"{stage}_ms": ms for stage, ms in stats_ms.items()}
stats["total_ms"] = total_ms
stats["total_ms_wo_finalize"] = total_ms_wo_finalize
mem_str = ""
if torch.cuda.is_available():
device = torch.cuda.current_device()
gib = 1024**3
mem_alloc_gib = torch.cuda.memory_allocated(device) / gib
mem_reserved_gib = torch.cuda.memory_reserved(device) / gib
mem_peak_gib = torch.cuda.max_memory_allocated(device) / gib
stats["mem_alloc_gib"] = mem_alloc_gib
stats["mem_reserved_gib"] = mem_reserved_gib
stats["mem_peak_gib"] = mem_peak_gib
mem_str = (
f" | GPU mem alloc {mem_alloc_gib:.3f} GiB "
f"reserved {mem_reserved_gib:.3f} GiB "
f"peak {mem_peak_gib:.3f} GiB"
)
logger.info(
f"AR {autoregressive_index} {stages_str} | "
f"total(w/o finalize) {total_ms_wo_finalize:.3f} ms "
f"total {total_ms:.3f} ms"
f"{mem_str}"
)
return stats