fastgen
Modules
Pydantic configuration classes for the fastgen distillation pipelines. |
|
Discriminator modules for the DMD2 GAN branch. |
|
Exponential moving average of a student network, FSDP2 DTensor aware. |
|
Convenience factory helpers for constructing the auxiliary DMD networks. |
|
Rectified-flow (RF) helpers: forward process, inversions, timestep sampling. |
|
YAML-driven configuration loading for fastgen distillation pipelines. |
|
Pure loss functions used by the fastgen distillation pipelines. |
|
Concrete distillation method implementations (DMD, future: Self-Forcing, CausVid, ...). |
|
Base class for diffusion step-distillation pipelines. |
|
Optional plugins for the fastgen subpackage (gated via |
|
Small tensor helpers shared across the fastgen subpackage. |
Framework-agnostic diffusion step-distillation losses (FastGen port).
modelopt.torch.fastgen is a loss-computation library. It accepts already-built
nn.Module references (student / teacher / fake-score / optional discriminator) and
returns scalar loss tensors. It does not load models, manage optimizers, wrap
anything as a DynamicModule, or register itself in any mode registry.
Typical usage with a YAML-driven config:
import modelopt.torch.fastgen as mtf
student, teacher = build_wan_student_and_teacher(...)
fake_score = mtf.create_fake_score(teacher)
cfg = mtf.load_dmd_config("general/distillation/dmd2_qwen_image")
# If GAN is enabled, expose intermediate teacher features to the discriminator.
if cfg.gan_loss_weight_gen > 0:
mtf.plugins.qwen_image.attach_feature_capture(teacher, feature_indices=[30])
pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc)
# Inside the training loop (framework-owned):
if step % cfg.student_update_freq == 0:
losses = pipeline.compute_student_loss(
latents, noise, text_embeds, negative_encoder_hidden_states=neg_embeds
)
losses["total"].backward()
student_opt.step()
pipeline.update_ema()
else:
f = pipeline.compute_fake_score_loss(latents, noise, text_embeds)
f["total"].backward()
fake_score_opt.step()
if disc is not None:
d = pipeline.compute_discriminator_loss(latents, noise, text_embeds)
d["total"].backward()
disc_opt.step()