import torch.nn as nn class StyleTransferModel(nn.Module): def __init__(self, base_model): super(StyleTransferModel, self).__init__() vgg19 = base_model # Freeze the parameters for param in vgg19.parameters(): param.requires_grad = False # Split VGG19 into blocks for feature extraction self.block1 = vgg19[:4] # conv1_1, relu, conv1_2, relu self.pool1 = vgg19[4] # maxpool self.block2 = vgg19[5:9] # conv2_1, relu, conv2_2, relu self.pool2 = vgg19[9] # maxpool self.block3 = vgg19[10:18] # conv3_1 to relu3_4 self.pool3 = vgg19[18] # maxpool self.block4 = vgg19[19:27] # conv4_1 to relu4_4 self.pool4 = vgg19[27] # maxpool self.block5 = vgg19[28:36] # conv5_1 to relu5_4 # Define content and style layers self.content_layers = ["block4"] # We'll use output of block4 for content self.style_layers = [ "block1", "block2", "block3", "block4", "block5", ] # All blocks for style def forward(self, x): # create a dict to save the results features = {} # Block 1 x = self.block1(x) features["block1"] = x x = self.pool1(x) # Block 2 x = self.block2(x) features["block2"] = x x = self.pool2(x) # Block 3 x = self.block3(x) features["block3"] = x x = self.pool3(x) # Block 4 x = self.block4(x) features["block4"] = x x = self.pool4(x) # Block 5 x = self.block5(x) features["block5"] = x return features