|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.autograd import Variable |
|
|
|
try: |
|
from itertools import ifilterfalse |
|
except ImportError: |
|
from itertools import filterfalse |
|
|
|
eps = 1e-6 |
|
|
|
def dice_round(preds, trues): |
|
preds = preds.float() |
|
return soft_dice_loss(preds, trues) |
|
|
|
|
|
def iou_round(preds, trues): |
|
preds = preds.float() |
|
return jaccard(preds, trues) |
|
|
|
|
|
def soft_dice_loss(outputs, targets, per_image=False): |
|
batch_size = outputs.size()[0] |
|
if not per_image: |
|
batch_size = 1 |
|
dice_target = targets.contiguous().view(batch_size, -1).float() |
|
dice_output = outputs.contiguous().view(batch_size, -1) |
|
intersection = torch.sum(dice_output * dice_target, dim=1) |
|
union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps |
|
loss = (1 - (2 * intersection + eps) / union).mean() |
|
return loss |
|
|
|
|
|
def jaccard(outputs, targets, per_image=False): |
|
batch_size = outputs.size()[0] |
|
if not per_image: |
|
batch_size = 1 |
|
dice_target = targets.contiguous().view(batch_size, -1).float() |
|
dice_output = outputs.contiguous().view(batch_size, -1) |
|
intersection = torch.sum(dice_output * dice_target, dim=1) |
|
union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) - intersection + eps |
|
losses = 1 - (intersection + eps) / union |
|
return losses.mean() |
|
|
|
|
|
class DiceLoss(nn.Module): |
|
def __init__(self, weight=None, size_average=True, per_image=False): |
|
super().__init__() |
|
self.size_average = size_average |
|
self.register_buffer('weight', weight) |
|
self.per_image = per_image |
|
|
|
def forward(self, input, target): |
|
return soft_dice_loss(input, target, per_image=self.per_image) |
|
|
|
|
|
class JaccardLoss(nn.Module): |
|
def __init__(self, weight=None, size_average=True, per_image=False): |
|
super().__init__() |
|
self.size_average = size_average |
|
self.register_buffer('weight', weight) |
|
self.per_image = per_image |
|
|
|
def forward(self, input, target): |
|
return jaccard(input, target, per_image=self.per_image) |
|
|
|
|
|
class StableBCELoss(nn.Module): |
|
def __init__(self): |
|
super(StableBCELoss, self).__init__() |
|
|
|
def forward(self, input, target): |
|
input = input.float().view(-1) |
|
target = target.float().view(-1) |
|
neg_abs = - input.abs() |
|
|
|
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() |
|
return loss.mean() |
|
|
|
|
|
class ComboLoss(nn.Module): |
|
def __init__(self, weights, per_image=False): |
|
super().__init__() |
|
self.weights = weights |
|
self.bce = StableBCELoss() |
|
self.dice = DiceLoss(per_image=False) |
|
self.jaccard = JaccardLoss(per_image=False) |
|
self.lovasz = LovaszLoss(per_image=per_image) |
|
self.lovasz_sigmoid = LovaszLossSigmoid(per_image=per_image) |
|
self.focal = FocalLoss2d() |
|
self.mapping = {'bce': self.bce, |
|
'dice': self.dice, |
|
'focal': self.focal, |
|
'jaccard': self.jaccard, |
|
'lovasz': self.lovasz, |
|
'lovasz_sigmoid': self.lovasz_sigmoid} |
|
self.expect_sigmoid = {'dice', 'focal', 'jaccard', 'lovasz_sigmoid'} |
|
self.values = {} |
|
|
|
def forward(self, outputs, targets): |
|
loss = 0 |
|
weights = self.weights |
|
sigmoid_input = torch.sigmoid(outputs) |
|
for k, v in weights.items(): |
|
if not v: |
|
continue |
|
val = self.mapping[k](sigmoid_input if k in self.expect_sigmoid else outputs, targets) |
|
self.values[k] = val |
|
loss += self.weights[k] * val |
|
return loss |
|
|
|
|
|
def lovasz_grad(gt_sorted): |
|
""" |
|
Computes gradient of the Lovasz extension w.r.t sorted errors |
|
See Alg. 1 in paper |
|
""" |
|
p = len(gt_sorted) |
|
gts = gt_sorted.sum() |
|
intersection = gts.float() - gt_sorted.float().cumsum(0) |
|
union = gts.float() + (1 - gt_sorted).float().cumsum(0) |
|
jaccard = 1. - intersection / union |
|
if p > 1: |
|
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] |
|
return jaccard |
|
|
|
|
|
def lovasz_hinge(logits, labels, per_image=True, ignore=None): |
|
""" |
|
Binary Lovasz hinge loss |
|
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) |
|
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) |
|
per_image: compute the loss per image instead of per batch |
|
ignore: void class id |
|
""" |
|
if per_image: |
|
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) |
|
for log, lab in zip(logits, labels)) |
|
else: |
|
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) |
|
return loss |
|
|
|
|
|
def lovasz_hinge_flat(logits, labels): |
|
""" |
|
Binary Lovasz hinge loss |
|
logits: [P] Variable, logits at each prediction (between -\infty and +\infty) |
|
labels: [P] Tensor, binary ground truth labels (0 or 1) |
|
ignore: label to ignore |
|
""" |
|
if len(labels) == 0: |
|
|
|
return logits.sum() * 0. |
|
signs = 2. * labels.float() - 1. |
|
errors = (1. - logits * Variable(signs)) |
|
errors_sorted, perm = torch.sort(errors, dim=0, descending=True) |
|
perm = perm.data |
|
gt_sorted = labels[perm] |
|
grad = lovasz_grad(gt_sorted) |
|
loss = torch.dot(F.relu(errors_sorted), Variable(grad)) |
|
return loss |
|
|
|
|
|
def flatten_binary_scores(scores, labels, ignore=None): |
|
""" |
|
Flattens predictions in the batch (binary case) |
|
Remove labels equal to 'ignore' |
|
""" |
|
scores = scores.view(-1) |
|
labels = labels.view(-1) |
|
if ignore is None: |
|
return scores, labels |
|
valid = (labels != ignore) |
|
vscores = scores[valid] |
|
vlabels = labels[valid] |
|
return vscores, vlabels |
|
|
|
|
|
def lovasz_sigmoid(probas, labels, per_image=False, ignore=None): |
|
""" |
|
Multi-class Lovasz-Softmax loss |
|
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) |
|
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) |
|
only_present: average only on classes present in ground truth |
|
per_image: compute the loss per image instead of per batch |
|
ignore: void class labels |
|
""" |
|
if per_image: |
|
loss = mean(lovasz_sigmoid_flat(*flatten_binary_scores(prob.unsqueeze(0), lab.unsqueeze(0), ignore)) |
|
for prob, lab in zip(probas, labels)) |
|
else: |
|
loss = lovasz_sigmoid_flat(*flatten_binary_scores(probas, labels, ignore)) |
|
return loss |
|
|
|
|
|
def lovasz_sigmoid_flat(probas, labels): |
|
""" |
|
Multi-class Lovasz-Softmax loss |
|
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) |
|
labels: [P] Tensor, ground truth labels (between 0 and C - 1) |
|
only_present: average only on classes present in ground truth |
|
""" |
|
fg = labels.float() |
|
errors = (Variable(fg) - probas).abs() |
|
errors_sorted, perm = torch.sort(errors, 0, descending=True) |
|
perm = perm.data |
|
fg_sorted = fg[perm] |
|
loss = torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))) |
|
return loss |
|
|
|
|
|
def mean(l, ignore_nan=False, empty=0): |
|
""" |
|
nanmean compatible with generators. |
|
""" |
|
l = iter(l) |
|
if ignore_nan: |
|
l = ifilterfalse(np.isnan, l) |
|
try: |
|
n = 1 |
|
acc = next(l) |
|
except StopIteration: |
|
if empty == 'raise': |
|
raise ValueError('Empty mean') |
|
return empty |
|
for n, v in enumerate(l, 2): |
|
acc += v |
|
if n == 1: |
|
return acc |
|
return acc / n |
|
|
|
|
|
class LovaszLoss(nn.Module): |
|
def __init__(self, ignore_index=255, per_image=True): |
|
super().__init__() |
|
self.ignore_index = ignore_index |
|
self.per_image = per_image |
|
|
|
def forward(self, outputs, targets): |
|
outputs = outputs.contiguous() |
|
targets = targets.contiguous() |
|
return lovasz_hinge(outputs, targets, per_image=self.per_image, ignore=self.ignore_index) |
|
|
|
|
|
class LovaszLossSigmoid(nn.Module): |
|
def __init__(self, ignore_index=255, per_image=True): |
|
super().__init__() |
|
self.ignore_index = ignore_index |
|
self.per_image = per_image |
|
|
|
def forward(self, outputs, targets): |
|
outputs = outputs.contiguous() |
|
targets = targets.contiguous() |
|
return lovasz_sigmoid(outputs, targets, per_image=self.per_image, ignore=self.ignore_index) |
|
|
|
|
|
class FocalLoss2d(nn.Module): |
|
def __init__(self, gamma=2, ignore_index=255): |
|
super().__init__() |
|
self.gamma = gamma |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, outputs, targets): |
|
outputs = outputs.contiguous() |
|
targets = targets.contiguous() |
|
|
|
non_ignored = targets.view(-1) != self.ignore_index |
|
targets = targets.view(-1)[non_ignored].float() |
|
outputs = outputs.contiguous().view(-1)[non_ignored] |
|
outputs = torch.clamp(outputs, eps, 1. - eps) |
|
targets = torch.clamp(targets, eps, 1. - eps) |
|
pt = (1 - targets) * (1 - outputs) + targets * outputs |
|
return (-(1. - pt) ** self.gamma * torch.log(pt)).mean() |