File size: 1,905 Bytes
8b06175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn

class StyleTransferLoss(nn.Module):
    def __init__(self, model, content_img, style_img, device="cuda"):
        super(StyleTransferLoss, self).__init__()
        self.device = device
        self.content_img = content_img.to(device)
        self.style_img = style_img.to(device)
        self.model = model.to(device)

    def gram_matrix(self, feature_maps):
        """
        Calculate Gram Matrix for style features
        """
        B, C, H, W = feature_maps.size()
        features = feature_maps.view(B * C, H * W)
        G = torch.mm(features, features.t())
        # Normalize by total elements
        return G.div(B * C * H * W)

    def get_features(self, image):
        """
        Get content and style features from the image
        """
        return self.model(image)

    def content_loss(self, target_features, content_features):
        """
        Calculate content loss between target and content features
        """
        return torch.mean((target_features - content_features) ** 2)

    def style_loss(self, target_features, style_features):
        """
        Calculate style loss between target and style features
        """
        loss = 0.0
        for key in self.model.style_layers:
            target_gram = self.gram_matrix(target_features[key])
            style_gram = self.gram_matrix(style_features[key])
            loss += torch.mean((target_gram - style_gram) ** 2)
        return loss

    def total_loss(
        self, target_features, content_features, style_features, alpha=1, beta=1e8
    ):
        """
        Calculate total loss (weighted sum of content and style losses)
        """
        content = self.content_loss(
            target_features["block4"], content_features["block4"]
        )
        style = self.style_loss(target_features, style_features)

        return alpha * content + beta * style