losses
Pure loss functions used by the fastgen distillation pipelines.
All functions in this module are stateless: they take tensors in and return a scalar
loss tensor. They do not touch any nn.Module. Higher-level orchestration (teacher
forward, CFG, noise scheduling) lives in modelopt.torch.fastgen.methods.dmd.
Math ported from FastGen/fastgen/methods/common_loss.py (lines 12-136) and
FastGen/fastgen/methods/distribution_matching/dmd2.py lines 287-317 (R1).
Functions
Denoising score-matching loss for |
|
Softplus GAN discriminator loss: |
|
Softplus GAN generator loss: |
|
Approximate R1 regularization (APT formulation). |
|
Variational score-distillation (VSD) loss used by the DMD student update. |
- dsm_loss(pred_type, net_pred, *, x0=None, eps=None, t=None, alpha_fn=None, sigma_fn=None)
Denoising score-matching loss for
x0/eps/v/flowpredictions.The forward process is
x_t = alpha_t * x_0 + sigma_t * eps. Forpred_type='v'we needalpha_tandsigma_t, which are supplied as callables rather than a full noise-scheduler object so this function stays scheduler-agnostic.- Parameters:
pred_type (str) – One of
"x0","eps","v","flow".net_pred (torch.Tensor) – The network output; its interpretation is determined by
pred_type.x0 (torch.Tensor | None) – Clean data. Required for all
pred_typeexcept"eps".eps (torch.Tensor | None) – Noise used in the forward process. Required for all
pred_typeexcept"x0".t (torch.Tensor | None) – Timesteps in
[0, 1](or scheduler convention). Required forpred_type='v'.alpha_fn (Callable[[torch.Tensor], torch.Tensor] | None) – Callable mapping
t->alpha_t. Required forpred_type='v'.sigma_fn (Callable[[torch.Tensor], torch.Tensor] | None) – Callable mapping
t->sigma_t. Required forpred_type='v'.
- Returns:
Scalar MSE loss.
- Return type:
torch.Tensor
- gan_disc_loss(real_logits, fake_logits)
Softplus GAN discriminator loss:
E[softplus(fake_logits)] + E[softplus(-real_logits)].- Parameters:
real_logits (Tensor)
fake_logits (Tensor)
- Return type:
Tensor
- gan_gen_loss(fake_logits)
Softplus GAN generator loss:
E[softplus(-fake_logits)].- Parameters:
fake_logits (Tensor) – Discriminator logits on generated samples. Must be 2D:
(B, num_heads).- Return type:
Tensor
- r1_loss(real_logits, perturbed_real_logits)
Approximate R1 regularization (APT formulation).
Penalizes the discriminator for being sensitive to small noise perturbations of the real data. The caller is responsible for computing
perturbed_real_logitsby re-running the teacher feature extractor and discriminator on real data that has been perturbed withalpha * randn_like(real); this function only applies the final MSE between the two logit sets.See
FastGen/fastgen/methods/distribution_matching/dmd2.pylines 287-317.- Parameters:
real_logits (Tensor)
perturbed_real_logits (Tensor)
- Return type:
Tensor
- vsd_loss(gen_data, teacher_x0, fake_score_x0, additional_scale=None)
Variational score-distillation (VSD) loss used by the DMD student update.
Implements the FastGen formulation: a per-sample weight
w = 1 / (mean_abs(gen_data - teacher_x0) + 1e-6)is computed in fp32 for numerical stability, then the gradient(fake_score_x0 - teacher_x0) * wis subtracted from the generated data to form a pseudo-target. The loss is0.5 * MSE(gen_data, pseudo_target).- Parameters:
gen_data (Tensor) – Student-generated clean data
x_0.teacher_x0 (Tensor) – Teacher
x_0prediction (after CFG, if enabled). Detached.fake_score_x0 (Tensor) – Fake-score
x_0prediction. Detached.additional_scale (Tensor | None) – Optional per-sample scale applied multiplicatively to the weight.
- Returns:
Scalar VSD loss.
- Return type:
Tensor