losses
Different types of distillation losses.
Classes
KL-Divergence loss on output logits. |
|
KL-divergence loss with Minifinetuning threshold modification. |
|
PyTorch version of Masked Generative Distillation. |
- class LogitsDistillationLoss
Bases:
_Loss
KL-Divergence loss on output logits.
This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.
- __init__(temperature=1.0, reduction='batchmean')
Constructor.
- Parameters:
temperature (float) – A value used to soften the logits_t and logits_s before computing loss on them.
reduction (str) – How to reduce the final pointwise loss before returning. Pass
"none"
to use your own reduction function afterwards, i.e. with loss masks.
- forward(logits_s, logits_t)
Compute KD loss on student and teacher logits.
- Parameters:
logits_s (Tensor) – Student’s logits, treated as prediction.
logits_t (Tensor) – Teacher’s logits, treated as label.
- Return type:
Tensor
Note
Assumes class logits dimension is last.
- class MFTLoss
Bases:
_Loss
KL-divergence loss with Minifinetuning threshold modification.
This function implements the distillation loss found in the paper: https://arxiv.org/abs/2506.15702.
- __init__(temperature=1.0, threshold=0.2, reduction='batchmean')
Constructor.
- Parameters:
temperature (float) – A value used to soften the logits_t and logits_s before computing the MFT loss on them.
reduction (str) – How to reduce the final pointwise loss before returning. Pass
"none"
to use your own reduction function afterwards, i.e. with loss masks.threshold (float) – A value used to correct the teacher’s distribution. It is used to ensure that the separation between the correct and incorrect argmax tokens is large enough. The value should be in the range [0, 1]. Defaults to 0.2.
- forward(logits_s, logits_t, labels)
Compute KD loss on student and teacher logits.
- Parameters:
logits_s (Tensor) – Student’s logits, treated as prediction.
logits_t (Tensor) – Teacher’s logits, treated as training target.
labels (Tensor) – Labels for the ground truth, used to prepare the corrected teacher distributions.
- Return type:
Tensor
Note
Assumes class logits dimension is last.
- class MGDLoss
Bases:
_Loss
PyTorch version of Masked Generative Distillation.
This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.
- __init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)
Constructor.
- Parameters:
num_student_channels (int) – Number of channels in the student’s feature map.
num_teacher_channels (int) – Number of channels in the teacher’s feature map.
alpha_mgd (float) – Scalar final loss is multiplied by. Defaults to 1.0.
lambda_mgd (float) – Masked ratio. Defaults to 0.65.
- forward(out_s, out_t)
Forward function.
- Parameters:
out_s (Tensor) – Student’s feature map (shape BxCxHxW).
out_t (Tensor) – Teacher’s feature map (shape BxCxHxW).