dmd

Distribution Matching Distillation (DMD2) pipeline.

DMDPipeline holds references to the student / teacher / fake-score / (optional) discriminator and exposes the three loss-computation entry points that a training loop calls from each update step:

The pipeline does not own optimizers, schedulers, gradient toggles, or device placement. Callers drive the alternation between student / fake-score / discriminator updates, toggle requires_grad, and call the appropriate compute_*_loss each step.

Math is a close port of FastGen/fastgen/methods/distribution_matching/dmd2.py (lines 45-455). See the docstrings on the individual methods for line-level references.

Classes

DMDPipeline

DMD2 loss pipeline.

class DMDPipeline

Bases: DistillationPipeline

DMD2 loss pipeline.

Parameters:
  • student – Trainable student module. Must be callable with (hidden_states, timestep, encoder_hidden_states=..., **kwargs) and return either a Tensor, a (Tensor, ...) tuple (as diffusers returns with return_dict=False), or an object with a .sample attribute.

  • teacher – Frozen reference module with the same call signature. If discriminator is provided, feature-capture hooks must be attached to teacher before calling compute_*_loss — see modelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture().

  • fake_score – Trainable auxiliary module (same signature as teacher/student). Used to approximate the student’s generated distribution for the VSD gradient.

  • configDMDConfig with the hyperparameters.

  • discriminator – Optional discriminator. Required when config.gan_loss_weight_gen > 0. Must accept list[Tensor] (the captured teacher features) and return a 2D logit tensor.

__init__(student, teacher, fake_score, config, *, discriminator=None)

Wire up student / teacher / fake-score / discriminator and create the EMA tracker.

Parameters:
  • student (nn.Module)

  • teacher (nn.Module)

  • fake_score (nn.Module)

  • config (DMDConfig)

  • discriminator (nn.Module | None)

Return type:

None

compute_discriminator_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)

Compute the discriminator update loss (GAN + optional R1).

Teacher and student forwards are wrapped in torch.no_grad(); gradient flows only through the discriminator.

Returns a dict with "gan_disc" and "total". When config.gan_r1_reg_weight > 0 the dict also contains "r1".

Parameters:
  • latents (Tensor)

  • noise (Tensor)

  • encoder_hidden_states (Tensor | None)

  • model_kwargs (Any)

Return type:

dict[str, Tensor]

compute_fake_score_loss(latents, noise, encoder_hidden_states=None, **model_kwargs)

Compute the fake-score (auxiliary) update loss.

The fake score is trained with denoising score matching against the student’s generated samples. The student forward is wrapped in torch.no_grad() — the gradient here is w.r.t. fake_score only.

Returns a dict with "fake_score" and "total" (both equal).

Parameters:
  • latents (Tensor)

  • noise (Tensor)

  • encoder_hidden_states (Tensor | None)

  • model_kwargs (Any)

Return type:

dict[str, Tensor]

compute_student_loss(latents, noise, encoder_hidden_states=None, *, negative_encoder_hidden_states=None, negative_encoder_hidden_states_mask=None, guidance_scale=None, **model_kwargs)

Compute the student update losses.

The returned dict always contains "vsd" and "total". When the GAN branch is enabled (discriminator is not None and config.gan_loss_weight_gen > 0), "gan_gen" is also present.

Gradient flow summary:

  • VSD gradient: flows through student only (teacher_x0 is detached, fake_score_x0 is computed under torch.no_grad()).

  • GAN generator gradient: flows through student via the feature-capture hooks on the teacher. The teacher forward is therefore not wrapped in torch.no_grad() when the GAN branch is active.

Parameters:
  • latents (Tensor) – Real clean-data latents x_0. Used only when student_sample_steps > 1 to construct input_student.

  • noise (Tensor) – Pure Gaussian noise tensor matching latents in shape/dtype.

  • encoder_hidden_states (Tensor | None) – Positive conditioning passed unchanged to all three models.

  • negative_encoder_hidden_states (Tensor | None) – Negative conditioning used by classifier-free guidance. Required when guidance_scale (or DMDConfig.guidance_scale) is not None.

  • negative_encoder_hidden_states_mask (Tensor | None) – Optional negative-conditioning mask. Used for models such as Qwen-Image whose positional embedding depends on the real text sequence length.

  • guidance_scale (float | None) – Overrides DMDConfig.guidance_scale for this call. None keeps the config-level value.

  • **model_kwargs (Any) – Forwarded verbatim to student, teacher, and fake_score.

Returns:

Dictionary with keys "vsd", "total", and optionally "gan_gen".

Return type:

dict[str, Tensor]

config: DMDConfig
property ema: ExponentialMovingAverage | None

Reference to the student EMA tracker, if configured.

update_ema(*, iteration=None)

Update the student EMA tracker (no-op if config.ema is None).

Typically called after the student optimizer step. If iteration is not provided, an internal counter is auto-incremented.

Parameters:

iteration (int | None)

Return type:

None