import torch import torch.nn as nn class StyleTransferLoss(nn.Module): def __init__(self, model, content_img, style_img, device="cuda"): super(StyleTransferLoss, self).__init__() self.device = device self.content_img = content_img.to(device) self.style_img = style_img.to(device) self.model = model.to(device) def gram_matrix(self, feature_maps): """ Calculate Gram Matrix for style features """ B, C, H, W = feature_maps.size() features = feature_maps.view(B * C, H * W) G = torch.mm(features, features.t()) # Normalize by total elements return G.div(B * C * H * W) def get_features(self, image): """ Get content and style features from the image """ return self.model(image) def content_loss(self, target_features, content_features): """ Calculate content loss between target and content features """ return torch.mean((target_features - content_features) ** 2) def style_loss(self, target_features, style_features): """ Calculate style loss between target and style features """ loss = 0.0 for key in self.model.style_layers: target_gram = self.gram_matrix(target_features[key]) style_gram = self.gram_matrix(style_features[key]) loss += torch.mean((target_gram - style_gram) ** 2) return loss def total_loss( self, target_features, content_features, style_features, alpha=1, beta=1e8 ): """ Calculate total loss (weighted sum of content and style losses) """ content = self.content_loss( target_features["block4"], content_features["block4"] ) style = self.style_loss(target_features, style_features) return alpha * content + beta * style