|
import torch.nn as nn |
|
|
|
class StyleTransferModel(nn.Module): |
|
def __init__(self, base_model): |
|
super(StyleTransferModel, self).__init__() |
|
vgg19 = base_model |
|
|
|
for param in vgg19.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.block1 = vgg19[:4] |
|
self.pool1 = vgg19[4] |
|
self.block2 = vgg19[5:9] |
|
self.pool2 = vgg19[9] |
|
self.block3 = vgg19[10:18] |
|
self.pool3 = vgg19[18] |
|
self.block4 = vgg19[19:27] |
|
self.pool4 = vgg19[27] |
|
self.block5 = vgg19[28:36] |
|
|
|
|
|
self.content_layers = ["block4"] |
|
self.style_layers = [ |
|
"block1", |
|
"block2", |
|
"block3", |
|
"block4", |
|
"block5", |
|
] |
|
|
|
def forward(self, x): |
|
|
|
features = {} |
|
|
|
|
|
x = self.block1(x) |
|
features["block1"] = x |
|
x = self.pool1(x) |
|
|
|
|
|
x = self.block2(x) |
|
features["block2"] = x |
|
x = self.pool2(x) |
|
|
|
|
|
x = self.block3(x) |
|
features["block3"] = x |
|
x = self.pool3(x) |
|
|
|
|
|
x = self.block4(x) |
|
features["block4"] = x |
|
x = self.pool4(x) |
|
|
|
|
|
x = self.block5(x) |
|
features["block5"] = x |
|
|
|
return features |