Infra

The flashdreams.infra package defines the swappable abstractions that every integration plugs into: a config system, an encoder / diffusion-model / decoder triple, and the streaming inference pipeline that drives them.

Config

Every component is built from a frozen InstantiateConfig dataclass via config.setup(). This makes the full configuration tree printable, hashable, and trivially serialisable.

class PrintableConfig[source]

Bases: object

Config base class providing a multi-line __str__ for human-readable dumps.

class InstantiateConfig(_target: type)[source]

Bases: PrintableConfig

Config carrying a _target class plus its kwargs, instantiable via setup.

setup(**kwargs: Any) Any[source]

Instantiate the configured object.

derive_config(base_config: T, **changes: Any) T[source]

Deep-copy a base config and apply nested keyword overrides.

Nested dict values walk into both dataclass attributes and nested dicts; leaf values overwrite directly. Raises KeyError on unknown paths.

Example:

new_config = derive_config(
    base_config,
    tokenizer=WanVAEInterfaceConfig(checkpoint_path=...),
    dit=dict(len_t=3, checkpoint_path=...),
)

Pipeline

The pipeline is the top-level streaming inference loop. It autoregressively generates one chunk of latent video at a time by running the encoder, the diffusion model, and the decoder back-to-back, threading per-chunk caches through every component.

class StreamInferencePipelineConfig(*, _target: type[StreamInferencePipeline] = <factory>, name: str, diffusion_model: DiffusionModelConfig, decoder: DecoderConfig | None = None, encoder: EncoderConfig | None = None, enable_sync_and_profile: bool = False)[source]

Bases: InstantiateConfig

Config for the streaming inference pipeline.

Set encoder=None when the pipeline has no per-AR-step control input (pure T2V). Set decoder=None to return the clean latent directly (training, latent-space evaluation, or pipelines that own decoding).

name: str

Stable slug for this pipeline variant; the primary key of <NAME>_CONFIGS. Runners mirror it as runner_name so flashdreams-run <slug> resolves to this pipeline.

diffusion_model: DiffusionModelConfig

Transformer + scheduler config.

decoder: DecoderConfig | None = None

Optional output StreamingDecoder with a per-rollout cache, called as decoder(input, autoregressive_index, cache). Use None to return the clean latent unchanged.

encoder: EncoderConfig | None = None

Optional per-AR-step input encoder. Must be a StreamingEncoder; one-shot encoders go on transformer.context_encoder instead.

enable_sync_and_profile: bool = False

Record per-stage CUDA events and log timing per AR step. Calls torch.cuda.synchronize() once per step, which hurts throughput.

class StreamInferencePipeline(config: StreamInferencePipelineConfig)[source]

Bases: Module, Generic[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT]

End-to-end streaming inference pipeline.

Generic over the encoder, transformer, and decoder cache types. The encoder’s input/output types are forwarded as Any so the transformer’s predict_flow / postprocess_clean_latent overrides own the typing on the input argument they receive.

Examples

cache = pipeline.initialize_cache(transformer_context={…}) output = pipeline.generate(0, cache, input=…) pipeline.finalize(0, cache) output = pipeline.generate(1, cache, input=…) pipeline.finalize(1, cache) # optional for the last rollout

initialize_cache(transformer_context: dict[str, Any] | None = None, encoder_context: dict[str, Any] | None = None, decoder_context: dict[str, Any] | None = None) StreamInferencePipelineCache[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT][source]

Build a fresh per-rollout cache.

Each *_context dict is forwarded as keyword arguments to the corresponding component’s initialize_autoregressive_cache.

Parameters:
  • transformer_context – Per-rollout state for the transformer (e.g. {"text_embeddings": ..., "image_embeddings": ...}).

  • encoder_context – Per-rollout state for the encoder. Ignored when there is no encoder.

  • decoder_context – Per-rollout state for the decoder. Ignored when there is no decoder.

Returns:

A fresh cache to thread through generate / finalize.

generate(autoregressive_index: int, cache: StreamInferencePipelineCache[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT], input: Any = None) Tensor[source]

Generate one chunk for this AR step.

Parameters:
  • autoregressive_index – Must be cache.autoregressive_index + 1, or 0 for the first call after initialize_cache.

  • cache – Per-rollout cache from initialize_cache.

  • input – Raw input fed to the encoder. Required when an encoder is configured, must be None otherwise. Use NullEncoderConfig to pass an already-encoded tensor straight through.

Returns:

Decoded tensor (e.g. RGB video) when a decoder is configured; otherwise the unpatchified clean latent from the diffusion model.

finalize(autoregressive_index: int, cache: StreamInferencePipelineCache[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT]) dict[str, float] | None[source]

Advance the diffusion AR cache for the next AR step.

Parameters:
  • autoregressive_index – Must match the index passed to the most recent generate (asserted).

  • cache – Same cache used by generate. Consumes cache.final_state.

Returns:

None when profiling is disabled. Otherwise a snapshot of this AR step’s per-stage timings (ms) and GPU memory (GiB): {<stage>_ms, total_ms, total_ms_wo_finalize, mem_alloc_gib, mem_reserved_gib, mem_peak_gib}. The same numbers are also logged via logger.info.

class StreamInferencePipelineCache(*, transformer_cache: TransformerCacheT, encoder_cache: StreamingEncoderCacheT | None = None, decoder_cache: StreamingDecoderCacheT | None = None, final_state: FinalState[TransformerCacheT] | None = None, autoregressive_index: int | None = None, event_profiler: EventProfiler | None = None)[source]

Bases: Generic[StreamingEncoderCacheT, TransformerCacheT, StreamingDecoderCacheT]

Per-rollout cache held by the pipeline.

transformer_cache: TransformerCacheT

Long-lived transformer AR cache (always present).

encoder_cache: StreamingEncoderCacheT | None = None

Encoder AR cache; None iff the pipeline has no encoder.

decoder_cache: StreamingDecoderCacheT | None = None

Decoder AR cache; None iff the pipeline has no decoder.

final_state: FinalState[TransformerCacheT] | None = None

Diffusion-model state from the most recent generate, consumed by finalize.

autoregressive_index: int | None = None

AR step index of the most recent generate.

event_profiler: EventProfiler | None = None

Per-step profiler, populated only when profiling is on.

Diffusion model

Wraps a transformer backbone with a denoising scheduler. Callers see only noise clean_latent; the per-step flow prediction and the iteration loop are hidden inside generate().

class DiffusionModelConfig(*, _target: type[DiffusionModel] = <factory>, transformer: TransformerConfig, scheduler: SchedulerConfig, seed: int | None = None, context_noise: int = 0, noise_in_unpatchified_shape: bool = False)[source]

Bases: InstantiateConfig

Config for the autoregressive diffusion model.

transformer: TransformerConfig

Flow-prediction network config.

scheduler: SchedulerConfig

Denoising-loop config.

seed: int | None = None

RNG seed for initial-noise draws and scheduler sampling. None uses the global RNG.

context_noise: int = 0

Timestep used by finalize for the AR cache-update forward. 0 skips add_noise.

noise_in_unpatchified_shape: bool = False

draw the initial noise in the unpatchified shape, then patchify. Slower than the default patchified path; useful when matching another implementation’s RNG sequence.

Type:

Debug-only

class DiffusionModel(config: DiffusionModelConfig)[source]

Bases: Module, Generic[TransformerCacheT]

Autoregressive diffusion model (scheduler + transformer).

Generic over the transformer’s AR cache type so user-facing typing on cache is preserved end-to-end.

Examples

model = config.setup().to(“cuda”) cache = model.transformer.initialize_autoregressive_cache(…) clean, final_state = model.generate(autoregressive_index=0, cache=cache) model.finalize(final_state)

class FinalState(*, clean_latent: Tensor, autoregressive_index: int, cache: _FinalStateCacheT, input: Any = None)[source]

Bases: Generic[_FinalStateCacheT]

State passed from generate to finalize.

clean_latent: Tensor

Patchified clean latent at the end of denoising.

autoregressive_index: int

AR step this state was produced at.

cache: _FinalStateCacheT

Long-lived AR cache used during generation.

input: Any = None

Patchified per-AR-step encoder output, or None.

property rng: Generator | None

Per-model generator, lazily built on the current device.

Returns None when config.seed is None. Rebuilt the first time the model’s device changes after a .to(...). A device move resets the RNG stream — fine for “construct on CPU, .to(gpu) once” but mid-rollout device hops lose RNG state.

generate(autoregressive_index: int, cache: TransformerCacheT, input: Any = None) tuple[Tensor, FinalState[TransformerCacheT]][source]

Run the denoising loop for one AR step.

Parameters:
  • autoregressive_index – AR step index.

  • cache – Long-lived AR cache, mutated in place.

  • input – Optional per-AR-step encoder output. Patchified here and forwarded to predict_flow / postprocess_clean_latent, then stashed on the returned FinalState for finalize.

Returns:

(clean_latent, final_state). clean_latent is unpatchified; final_state should be passed to finalize.

finalize(final_state: FinalState[TransformerCacheT]) None[source]

Advance the AR cache using the clean latent from generate.

Re-noises the clean latent to config.context_noise and runs the transformer’s finalize_kv_cache (one forward for vanilla transformers, multiple for dual-network DiTs).

context_noise == 0 skips add_noise (sigma=0 is identity) and feeds the clean latent directly. This also dodges the requirement for schedulers to support a t=0 lookup (UniPC’s inference schedule has no t=0 entry).

Transformer

class Transformer(config: TransformerConfig)[source]

Bases: Module, ABC, Generic[TransformerCacheT]

Flow-prediction transformer, generic over its AR cache subclass.

Subclasses implement predict_flow and the patchify hooks. AR transformers also subclass TransformerAutoregressiveCache and override initialize_autoregressive_cache.

Example:

class MyTransformer(Transformer[MyCache]):
    def predict_flow(self, noisy_latent, timestep, cache, input=None):
        ...

    def initialize_autoregressive_cache(self, **context) -> MyCache:
        ...
abstract property latent_shape: tuple[int, ...]

Shape of the input/output latent tensor for this rank.

Includes batch dims. May depend on hierarchical context-parallel group sizes (V/T/HW), so subclasses typically derive this from self.cp_groups rather than from static config alone.

abstractmethod predict_flow(noisy_latent: Tensor, timestep: Tensor, cache: TransformerCacheT, input: Any = None) Tensor[source]

Predict the flow at timestep.

Parameters:
  • noisy_latent – Patchified noisy latent for this denoising step.

  • timestep – Scalar timestep tensor.

  • cache – Per-rollout AR cache.

  • input – Patchified encoder output for this AR step, or None when the pipeline has no encoder. Subclasses should narrow the type to their encoder’s output type.

Returns:

Predicted flow tensor with the same shape as noisy_latent.

finalize_kv_cache(noisy_latent: Tensor, timestep: Tensor, cache: TransformerCacheT, input: Any = None) None[source]

Advance the AR cache so it is ready for the next AR step.

Called by DiffusionModel.finalize after the denoising loop; the flow is discarded — only the cache side effect matters. Default runs a single predict_flow forward. Override for transformers with multiple parallel networks that must stay in lock-step.

Parameters:
  • noisy_latent – Patchified latent at the AR-step’s context noise (or the clean latent when context_noise == 0).

  • timestep – 0-d context-noise timestep tensor.

  • cache – Per-rollout AR cache.

  • input – Same patchified encoder output passed to predict_flow.

initialize_autoregressive_cache(**context: Any) TransformerCacheT[source]

Build a fresh AR cache for a new rollout.

Default returns an empty cache, correct for non-AR transformers. Subclasses with custom cache types must override this and may declare typed per-rollout context (e.g. text_embeddings).

Parameters:

context – Per-rollout state forwarded as keyword arguments.

Returns:

Fresh AR cache.

postprocess_clean_latent(clean_latent: Tensor, cache: TransformerCacheT, input: Any = None) Tensor[source]

Optional postprocessing hook for the predicted clean latent.

Default is identity. Override to clamp or re-inject regions whose clean value is known a priori (e.g. I2V first-frame pinning). Called at the end of DiffusionModel.generate.

Parameters:
  • clean_latent – Patchified x0 from the denoising loop.

  • cache – Per-rollout AR cache.

  • input – Same patchified encoder output passed to predict_flow.

Returns:

Postprocessed clean latent with the same shape.

abstractmethod patchify_and_maybe_split_cp(x: Any) Any[source]

Patchify and (optionally) CP-split a noisy latent or encoder payload.

Tensors patchify and split. Structured payloads (e.g. an image-control struct with latent + mask) patchify each tensor field and return the same struct type. Implement as identity when neither token packing nor CP sharding applies. Output preserves the input Python type — only shapes change.

abstractmethod unpatchify_and_maybe_gather_cp(x: Tensor) Tensor[source]

Inverse of patchify_and_maybe_split_cp for the network output.

class TransformerAutoregressiveCache[source]

Bases: object

Cache that persists across an AR rollout.

Empty by default; safe to instantiate directly for non-AR transformers (default start / finalize are no-ops). Subclass and add fields plus AR bookkeeping for real per-rollout state.

Example:

cache.start(autoregressive_index)
# one or more denoising steps...
cache.finalize(autoregressive_index)
start(autoregressive_index: int) None[source]

Mark the start of an AR step. Default is a no-op.

Parameters:

autoregressive_index – Index of the AR step being started.

finalize(autoregressive_index: int) None[source]

Finalize bookkeeping after use at this AR step. Default is a no-op.

Parameters:

autoregressive_index – Index of the AR step just finalized.

Schedulers

A scheduler owns the entire denoising loop. It is shape-agnostic: every internal op is a broadcast against per-step scalar sigmas, so the same scheduler works for any latent layout.

class Scheduler(config: SchedulerConfig)[source]

Bases: Module, ABC

Denoising scheduler.

Owns the entire denoising loop. Callers see only noise clean; the loop shape (renoise / multistep / plain ODE) is private.

Concrete configs inherit SchedulerConfig and declare their own num_inference_steps / shift fields (the base holds no shared dataclass fields).

Examples

scheduler = config.setup() clean = scheduler.sample(initial_noise=noise, predict_flow=predictor) noisy = scheduler.add_noise(clean_input=clean, timestep=t)

abstractmethod sample(initial_noise: Tensor, predict_flow: FlowPredictor, rng: Generator | None = None) Tensor[source]

Run the full denoising loop and return the clean latent.

Schedulers are shape-agnostic: every internal op broadcasts against per-step scalar sigmas. In practice initial_noise is a video latent [B, C, T, H, W], conventionally treated as a sample at sigma=1.

Parameters:
  • initial_noise – Gaussian noise on the caller’s device/dtype.

  • predict_flow – Per-step closure invoked num_inference_steps times.

  • rng – Generator on the same device. Used by self-forcing renoise loops; pure ODE solvers ignore it.

Returns:

Clean latent with the same shape, device, and dtype as initial_noise.

abstractmethod add_noise(clean_input: Tensor, timestep: Tensor, rng: Generator | None = None) Tensor[source]

Apply the forward corruption x_t = (1 - sigma(t)) * x_0 + sigma(t) * eps.

Timestep value semantics are scheduler-specific: all schedulers snap to the nearest entry of their inference schedule.

class FlowPredictor(*args, **kwargs)[source]

Bases: Protocol

Closure (noisy_latent, timestep) -> predicted_flow.

Built by DiffusionModel.generate by binding the per-AR-step cache / input to the transformer’s predict_flow. A scheduler invokes it once per denoising iteration. The scheduler decides the timestep dtype (UniPC uses int64, flow-match uses float).

class FlowMatchSchedulerConfig(*, _target: type[FlowMatchScheduler] = <factory>, num_inference_steps: int = 4, shift: float = 8.0, denoising_timesteps: list[int] = <factory>, warp_denoising_step: bool = True, num_train_timesteps: int = 1000, sigma_max: float = 1.0, sigma_min: float = 0.0, extra_one_step: bool = True, timestep_dtype: dtype = torch.float32, enable_tqdm: bool = False)[source]

Bases: SchedulerConfig

Config for the flow-matching scheduler.

num_inference_steps: int = 4

Must equal len(denoising_timesteps).

shift: float = 8.0

Schedule warp factor.

denoising_timesteps: list[int]

Per-step diffusion timesteps in [0, num_train_timesteps].

warp_denoising_step: bool = True

Map denoising_timesteps through the warped sigma schedule.

num_train_timesteps: int = 1000

Length of the training sigma table.

sigma_max: float = 1.0

Top of the linspace before warping; 1.0 matches DiffSynth, upstream Wan / Lingbot ships 0.999.

sigma_min: float = 0.0

Bottom of the linspace before warping. Reserved for upstream parity; only 0.0 is exercised.

extra_one_step: bool = True

If True, build the schedule from linspace(sigma_max, sigma_min, N+1)[:-1] (matches DiffSynth / upstream Wan); False uses N points and is kept for non-Wan recipes.

timestep_dtype: dtype = torch.float32

Dtype of denoising_step_list. Set to an integer dtype (e.g. torch.int64) when the network’s time embedding is sensitive to the fractional part of the warped timestep — upstream Wan stores scheduler.timesteps as int64 and lets the embedding upcast to float64 internally.

enable_tqdm: bool = False

Whether to enable tqdm progress bar.

class FlowMatchScheduler(config: FlowMatchSchedulerConfig)[source]

Bases: Scheduler

Flow-matching scheduler with self-forcing renoise (DiffSynth-style).

Each iteration converts the predicted flow to an x0 estimate, then re-noises at the same sigma to feed the next iteration. The final x0 is returned:

x_t = initial_noise
for t in denoising_step_list:
    v = predict_flow(x_t, t)
    x0 = x_t - sigma(t) * v
    x_t = (1 - sigma(t)) * x0 + sigma(t) * eps
return x0

Example:

scheduler = FlowMatchSchedulerConfig(
    num_inference_steps=4,
    shift=8.0,
    denoising_timesteps=[1000, 750, 500, 250],
).setup().to("cuda")
clean = scheduler.sample(initial_noise=noise, predict_flow=fn)

Schedule buffers are pinned to fp32 even after module.to(bf16); integer timesteps like 1000 would otherwise round to 1024.

sample(initial_noise: Tensor, predict_flow: FlowPredictor, rng: Generator | None = None) Tensor[source]

Run the self-forcing flow-match denoising loop.

Iteration 0 trusts initial_noise as the sigma=1 sample; later iterations re-noise the previous x0 estimate to the new sigma before the network forward. Schedule arithmetic auto-promotes to fp32; the result is cast back to initial_noise.dtype.

add_noise(clean_input: Tensor, timestep: Tensor, rng: Generator | None = None) Tensor[source]

Apply the forward corruption at an arbitrary timestep.

Snaps timestep to the nearest entry of the warped training table and uses it as sigma in the standard lerp.

class FlowMatchUniPCSchedulerConfig(*, _target: type[FlowMatchUniPCScheduler] = <factory>, num_inference_steps: int = 50, shift: float = 5.0, num_train_timesteps: int = 1000, solver_order: int = 2, use_kerras_sigma: bool = False, enable_tqdm: bool = False)[source]

Bases: SchedulerConfig

Config for the flow-matching UniPC scheduler.

Defaults match the official Wan 2.1 inference integration (UniPC, BH2, order 2, shift 5.0). Override shift per checkpoint as recommended upstream (e.g. 3.0 for Wan 2.1 14B I2V 480P).

num_inference_steps: int = 50

Number of UniPC denoising steps.

shift: float = 5.0

Schedule warp factor.

num_train_timesteps: int = 1000

Length of the training sigma table.

solver_order: int = 2

UniPC solver order; only 2 is supported.

use_kerras_sigma: bool = False

Whether to use the exact sigma used in edm sampler.

enable_tqdm: bool = False

Whether to enable tqdm progress bar.

class FlowMatchUniPCScheduler(config: FlowMatchUniPCSchedulerConfig)[source]

Bases: Scheduler

Order-2 UniPC predictor-corrector for flow-matching.

Specialized + pre-baked variant of the upstream Wan 2.1 UniPC solver. Schedule buffers (sigmas + per-step coefficients) stay fp32 regardless of module.to(dtype).

Example:

scheduler = FlowMatchUniPCSchedulerConfig(
    num_inference_steps=50,
    shift=5.0,
).setup().to("cuda")
clean = scheduler.sample(initial_noise=noise, predict_flow=fn)
sample(initial_noise: Tensor, predict_flow: FlowPredictor, rng: Generator | None = None) Tensor[source]

Run the order-2 UniPC predictor-corrector denoising loop.

Each iteration: network → flow → x0 → corrector (skipped at step 0) → predictor. All per-step coefficients are pre-baked at construction; the loop is pure tensor ops. Internal arithmetic is fp32; the result is cast back to initial_noise.dtype. rng is unused (deterministic ODE) but accepted for interface conformance.

add_noise(clean_input: Tensor, timestep: Tensor, rng: Generator | None = None) Tensor[source]

Apply the forward corruption at an arbitrary timestep.

Snaps timestep to the nearest entry of the inference schedule on-device (no Python sync) and uses the matching sigma in the lerp.

Encoder

Encoders turn raw conditioning (text prompts, reference images, per-AR-step control inputs, …) into latent tensors. Two flavours:

  • Encoder is stateless and one-shot. forward(self, input). Used as transformer.context_encoder for text / CLIP-image / identity.

  • StreamingEncoder is stateful and per-AR-step. forward(self, input, autoregressive_index, cache) with an StreamingEncoderCache. Used as pipeline.encoder for per-step control (HDMap, camera trajectory, I2V first-frame VAE).

  • StreamingVideoEncoder extends StreamingEncoder with the contracts a streaming pixel-video encoder always needs: spatial / temporal compression ratios plus AR-step-aware temporal size mappers between pixel and latent space.

class Encoder(config: EncoderConfig)[source]

Bases: ABC, Module

Stateless encoder.

forward is not pinned by the base. Encoders used as a context_encoder (one-shot, called once inside Transformer.initialize_autoregressive_cache()) must match the slim call shape forward(self, input).

For per-AR-step encoders that need a per-rollout cache, inherit from StreamingEncoder instead.

class StreamingEncoder(config: EncoderConfig)[source]

Bases: ABC, Module, Generic[StreamingEncoderCacheT]

Streaming encoder, generic over the per-rollout cache type.

forward is not pinned by the base. Streaming encoders called by StreamInferencePipeline must match its call shape: forward(self, input, autoregressive_index=0, cache=None).

abstractmethod initialize_autoregressive_cache(**context: Any) StreamingEncoderCacheT[source]

Build a fresh per-rollout cache.

Override to return the encoder’s concrete cache type.

class StreamingVideoEncoder(config: EncoderConfig)[source]

Bases: StreamingEncoder[StreamingEncoderCacheT]

Streaming pixel-video encoder.

Pins down the contracts that every streaming pixel→latent video encoder satisfies in addition to StreamingEncoder:

  • Spatial and temporal compression ratios between the pixel and latent grids (constants of the architecture).

  • AR-step-aware temporal size mappers, so a pipeline can size its inputs and outputs without knowing the encoder’s concrete temporal cache topology (causal first-frame padding, sliding windows, etc.).

Spatial scaling is trivially side // spatial_compression_ratio in either direction; the AR-step-asymmetric piece is the temporal size, which gets its own mapper. Typically AR 0 takes 1 + (T_lat - 1) * r pixel frames to produce T_lat latent frames because of causal first-frame padding, while AR ≥ 1 takes T_lat * r pixel frames.

abstract property spatial_compression_ratio: int

Pixel side ÷ latent side. Constant across AR steps.

abstract property temporal_compression_ratio: int

Pixel frames ÷ latent frames in steady state (AR ≥ 1).

AR 0 typically takes one extra (un-grouped) pixel frame to produce its first latent frame because of causal first-frame padding; that asymmetry lives inside get_output_temporal_size() / get_input_temporal_size().

abstractmethod get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]

Latent frame count produced from input_temporal_size pixel frames.

Parameters:
  • autoregressive_index – AR step index (0-based).

  • input_temporal_size – Number of pixel frames fed at this step.

Returns:

Number of latent frames emitted at this step.

abstractmethod get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]

Pixel frame count needed to produce output_temporal_size latents.

Inverse of get_output_temporal_size(). Implementations should assert output_temporal_size is achievable at this AR step (i.e. the corresponding pixel count comes out as a positive integer).

Parameters:
  • autoregressive_index – AR step index (0-based).

  • output_temporal_size – Desired number of latent frames.

Returns:

Number of pixel frames needed at this step.

class StreamingEncoderCache[source]

Bases: object

Per-rollout cache for StreamingEncoder.

Empty by default; subclass to add fields (e.g. last-frame latent, cross-step accumulators).

class NullEncoderConfig(*, _target: type[NullEncoder] = <factory>)[source]

Bases: EncoderConfig

Config for the identity encoder.

class NullEncoder(config: EncoderConfig)[source]

Bases: Encoder

Identity encoder: returns its input unchanged.

Wire as the transformer’s context_encoder slot to pass already- encoded tensors straight to the diffusion model.

Example:

config = TransformerConfig(
    context_encoder=NullEncoderConfig(),
    ...,
)
forward(input: Any) Any[source]

Return input unchanged.

Decoder

Decoders turn the latents emitted by the diffusion model back into pixel frames. Single base class with two specialisations:

  • StreamingDecoder is stateful. forward(self, input, autoregressive_index, cache) with a StreamingDecoderCache. Use for chunk-by-chunk streaming decoders (e.g. WAN VAE that maintains a temporal cache across AR steps); stateless decoders just return an empty StreamingDecoderCache from StreamingDecoder.initialize_autoregressive_cache() and ignore autoregressive_index / cache in forward.

  • StreamingVideoDecoder extends StreamingDecoder with the contracts a streaming pixel-video decoder always needs: spatial / temporal compression ratios plus AR-step-aware temporal size mappers between latent and pixel space.

class StreamingDecoder(config: DecoderConfig)[source]

Bases: ABC, Module, Generic[StreamingDecoderCacheT]

Streaming decoder, generic over the per-rollout cache type.

forward is not pinned by the base. Streaming decoders called by StreamInferencePipeline must match its call shape: forward(self, input, autoregressive_index=0, cache=None).

abstractmethod initialize_autoregressive_cache(**context: Any) StreamingDecoderCacheT[source]

Build a fresh per-rollout cache.

Override to return the decoder’s concrete cache type.

class StreamingVideoDecoder(config: DecoderConfig)[source]

Bases: StreamingDecoder[StreamingDecoderCacheT]

Streaming pixel-video decoder.

Pins down the contracts that every streaming latent→pixel video decoder satisfies in addition to StreamingDecoder:

  • Spatial and temporal compression ratios between the latent and pixel grids (constants of the architecture).

  • AR-step-aware temporal size mappers, so a pipeline can size its inputs and outputs without knowing the decoder’s concrete temporal cache topology (causal first-frame padding, sliding windows, etc.).

Spatial scaling is trivially side * spatial_compression_ratio in either direction; the AR-step-asymmetric piece is the temporal size, which gets its own mapper. Typically AR 0 produces fewer pixel frames per latent frame than AR ≥ 1 because of causal first-frame padding.

abstract property spatial_compression_ratio: int

Pixel side ÷ latent side. Constant across AR steps.

abstract property temporal_compression_ratio: int

Pixel frames ÷ latent frames in steady state (AR ≥ 1).

AR 0 typically yields fewer pixel frames per latent frame because of causal first-frame padding; that asymmetry lives inside get_output_temporal_size() / get_input_temporal_size().

abstractmethod get_output_temporal_size(autoregressive_index: int, input_temporal_size: int) int[source]

Pixel frame count produced by input_temporal_size latent frames.

Parameters:
  • autoregressive_index – AR step index (0-based).

  • input_temporal_size – Number of latent frames fed at this step.

Returns:

Number of pixel frames emitted at this step.

abstractmethod get_input_temporal_size(autoregressive_index: int, output_temporal_size: int) int[source]

Latent frame count needed to produce output_temporal_size pixels.

Inverse of get_output_temporal_size(). Implementations should assert output_temporal_size is achievable at this AR step (i.e. divisible by the right ratio after subtracting any causal padding).

Parameters:
  • autoregressive_index – AR step index (0-based).

  • output_temporal_size – Desired number of pixel frames.

Returns:

Number of latent frames needed at this step.

class StreamingDecoderCache[source]

Bases: object

Per-rollout cache for StreamingDecoder.

Empty by default; subclass to add fields (e.g. temporal feature buffers carried across AR steps).