sebastiansarasti's picture
Upload 5 files
8b06175 verified
import torch.nn as nn
class StyleTransferModel(nn.Module):
def __init__(self, base_model):
super(StyleTransferModel, self).__init__()
vgg19 = base_model
# Freeze the parameters
for param in vgg19.parameters():
param.requires_grad = False
# Split VGG19 into blocks for feature extraction
self.block1 = vgg19[:4] # conv1_1, relu, conv1_2, relu
self.pool1 = vgg19[4] # maxpool
self.block2 = vgg19[5:9] # conv2_1, relu, conv2_2, relu
self.pool2 = vgg19[9] # maxpool
self.block3 = vgg19[10:18] # conv3_1 to relu3_4
self.pool3 = vgg19[18] # maxpool
self.block4 = vgg19[19:27] # conv4_1 to relu4_4
self.pool4 = vgg19[27] # maxpool
self.block5 = vgg19[28:36] # conv5_1 to relu5_4
# Define content and style layers
self.content_layers = ["block4"] # We'll use output of block4 for content
self.style_layers = [
"block1",
"block2",
"block3",
"block4",
"block5",
] # All blocks for style
def forward(self, x):
# create a dict to save the results
features = {}
# Block 1
x = self.block1(x)
features["block1"] = x
x = self.pool1(x)
# Block 2
x = self.block2(x)
features["block2"] = x
x = self.pool2(x)
# Block 3
x = self.block3(x)
features["block3"] = x
x = self.pool3(x)
# Block 4
x = self.block4(x)
features["block4"] = x
x = self.pool4(x)
# Block 5
x = self.block5(x)
features["block5"] = x
return features