File size: 492 Bytes
4c1e73e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""Dice Loss for segmentation"""
def __init__(self, smooth=1.0):
super().__init__()
self.smooth = smooth
def forward(self, inputs, targets):
inputs = torch.sigmoid(inputs).view(-1)
targets = targets.view(-1).float()
intersection = (inputs * targets).sum()
dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
return 1 - dice
|