Loss function for semantic segmentation using PyTorch (CrossEntropyLoss and BCELoss)
29 Dec 2022Today I was trying to implement, using PyTorch, the Focal Loss (paperswithcode, original paper) for my semantic segmentation model. Focal Loss is “just” Cross Entropy Loss with some extra sauce that allows you to adjust (γ) how much weight you give to examples that are harder to classify otherwise your optimiser will focus
on the easy examples because they have more impact on the loss. To save time, I didn’t even considered writing my own code (although the focal loss is fairly simple), and I went directly to google where I found this nice example:
import torch.nn as nn
import torch.nn.functional as F
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.size_average = size_average
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(
inputs, targets, reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
The example above uses PyTorch’s CrossEntropyLoss, and that caused problems because I had only one output mask in the model I was testing. The CrossEntropyLoss would not spit any error, just give me 0
or -0
. Long story short, I couldn’t find clear usage examples for images (2D), therefore I created this gist that you can find embedded below: