import torch.nn as nn | |
import torch.nn.functional as F | |
class KLDivLoss(nn.Module): | |
"""Kullback-Leibler Divergence Loss""" | |
def __init__(self, reduction='batchmean'): | |
super().__init__() | |
self.reduction = reduction | |
def forward(self, inputs, targets): | |
return F.kl_div(inputs, targets, reduction=self.reduction) | |