climax-xview / losses.py
jacklishufan's picture
init commit
844f7c0
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse
eps = 1e-6
def dice_round(preds, trues):
preds = preds.float()
return soft_dice_loss(preds, trues)
def iou_round(preds, trues):
preds = preds.float()
return jaccard(preds, trues)
def soft_dice_loss(outputs, targets, per_image=False):
batch_size = outputs.size()[0]
if not per_image:
batch_size = 1
dice_target = targets.contiguous().view(batch_size, -1).float()
dice_output = outputs.contiguous().view(batch_size, -1)
intersection = torch.sum(dice_output * dice_target, dim=1)
union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps
loss = (1 - (2 * intersection + eps) / union).mean()
return loss
def jaccard(outputs, targets, per_image=False):
batch_size = outputs.size()[0]
if not per_image:
batch_size = 1
dice_target = targets.contiguous().view(batch_size, -1).float()
dice_output = outputs.contiguous().view(batch_size, -1)
intersection = torch.sum(dice_output * dice_target, dim=1)
union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) - intersection + eps
losses = 1 - (intersection + eps) / union
return losses.mean()
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True, per_image=False):
super().__init__()
self.size_average = size_average
self.register_buffer('weight', weight)
self.per_image = per_image
def forward(self, input, target):
return soft_dice_loss(input, target, per_image=self.per_image)
class JaccardLoss(nn.Module):
def __init__(self, weight=None, size_average=True, per_image=False):
super().__init__()
self.size_average = size_average
self.register_buffer('weight', weight)
self.per_image = per_image
def forward(self, input, target):
return jaccard(input, target, per_image=self.per_image)
class StableBCELoss(nn.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
def forward(self, input, target):
input = input.float().view(-1)
target = target.float().view(-1)
neg_abs = - input.abs()
# todo check correctness
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
class ComboLoss(nn.Module):
def __init__(self, weights, per_image=False):
super().__init__()
self.weights = weights
self.bce = StableBCELoss()
self.dice = DiceLoss(per_image=False)
self.jaccard = JaccardLoss(per_image=False)
self.lovasz = LovaszLoss(per_image=per_image)
self.lovasz_sigmoid = LovaszLossSigmoid(per_image=per_image)
self.focal = FocalLoss2d()
self.mapping = {'bce': self.bce,
'dice': self.dice,
'focal': self.focal,
'jaccard': self.jaccard,
'lovasz': self.lovasz,
'lovasz_sigmoid': self.lovasz_sigmoid}
self.expect_sigmoid = {'dice', 'focal', 'jaccard', 'lovasz_sigmoid'}
self.values = {}
def forward(self, outputs, targets):
loss = 0
weights = self.weights
sigmoid_input = torch.sigmoid(outputs)
for k, v in weights.items():
if not v:
continue
val = self.mapping[k](sigmoid_input if k in self.expect_sigmoid else outputs, targets)
self.values[k] = val
loss += self.weights[k] * val
return loss
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts.float() - gt_sorted.float().cumsum(0)
union = gts.float() + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
def lovasz_sigmoid(probas, labels, per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
only_present: average only on classes present in ground truth
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_sigmoid_flat(*flatten_binary_scores(prob.unsqueeze(0), lab.unsqueeze(0), ignore))
for prob, lab in zip(probas, labels))
else:
loss = lovasz_sigmoid_flat(*flatten_binary_scores(probas, labels, ignore))
return loss
def lovasz_sigmoid_flat(probas, labels):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
only_present: average only on classes present in ground truth
"""
fg = labels.float()
errors = (Variable(fg) - probas).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
loss = torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))
return loss
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(np.isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
class LovaszLoss(nn.Module):
def __init__(self, ignore_index=255, per_image=True):
super().__init__()
self.ignore_index = ignore_index
self.per_image = per_image
def forward(self, outputs, targets):
outputs = outputs.contiguous()
targets = targets.contiguous()
return lovasz_hinge(outputs, targets, per_image=self.per_image, ignore=self.ignore_index)
class LovaszLossSigmoid(nn.Module):
def __init__(self, ignore_index=255, per_image=True):
super().__init__()
self.ignore_index = ignore_index
self.per_image = per_image
def forward(self, outputs, targets):
outputs = outputs.contiguous()
targets = targets.contiguous()
return lovasz_sigmoid(outputs, targets, per_image=self.per_image, ignore=self.ignore_index)
class FocalLoss2d(nn.Module):
def __init__(self, gamma=2, ignore_index=255):
super().__init__()
self.gamma = gamma
self.ignore_index = ignore_index
def forward(self, outputs, targets):
outputs = outputs.contiguous()
targets = targets.contiguous()
# eps = 1e-8
non_ignored = targets.view(-1) != self.ignore_index
targets = targets.view(-1)[non_ignored].float()
outputs = outputs.contiguous().view(-1)[non_ignored]
outputs = torch.clamp(outputs, eps, 1. - eps)
targets = torch.clamp(targets, eps, 1. - eps)
pt = (1 - targets) * (1 - outputs) + targets * outputs
return (-(1. - pt) ** self.gamma * torch.log(pt)).mean()