PhoenixStormJr commited on
Commit
16747e4
·
verified ·
1 Parent(s): 8d57554

Upload folder using huggingface_hub

Browse files
train/cmd.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python train_nsf_sim_cache_sid.py -c configs/mi_mix40k_nsf_co256_cs1sid_ms2048.json -m ft-mi
train/data_utils.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, traceback
2
+ import numpy as np
3
+ import torch
4
+ import torch.utils.data
5
+
6
+ from mel_processing import spectrogram_torch
7
+ from utils import load_wav_to_torch, load_filepaths_and_text
8
+
9
+
10
+ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
11
+ """
12
+ 1) loads audio, text pairs
13
+ 2) normalizes text and converts them to sequences of integers
14
+ 3) computes spectrograms from audio files.
15
+ """
16
+
17
+ def __init__(self, audiopaths_and_text, hparams):
18
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
19
+ self.max_wav_value = hparams.max_wav_value
20
+ self.sampling_rate = hparams.sampling_rate
21
+ self.filter_length = hparams.filter_length
22
+ self.hop_length = hparams.hop_length
23
+ self.win_length = hparams.win_length
24
+ self.sampling_rate = hparams.sampling_rate
25
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
26
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
27
+ self._filter()
28
+
29
+ def _filter(self):
30
+ """
31
+ Filter text & store spec lengths
32
+ """
33
+ # Store spectrogram lengths for Bucketing
34
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
35
+ # spec_length = wav_length // hop_length
36
+ audiopaths_and_text_new = []
37
+ lengths = []
38
+ for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
39
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
40
+ audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
41
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
42
+ self.audiopaths_and_text = audiopaths_and_text_new
43
+ self.lengths = lengths
44
+
45
+ def get_sid(self, sid):
46
+ sid = torch.LongTensor([int(sid)])
47
+ return sid
48
+
49
+ def get_audio_text_pair(self, audiopath_and_text):
50
+ # separate filename and text
51
+ file = audiopath_and_text[0]
52
+ phone = audiopath_and_text[1]
53
+ pitch = audiopath_and_text[2]
54
+ pitchf = audiopath_and_text[3]
55
+ dv = audiopath_and_text[4]
56
+
57
+ phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
58
+ spec, wav = self.get_audio(file)
59
+ dv = self.get_sid(dv)
60
+
61
+ len_phone = phone.size()[0]
62
+ len_spec = spec.size()[-1]
63
+ # print(123,phone.shape,pitch.shape,spec.shape)
64
+ if len_phone != len_spec:
65
+ len_min = min(len_phone, len_spec)
66
+ # amor
67
+ len_wav = len_min * self.hop_length
68
+
69
+ spec = spec[:, :len_min]
70
+ wav = wav[:, :len_wav]
71
+
72
+ phone = phone[:len_min, :]
73
+ pitch = pitch[:len_min]
74
+ pitchf = pitchf[:len_min]
75
+
76
+ return (spec, wav, phone, pitch, pitchf, dv)
77
+
78
+ def get_labels(self, phone, pitch, pitchf):
79
+ phone = np.load(phone)
80
+ phone = np.repeat(phone, 2, axis=0)
81
+ pitch = np.load(pitch)
82
+ pitchf = np.load(pitchf)
83
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
84
+ # print(234,phone.shape,pitch.shape)
85
+ phone = phone[:n_num, :]
86
+ pitch = pitch[:n_num]
87
+ pitchf = pitchf[:n_num]
88
+ phone = torch.FloatTensor(phone)
89
+ pitch = torch.LongTensor(pitch)
90
+ pitchf = torch.FloatTensor(pitchf)
91
+ return phone, pitch, pitchf
92
+
93
+ def get_audio(self, filename):
94
+ audio, sampling_rate = load_wav_to_torch(filename)
95
+ if sampling_rate != self.sampling_rate:
96
+ raise ValueError(
97
+ "{} SR doesn't match target {} SR".format(
98
+ sampling_rate, self.sampling_rate
99
+ )
100
+ )
101
+ audio_norm = audio
102
+ # audio_norm = audio / self.max_wav_value
103
+ # audio_norm = audio / np.abs(audio).max()
104
+
105
+ audio_norm = audio_norm.unsqueeze(0)
106
+ spec_filename = filename.replace(".wav", ".spec.pt")
107
+ if os.path.exists(spec_filename):
108
+ try:
109
+ spec = torch.load(spec_filename)
110
+ except:
111
+ print(spec_filename, traceback.format_exc())
112
+ spec = spectrogram_torch(
113
+ audio_norm,
114
+ self.filter_length,
115
+ self.sampling_rate,
116
+ self.hop_length,
117
+ self.win_length,
118
+ center=False,
119
+ )
120
+ spec = torch.squeeze(spec, 0)
121
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
122
+ else:
123
+ spec = spectrogram_torch(
124
+ audio_norm,
125
+ self.filter_length,
126
+ self.sampling_rate,
127
+ self.hop_length,
128
+ self.win_length,
129
+ center=False,
130
+ )
131
+ spec = torch.squeeze(spec, 0)
132
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
133
+ return spec, audio_norm
134
+
135
+ def __getitem__(self, index):
136
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
137
+
138
+ def __len__(self):
139
+ return len(self.audiopaths_and_text)
140
+
141
+
142
+ class TextAudioCollateMultiNSFsid:
143
+ """Zero-pads model inputs and targets"""
144
+
145
+ def __init__(self, return_ids=False):
146
+ self.return_ids = return_ids
147
+
148
+ def __call__(self, batch):
149
+ """Collate's training batch from normalized text and aduio
150
+ PARAMS
151
+ ------
152
+ batch: [text_normalized, spec_normalized, wav_normalized]
153
+ """
154
+ # Right zero-pad all one-hot text sequences to max input length
155
+ _, ids_sorted_decreasing = torch.sort(
156
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
157
+ )
158
+
159
+ max_spec_len = max([x[0].size(1) for x in batch])
160
+ max_wave_len = max([x[1].size(1) for x in batch])
161
+ spec_lengths = torch.LongTensor(len(batch))
162
+ wave_lengths = torch.LongTensor(len(batch))
163
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
164
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
165
+ spec_padded.zero_()
166
+ wave_padded.zero_()
167
+
168
+ max_phone_len = max([x[2].size(0) for x in batch])
169
+ phone_lengths = torch.LongTensor(len(batch))
170
+ phone_padded = torch.FloatTensor(
171
+ len(batch), max_phone_len, batch[0][2].shape[1]
172
+ ) # (spec, wav, phone, pitch)
173
+ pitch_padded = torch.LongTensor(len(batch), max_phone_len)
174
+ pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
175
+ phone_padded.zero_()
176
+ pitch_padded.zero_()
177
+ pitchf_padded.zero_()
178
+ # dv = torch.FloatTensor(len(batch), 256)#gin=256
179
+ sid = torch.LongTensor(len(batch))
180
+
181
+ for i in range(len(ids_sorted_decreasing)):
182
+ row = batch[ids_sorted_decreasing[i]]
183
+
184
+ spec = row[0]
185
+ spec_padded[i, :, : spec.size(1)] = spec
186
+ spec_lengths[i] = spec.size(1)
187
+
188
+ wave = row[1]
189
+ wave_padded[i, :, : wave.size(1)] = wave
190
+ wave_lengths[i] = wave.size(1)
191
+
192
+ phone = row[2]
193
+ phone_padded[i, : phone.size(0), :] = phone
194
+ phone_lengths[i] = phone.size(0)
195
+
196
+ pitch = row[3]
197
+ pitch_padded[i, : pitch.size(0)] = pitch
198
+ pitchf = row[4]
199
+ pitchf_padded[i, : pitchf.size(0)] = pitchf
200
+
201
+ # dv[i] = row[5]
202
+ sid[i] = row[5]
203
+
204
+ return (
205
+ phone_padded,
206
+ phone_lengths,
207
+ pitch_padded,
208
+ pitchf_padded,
209
+ spec_padded,
210
+ spec_lengths,
211
+ wave_padded,
212
+ wave_lengths,
213
+ # dv
214
+ sid,
215
+ )
216
+
217
+
218
+ class TextAudioLoader(torch.utils.data.Dataset):
219
+ """
220
+ 1) loads audio, text pairs
221
+ 2) normalizes text and converts them to sequences of integers
222
+ 3) computes spectrograms from audio files.
223
+ """
224
+
225
+ def __init__(self, audiopaths_and_text, hparams):
226
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
227
+ self.max_wav_value = hparams.max_wav_value
228
+ self.sampling_rate = hparams.sampling_rate
229
+ self.filter_length = hparams.filter_length
230
+ self.hop_length = hparams.hop_length
231
+ self.win_length = hparams.win_length
232
+ self.sampling_rate = hparams.sampling_rate
233
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
234
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
235
+ self._filter()
236
+
237
+ def _filter(self):
238
+ """
239
+ Filter text & store spec lengths
240
+ """
241
+ # Store spectrogram lengths for Bucketing
242
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
243
+ # spec_length = wav_length // hop_length
244
+ audiopaths_and_text_new = []
245
+ lengths = []
246
+ for audiopath, text, dv in self.audiopaths_and_text:
247
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
248
+ audiopaths_and_text_new.append([audiopath, text, dv])
249
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
250
+ self.audiopaths_and_text = audiopaths_and_text_new
251
+ self.lengths = lengths
252
+
253
+ def get_sid(self, sid):
254
+ sid = torch.LongTensor([int(sid)])
255
+ return sid
256
+
257
+ def get_audio_text_pair(self, audiopath_and_text):
258
+ # separate filename and text
259
+ file = audiopath_and_text[0]
260
+ phone = audiopath_and_text[1]
261
+ dv = audiopath_and_text[2]
262
+
263
+ phone = self.get_labels(phone)
264
+ spec, wav = self.get_audio(file)
265
+ dv = self.get_sid(dv)
266
+
267
+ len_phone = phone.size()[0]
268
+ len_spec = spec.size()[-1]
269
+ if len_phone != len_spec:
270
+ len_min = min(len_phone, len_spec)
271
+ len_wav = len_min * self.hop_length
272
+ spec = spec[:, :len_min]
273
+ wav = wav[:, :len_wav]
274
+ phone = phone[:len_min, :]
275
+ return (spec, wav, phone, dv)
276
+
277
+ def get_labels(self, phone):
278
+ phone = np.load(phone)
279
+ phone = np.repeat(phone, 2, axis=0)
280
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
281
+ phone = phone[:n_num, :]
282
+ phone = torch.FloatTensor(phone)
283
+ return phone
284
+
285
+ def get_audio(self, filename):
286
+ audio, sampling_rate = load_wav_to_torch(filename)
287
+ if sampling_rate != self.sampling_rate:
288
+ raise ValueError(
289
+ "{} SR doesn't match target {} SR".format(
290
+ sampling_rate, self.sampling_rate
291
+ )
292
+ )
293
+ audio_norm = audio
294
+ # audio_norm = audio / self.max_wav_value
295
+ # audio_norm = audio / np.abs(audio).max()
296
+
297
+ audio_norm = audio_norm.unsqueeze(0)
298
+ spec_filename = filename.replace(".wav", ".spec.pt")
299
+ if os.path.exists(spec_filename):
300
+ try:
301
+ spec = torch.load(spec_filename)
302
+ except:
303
+ print(spec_filename, traceback.format_exc())
304
+ spec = spectrogram_torch(
305
+ audio_norm,
306
+ self.filter_length,
307
+ self.sampling_rate,
308
+ self.hop_length,
309
+ self.win_length,
310
+ center=False,
311
+ )
312
+ spec = torch.squeeze(spec, 0)
313
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
314
+ else:
315
+ spec = spectrogram_torch(
316
+ audio_norm,
317
+ self.filter_length,
318
+ self.sampling_rate,
319
+ self.hop_length,
320
+ self.win_length,
321
+ center=False,
322
+ )
323
+ spec = torch.squeeze(spec, 0)
324
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
325
+ return spec, audio_norm
326
+
327
+ def __getitem__(self, index):
328
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
329
+
330
+ def __len__(self):
331
+ return len(self.audiopaths_and_text)
332
+
333
+
334
+ class TextAudioCollate:
335
+ """Zero-pads model inputs and targets"""
336
+
337
+ def __init__(self, return_ids=False):
338
+ self.return_ids = return_ids
339
+
340
+ def __call__(self, batch):
341
+ """Collate's training batch from normalized text and aduio
342
+ PARAMS
343
+ ------
344
+ batch: [text_normalized, spec_normalized, wav_normalized]
345
+ """
346
+ # Right zero-pad all one-hot text sequences to max input length
347
+ _, ids_sorted_decreasing = torch.sort(
348
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
349
+ )
350
+
351
+ max_spec_len = max([x[0].size(1) for x in batch])
352
+ max_wave_len = max([x[1].size(1) for x in batch])
353
+ spec_lengths = torch.LongTensor(len(batch))
354
+ wave_lengths = torch.LongTensor(len(batch))
355
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
356
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
357
+ spec_padded.zero_()
358
+ wave_padded.zero_()
359
+
360
+ max_phone_len = max([x[2].size(0) for x in batch])
361
+ phone_lengths = torch.LongTensor(len(batch))
362
+ phone_padded = torch.FloatTensor(
363
+ len(batch), max_phone_len, batch[0][2].shape[1]
364
+ )
365
+ phone_padded.zero_()
366
+ sid = torch.LongTensor(len(batch))
367
+
368
+ for i in range(len(ids_sorted_decreasing)):
369
+ row = batch[ids_sorted_decreasing[i]]
370
+
371
+ spec = row[0]
372
+ spec_padded[i, :, : spec.size(1)] = spec
373
+ spec_lengths[i] = spec.size(1)
374
+
375
+ wave = row[1]
376
+ wave_padded[i, :, : wave.size(1)] = wave
377
+ wave_lengths[i] = wave.size(1)
378
+
379
+ phone = row[2]
380
+ phone_padded[i, : phone.size(0), :] = phone
381
+ phone_lengths[i] = phone.size(0)
382
+
383
+ sid[i] = row[3]
384
+
385
+ return (
386
+ phone_padded,
387
+ phone_lengths,
388
+ spec_padded,
389
+ spec_lengths,
390
+ wave_padded,
391
+ wave_lengths,
392
+ sid,
393
+ )
394
+
395
+
396
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
397
+ """
398
+ Maintain similar input lengths in a batch.
399
+ Length groups are specified by boundaries.
400
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
401
+
402
+ It removes samples which are not included in the boundaries.
403
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ dataset,
409
+ batch_size,
410
+ boundaries,
411
+ num_replicas=None,
412
+ rank=None,
413
+ shuffle=True,
414
+ ):
415
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
416
+ self.lengths = dataset.lengths
417
+ self.batch_size = batch_size
418
+ self.boundaries = boundaries
419
+
420
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
421
+ self.total_size = sum(self.num_samples_per_bucket)
422
+ self.num_samples = self.total_size // self.num_replicas
423
+
424
+ def _create_buckets(self):
425
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
426
+ for i in range(len(self.lengths)):
427
+ length = self.lengths[i]
428
+ idx_bucket = self._bisect(length)
429
+ if idx_bucket != -1:
430
+ buckets[idx_bucket].append(i)
431
+
432
+ for i in range(len(buckets) - 1, -1, -1): #
433
+ if len(buckets[i]) == 0:
434
+ buckets.pop(i)
435
+ self.boundaries.pop(i + 1)
436
+
437
+ num_samples_per_bucket = []
438
+ for i in range(len(buckets)):
439
+ len_bucket = len(buckets[i])
440
+ total_batch_size = self.num_replicas * self.batch_size
441
+ rem = (
442
+ total_batch_size - (len_bucket % total_batch_size)
443
+ ) % total_batch_size
444
+ num_samples_per_bucket.append(len_bucket + rem)
445
+ return buckets, num_samples_per_bucket
446
+
447
+ def __iter__(self):
448
+ # deterministically shuffle based on epoch
449
+ g = torch.Generator()
450
+ g.manual_seed(self.epoch)
451
+
452
+ indices = []
453
+ if self.shuffle:
454
+ for bucket in self.buckets:
455
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
456
+ else:
457
+ for bucket in self.buckets:
458
+ indices.append(list(range(len(bucket))))
459
+
460
+ batches = []
461
+ for i in range(len(self.buckets)):
462
+ bucket = self.buckets[i]
463
+ len_bucket = len(bucket)
464
+ ids_bucket = indices[i]
465
+ num_samples_bucket = self.num_samples_per_bucket[i]
466
+
467
+ # add extra samples to make it evenly divisible
468
+ rem = num_samples_bucket - len_bucket
469
+ ids_bucket = (
470
+ ids_bucket
471
+ + ids_bucket * (rem // len_bucket)
472
+ + ids_bucket[: (rem % len_bucket)]
473
+ )
474
+
475
+ # subsample
476
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
477
+
478
+ # batching
479
+ for j in range(len(ids_bucket) // self.batch_size):
480
+ batch = [
481
+ bucket[idx]
482
+ for idx in ids_bucket[
483
+ j * self.batch_size : (j + 1) * self.batch_size
484
+ ]
485
+ ]
486
+ batches.append(batch)
487
+
488
+ if self.shuffle:
489
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
490
+ batches = [batches[i] for i in batch_ids]
491
+ self.batches = batches
492
+
493
+ assert len(self.batches) * self.batch_size == self.num_samples
494
+ return iter(self.batches)
495
+
496
+ def _bisect(self, x, lo=0, hi=None):
497
+ if hi is None:
498
+ hi = len(self.boundaries) - 1
499
+
500
+ if hi > lo:
501
+ mid = (hi + lo) // 2
502
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
503
+ return mid
504
+ elif x <= self.boundaries[mid]:
505
+ return self._bisect(x, lo, mid)
506
+ else:
507
+ return self._bisect(x, mid + 1, hi)
508
+ else:
509
+ return -1
510
+
511
+ def __len__(self):
512
+ return self.num_samples // self.batch_size
train/losses.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def feature_loss(fmap_r, fmap_g):
6
+ loss = 0
7
+ for dr, dg in zip(fmap_r, fmap_g):
8
+ for rl, gl in zip(dr, dg):
9
+ rl = rl.float().detach()
10
+ gl = gl.float()
11
+ loss += torch.mean(torch.abs(rl - gl))
12
+
13
+ return loss * 2
14
+
15
+
16
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
17
+ loss = 0
18
+ r_losses = []
19
+ g_losses = []
20
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
21
+ dr = dr.float()
22
+ dg = dg.float()
23
+ r_loss = torch.mean((1 - dr) ** 2)
24
+ g_loss = torch.mean(dg**2)
25
+ loss += r_loss + g_loss
26
+ r_losses.append(r_loss.item())
27
+ g_losses.append(g_loss.item())
28
+
29
+ return loss, r_losses, g_losses
30
+
31
+
32
+ def generator_loss(disc_outputs):
33
+ loss = 0
34
+ gen_losses = []
35
+ for dg in disc_outputs:
36
+ dg = dg.float()
37
+ l = torch.mean((1 - dg) ** 2)
38
+ gen_losses.append(l)
39
+ loss += l
40
+
41
+ return loss, gen_losses
42
+
43
+
44
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
45
+ """
46
+ z_p, logs_q: [b, h, t_t]
47
+ m_p, logs_p: [b, h, t_t]
48
+ """
49
+ z_p = z_p.float()
50
+ logs_q = logs_q.float()
51
+ m_p = m_p.float()
52
+ logs_p = logs_p.float()
53
+ z_mask = z_mask.float()
54
+
55
+ kl = logs_p - logs_q - 0.5
56
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
57
+ kl = torch.sum(kl * z_mask)
58
+ l = kl / torch.sum(z_mask)
59
+ return l
train/mel_processing.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
+ """
11
+ PARAMS
12
+ ------
13
+ C: compression factor
14
+ """
15
+ return torch.log(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression_torch(x, C=1):
19
+ """
20
+ PARAMS
21
+ ------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ return dynamic_range_compression_torch(magnitudes)
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ return dynamic_range_decompression_torch(magnitudes)
33
+
34
+
35
+ # Reusable banks
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ """Convert waveform into Linear-frequency Linear-amplitude spectrogram.
42
+
43
+ Args:
44
+ y :: (B, T) - Audio waveforms
45
+ n_fft
46
+ sampling_rate
47
+ hop_size
48
+ win_size
49
+ center
50
+ Returns:
51
+ :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
52
+ """
53
+ # Validation
54
+ if torch.min(y) < -1.07:
55
+ print("min value is ", torch.min(y))
56
+ if torch.max(y) > 1.07:
57
+ print("max value is ", torch.max(y))
58
+
59
+ # Window - Cache if needed
60
+ global hann_window
61
+ dtype_device = str(y.dtype) + "_" + str(y.device)
62
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
63
+ if wnsize_dtype_device not in hann_window:
64
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
65
+ dtype=y.dtype, device=y.device
66
+ )
67
+
68
+ # Padding
69
+ y = torch.nn.functional.pad(
70
+ y.unsqueeze(1),
71
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
72
+ mode="reflect",
73
+ )
74
+ y = y.squeeze(1)
75
+
76
+ # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
77
+ spec = torch.stft(
78
+ y,
79
+ n_fft,
80
+ hop_length=hop_size,
81
+ win_length=win_size,
82
+ window=hann_window[wnsize_dtype_device],
83
+ center=center,
84
+ pad_mode="reflect",
85
+ normalized=False,
86
+ onesided=True,
87
+ return_complex=False,
88
+ )
89
+
90
+ # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
91
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
92
+ return spec
93
+
94
+
95
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
96
+ # MelBasis - Cache if needed
97
+ global mel_basis
98
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
99
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
100
+ if fmax_dtype_device not in mel_basis:
101
+ mel = librosa_mel_fn(
102
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
103
+ )
104
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
105
+ dtype=spec.dtype, device=spec.device
106
+ )
107
+
108
+ # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
109
+ melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ melspec = spectral_normalize_torch(melspec)
111
+ return melspec
112
+
113
+
114
+ def mel_spectrogram_torch(
115
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
116
+ ):
117
+ """Convert waveform into Mel-frequency Log-amplitude spectrogram.
118
+
119
+ Args:
120
+ y :: (B, T) - Waveforms
121
+ Returns:
122
+ melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
123
+ """
124
+ # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
125
+ spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
126
+
127
+ # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
128
+ melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
129
+
130
+ return melspec
train/process_ckpt.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, traceback, os, pdb, sys
2
+
3
+ now_dir = os.getcwd()
4
+ sys.path.append(now_dir)
5
+ from collections import OrderedDict
6
+ from i18n import I18nAuto
7
+
8
+ i18n = I18nAuto()
9
+
10
+
11
+ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
12
+ try:
13
+ opt = OrderedDict()
14
+ opt["weight"] = {}
15
+ for key in ckpt.keys():
16
+ if "enc_q" in key:
17
+ continue
18
+ opt["weight"][key] = ckpt[key].half()
19
+ opt["config"] = [
20
+ hps.data.filter_length // 2 + 1,
21
+ 32,
22
+ hps.model.inter_channels,
23
+ hps.model.hidden_channels,
24
+ hps.model.filter_channels,
25
+ hps.model.n_heads,
26
+ hps.model.n_layers,
27
+ hps.model.kernel_size,
28
+ hps.model.p_dropout,
29
+ hps.model.resblock,
30
+ hps.model.resblock_kernel_sizes,
31
+ hps.model.resblock_dilation_sizes,
32
+ hps.model.upsample_rates,
33
+ hps.model.upsample_initial_channel,
34
+ hps.model.upsample_kernel_sizes,
35
+ hps.model.spk_embed_dim,
36
+ hps.model.gin_channels,
37
+ hps.data.sampling_rate,
38
+ ]
39
+ opt["info"] = "%sepoch" % epoch
40
+ opt["sr"] = sr
41
+ opt["f0"] = if_f0
42
+ opt["version"] = version
43
+ torch.save(opt, "weights/%s.pth" % name)
44
+ return "Success."
45
+ except:
46
+ return traceback.format_exc()
47
+
48
+
49
+ def show_info(path):
50
+ try:
51
+ a = torch.load(path, map_location="cpu")
52
+ return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
53
+ a.get("info", "None"),
54
+ a.get("sr", "None"),
55
+ a.get("f0", "None"),
56
+ a.get("version", "None"),
57
+ )
58
+ except:
59
+ return traceback.format_exc()
60
+
61
+
62
+ def extract_small_model(path, name, sr, if_f0, info, version):
63
+ try:
64
+ ckpt = torch.load(path, map_location="cpu")
65
+ if "model" in ckpt:
66
+ ckpt = ckpt["model"]
67
+ opt = OrderedDict()
68
+ opt["weight"] = {}
69
+ for key in ckpt.keys():
70
+ if "enc_q" in key:
71
+ continue
72
+ opt["weight"][key] = ckpt[key].half()
73
+ if sr == "40k":
74
+ opt["config"] = [
75
+ 1025,
76
+ 32,
77
+ 192,
78
+ 192,
79
+ 768,
80
+ 2,
81
+ 6,
82
+ 3,
83
+ 0,
84
+ "1",
85
+ [3, 7, 11],
86
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
87
+ [10, 10, 2, 2],
88
+ 512,
89
+ [16, 16, 4, 4],
90
+ 109,
91
+ 256,
92
+ 40000,
93
+ ]
94
+ elif sr == "48k":
95
+ opt["config"] = [
96
+ 1025,
97
+ 32,
98
+ 192,
99
+ 192,
100
+ 768,
101
+ 2,
102
+ 6,
103
+ 3,
104
+ 0,
105
+ "1",
106
+ [3, 7, 11],
107
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
108
+ [10, 6, 2, 2, 2],
109
+ 512,
110
+ [16, 16, 4, 4, 4],
111
+ 109,
112
+ 256,
113
+ 48000,
114
+ ]
115
+ elif sr == "32k":
116
+ opt["config"] = [
117
+ 513,
118
+ 32,
119
+ 192,
120
+ 192,
121
+ 768,
122
+ 2,
123
+ 6,
124
+ 3,
125
+ 0,
126
+ "1",
127
+ [3, 7, 11],
128
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
129
+ [10, 4, 2, 2, 2],
130
+ 512,
131
+ [16, 16, 4, 4, 4],
132
+ 109,
133
+ 256,
134
+ 32000,
135
+ ]
136
+ if info == "":
137
+ info = "Extracted model."
138
+ opt["info"] = info
139
+ opt["version"] = version
140
+ opt["sr"] = sr
141
+ opt["f0"] = int(if_f0)
142
+ torch.save(opt, "weights/%s.pth" % name)
143
+ return "Success."
144
+ except:
145
+ return traceback.format_exc()
146
+
147
+
148
+ def change_info(path, info, name):
149
+ try:
150
+ ckpt = torch.load(path, map_location="cpu")
151
+ ckpt["info"] = info
152
+ if name == "":
153
+ name = os.path.basename(path)
154
+ torch.save(ckpt, "weights/%s" % name)
155
+ return "Success."
156
+ except:
157
+ return traceback.format_exc()
158
+
159
+
160
+ def merge(path1, path2, alpha1, sr, f0, info, name, version):
161
+ try:
162
+
163
+ def extract(ckpt):
164
+ a = ckpt["model"]
165
+ opt = OrderedDict()
166
+ opt["weight"] = {}
167
+ for key in a.keys():
168
+ if "enc_q" in key:
169
+ continue
170
+ opt["weight"][key] = a[key]
171
+ return opt
172
+
173
+ ckpt1 = torch.load(path1, map_location="cpu")
174
+ ckpt2 = torch.load(path2, map_location="cpu")
175
+ cfg = ckpt1["config"]
176
+ if "model" in ckpt1:
177
+ ckpt1 = extract(ckpt1)
178
+ else:
179
+ ckpt1 = ckpt1["weight"]
180
+ if "model" in ckpt2:
181
+ ckpt2 = extract(ckpt2)
182
+ else:
183
+ ckpt2 = ckpt2["weight"]
184
+ if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
185
+ return "Fail to merge the models. The model architectures are not the same."
186
+ opt = OrderedDict()
187
+ opt["weight"] = {}
188
+ for key in ckpt1.keys():
189
+ # try:
190
+ if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
191
+ min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
192
+ opt["weight"][key] = (
193
+ alpha1 * (ckpt1[key][:min_shape0].float())
194
+ + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
195
+ ).half()
196
+ else:
197
+ opt["weight"][key] = (
198
+ alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
199
+ ).half()
200
+ # except:
201
+ # pdb.set_trace()
202
+ opt["config"] = cfg
203
+ """
204
+ if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
205
+ elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
206
+ elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
207
+ """
208
+ opt["sr"] = sr
209
+ opt["f0"] = 1 if f0 == i18n("是") else 0
210
+ opt["version"] = version
211
+ opt["info"] = info
212
+ torch.save(opt, "weights/%s.pth" % name)
213
+ return "Success."
214
+ except:
215
+ return traceback.format_exc()
train/utils.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, traceback
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+
22
+ ##################
23
+ def go(model, bkey):
24
+ saved_state_dict = checkpoint_dict[bkey]
25
+ if hasattr(model, "module"):
26
+ state_dict = model.module.state_dict()
27
+ else:
28
+ state_dict = model.state_dict()
29
+ new_state_dict = {}
30
+ for k, v in state_dict.items(): # 模型需要的shape
31
+ try:
32
+ new_state_dict[k] = saved_state_dict[k]
33
+ if saved_state_dict[k].shape != state_dict[k].shape:
34
+ print(
35
+ "shape-%s-mismatch|need-%s|get-%s"
36
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
37
+ ) #
38
+ raise KeyError
39
+ except:
40
+ # logger.info(traceback.format_exc())
41
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
42
+ new_state_dict[k] = v # 模型自带的随机值
43
+ if hasattr(model, "module"):
44
+ model.module.load_state_dict(new_state_dict, strict=False)
45
+ else:
46
+ model.load_state_dict(new_state_dict, strict=False)
47
+
48
+ go(combd, "combd")
49
+ go(sbd, "sbd")
50
+ #############
51
+ logger.info("Loaded model weights")
52
+
53
+ iteration = checkpoint_dict["iteration"]
54
+ learning_rate = checkpoint_dict["learning_rate"]
55
+ if (
56
+ optimizer is not None and load_opt == 1
57
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
58
+ # try:
59
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
60
+ # except:
61
+ # traceback.print_exc()
62
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
63
+ return model, optimizer, learning_rate, iteration
64
+
65
+
66
+ # def load_checkpoint(checkpoint_path, model, optimizer=None):
67
+ # assert os.path.isfile(checkpoint_path)
68
+ # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
69
+ # iteration = checkpoint_dict['iteration']
70
+ # learning_rate = checkpoint_dict['learning_rate']
71
+ # if optimizer is not None:
72
+ # optimizer.load_state_dict(checkpoint_dict['optimizer'])
73
+ # # print(1111)
74
+ # saved_state_dict = checkpoint_dict['model']
75
+ # # print(1111)
76
+ #
77
+ # if hasattr(model, 'module'):
78
+ # state_dict = model.module.state_dict()
79
+ # else:
80
+ # state_dict = model.state_dict()
81
+ # new_state_dict= {}
82
+ # for k, v in state_dict.items():
83
+ # try:
84
+ # new_state_dict[k] = saved_state_dict[k]
85
+ # except:
86
+ # logger.info("%s is not in the checkpoint" % k)
87
+ # new_state_dict[k] = v
88
+ # if hasattr(model, 'module'):
89
+ # model.module.load_state_dict(new_state_dict)
90
+ # else:
91
+ # model.load_state_dict(new_state_dict)
92
+ # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
93
+ # checkpoint_path, iteration))
94
+ # return model, optimizer, learning_rate, iteration
95
+ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
96
+ assert os.path.isfile(checkpoint_path)
97
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
98
+
99
+ saved_state_dict = checkpoint_dict["model"]
100
+ if hasattr(model, "module"):
101
+ state_dict = model.module.state_dict()
102
+ else:
103
+ state_dict = model.state_dict()
104
+ new_state_dict = {}
105
+ for k, v in state_dict.items(): # 模型需要的shape
106
+ try:
107
+ new_state_dict[k] = saved_state_dict[k]
108
+ if saved_state_dict[k].shape != state_dict[k].shape:
109
+ print(
110
+ "shape-%s-mismatch|need-%s|get-%s"
111
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
112
+ ) #
113
+ raise KeyError
114
+ except:
115
+ # logger.info(traceback.format_exc())
116
+ logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
117
+ new_state_dict[k] = v # 模型自带的随机值
118
+ if hasattr(model, "module"):
119
+ model.module.load_state_dict(new_state_dict, strict=False)
120
+ else:
121
+ model.load_state_dict(new_state_dict, strict=False)
122
+ logger.info("Loaded model weights")
123
+
124
+ iteration = checkpoint_dict["iteration"]
125
+ learning_rate = checkpoint_dict["learning_rate"]
126
+ if (
127
+ optimizer is not None and load_opt == 1
128
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
129
+ # try:
130
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
131
+ # except:
132
+ # traceback.print_exc()
133
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
134
+ return model, optimizer, learning_rate, iteration
135
+
136
+
137
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
138
+ logger.info(
139
+ "Saving model and optimizer state at epoch {} to {}".format(
140
+ iteration, checkpoint_path
141
+ )
142
+ )
143
+ if hasattr(model, "module"):
144
+ state_dict = model.module.state_dict()
145
+ else:
146
+ state_dict = model.state_dict()
147
+ torch.save(
148
+ {
149
+ "model": state_dict,
150
+ "iteration": iteration,
151
+ "optimizer": optimizer.state_dict(),
152
+ "learning_rate": learning_rate,
153
+ },
154
+ checkpoint_path,
155
+ )
156
+
157
+
158
+ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
159
+ logger.info(
160
+ "Saving model and optimizer state at epoch {} to {}".format(
161
+ iteration, checkpoint_path
162
+ )
163
+ )
164
+ if hasattr(combd, "module"):
165
+ state_dict_combd = combd.module.state_dict()
166
+ else:
167
+ state_dict_combd = combd.state_dict()
168
+ if hasattr(sbd, "module"):
169
+ state_dict_sbd = sbd.module.state_dict()
170
+ else:
171
+ state_dict_sbd = sbd.state_dict()
172
+ torch.save(
173
+ {
174
+ "combd": state_dict_combd,
175
+ "sbd": state_dict_sbd,
176
+ "iteration": iteration,
177
+ "optimizer": optimizer.state_dict(),
178
+ "learning_rate": learning_rate,
179
+ },
180
+ checkpoint_path,
181
+ )
182
+
183
+
184
+ def summarize(
185
+ writer,
186
+ global_step,
187
+ scalars={},
188
+ histograms={},
189
+ images={},
190
+ audios={},
191
+ audio_sampling_rate=22050,
192
+ ):
193
+ for k, v in scalars.items():
194
+ writer.add_scalar(k, v, global_step)
195
+ for k, v in histograms.items():
196
+ writer.add_histogram(k, v, global_step)
197
+ for k, v in images.items():
198
+ writer.add_image(k, v, global_step, dataformats="HWC")
199
+ for k, v in audios.items():
200
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
201
+
202
+
203
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204
+ f_list = glob.glob(os.path.join(dir_path, regex))
205
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206
+ x = f_list[-1]
207
+ print(x)
208
+ return x
209
+
210
+
211
+ def plot_spectrogram_to_numpy(spectrogram):
212
+ global MATPLOTLIB_FLAG
213
+ if not MATPLOTLIB_FLAG:
214
+ import matplotlib
215
+
216
+ matplotlib.use("Agg")
217
+ MATPLOTLIB_FLAG = True
218
+ mpl_logger = logging.getLogger("matplotlib")
219
+ mpl_logger.setLevel(logging.WARNING)
220
+ import matplotlib.pylab as plt
221
+ import numpy as np
222
+
223
+ fig, ax = plt.subplots(figsize=(10, 2))
224
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
225
+ plt.colorbar(im, ax=ax)
226
+ plt.xlabel("Frames")
227
+ plt.ylabel("Channels")
228
+ plt.tight_layout()
229
+
230
+ fig.canvas.draw()
231
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
232
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
233
+ plt.close()
234
+ return data
235
+
236
+
237
+ def plot_alignment_to_numpy(alignment, info=None):
238
+ global MATPLOTLIB_FLAG
239
+ if not MATPLOTLIB_FLAG:
240
+ import matplotlib
241
+
242
+ matplotlib.use("Agg")
243
+ MATPLOTLIB_FLAG = True
244
+ mpl_logger = logging.getLogger("matplotlib")
245
+ mpl_logger.setLevel(logging.WARNING)
246
+ import matplotlib.pylab as plt
247
+ import numpy as np
248
+
249
+ fig, ax = plt.subplots(figsize=(6, 4))
250
+ im = ax.imshow(
251
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
252
+ )
253
+ fig.colorbar(im, ax=ax)
254
+ xlabel = "Decoder timestep"
255
+ if info is not None:
256
+ xlabel += "\n\n" + info
257
+ plt.xlabel(xlabel)
258
+ plt.ylabel("Encoder timestep")
259
+ plt.tight_layout()
260
+
261
+ fig.canvas.draw()
262
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
263
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
264
+ plt.close()
265
+ return data
266
+
267
+
268
+ def load_wav_to_torch(full_path):
269
+ sampling_rate, data = read(full_path)
270
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
271
+
272
+
273
+ def load_filepaths_and_text(filename, split="|"):
274
+ with open(filename, encoding="utf-8") as f:
275
+ filepaths_and_text = [line.strip().split(split) for line in f]
276
+ return filepaths_and_text
277
+
278
+
279
+ def get_hparams(init=True):
280
+ """
281
+ todo:
282
+ 结尾七人组:
283
+ 保存频率、总epoch done
284
+ bs done
285
+ pretrainG、pretrainD done
286
+ 卡号:os.en["CUDA_VISIBLE_DEVICES"] done
287
+ if_latest done
288
+ 模型:if_f0 done
289
+ 采样率:自动选择config done
290
+ 是否缓存数据集进GPU:if_cache_data_in_gpu done
291
+
292
+ -m:
293
+ 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
294
+ -c不要了
295
+ """
296
+ parser = argparse.ArgumentParser()
297
+ # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
298
+ parser.add_argument(
299
+ "-se",
300
+ "--save_every_epoch",
301
+ type=int,
302
+ required=True,
303
+ help="checkpoint save frequency (epoch)",
304
+ )
305
+ parser.add_argument(
306
+ "-te", "--total_epoch", type=int, required=True, help="total_epoch"
307
+ )
308
+ parser.add_argument(
309
+ "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
310
+ )
311
+ parser.add_argument(
312
+ "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
313
+ )
314
+ parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
315
+ parser.add_argument(
316
+ "-bs", "--batch_size", type=int, required=True, help="batch size"
317
+ )
318
+ parser.add_argument(
319
+ "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
320
+ ) # -m
321
+ parser.add_argument(
322
+ "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323
+ )
324
+ parser.add_argument(
325
+ "-sw",
326
+ "--save_every_weights",
327
+ type=str,
328
+ default="0",
329
+ help="save the extracted model in weights directory when saving checkpoints",
330
+ )
331
+ parser.add_argument(
332
+ "-v", "--version", type=str, required=True, help="model version"
333
+ )
334
+ parser.add_argument(
335
+ "-f0",
336
+ "--if_f0",
337
+ type=int,
338
+ required=True,
339
+ help="use f0 as one of the inputs of the model, 1 or 0",
340
+ )
341
+ parser.add_argument(
342
+ "-l",
343
+ "--if_latest",
344
+ type=int,
345
+ required=True,
346
+ help="if only save the latest G/D pth file, 1 or 0",
347
+ )
348
+ parser.add_argument(
349
+ "-c",
350
+ "--if_cache_data_in_gpu",
351
+ type=int,
352
+ required=True,
353
+ help="if caching the dataset in GPU memory, 1 or 0",
354
+ )
355
+
356
+ args = parser.parse_args()
357
+ name = args.experiment_dir
358
+ experiment_dir = os.path.join("./logs", args.experiment_dir)
359
+
360
+ if not os.path.exists(experiment_dir):
361
+ os.makedirs(experiment_dir)
362
+
363
+ if args.version == "v1" or args.sample_rate == "40k":
364
+ config_path = "configs/%s.json" % args.sample_rate
365
+ else:
366
+ config_path = "configs/%s_v2.json" % args.sample_rate
367
+ config_save_path = os.path.join(experiment_dir, "config.json")
368
+ if init:
369
+ with open(config_path, "r") as f:
370
+ data = f.read()
371
+ with open(config_save_path, "w") as f:
372
+ f.write(data)
373
+ else:
374
+ with open(config_save_path, "r") as f:
375
+ data = f.read()
376
+ config = json.loads(data)
377
+
378
+ hparams = HParams(**config)
379
+ hparams.model_dir = hparams.experiment_dir = experiment_dir
380
+ hparams.save_every_epoch = args.save_every_epoch
381
+ hparams.name = name
382
+ hparams.total_epoch = args.total_epoch
383
+ hparams.pretrainG = args.pretrainG
384
+ hparams.pretrainD = args.pretrainD
385
+ hparams.version = args.version
386
+ hparams.gpus = args.gpus
387
+ hparams.train.batch_size = args.batch_size
388
+ hparams.sample_rate = args.sample_rate
389
+ hparams.if_f0 = args.if_f0
390
+ hparams.if_latest = args.if_latest
391
+ hparams.save_every_weights = args.save_every_weights
392
+ hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
393
+ hparams.data.training_files = "%s/filelist.txt" % experiment_dir
394
+ return hparams
395
+
396
+
397
+ def get_hparams_from_dir(model_dir):
398
+ config_save_path = os.path.join(model_dir, "config.json")
399
+ with open(config_save_path, "r") as f:
400
+ data = f.read()
401
+ config = json.loads(data)
402
+
403
+ hparams = HParams(**config)
404
+ hparams.model_dir = model_dir
405
+ return hparams
406
+
407
+
408
+ def get_hparams_from_file(config_path):
409
+ with open(config_path, "r") as f:
410
+ data = f.read()
411
+ config = json.loads(data)
412
+
413
+ hparams = HParams(**config)
414
+ return hparams
415
+
416
+
417
+ def check_git_hash(model_dir):
418
+ source_dir = os.path.dirname(os.path.realpath(__file__))
419
+ if not os.path.exists(os.path.join(source_dir, ".git")):
420
+ logger.warn(
421
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
422
+ source_dir
423
+ )
424
+ )
425
+ return
426
+
427
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
428
+
429
+ path = os.path.join(model_dir, "githash")
430
+ if os.path.exists(path):
431
+ saved_hash = open(path).read()
432
+ if saved_hash != cur_hash:
433
+ logger.warn(
434
+ "git hash values are different. {}(saved) != {}(current)".format(
435
+ saved_hash[:8], cur_hash[:8]
436
+ )
437
+ )
438
+ else:
439
+ open(path, "w").write(cur_hash)
440
+
441
+
442
+ def get_logger(model_dir, filename="train.log"):
443
+ global logger
444
+ logger = logging.getLogger(os.path.basename(model_dir))
445
+ logger.setLevel(logging.DEBUG)
446
+
447
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
448
+ if not os.path.exists(model_dir):
449
+ os.makedirs(model_dir)
450
+ h = logging.FileHandler(os.path.join(model_dir, filename))
451
+ h.setLevel(logging.DEBUG)
452
+ h.setFormatter(formatter)
453
+ logger.addHandler(h)
454
+ return logger
455
+
456
+
457
+ class HParams:
458
+ def __init__(self, **kwargs):
459
+ for k, v in kwargs.items():
460
+ if type(v) == dict:
461
+ v = HParams(**v)
462
+ self[k] = v
463
+
464
+ def keys(self):
465
+ return self.__dict__.keys()
466
+
467
+ def items(self):
468
+ return self.__dict__.items()
469
+
470
+ def values(self):
471
+ return self.__dict__.values()
472
+
473
+ def __len__(self):
474
+ return len(self.__dict__)
475
+
476
+ def __getitem__(self, key):
477
+ return getattr(self, key)
478
+
479
+ def __setitem__(self, key, value):
480
+ return setattr(self, key, value)
481
+
482
+ def __contains__(self, key):
483
+ return key in self.__dict__
484
+
485
+ def __repr__(self):
486
+ return self.__dict__.__repr__()