juwaeze commited on
Commit
492d41d
·
verified ·
1 Parent(s): f8c51d7

Upload stylegan2.py

Browse files
Files changed (1) hide show
  1. stylegan2.py +779 -0
stylegan2.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn import Embedding as Embedding
8
+
9
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
10
+
11
+
12
+ class PixelNorm(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def forward(self, input):
17
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
18
+
19
+
20
+ def make_kernel(k):
21
+ k = torch.tensor(k, dtype=torch.float32)
22
+
23
+ if k.ndim == 1:
24
+ k = k[None, :] * k[:, None]
25
+
26
+ k /= k.sum()
27
+
28
+ return k
29
+
30
+
31
+ class Upsample(nn.Module):
32
+ def __init__(self, kernel, factor=2):
33
+ super().__init__()
34
+
35
+ self.factor = factor
36
+ kernel = make_kernel(kernel) * (factor ** 2)
37
+ self.register_buffer("kernel", kernel)
38
+
39
+ p = kernel.shape[0] - factor
40
+
41
+ pad0 = (p + 1) // 2 + factor - 1
42
+ pad1 = p // 2
43
+
44
+ self.pad = (pad0, pad1)
45
+
46
+ def forward(self, input):
47
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
48
+
49
+ return out
50
+
51
+
52
+ class Downsample(nn.Module):
53
+ def __init__(self, kernel, factor=2):
54
+ super().__init__()
55
+
56
+ self.factor = factor
57
+ kernel = make_kernel(kernel)
58
+ self.register_buffer("kernel", kernel)
59
+
60
+ p = kernel.shape[0] - factor
61
+
62
+ pad0 = (p + 1) // 2
63
+ pad1 = p // 2
64
+
65
+ self.pad = (pad0, pad1)
66
+
67
+ def forward(self, input):
68
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
69
+
70
+ return out
71
+
72
+
73
+ class Blur(nn.Module):
74
+ def __init__(self, kernel, pad, upsample_factor=1):
75
+ super().__init__()
76
+
77
+ kernel = make_kernel(kernel)
78
+
79
+ if upsample_factor > 1:
80
+ kernel = kernel * (upsample_factor ** 2)
81
+
82
+ self.register_buffer("kernel", kernel)
83
+
84
+ self.pad = pad
85
+
86
+ def forward(self, input):
87
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
88
+
89
+ return out
90
+
91
+
92
+ class EqualConv2d(nn.Module):
93
+ def __init__(
94
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
95
+ ):
96
+ super().__init__()
97
+
98
+ self.weight = nn.Parameter(
99
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
100
+ )
101
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
102
+
103
+ self.stride = stride
104
+ self.padding = padding
105
+
106
+ if bias:
107
+ self.bias = nn.Parameter(torch.zeros(out_channel))
108
+
109
+ else:
110
+ self.bias = None
111
+
112
+ def forward(self, input):
113
+ out = conv2d_gradfix.conv2d(
114
+ input,
115
+ self.weight * self.scale,
116
+ bias=self.bias,
117
+ stride=self.stride,
118
+ padding=self.padding,
119
+ )
120
+
121
+ return out
122
+
123
+ def __repr__(self):
124
+ return (
125
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
126
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
127
+ )
128
+
129
+
130
+ class EqualLinear(nn.Module):
131
+ def __init__(
132
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
137
+
138
+ if bias:
139
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
140
+
141
+ else:
142
+ self.bias = None
143
+
144
+ self.activation = activation
145
+
146
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
147
+ self.lr_mul = lr_mul
148
+
149
+ def forward(self, input):
150
+ if self.activation:
151
+ out = F.linear(input, self.weight * self.scale)
152
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
153
+
154
+ else:
155
+ out = F.linear(
156
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
157
+ )
158
+
159
+ return out
160
+
161
+ def __repr__(self):
162
+ return (
163
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
164
+ )
165
+
166
+
167
+ class ModulatedConv2d(nn.Module):
168
+ def __init__(
169
+ self,
170
+ in_channel,
171
+ out_channel,
172
+ kernel_size,
173
+ style_dim,
174
+ demodulate=True,
175
+ upsample=False,
176
+ downsample=False,
177
+ blur_kernel=[1, 3, 3, 1],
178
+ fused=True,
179
+ ):
180
+ super().__init__()
181
+
182
+ self.eps = 1e-8
183
+ self.kernel_size = kernel_size
184
+ self.in_channel = in_channel
185
+ self.out_channel = out_channel
186
+ self.upsample = upsample
187
+ self.downsample = downsample
188
+
189
+ if upsample:
190
+ factor = 2
191
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
192
+ pad0 = (p + 1) // 2 + factor - 1
193
+ pad1 = p // 2 + 1
194
+
195
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
196
+
197
+ if downsample:
198
+ factor = 2
199
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
200
+ pad0 = (p + 1) // 2
201
+ pad1 = p // 2
202
+
203
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
204
+
205
+ fan_in = in_channel * kernel_size ** 2
206
+ self.scale = 1 / math.sqrt(fan_in)
207
+ self.padding = kernel_size // 2
208
+
209
+ self.weight = nn.Parameter(
210
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
211
+ )
212
+
213
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
214
+
215
+ self.demodulate = demodulate
216
+ self.fused = fused
217
+
218
+ def __repr__(self):
219
+ return (
220
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221
+ f"upsample={self.upsample}, downsample={self.downsample})"
222
+ )
223
+
224
+ def forward(self, input, style):
225
+ batch, in_channel, height, width = input.shape
226
+
227
+ if not self.fused:
228
+ weight = self.scale * self.weight.squeeze(0)
229
+ style = self.modulation(style)
230
+
231
+ if self.demodulate:
232
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
233
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
234
+
235
+ input = input * style.reshape(batch, in_channel, 1, 1)
236
+
237
+ if self.upsample:
238
+ weight = weight.transpose(0, 1)
239
+ out = conv2d_gradfix.conv_transpose2d(
240
+ input, weight, padding=0, stride=2
241
+ )
242
+ out = self.blur(out)
243
+
244
+ elif self.downsample:
245
+ input = self.blur(input)
246
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
247
+
248
+ else:
249
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
250
+
251
+ if self.demodulate:
252
+ out = out * dcoefs.view(batch, -1, 1, 1)
253
+
254
+ return out
255
+
256
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
257
+ weight = self.scale * self.weight * style
258
+
259
+ if self.demodulate:
260
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
261
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
262
+
263
+ weight = weight.view(
264
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
265
+ )
266
+
267
+ if self.upsample:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ weight = weight.view(
270
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
271
+ )
272
+ weight = weight.transpose(1, 2).reshape(
273
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
274
+ )
275
+ out = conv2d_gradfix.conv_transpose2d(
276
+ input, weight, padding=0, stride=2, groups=batch
277
+ )
278
+ _, _, height, width = out.shape
279
+ out = out.view(batch, self.out_channel, height, width)
280
+ out = self.blur(out)
281
+
282
+ elif self.downsample:
283
+ input = self.blur(input)
284
+ _, _, height, width = input.shape
285
+ input = input.view(1, batch * in_channel, height, width)
286
+ out = conv2d_gradfix.conv2d(
287
+ input, weight, padding=0, stride=2, groups=batch
288
+ )
289
+ _, _, height, width = out.shape
290
+ out = out.view(batch, self.out_channel, height, width)
291
+
292
+ else:
293
+ input = input.view(1, batch * in_channel, height, width)
294
+ out = conv2d_gradfix.conv2d(
295
+ input, weight, padding=self.padding, groups=batch
296
+ )
297
+ _, _, height, width = out.shape
298
+ out = out.view(batch, self.out_channel, height, width)
299
+
300
+ return out
301
+
302
+
303
+ class NoiseInjection(nn.Module):
304
+ def __init__(self):
305
+ super().__init__()
306
+
307
+ self.weight = nn.Parameter(torch.zeros(1))
308
+
309
+ def forward(self, image, noise=None):
310
+ if noise is None:
311
+ batch, _, height, width = image.shape
312
+ noise = image.new_empty(batch, 1, height, width).normal_()
313
+
314
+ return image + self.weight * noise
315
+
316
+
317
+ class ConstantInput(nn.Module):
318
+ def __init__(self, channel, size=4):
319
+ super().__init__()
320
+
321
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
322
+
323
+ def forward(self, input):
324
+ batch = input.shape[0]
325
+ out = self.input.repeat(batch, 1, 1, 1)
326
+
327
+ return out
328
+
329
+
330
+ class StyledConv(nn.Module):
331
+ def __init__(
332
+ self,
333
+ in_channel,
334
+ out_channel,
335
+ kernel_size,
336
+ style_dim,
337
+ upsample=False,
338
+ blur_kernel=[1, 3, 3, 1],
339
+ demodulate=True,
340
+ ):
341
+ super().__init__()
342
+
343
+ self.conv = ModulatedConv2d(
344
+ in_channel,
345
+ out_channel,
346
+ kernel_size,
347
+ style_dim,
348
+ upsample=upsample,
349
+ blur_kernel=blur_kernel,
350
+ demodulate=demodulate,
351
+ )
352
+
353
+ self.noise = NoiseInjection()
354
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
355
+ # self.activate = ScaledLeakyReLU(0.2)
356
+ self.activate = FusedLeakyReLU(out_channel)
357
+
358
+ def forward(self, input, style, noise=None):
359
+ out = self.conv(input, style)
360
+ out = self.noise(out, noise=noise)
361
+ # out = out + self.bias
362
+ out = self.activate(out)
363
+
364
+ return out
365
+
366
+
367
+ class ToRGB(nn.Module):
368
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
369
+ super().__init__()
370
+
371
+ if upsample:
372
+ self.upsample = Upsample(blur_kernel)
373
+
374
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
375
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
376
+
377
+ def forward(self, input, style, skip=None):
378
+ out = self.conv(input, style)
379
+ out = out + self.bias
380
+
381
+ if skip is not None:
382
+ skip = self.upsample(skip)
383
+
384
+ out = out + skip
385
+
386
+ return out
387
+
388
+
389
+ class Generator(nn.Module):
390
+ def __init__(
391
+ self,
392
+ size,
393
+ style_dim,
394
+ n_mlp,
395
+ channel_multiplier=2,
396
+ blur_kernel=[1, 3, 3, 1],
397
+ lr_mlp=0.01,
398
+ conditional_gan=False,
399
+ nof_classes=2,
400
+ embedding_size=10
401
+ ):
402
+ super().__init__()
403
+
404
+ self.size = size
405
+
406
+ self.style_dim = style_dim
407
+
408
+ layers = [PixelNorm()]
409
+
410
+ for i in range(n_mlp):
411
+ layers.append(
412
+ EqualLinear(
413
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
414
+ )
415
+ )
416
+
417
+ self.style = nn.Sequential(*layers)
418
+
419
+ self.channels = {
420
+ 4: 512,
421
+ 8: 512,
422
+ 16: 512,
423
+ 32: 512,
424
+ 64: 256 * channel_multiplier,
425
+ 128: 128 * channel_multiplier,
426
+ 256: 64 * channel_multiplier,
427
+ 512: 32 * channel_multiplier,
428
+ 1024: 16 * channel_multiplier,
429
+ }
430
+ if not conditional_gan:
431
+ self.input = ConstantInput(self.channels[4])
432
+ self.conv1 = StyledConv(
433
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
434
+ )
435
+ else:
436
+ self.embedding = Embedding(2, embedding_size)
437
+ self.input = ConstantInput(self.channels[4] + (embedding_size * nof_classes))
438
+ self.conv1 = StyledConv(
439
+ self.channels[4] + (embedding_size * nof_classes) , self.channels[4], 3, style_dim, blur_kernel=blur_kernel
440
+ )
441
+
442
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
443
+
444
+ self.log_size = int(math.log(size, 2))
445
+ self.num_layers = (self.log_size - 2) * 2 + 1
446
+
447
+ self.convs = nn.ModuleList()
448
+ self.upsamples = nn.ModuleList()
449
+ self.to_rgbs = nn.ModuleList()
450
+ self.noises = nn.Module()
451
+
452
+ in_channel = self.channels[4]
453
+
454
+ for layer_idx in range(self.num_layers):
455
+ res = (layer_idx + 5) // 2
456
+ shape = [1, 1, 2 ** res, 2 ** res]
457
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
458
+
459
+ for i in range(3, self.log_size + 1):
460
+ out_channel = self.channels[2 ** i]
461
+
462
+ self.convs.append(
463
+ StyledConv(
464
+ in_channel,
465
+ out_channel,
466
+ 3,
467
+ style_dim,
468
+ upsample=True,
469
+ blur_kernel=blur_kernel,
470
+ )
471
+ )
472
+
473
+ self.convs.append(
474
+ StyledConv(
475
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
476
+ )
477
+ )
478
+
479
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
480
+
481
+ in_channel = out_channel
482
+
483
+ self.n_latent = self.log_size * 2 - 2
484
+
485
+ def make_noise(self):
486
+ device = self.input.input.device
487
+
488
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
489
+
490
+ for i in range(3, self.log_size + 1):
491
+ for _ in range(2):
492
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
493
+
494
+ return noises
495
+
496
+ def mean_latent(self, n_latent):
497
+ latent_in = torch.randn(
498
+ n_latent, self.style_dim, device=self.input.input.device
499
+ )
500
+ latent = self.style(latent_in).mean(0, keepdim=True)
501
+
502
+ return latent
503
+
504
+ def get_latent(self, input):
505
+ return self.style(input)
506
+
507
+ def forward(
508
+ self,
509
+ styles,
510
+ labels=None,
511
+ return_latents=False,
512
+ inject_index=None,
513
+ truncation=1,
514
+ truncation_latent=None,
515
+ input_is_latent=False,
516
+ noise=None,
517
+ randomize_noise=True,
518
+ ):
519
+ if not input_is_latent:
520
+ styles = [self.style(s) for s in styles]
521
+
522
+ if noise is None:
523
+ if randomize_noise:
524
+ noise = [None] * self.num_layers
525
+ else:
526
+ noise = [
527
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
528
+ ]
529
+
530
+ if truncation < 1:
531
+ style_t = []
532
+
533
+ for style in styles:
534
+ style_t.append(
535
+ truncation_latent + truncation * (style - truncation_latent)
536
+ )
537
+
538
+ styles = style_t
539
+
540
+ if len(styles) < 2:
541
+
542
+ inject_index = self.n_latent
543
+
544
+ if styles[0].ndim < 3:
545
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
546
+
547
+ else:
548
+ latent = styles[0]
549
+
550
+ else:
551
+ if inject_index is None:
552
+ inject_index = random.randint(1, self.n_latent - 1)
553
+
554
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
555
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
556
+
557
+ latent = torch.cat([latent, latent2], 1)
558
+
559
+
560
+ if labels is not None:
561
+ batch_size = labels.size()[0]
562
+ embedding = self.embedding(labels)
563
+ embedding = embedding.flatten().reshape(batch_size, -1).unsqueeze(1).repeat(1, latent.size()[1], 1)
564
+ latent_embed = torch.cat([latent, embedding], 2)
565
+ out = self.input(latent_embed)
566
+ else:
567
+ out = self.input(latent)
568
+
569
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
570
+
571
+ skip = self.to_rgb1(out, latent[:, 1])
572
+
573
+ i = 1
574
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
575
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
576
+ ):
577
+ out = conv1(out, latent[:, i], noise=noise1)
578
+ out = conv2(out, latent[:, i + 1], noise=noise2)
579
+ skip = to_rgb(out, latent[:, i + 2], skip)
580
+
581
+ i += 2
582
+
583
+ image = skip
584
+
585
+ if return_latents:
586
+ return image, latent
587
+
588
+ else:
589
+ return image, None
590
+
591
+
592
+ class ConvLayer(nn.Sequential):
593
+ def __init__(
594
+ self,
595
+ in_channel,
596
+ out_channel,
597
+ kernel_size,
598
+ downsample=False,
599
+ blur_kernel=[1, 3, 3, 1],
600
+ bias=True,
601
+ activate=True,
602
+ ):
603
+ layers = []
604
+
605
+ if downsample:
606
+ factor = 2
607
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
608
+ pad0 = (p + 1) // 2
609
+ pad1 = p // 2
610
+
611
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
612
+
613
+ stride = 2
614
+ self.padding = 0
615
+
616
+ else:
617
+ stride = 1
618
+ self.padding = kernel_size // 2
619
+
620
+ layers.append(
621
+ EqualConv2d(
622
+ in_channel,
623
+ out_channel,
624
+ kernel_size,
625
+ padding=self.padding,
626
+ stride=stride,
627
+ bias=bias and not activate,
628
+ )
629
+ )
630
+
631
+ if activate:
632
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
633
+
634
+ super().__init__(*layers)
635
+
636
+
637
+ class ResBlock(nn.Module):
638
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
639
+ super().__init__()
640
+
641
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
642
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
643
+
644
+ self.skip = ConvLayer(
645
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
646
+ )
647
+
648
+ def forward(self, input):
649
+ out = self.conv1(input)
650
+ out = self.conv2(out)
651
+
652
+ skip = self.skip(input)
653
+ out = (out + skip) / math.sqrt(2)
654
+
655
+ return out
656
+
657
+
658
+ class Discriminator(nn.Module):
659
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], nof_classes=2, conditional_gan=False):
660
+ super().__init__()
661
+
662
+ channels = {
663
+ 4: 512,
664
+ 8: 512,
665
+ 16: 512,
666
+ 32: 512,
667
+ 64: 256 * channel_multiplier,
668
+ 128: 128 * channel_multiplier,
669
+ 256: 64 * channel_multiplier,
670
+ 512: 32 * channel_multiplier,
671
+ 1024: 16 * channel_multiplier,
672
+ }
673
+
674
+ self.input_dim = nof_classes + 3 if conditional_gan else 3
675
+ self.size = size
676
+ self.nof_classes = nof_classes
677
+
678
+ convs = [ConvLayer(self.input_dim, channels[size], 1)]
679
+
680
+ if conditional_gan:
681
+ self.embedding = Embedding(2, size * size)
682
+
683
+ log_size = int(math.log(size, 2))
684
+
685
+ in_channel = channels[size]
686
+
687
+ for i in range(log_size, 2, -1):
688
+ out_channel = channels[2 ** (i - 1)]
689
+
690
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
691
+
692
+ in_channel = out_channel
693
+
694
+ self.convs = nn.Sequential(*convs)
695
+
696
+ self.stddev_group = 4
697
+ self.stddev_feat = 1
698
+
699
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
700
+ self.final_linear = nn.Sequential(
701
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
702
+ EqualLinear(channels[4], 1),
703
+ )
704
+
705
+ def forward(self, input, labels=None):
706
+ if labels is not None:
707
+ embed = self.embedding(labels)
708
+ batch_size = labels.size()[0]
709
+ embed = embed.flatten().reshape(batch_size, self.nof_classes, self.size, self.size)
710
+ input = torch.cat((input, embed), dim=1)
711
+
712
+ out = self.convs(input)
713
+
714
+ batch, channel, height, width = out.shape
715
+ group = min(batch, self.stddev_group)
716
+ stddev = out.view(
717
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
718
+ )
719
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
720
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
721
+ stddev = stddev.repeat(group, 1, height, width)
722
+ out = torch.cat([out, stddev], 1)
723
+
724
+ out = self.final_conv(out)
725
+
726
+ out = out.view(batch, -1)
727
+ out = self.final_linear(out)
728
+
729
+ return out
730
+
731
+
732
+ class Encoder(nn.Module):
733
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], output_channels=None):
734
+ super().__init__()
735
+
736
+ channels = {
737
+ 4: 512,
738
+ 8: 512,
739
+ 16: 512,
740
+ 32: 512,
741
+ 64: 256 * channel_multiplier,
742
+ 128: 128 * channel_multiplier,
743
+ 256: 64 * channel_multiplier,
744
+ 512: 32 * channel_multiplier,
745
+ 1024: 16 * channel_multiplier,
746
+ }
747
+
748
+ convs = [ConvLayer(3, channels[size], 1)]
749
+
750
+ log_size = int(math.log(size, 2))
751
+
752
+ in_channel = channels[size]
753
+
754
+ for i in range(log_size, 2, -1):
755
+ out_channel = channels[2 ** (i - 1)]
756
+
757
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
758
+
759
+ in_channel = out_channel
760
+
761
+ self.convs = nn.Sequential(*convs)
762
+
763
+ self.final_conv = ConvLayer(in_channel, channels[4], 3)
764
+
765
+ if output_channels is None:
766
+ output_channels = channels[4]
767
+
768
+ self.final_linear = nn.Sequential(
769
+ EqualLinear(channels[4] * 4 * 4, output_channels)
770
+ )
771
+
772
+ def forward(self, input):
773
+ out = self.convs(input)
774
+ out = self.final_conv(out)
775
+ batch, _, _, _ = out.shape
776
+ out = out.view(batch, -1)
777
+ out = self.final_linear(out)
778
+
779
+ return out