File size: 295 Bytes
4c1e73e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch.nn.functional as F

class CosineEmbeddingLoss:
    """Cosine Embedding Loss for similarity learning"""
    def __init__(self, margin=0.0):
        self.margin = margin

    def __call__(self, x1, x2, label):
        return F.cosine_embedding_loss(x1, x2, label, margin=self.margin)