Respair commited on
Commit
04a2208
·
verified ·
1 Parent(s): 10c6e0f

Update discriminator.py

Browse files
Files changed (1) hide show
  1. discriminator.py +18 -244
discriminator.py CHANGED
@@ -1,231 +1,3 @@
1
- # import torch
2
- # import torch.nn as nn
3
- # import torch.nn.functional as F
4
- # from audiotools import AudioSignal
5
- # from audiotools import ml
6
- # from audiotools import STFTParams
7
- # from einops import rearrange
8
- # from torch.nn.utils import weight_norm
9
-
10
-
11
- # def WNConv1d(*args, **kwargs):
12
- # act = kwargs.pop("act", True)
13
- # conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
- # if not act:
15
- # return conv
16
- # return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
-
18
-
19
- # def WNConv2d(*args, **kwargs):
20
- # act = kwargs.pop("act", True)
21
- # conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
- # if not act:
23
- # return conv
24
- # return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
-
26
-
27
- # class MPD(nn.Module):
28
- # def __init__(self, period):
29
- # super().__init__()
30
- # self.period = period
31
- # self.convs = nn.ModuleList(
32
- # [
33
- # WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
- # WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
- # WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
- # WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
- # WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
- # ]
39
- # )
40
- # self.conv_post = WNConv2d(
41
- # 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
- # )
43
-
44
- # def pad_to_period(self, x):
45
- # t = x.shape[-1]
46
- # x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
- # return x
48
-
49
- # def forward(self, x):
50
- # fmap = []
51
-
52
- # x = self.pad_to_period(x)
53
- # x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
-
55
- # for layer in self.convs:
56
- # x = layer(x)
57
- # fmap.append(x)
58
-
59
- # x = self.conv_post(x)
60
- # fmap.append(x)
61
-
62
- # return fmap
63
-
64
-
65
- # class MSD(nn.Module):
66
- # def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
- # super().__init__()
68
- # self.convs = nn.ModuleList(
69
- # [
70
- # WNConv1d(1, 16, 15, 1, padding=7),
71
- # WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
- # WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
- # WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
- # WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
- # WNConv1d(1024, 1024, 5, 1, padding=2),
76
- # ]
77
- # )
78
- # self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
- # self.sample_rate = sample_rate
80
- # self.rate = rate
81
-
82
- # def forward(self, x):
83
- # x = AudioSignal(x, self.sample_rate)
84
- # x.resample(self.sample_rate // self.rate)
85
- # x = x.audio_data
86
-
87
- # fmap = []
88
-
89
- # for l in self.convs:
90
- # x = l(x)
91
- # fmap.append(x)
92
- # x = self.conv_post(x)
93
- # fmap.append(x)
94
-
95
- # return fmap
96
-
97
-
98
- # BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
-
100
-
101
- # class MRD(nn.Module):
102
- # def __init__(
103
- # self,
104
- # window_length: int,
105
- # hop_factor: float = 0.25,
106
- # sample_rate: int = 44100,
107
- # bands: list = BANDS,
108
- # ):
109
- # """Complex multi-band spectrogram discriminator.
110
- # Parameters
111
- # ----------
112
- # window_length : int
113
- # Window length of STFT.
114
- # hop_factor : float, optional
115
- # Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
- # sample_rate : int, optional
117
- # Sampling rate of audio in Hz, by default 44100
118
- # bands : list, optional
119
- # Bands to run discriminator over.
120
- # """
121
- # super().__init__()
122
-
123
- # self.window_length = window_length
124
- # self.hop_factor = hop_factor
125
- # self.sample_rate = sample_rate
126
- # self.stft_params = STFTParams(
127
- # window_length=window_length,
128
- # hop_length=int(window_length * hop_factor),
129
- # match_stride=True,
130
- # )
131
-
132
- # n_fft = window_length // 2 + 1
133
- # bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
- # self.bands = bands
135
-
136
- # ch = 32
137
- # convs = lambda: nn.ModuleList(
138
- # [
139
- # WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
- # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
- # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
- # WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
- # WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
- # ]
145
- # )
146
- # self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
- # self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
-
149
- # def spectrogram(self, x):
150
- # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
- # x = torch.view_as_real(x.stft())
152
- # x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
- # # Split into bands
154
- # x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
- # return x_bands
156
-
157
- # def forward(self, x):
158
- # x_bands = self.spectrogram(x)
159
- # fmap = []
160
-
161
- # x = []
162
- # for band, stack in zip(x_bands, self.band_convs):
163
- # for layer in stack:
164
- # band = layer(band)
165
- # fmap.append(band)
166
- # x.append(band)
167
-
168
- # x = torch.cat(x, dim=-1)
169
- # x = self.conv_post(x)
170
- # fmap.append(x)
171
-
172
- # return fmap
173
-
174
-
175
- # class Discriminator(ml.BaseModel):
176
- # def __init__(
177
- # self,
178
- # rates: list = [],
179
- # periods: list = [2, 3, 5, 7, 11],
180
- # fft_sizes: list = [2048, 1024, 512],
181
- # sample_rate: int = 44100,
182
- # bands: list = BANDS,
183
- # ):
184
- # """Discriminator that combines multiple discriminators.
185
-
186
- # Parameters
187
- # ----------
188
- # rates : list, optional
189
- # sampling rates (in Hz) to run MSD at, by default []
190
- # If empty, MSD is not used.
191
- # periods : list, optional
192
- # periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
- # fft_sizes : list, optional
194
- # Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
- # sample_rate : int, optional
196
- # Sampling rate of audio in Hz, by default 44100
197
- # bands : list, optional
198
- # Bands to run MRD at, by default `BANDS`
199
- # """
200
- # super().__init__()
201
- # discs = []
202
- # discs += [MPD(p) for p in periods]
203
- # discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
- # discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
- # self.discriminators = nn.ModuleList(discs)
206
-
207
- # def preprocess(self, y):
208
- # # Remove DC offset
209
- # y = y - y.mean(dim=-1, keepdims=True)
210
- # # Peak normalize the volume of input audio
211
- # y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
- # return y
213
-
214
- # def forward(self, x):
215
- # x = self.preprocess(x)
216
- # fmaps = [d(x) for d in self.discriminators]
217
- # return fmaps
218
-
219
-
220
- # if __name__ == "__main__":
221
- # disc = Discriminator()
222
- # x = torch.zeros(1, 1, 44100)
223
- # results = disc(x)
224
- # for i, result in enumerate(results):
225
- # print(f"disc{i}")
226
- # for i, r in enumerate(result):
227
- # print(r.shape, r.mean(), r.min(), r.max())
228
- # print()
229
  import torch
230
  import torch.nn as nn
231
  import torch.nn.functional as F
@@ -313,7 +85,7 @@ class MPD(nn.Module):
313
 
314
 
315
  class MSD(nn.Module):
316
- def __init__(self, rate: int = 1, sample_rate: int = 44100):
317
  super().__init__()
318
  self.convs = nn.ModuleList([
319
  WNConv1d(1, 16, 15, 1, padding=7),
@@ -463,19 +235,19 @@ class DiscriminatorCQT(nn.Module):
463
 
464
  class MultiScaleSubbandCQT(nn.Module):
465
  """CQT discriminator at multiple scales"""
466
- def __init__(self, sample_rate=44100):
467
  super().__init__()
468
  cfg = Munch({
469
- "hop_lengths": [1024, 512, 512],
470
- "sampling_rate": sample_rate,
471
- "filters": 32,
472
- "max_filters": 1024,
473
- "filters_scale": 1,
474
- "dilations": [1, 2, 4],
475
- "in_channels": 1,
476
- "out_channels": 1,
477
- "n_octaves": [10, 10, 10],
478
- "bins_per_octaves": [24, 36, 48],
479
  })
480
  self.cfg = cfg
481
  self.discriminators = nn.ModuleList([
@@ -499,7 +271,7 @@ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
499
 
500
  class MRD(nn.Module):
501
  def __init__(self, window_length: int, hop_factor: float = 0.25,
502
- sample_rate: int = 44100, bands: list = BANDS):
503
  """Multi-resolution spectrogram discriminator."""
504
  super().__init__()
505
  self.window_length = window_length
@@ -556,7 +328,7 @@ class Discriminator(ml.BaseModel):
556
  rates: list = [],
557
  periods: list = [2, 3, 5, 7, 11],
558
  fft_sizes: list = [2048, 1024, 512],
559
- sample_rate: int = 44100,
560
  ):
561
  """Discriminator combining MPD, MSD, MRD and CQT.
562
 
@@ -569,7 +341,7 @@ class Discriminator(ml.BaseModel):
569
  fft_sizes : list, optional
570
  FFT sizes for MRD, by default [2048, 1024, 512]
571
  sample_rate : int, optional
572
- Sampling rate of audio in Hz, by default 44100
573
  """
574
  super().__init__()
575
  discs = []
@@ -593,4 +365,6 @@ class Discriminator(ml.BaseModel):
593
  def forward(self, x):
594
  x = self.preprocess(x)
595
  fmaps = [d(x) for d in self.discriminators]
596
- return fmaps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
85
 
86
 
87
  class MSD(nn.Module):
88
+ def __init__(self, rate: int = 1, sample_rate: int = 24000):
89
  super().__init__()
90
  self.convs = nn.ModuleList([
91
  WNConv1d(1, 16, 15, 1, padding=7),
 
235
 
236
  class MultiScaleSubbandCQT(nn.Module):
237
  """CQT discriminator at multiple scales"""
238
+ def __init__(self, sample_rate=24000):
239
  super().__init__()
240
  cfg = Munch({
241
+ "hop_lengths": [512, 256, 256],
242
+ "sampling_rate": 24000,
243
+ "filters": 32,
244
+ "max_filters": 1024,
245
+ "filters_scale": 1,
246
+ "dilations": [1, 2, 4],
247
+ "in_channels": 1,
248
+ "out_channels": 1,
249
+ "n_octaves": [9, 9, 9],
250
+ "bins_per_octaves": [24, 36, 48],
251
  })
252
  self.cfg = cfg
253
  self.discriminators = nn.ModuleList([
 
271
 
272
  class MRD(nn.Module):
273
  def __init__(self, window_length: int, hop_factor: float = 0.25,
274
+ sample_rate: int = 24000, bands: list = BANDS):
275
  """Multi-resolution spectrogram discriminator."""
276
  super().__init__()
277
  self.window_length = window_length
 
328
  rates: list = [],
329
  periods: list = [2, 3, 5, 7, 11],
330
  fft_sizes: list = [2048, 1024, 512],
331
+ sample_rate: int = 24000,
332
  ):
333
  """Discriminator combining MPD, MSD, MRD and CQT.
334
 
 
341
  fft_sizes : list, optional
342
  FFT sizes for MRD, by default [2048, 1024, 512]
343
  sample_rate : int, optional
344
+ Sampling rate of audio in Hz, by default 24000
345
  """
346
  super().__init__()
347
  discs = []
 
365
  def forward(self, x):
366
  x = self.preprocess(x)
367
  fmaps = [d(x) for d in self.discriminators]
368
+ return fmaps
369
+
370
+