mlgawd commited on
Commit
2ac3e6a
·
verified ·
1 Parent(s): dbf8f17

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +120 -0
model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from math import log2
5
+
6
+ """
7
+ Factors is used in Discrmininator and Generator for how much
8
+ the channels should be multiplied and expanded for each layer,
9
+ so specifically the first 5 layers the channels stay the same,
10
+ whereas when we increase the img_size (towards the later layers)
11
+ we decrease the number of chanels by 1/2, 1/4, etc.
12
+ """
13
+ factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
14
+
15
+
16
+ class WSConv2d(nn.Module):
17
+ """
18
+ Weight scaled Conv2d (Equalized Learning Rate)
19
+ Note that input is multiplied rather than changing weights
20
+ this will have the same result.
21
+
22
+
23
+ """
24
+
25
+ def __init__(
26
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
27
+ ):
28
+ super(WSConv2d, self).__init__()
29
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
30
+ self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
31
+ self.bias = self.conv.bias
32
+ self.conv.bias = None
33
+
34
+ # initialize conv layer
35
+ nn.init.normal_(self.conv.weight)
36
+ nn.init.zeros_(self.bias)
37
+
38
+ def forward(self, x):
39
+ return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
40
+
41
+
42
+ class PixelNorm(nn.Module):
43
+ def __init__(self):
44
+ super(PixelNorm, self).__init__()
45
+ self.epsilon = 1e-8
46
+
47
+ def forward(self, x):
48
+ return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
49
+
50
+
51
+ class ConvBlock(nn.Module):
52
+ def __init__(self, in_channels, out_channels, use_pixelnorm=True):
53
+ super(ConvBlock, self).__init__()
54
+ self.use_pn = use_pixelnorm
55
+ self.conv1 = WSConv2d(in_channels, out_channels)
56
+ self.conv2 = WSConv2d(out_channels, out_channels)
57
+ self.leaky = nn.LeakyReLU(0.2)
58
+ self.pn = PixelNorm()
59
+
60
+ def forward(self, x):
61
+ x = self.leaky(self.conv1(x))
62
+ x = self.pn(x) if self.use_pn else x
63
+ x = self.leaky(self.conv2(x))
64
+ x = self.pn(x) if self.use_pn else x
65
+ return x
66
+
67
+
68
+ class Generator(nn.Module):
69
+ def __init__(self, z_dim, in_channels, img_channels=3):
70
+ super(Generator, self).__init__()
71
+
72
+ # initial takes 1x1 -> 4x4
73
+ self.initial = nn.Sequential(
74
+ PixelNorm(),
75
+ nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
76
+ nn.LeakyReLU(0.2),
77
+ WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
78
+ nn.LeakyReLU(0.2),
79
+ PixelNorm(),
80
+ )
81
+
82
+ self.initial_rgb = WSConv2d(
83
+ in_channels, img_channels, kernel_size=1, stride=1, padding=0
84
+ )
85
+ self.prog_blocks, self.rgb_layers = (
86
+ nn.ModuleList([]),
87
+ nn.ModuleList([self.initial_rgb]),
88
+ )
89
+
90
+ for i in range(
91
+ len(factors) - 1
92
+ ): # -1 to prevent index error because of factors[i+1]
93
+ conv_in_c = int(in_channels * factors[i])
94
+ conv_out_c = int(in_channels * factors[i + 1])
95
+ self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
96
+ self.rgb_layers.append(
97
+ WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
98
+ )
99
+
100
+ def fade_in(self, alpha, upscaled, generated):
101
+ # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
102
+ return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
103
+
104
+ def forward(self, x, alpha, steps):
105
+ out = self.initial(x)
106
+
107
+ if steps == 0:
108
+ return self.initial_rgb(out)
109
+
110
+ for step in range(steps):
111
+ upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
112
+ out = self.prog_blocks[step](upscaled)
113
+
114
+ # The number of channels in upscale will stay the same, while
115
+ # out which has moved through prog_blocks might change. To ensure
116
+ # we can convert both to rgb we use different rgb_layers
117
+ # (steps-1) and steps for upscaled, out respectively
118
+ final_upscaled = self.rgb_layers[steps - 1](upscaled)
119
+ final_out = self.rgb_layers[steps](out)
120
+ return self.fade_in(alpha, final_upscaled, final_out)