losses

Pure loss functions used by the fastgen distillation pipelines.

All functions in this module are stateless: they take tensors in and return a scalar loss tensor. They do not touch any nn.Module. Higher-level orchestration (teacher forward, CFG, noise scheduling) lives in modelopt.torch.fastgen.methods.dmd.

Math ported from FastGen/fastgen/methods/common_loss.py (lines 12-136) and FastGen/fastgen/methods/distribution_matching/dmd2.py lines 287-317 (R1).

Functions

dsm_loss

Denoising score-matching loss for x0 / eps / v / flow predictions.

gan_disc_loss

Softplus GAN discriminator loss: E[softplus(fake_logits)] + E[softplus(-real_logits)].

gan_gen_loss

Softplus GAN generator loss: E[softplus(-fake_logits)].

r1_loss

Approximate R1 regularization (APT formulation).

vsd_loss

Variational score-distillation (VSD) loss used by the DMD student update.

dsm_loss(pred_type, net_pred, *, x0=None, eps=None, t=None, alpha_fn=None, sigma_fn=None)

Denoising score-matching loss for x0 / eps / v / flow predictions.

The forward process is x_t = alpha_t * x_0 + sigma_t * eps. For pred_type='v' we need alpha_t and sigma_t, which are supplied as callables rather than a full noise-scheduler object so this function stays scheduler-agnostic.

Parameters:
  • pred_type (str) – One of "x0", "eps", "v", "flow".

  • net_pred (torch.Tensor) – The network output; its interpretation is determined by pred_type.

  • x0 (torch.Tensor | None) – Clean data. Required for all pred_type except "eps".

  • eps (torch.Tensor | None) – Noise used in the forward process. Required for all pred_type except "x0".

  • t (torch.Tensor | None) – Timesteps in [0, 1] (or scheduler convention). Required for pred_type='v'.

  • alpha_fn (Callable[[torch.Tensor], torch.Tensor] | None) – Callable mapping t -> alpha_t. Required for pred_type='v'.

  • sigma_fn (Callable[[torch.Tensor], torch.Tensor] | None) – Callable mapping t -> sigma_t. Required for pred_type='v'.

Returns:

Scalar MSE loss.

Return type:

torch.Tensor

gan_disc_loss(real_logits, fake_logits)

Softplus GAN discriminator loss: E[softplus(fake_logits)] + E[softplus(-real_logits)].

Parameters:
  • real_logits (Tensor)

  • fake_logits (Tensor)

Return type:

Tensor

gan_gen_loss(fake_logits)

Softplus GAN generator loss: E[softplus(-fake_logits)].

Parameters:

fake_logits (Tensor) – Discriminator logits on generated samples. Must be 2D: (B, num_heads).

Return type:

Tensor

r1_loss(real_logits, perturbed_real_logits)

Approximate R1 regularization (APT formulation).

Penalizes the discriminator for being sensitive to small noise perturbations of the real data. The caller is responsible for computing perturbed_real_logits by re-running the teacher feature extractor and discriminator on real data that has been perturbed with alpha * randn_like(real); this function only applies the final MSE between the two logit sets.

See FastGen/fastgen/methods/distribution_matching/dmd2.py lines 287-317.

Parameters:
  • real_logits (Tensor)

  • perturbed_real_logits (Tensor)

Return type:

Tensor

vsd_loss(gen_data, teacher_x0, fake_score_x0, additional_scale=None)

Variational score-distillation (VSD) loss used by the DMD student update.

Implements the FastGen formulation: a per-sample weight w = 1 / (mean_abs(gen_data - teacher_x0) + 1e-6) is computed in fp32 for numerical stability, then the gradient (fake_score_x0 - teacher_x0) * w is subtracted from the generated data to form a pseudo-target. The loss is 0.5 * MSE(gen_data, pseudo_target).

Parameters:
  • gen_data (Tensor) – Student-generated clean data x_0.

  • teacher_x0 (Tensor) – Teacher x_0 prediction (after CFG, if enabled). Detached.

  • fake_score_x0 (Tensor) – Fake-score x_0 prediction. Detached.

  • additional_scale (Tensor | None) – Optional per-sample scale applied multiplicatively to the weight.

Returns:

Scalar VSD loss.

Return type:

Tensor