loss_balancers
Basic loss balancers for Distillation task.
Classes
Interface for loss balancers. |
|
Static weights-based loss aggregation of KD losses. |
- class DistillationLossBalancer
Bases:
Module
Interface for loss balancers.
- __init__()
Constructor.
- abstract forward(loss)
Compute aggregate loss.
- Parameters:
loss (Dict[str, Tensor]) – The loss dict to aggregate. The keys will be the class name of the loss function applied to obtain the loss, suffixed by _{idx} for uniqueness. And if a student loss is provided to
mtd.DistillationModel.compute_kd_loss
then it will have the keymtd.loss_balancers.STUDENT_LOSS_KEY
. For example, if thecriterion
argument tomtd.convert
is{("mod1_s", "mod1_t"): torch.nn.MSELoss(), ("mod2_s", "mod2_t"): torch.nn.MSELoss()}
and the student_loss provided tomtd.DistillationModel.compute_kd_loss
is not None, then the loss dict here will look like{"student_loss": torch.tensor(...), "MSELoss_0": torch.tensor(...), "MSELoss_1": torch.tensor(...)}
.- Returns:
The total loss after balancing student and kd loss loss components.
- Return type:
Tensor
- set_student_loss_reduction_fn(student_loss_reduction_fn)
Set student loss reduction function value.
Needed in special case of loss-reducing the student loss prior to balancing.
- Parameters:
student_loss_reduction_fn (Callable[[Any], Tensor]) –
- class StaticLossBalancer
Bases:
DistillationLossBalancer
Static weights-based loss aggregation of KD losses.
- __init__(kd_loss_weight=0.5)
Constructor.
- Parameters:
kd_loss_weight (float | List[float]) – The static weight to be applied to balance the knowledge distillation loss and original student loss. If it is a float, it would be applied to the sum(KD losses). If it is a list, the keys are the KD loss keys, in order specified to the
criterion
argument, and the weight corresponding to each key is applied to the corresponding loss value. If the weights do not sum to 1.0, astudent_loss
should be passed intomtd.DistillationModel.compute_kd_loss
, and the weight difference will be applied to this loss value.- Raises:
ValueError if kd_loss_weight is out of bounds. –
- forward(loss)
Compute aggregate loss.
- Parameters:
loss (Dict[str, Tensor]) – The loss dict to aggregate.
- Returns:
The total loss after balancing student and kd loss loss components.
- Return type:
Tensor