dmd
Distribution Matching Distillation (DMD2) pipeline.
DMDPipeline holds references to the student / teacher / fake-score / (optional)
discriminator and exposes the three loss-computation entry points that a training loop
calls from each update step:
DMDPipeline.compute_student_loss()— variational score-distillation loss plus an optional GAN generator term.DMDPipeline.compute_fake_score_loss()— denoising score matching against the student’s generated samples.DMDPipeline.compute_discriminator_loss()— GAN discriminator loss plus an optional R1 regularizer.
The pipeline does not own optimizers, schedulers, gradient toggles, or device placement.
Callers drive the alternation between student / fake-score / discriminator updates, toggle
requires_grad, and call the appropriate compute_*_loss each step.
Math is a close port of FastGen/fastgen/methods/distribution_matching/dmd2.py (lines
45-455). See the docstrings on the individual methods for line-level references.
Classes
DMD2 loss pipeline. |
- class DMDPipeline
Bases:
DistillationPipelineDMD2 loss pipeline.
- Parameters:
student – Trainable student module. Must be callable with
(hidden_states, timestep, encoder_hidden_states=..., **kwargs)and return either aTensor, a(Tensor, ...)tuple (as diffusers returns withreturn_dict=False), or an object with a.sampleattribute.teacher – Frozen reference module with the same call signature. If
discriminatoris provided, feature-capture hooks must be attached toteacherbefore callingcompute_*_loss— seemodelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture().fake_score – Trainable auxiliary module (same signature as teacher/student). Used to approximate the student’s generated distribution for the VSD gradient.
config –
DMDConfigwith the hyperparameters.discriminator – Optional discriminator. Required when
config.gan_loss_weight_gen > 0. Must acceptlist[Tensor](the captured teacher features) and return a 2D logit tensor.
- __init__(student, teacher, fake_score, config, *, discriminator=None)
Wire up student / teacher / fake-score / discriminator and create the EMA tracker.
- Parameters:
student (nn.Module)
teacher (nn.Module)
fake_score (nn.Module)
config (DMDConfig)
discriminator (nn.Module | None)
- Return type:
None
- compute_discriminator_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)
Compute the discriminator update loss (GAN + optional R1).
Teacher and student forwards are wrapped in
torch.no_grad(); gradient flows only through the discriminator.Returns a dict with
"gan_disc"and"total". Whenconfig.gan_r1_reg_weight > 0the dict also contains"r1".- Parameters:
latents (Tensor)
noise (Tensor)
encoder_hidden_states (Tensor | None)
model_kwargs (Any)
- Return type:
dict[str, Tensor]
- compute_fake_score_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)
Compute the fake-score (auxiliary) update loss.
The fake score is trained with denoising score matching against the student’s generated samples. The student forward is wrapped in
torch.no_grad()— the gradient here is w.r.t.fake_scoreonly.Returns a dict with
"fake_score"and"total"(both equal).- Parameters:
latents (Tensor)
noise (Tensor)
encoder_hidden_states (Tensor | None)
model_kwargs (Any)
- Return type:
dict[str, Tensor]
- compute_student_loss(latents, noise, encoder_hidden_states=None, *, negative_encoder_hidden_states=None, negative_encoder_hidden_states_mask=None, guidance_scale=None, **model_kwargs)
Compute the student update losses.
The returned dict always contains
"vsd"and"total". When the GAN branch is enabled (discriminator is not Noneandconfig.gan_loss_weight_gen > 0),"gan_gen"is also present.Gradient flow summary:
VSD gradient: flows through
studentonly (teacher_x0is detached,fake_score_x0is computed undertorch.no_grad()).GAN generator gradient: flows through
studentvia the feature-capture hooks on the teacher. The teacher forward is therefore not wrapped intorch.no_grad()when the GAN branch is active.
- Parameters:
latents (Tensor) – Real clean-data latents
x_0. Used only whenstudent_sample_steps > 1to constructinput_student.noise (Tensor) – Pure Gaussian noise tensor matching
latentsin shape/dtype.encoder_hidden_states (Tensor | None) – Positive conditioning passed unchanged to all three models.
negative_encoder_hidden_states (Tensor | None) – Negative conditioning used by classifier-free guidance. Required when
guidance_scale(orDMDConfig.guidance_scale) is notNone.negative_encoder_hidden_states_mask (Tensor | None) – Optional negative-conditioning mask. Used for models such as Qwen-Image whose positional embedding depends on the real text sequence length.
guidance_scale (float | None) – Overrides
DMDConfig.guidance_scalefor this call.Nonekeeps the config-level value.**model_kwargs (Any) – Forwarded verbatim to
student,teacher, andfake_score.
- Returns:
Dictionary with keys
"vsd","total", and optionally"gan_gen".- Return type:
dict[str, Tensor]
- property ema: ExponentialMovingAverage | None
Reference to the student EMA tracker, if configured.
- update_ema(*, iteration=None)
Update the student EMA tracker (no-op if
config.emaisNone).Typically called after the student optimizer step. If
iterationis not provided, an internal counter is auto-incremented.- Parameters:
iteration (int | None)
- Return type:
None