import torch | |
import torch.nn as nn | |
class DiceLoss(nn.Module): | |
"""Dice Loss for segmentation""" | |
def __init__(self, smooth=1.0): | |
super().__init__() | |
self.smooth = smooth | |
def forward(self, inputs, targets): | |
inputs = torch.sigmoid(inputs).view(-1) | |
targets = targets.view(-1).float() | |
intersection = (inputs * targets).sum() | |
dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth) | |
return 1 - dice | |