losses

Different types of distillation losses.

Classes

LogitsDistillationLoss

KL-Divergence loss on output logits.

MGDLoss

PyTorch version of Masked Generative Distillation.

class LogitsDistillationLoss

Bases: _Loss

KL-Divergence loss on output logits.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.

__init__(temperature=1.0, reduction='batchmean')

Constructor.

Parameters:
  • temperature (float) – A value used to soften the logits_t and logits_s before computing loss on them.

  • reduction (str) – How to reduce the final pointwise loss before returning. Pass "none" to use your own reduction function afterwards, i.e. with loss masks.

forward(logits_s, logits_t)

Compute KD loss on student and teacher logits.

Parameters:
  • logits_s (Tensor) – Student’s logits, treated as prediction.

  • logits_t (Tensor) – Teacher’s logits, treated as label.

Return type:

Tensor

Note

Assumes class logits dimension is last.

class MGDLoss

Bases: _Loss

PyTorch version of Masked Generative Distillation.

This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.

__init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)

Constructor.

Parameters:
  • num_student_channels (int) – Number of channels in the student’s feature map.

  • num_teacher_channels (int) – Number of channels in the teacher’s feature map.

  • alpha_mgd (float) – Scalar final loss is multiplied by. Defaults to 1.0.

  • lambda_mgd (float) – Masked ratio. Defaults to 0.65.

forward(out_s, out_t)

Forward function.

Parameters:
  • out_s (Tensor) – Student’s feature map (shape BxCxHxW).

  • out_t (Tensor) – Teacher’s feature map (shape BxCxHxW).