losses

Different types of distillation losses.

Classes

LogitsDistillationLoss

KL-Divergence loss on output logits.

MFTLoss

KL-divergence loss with Minifinetuning threshold modification.

MGDLoss

PyTorch version of Masked Generative Distillation.

class LogitsDistillationLoss

Bases: _Loss

KL-Divergence loss on output logits.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.

__init__(temperature=1.0, reduction='batchmean')

Constructor.

Parameters:
  • temperature (float) – A value used to soften the logits_t and logits_s before computing loss on them.

  • reduction (str) – How to reduce the final pointwise loss before returning. Pass "none" to use your own reduction function afterwards, i.e. with loss masks.

forward(logits_s, logits_t)

Compute KD loss on student and teacher logits.

Parameters:
  • logits_s (Tensor) – Student’s logits, treated as prediction.

  • logits_t (Tensor) – Teacher’s logits, treated as label.

Return type:

Tensor

Note

Assumes class logits dimension is last.

class MFTLoss

Bases: _Loss

KL-divergence loss with Minifinetuning threshold modification.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/2506.15702.

__init__(temperature=1.0, threshold=0.2, reduction='batchmean')

Constructor.

Parameters:
  • temperature (float) – A value used to soften the logits_t and logits_s before computing the MFT loss on them.

  • reduction (str) – How to reduce the final pointwise loss before returning. Pass "none" to use your own reduction function afterwards, i.e. with loss masks.

  • threshold (float) – A value used to correct the teacher’s distribution. It is used to ensure that the separation between the correct and incorrect argmax tokens is large enough. The value should be in the range [0, 1]. Defaults to 0.2.

forward(logits_s, logits_t, labels)

Compute KD loss on student and teacher logits.

Parameters:
  • logits_s (Tensor) – Student’s logits, treated as prediction.

  • logits_t (Tensor) – Teacher’s logits, treated as training target.

  • labels (Tensor) – Labels for the ground truth, used to prepare the corrected teacher distributions.

Return type:

Tensor

Note

Assumes class logits dimension is last.

class MGDLoss

Bases: _Loss

PyTorch version of Masked Generative Distillation.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.

__init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)

Constructor.

Parameters:
  • num_student_channels (int) – Number of channels in the student’s feature map.

  • num_teacher_channels (int) – Number of channels in the teacher’s feature map.

  • alpha_mgd (float) – Scalar final loss is multiplied by. Defaults to 1.0.

  • lambda_mgd (float) – Masked ratio. Defaults to 0.65.

forward(out_s, out_t)

Forward function.

Parameters:
  • out_s (Tensor) – Student’s feature map (shape BxCxHxW).

  • out_t (Tensor) – Teacher’s feature map (shape BxCxHxW).