fastgen

Modules

modelopt.torch.fastgen.config

Pydantic configuration classes for the fastgen distillation pipelines.

modelopt.torch.fastgen.discriminators

Discriminator modules for the DMD2 GAN branch.

modelopt.torch.fastgen.ema

Exponential moving average of a student network, FSDP2 DTensor aware.

modelopt.torch.fastgen.factory

Convenience factory helpers for constructing the auxiliary DMD networks.

modelopt.torch.fastgen.flow_matching

Rectified-flow (RF) helpers: forward process, inversions, timestep sampling.

modelopt.torch.fastgen.loader

YAML-driven configuration loading for fastgen distillation pipelines.

modelopt.torch.fastgen.losses

Pure loss functions used by the fastgen distillation pipelines.

modelopt.torch.fastgen.methods

Concrete distillation method implementations (DMD, future: Self-Forcing, CausVid, ...).

modelopt.torch.fastgen.pipeline

Base class for diffusion step-distillation pipelines.

modelopt.torch.fastgen.plugins

Optional plugins for the fastgen subpackage (gated via import_plugin).

modelopt.torch.fastgen.utils

Small tensor helpers shared across the fastgen subpackage.

Framework-agnostic diffusion step-distillation losses (FastGen port).

modelopt.torch.fastgen is a loss-computation library. It accepts already-built nn.Module references (student / teacher / fake-score / optional discriminator) and returns scalar loss tensors. It does not load models, manage optimizers, wrap anything as a DynamicModule, or register itself in any mode registry.

Typical usage with a YAML-driven config:

import modelopt.torch.fastgen as mtf

student, teacher = build_wan_student_and_teacher(...)
fake_score = mtf.create_fake_score(teacher)

cfg = mtf.load_dmd_config("general/distillation/dmd2_qwen_image")

# If GAN is enabled, expose intermediate teacher features to the discriminator.
if cfg.gan_loss_weight_gen > 0:
    mtf.plugins.qwen_image.attach_feature_capture(teacher, feature_indices=[30])

pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc)

# Inside the training loop (framework-owned):
if step % cfg.student_update_freq == 0:
    losses = pipeline.compute_student_loss(
        latents, noise, text_embeds, negative_encoder_hidden_states=neg_embeds
    )
    losses["total"].backward()
    student_opt.step()
    pipeline.update_ema()
else:
    f = pipeline.compute_fake_score_loss(latents, noise, text_embeds)
    f["total"].backward()
    fake_score_opt.step()
    if disc is not None:
        d = pipeline.compute_discriminator_loss(latents, noise, text_embeds)
        d["total"].backward()
        disc_opt.step()