Ricardo's Place Robotics, machine learning, or simply random thoughts!

Loss function for semantic segmentation using PyTorch (CrossEntropyLoss and BCELoss)

Today I was trying to implement, using PyTorch, the Focal Loss (paperswithcode, original paper) for my semantic segmentation model. Focal Loss is “just” Cross Entropy Loss with some extra sauce that allows you to adjust (γ) how much weight you give to examples that are harder to classify otherwise your optimiser will focus on the easy examples because they have more impact on the loss. To save time, I didn’t even considered writing my own code (although the focal loss is fairly simple), and I went directly to google where I found this nice example:

import torch.nn as nn
import torch.nn.functional as F
import torch 

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.size_average = size_average

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(
            inputs, targets, reduction='none', ignore_index=self.ignore_index)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        if self.size_average:
            return focal_loss.mean()
        else:
            return focal_loss.sum()

The example above uses PyTorch’s CrossEntropyLoss, and that caused problems because I had only one output mask in the model I was testing. The CrossEntropyLoss would not spit any error, just give me 0 or -0. Long story short, I couldn’t find clear usage examples for images (2D), therefore I created this gist that you can find embedded below: