|
import torch |
|
import torch.nn as nn |
|
|
|
class StyleTransferLoss(nn.Module): |
|
def __init__(self, model, content_img, style_img, device="cuda"): |
|
super(StyleTransferLoss, self).__init__() |
|
self.device = device |
|
self.content_img = content_img.to(device) |
|
self.style_img = style_img.to(device) |
|
self.model = model.to(device) |
|
|
|
def gram_matrix(self, feature_maps): |
|
""" |
|
Calculate Gram Matrix for style features |
|
""" |
|
B, C, H, W = feature_maps.size() |
|
features = feature_maps.view(B * C, H * W) |
|
G = torch.mm(features, features.t()) |
|
|
|
return G.div(B * C * H * W) |
|
|
|
def get_features(self, image): |
|
""" |
|
Get content and style features from the image |
|
""" |
|
return self.model(image) |
|
|
|
def content_loss(self, target_features, content_features): |
|
""" |
|
Calculate content loss between target and content features |
|
""" |
|
return torch.mean((target_features - content_features) ** 2) |
|
|
|
def style_loss(self, target_features, style_features): |
|
""" |
|
Calculate style loss between target and style features |
|
""" |
|
loss = 0.0 |
|
for key in self.model.style_layers: |
|
target_gram = self.gram_matrix(target_features[key]) |
|
style_gram = self.gram_matrix(style_features[key]) |
|
loss += torch.mean((target_gram - style_gram) ** 2) |
|
return loss |
|
|
|
def total_loss( |
|
self, target_features, content_features, style_features, alpha=1, beta=1e8 |
|
): |
|
""" |
|
Calculate total loss (weighted sum of content and style losses) |
|
""" |
|
content = self.content_loss( |
|
target_features["block4"], content_features["block4"] |
|
) |
|
style = self.style_loss(target_features, style_features) |
|
|
|
return alpha * content + beta * style |