How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification

Written by- Aionlinecourse1695 times views


In this tutorial, we will learn how to use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification.

The first step is to import PyTorch and create a neural network. For this example, we will create a neural network with two hidden layers of 10 nodes each and using ReLU activation function. The last layer will have 3 nodes corresponding to the three classes: A, B, and C.

The next step is to define the ground truth labels for each class as well as the weights associated with each class. In this example, we have defined that there are 100 observations of A (with a weight of 1), 200 observations of B (with a weight of 1), and 300 observations of C (with a weight of 3). Also, I am giving here some solutions -

How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification

Solution 1:

You may find answers to your questions as follows:

  1. Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss. The alpha and gamma factors handle the class imbalance in the focal loss equation.
  2. No need of extra weights because focal loss handles them using alpha and gamma modulating factors
  3. The implementation you mentioned is correct according to the focal loss formula but I had trouble in causing my model to converge with this version hence, I used the following implementation from mmdetection framework
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

You can also experiment with another focal loss version available


Solution 2:

I think OP would've gotten his answer by now. I am writing this for other people who might ponder upon this.

There in one problem in OPs implementation of Focal Loss:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

In this line, the same alpha value is multiplied with every class output probability i.e. (pt). Additionally, code doesn't show how we get pt. A very good implementation of Focal Loss could be find here. But this implementation is only for binary classification as it has alpha and 1-alpha for two classes in self.alpha tensor.

In case of multi-class classification or multi-label classification, self.alpha tensor should contain number of elements equal to the total number of labels. The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).


Solution 3:

I think the implementation in your question is wrong. The alpha is the class weight.

In cross entropy the class weight is the alpha_t as shown in the following expression:

enter image description here

you see that it is alpha_t rather than alpha.

In focal loss the fomular is
enter image description here

and we can see from this popular Pytorch implementation the alpha acts the same way as class weight.

Thank you for reading the article. If you face any problem, please comment below.