File size: 1,905 Bytes
8b06175 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
import torch.nn as nn
class StyleTransferLoss(nn.Module):
def __init__(self, model, content_img, style_img, device="cuda"):
super(StyleTransferLoss, self).__init__()
self.device = device
self.content_img = content_img.to(device)
self.style_img = style_img.to(device)
self.model = model.to(device)
def gram_matrix(self, feature_maps):
"""
Calculate Gram Matrix for style features
"""
B, C, H, W = feature_maps.size()
features = feature_maps.view(B * C, H * W)
G = torch.mm(features, features.t())
# Normalize by total elements
return G.div(B * C * H * W)
def get_features(self, image):
"""
Get content and style features from the image
"""
return self.model(image)
def content_loss(self, target_features, content_features):
"""
Calculate content loss between target and content features
"""
return torch.mean((target_features - content_features) ** 2)
def style_loss(self, target_features, style_features):
"""
Calculate style loss between target and style features
"""
loss = 0.0
for key in self.model.style_layers:
target_gram = self.gram_matrix(target_features[key])
style_gram = self.gram_matrix(style_features[key])
loss += torch.mean((target_gram - style_gram) ** 2)
return loss
def total_loss(
self, target_features, content_features, style_features, alpha=1, beta=1e8
):
"""
Calculate total loss (weighted sum of content and style losses)
"""
content = self.content_loss(
target_features["block4"], content_features["block4"]
)
style = self.style_loss(target_features, style_features)
return alpha * content + beta * style |