losses
Different types of distillation losses.
Classes
KL-Divergence loss on output logits. |
|
PyTorch version of Masked Generative Distillation. |
- class LogitsDistillationLoss
Bases:
_Loss
KL-Divergence loss on output logits.
This function implements the distillation loss found in the paper: https://arxiv.org/abs/1503.02531.
- __init__(temperature=1.0, reduction='batchmean')
Constructor.
- Parameters:
temperature (float) – A value used to soften the logits_t and logits_s before computing loss on them.
reduction (str) – How to reduce the final pointwise loss before returning. Pass
"none"
to use your own reduction function afterwards, i.e. with loss masks.
- forward(logits_s, logits_t)
Compute KD loss on student and teacher logits.
- Parameters:
logits_s (Tensor) – Student’s logits, treated as prediction.
logits_t (Tensor) – Teacher’s logits, treated as label.
- Return type:
Tensor
Note
Assumes class logits dimension is last.
- class MGDLoss
Bases:
_Loss
PyTorch version of Masked Generative Distillation.
This function implements the distillation loss found in the paper: https://arxiv.org/abs/2205.01529.
- __init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)
Constructor.
- Parameters:
num_student_channels (int) – Number of channels in the student’s feature map.
num_teacher_channels (int) – Number of channels in the teacher’s feature map.
alpha_mgd (float) – Scalar final loss is multiplied by. Defaults to 1.0.
lambda_mgd (float) – Masked ratio. Defaults to 0.65.
- forward(out_s, out_t)
Forward function.
- Parameters:
out_s (Tensor) – Student’s feature map (shape BxCxHxW).
out_t (Tensor) – Teacher’s feature map (shape BxCxHxW).