loss_balancers

Basic loss balancers for Distillation task.

Classes

DistillationLossBalancer

Interface for loss balancers.

StaticLossBalancer

Static weights-based loss aggregation of KD losses.

class DistillationLossBalancer

Bases: Module

Interface for loss balancers.

__init__()

Constructor.

abstract forward(loss)

Compute aggregate loss.

Parameters:

loss (Dict[str, Tensor]) – The loss dict to aggregate. The keys will be the class name of the loss function applied to obtain the loss, suffixed by _{idx} for uniqueness. And if a student loss is provided to mtd.DistillationModel.compute_kd_loss then it will have the key mtd.loss_balancers.STUDENT_LOSS_KEY. For example, if the criterion argument to mtd.convert is {("mod1_s", "mod1_t"): torch.nn.MSELoss(), ("mod2_s", "mod2_t"): torch.nn.MSELoss()} and the student_loss provided to mtd.DistillationModel.compute_kd_loss is not None, then the loss dict here will look like {"student_loss": torch.tensor(...), "MSELoss_0": torch.tensor(...), "MSELoss_1": torch.tensor(...)}.

Returns:

The total loss after balancing student and kd loss loss components.

Return type:

Tensor

set_student_loss_reduction_fn(student_loss_reduction_fn)

Set student loss reduction function value.

Needed in special case of loss-reducing the student loss prior to balancing.

Parameters:

student_loss_reduction_fn (Callable[[Any], Tensor]) –

class StaticLossBalancer

Bases: DistillationLossBalancer

Static weights-based loss aggregation of KD losses.

__init__(kd_loss_weight=0.5)

Constructor.

Parameters:

kd_loss_weight (float | List[float]) – The static weight to be applied to balance the knowledge distillation loss and original student loss. If it is a float, it would be applied to the sum(KD losses). If it is a list, the keys are the KD loss keys, in order specified to the criterion argument, and the weight corresponding to each key is applied to the corresponding loss value. If the weights do not sum to 1.0, a student_loss should be passed into mtd.DistillationModel.compute_kd_loss, and the weight difference will be applied to this loss value.

Raises:

ValueError if kd_loss_weight is out of bounds.

forward(loss)

Compute aggregate loss.

Parameters:

loss (Dict[str, Tensor]) – The loss dict to aggregate.

Returns:

The total loss after balancing student and kd loss loss components.

Return type:

Tensor