torchlosses / kldiv.py
GenAIDevTOProd's picture
Upload folder using huggingface_hub
4c1e73e verified
raw
history blame contribute delete
343 Bytes
import torch.nn as nn
import torch.nn.functional as F
class KLDivLoss(nn.Module):
"""Kullback-Leibler Divergence Loss"""
def __init__(self, reduction='batchmean'):
super().__init__()
self.reduction = reduction
def forward(self, inputs, targets):
return F.kl_div(inputs, targets, reduction=self.reduction)