aishams commited on
Commit
89556c8
·
1 Parent(s): ec0fd54

Upload 17 files

Browse files
config (2).json ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 200,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 16,
14
+ "fp16_run": false,
15
+ "bf16_run": false,
16
+ "lr_decay": 0.999875,
17
+ "segment_size": 10240,
18
+ "init_lr_ratio": 1,
19
+ "warmup_epochs": 0,
20
+ "c_mel": 45,
21
+ "c_kl": 1.0,
22
+ "use_sr": true,
23
+ "max_speclen": 512,
24
+ "port": "8001",
25
+ "keep_ckpts": 3,
26
+ "num_workers": 4,
27
+ "log_version": 0,
28
+ "ckpt_name_by_step": false,
29
+ "accumulate_grad_batches": 1
30
+ },
31
+ "data": {
32
+ "training_files": "filelists/44k/train.txt",
33
+ "validation_files": "filelists/44k/val.txt",
34
+ "max_wav_value": 32768.0,
35
+ "sampling_rate": 44100,
36
+ "filter_length": 2048,
37
+ "hop_length": 512,
38
+ "win_length": 2048,
39
+ "n_mel_channels": 80,
40
+ "mel_fmin": 0.0,
41
+ "mel_fmax": 22050,
42
+ "contentvec_final_proj": false
43
+ },
44
+ "model": {
45
+ "inter_channels": 192,
46
+ "hidden_channels": 192,
47
+ "filter_channels": 768,
48
+ "n_heads": 2,
49
+ "n_layers": 6,
50
+ "kernel_size": 3,
51
+ "p_dropout": 0.1,
52
+ "resblock": "1",
53
+ "resblock_kernel_sizes": [
54
+ 3,
55
+ 7,
56
+ 11
57
+ ],
58
+ "resblock_dilation_sizes": [
59
+ [
60
+ 1,
61
+ 3,
62
+ 5
63
+ ],
64
+ [
65
+ 1,
66
+ 3,
67
+ 5
68
+ ],
69
+ [
70
+ 1,
71
+ 3,
72
+ 5
73
+ ]
74
+ ],
75
+ "upsample_rates": [
76
+ 8,
77
+ 8,
78
+ 2,
79
+ 2,
80
+ 2
81
+ ],
82
+ "upsample_initial_channel": 512,
83
+ "upsample_kernel_sizes": [
84
+ 16,
85
+ 16,
86
+ 4,
87
+ 4,
88
+ 4
89
+ ],
90
+ "n_layers_q": 3,
91
+ "use_spectral_norm": false,
92
+ "gin_channels": 256,
93
+ "ssl_dim": 768,
94
+ "n_speakers": 200,
95
+ "type_": "hifi-gan",
96
+ "pretrained": {
97
+ "D_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_D_320000.pth",
98
+ "G_0.pth": "https://huggingface.co/datasets/ms903/sovits4.0-768vec-layer12/resolve/main/sovits_768l12_pre_large_320k/clean_G_320000.pth"
99
+ }
100
+ },
101
+ "spk": {
102
+ "Mr.ameli": 0
103
+ }
104
+ }
data_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data
7
+
8
+ import modules.commons as commons
9
+ import utils
10
+ from modules.mel_processing import spectrogram_torch, spec_to_mel_torch
11
+ from utils import load_wav_to_torch, load_filepaths_and_text
12
+
13
+ # import h5py
14
+
15
+
16
+ """Multi speaker version"""
17
+
18
+
19
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
20
+ """
21
+ 1) loads audio, speaker_id, text pairs
22
+ 2) normalizes text and converts them to sequences of integers
23
+ 3) computes spectrograms from audio files.
24
+ """
25
+
26
+ def __init__(self, audiopaths, hparams):
27
+ self.audiopaths = load_filepaths_and_text(audiopaths)
28
+ self.max_wav_value = hparams.data.max_wav_value
29
+ self.sampling_rate = hparams.data.sampling_rate
30
+ self.filter_length = hparams.data.filter_length
31
+ self.hop_length = hparams.data.hop_length
32
+ self.win_length = hparams.data.win_length
33
+ self.sampling_rate = hparams.data.sampling_rate
34
+ self.use_sr = hparams.train.use_sr
35
+ self.spec_len = hparams.train.max_speclen
36
+ self.spk_map = hparams.spk
37
+
38
+ random.seed(1234)
39
+ random.shuffle(self.audiopaths)
40
+
41
+ def get_audio(self, filename):
42
+ filename = filename.replace("\\", "/")
43
+ audio, sampling_rate = load_wav_to_torch(filename)
44
+ if sampling_rate != self.sampling_rate:
45
+ raise ValueError("{} SR doesn't match target {} SR".format(
46
+ sampling_rate, self.sampling_rate))
47
+ audio_norm = audio / self.max_wav_value
48
+ audio_norm = audio_norm.unsqueeze(0)
49
+ spec_filename = filename.replace(".wav", ".spec.pt")
50
+ if os.path.exists(spec_filename):
51
+ spec = torch.load(spec_filename)
52
+ else:
53
+ spec = spectrogram_torch(audio_norm, self.filter_length,
54
+ self.sampling_rate, self.hop_length, self.win_length,
55
+ center=False)
56
+ spec = torch.squeeze(spec, 0)
57
+ torch.save(spec, spec_filename)
58
+
59
+ spk = filename.split("/")[-2]
60
+ spk = torch.LongTensor([self.spk_map[spk]])
61
+
62
+ f0 = np.load(filename + ".f0.npy")
63
+ f0, uv = utils.interpolate_f0(f0)
64
+ f0 = torch.FloatTensor(f0)
65
+ uv = torch.FloatTensor(uv)
66
+
67
+ c = torch.load(filename+ ".soft.pt")
68
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
69
+
70
+
71
+ lmin = min(c.size(-1), spec.size(-1))
72
+ assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
73
+ assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
74
+ spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
75
+ audio_norm = audio_norm[:, :lmin * self.hop_length]
76
+ # if spec.shape[1] < 30:
77
+ # print("skip too short audio:", filename)
78
+ # return None
79
+ if spec.shape[1] > 800:
80
+ start = random.randint(0, spec.shape[1]-800)
81
+ end = start + 790
82
+ spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end]
83
+ audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length]
84
+
85
+ return c, f0, spec, audio_norm, spk, uv
86
+
87
+ def __getitem__(self, index):
88
+ return self.get_audio(self.audiopaths[index][0])
89
+
90
+ def __len__(self):
91
+ return len(self.audiopaths)
92
+
93
+
94
+ class TextAudioCollate:
95
+
96
+ def __call__(self, batch):
97
+ batch = [b for b in batch if b is not None]
98
+
99
+ input_lengths, ids_sorted_decreasing = torch.sort(
100
+ torch.LongTensor([x[0].shape[1] for x in batch]),
101
+ dim=0, descending=True)
102
+
103
+ max_c_len = max([x[0].size(1) for x in batch])
104
+ max_wav_len = max([x[3].size(1) for x in batch])
105
+
106
+ lengths = torch.LongTensor(len(batch))
107
+
108
+ c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len)
109
+ f0_padded = torch.FloatTensor(len(batch), max_c_len)
110
+ spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len)
111
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
112
+ spkids = torch.LongTensor(len(batch), 1)
113
+ uv_padded = torch.FloatTensor(len(batch), max_c_len)
114
+
115
+ c_padded.zero_()
116
+ spec_padded.zero_()
117
+ f0_padded.zero_()
118
+ wav_padded.zero_()
119
+ uv_padded.zero_()
120
+
121
+ for i in range(len(ids_sorted_decreasing)):
122
+ row = batch[ids_sorted_decreasing[i]]
123
+
124
+ c = row[0]
125
+ c_padded[i, :, :c.size(1)] = c
126
+ lengths[i] = c.size(1)
127
+
128
+ f0 = row[1]
129
+ f0_padded[i, :f0.size(0)] = f0
130
+
131
+ spec = row[2]
132
+ spec_padded[i, :, :spec.size(1)] = spec
133
+
134
+ wav = row[3]
135
+ wav_padded[i, :, :wav.size(1)] = wav
136
+
137
+ spkids[i, 0] = row[4]
138
+
139
+ uv = row[5]
140
+ uv_padded[i, :uv.size(0)] = uv
141
+
142
+ return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded
flask_api.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+ import soundfile
5
+ import torch
6
+ import torchaudio
7
+ from flask import Flask, request, send_file
8
+ from flask_cors import CORS
9
+
10
+ from inference.infer_tool import Svc, RealTimeVC
11
+
12
+ app = Flask(__name__)
13
+
14
+ CORS(app)
15
+
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+
18
+
19
+ @app.route("/voiceChangeModel", methods=["POST"])
20
+ def voice_change_model():
21
+ request_form = request.form
22
+ wave_file = request.files.get("sample", None)
23
+ # 变调信息
24
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
25
+ # DAW所需的采样率
26
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
27
+ speaker_id = int(float(request_form.get("sSpeakId", 0)))
28
+ # http获得wav文件并转换
29
+ input_wav_path = io.BytesIO(wave_file.read())
30
+
31
+ # 模型推理
32
+ if raw_infer:
33
+ out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path)
34
+ tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample)
35
+ else:
36
+ out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path)
37
+ tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample)
38
+ # 返回音频
39
+ out_wav_path = io.BytesIO()
40
+ soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav")
41
+ out_wav_path.seek(0)
42
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
43
+
44
+
45
+ if __name__ == '__main__':
46
+ # 启用则为直接切片合成,False为交叉淡化方式
47
+ # vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音
48
+ # 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些
49
+ raw_infer = True
50
+ # 每个模型和config是唯一对应的
51
+ model_name = "logs/32k/G_174000-Copy1.pth"
52
+ config_name = "configs/config.json"
53
+ svc_model = Svc(model_name, config_name)
54
+ svc = RealTimeVC()
55
+ # 此处与vst插件对应,不建议更改
56
+ app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
flask_api_full_song.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import numpy as np
3
+ import soundfile
4
+ from flask import Flask, request, send_file
5
+
6
+ from inference import infer_tool
7
+ from inference import slicer
8
+
9
+ app = Flask(__name__)
10
+
11
+
12
+ @app.route("/wav2wav", methods=["POST"])
13
+ def wav2wav():
14
+ request_form = request.form
15
+ audio_path = request_form.get("audio_path", None) # wav文件地址
16
+ tran = int(float(request_form.get("tran", 0))) # 音调
17
+ spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config)
18
+ wav_format = request_form.get("wav_format", 'wav') # 范围文件格式
19
+ infer_tool.format_wav(audio_path)
20
+ chunks = slicer.cut(audio_path, db_thresh=-40)
21
+ audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)
22
+
23
+ audio = []
24
+ for (slice_tag, data) in audio_data:
25
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
26
+
27
+ length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
28
+ if slice_tag:
29
+ print('jump empty segment')
30
+ _audio = np.zeros(length)
31
+ else:
32
+ # padd
33
+ pad_len = int(audio_sr * 0.5)
34
+ data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
35
+ raw_path = io.BytesIO()
36
+ soundfile.write(raw_path, data, audio_sr, format="wav")
37
+ raw_path.seek(0)
38
+ out_audio, out_sr = svc_model.infer(spk, tran, raw_path)
39
+ svc_model.clear_empty()
40
+ _audio = out_audio.cpu().numpy()
41
+ pad_len = int(svc_model.target_sample * 0.5)
42
+ _audio = _audio[pad_len:-pad_len]
43
+
44
+ audio.extend(list(infer_tool.pad_array(_audio, length)))
45
+ out_wav_path = io.BytesIO()
46
+ soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
47
+ out_wav_path.seek(0)
48
+ return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True)
49
+
50
+
51
+ if __name__ == '__main__':
52
+ model_name = "logs/44k/G_60000.pth" # 模型地址
53
+ config_name = "configs/config.json" # config地址
54
+ svc_model = infer_tool.Svc(model_name, config_name)
55
+ app.run(port=1145, host="0.0.0.0", debug=False, threaded=False)
inference_main.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import librosa
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import soundfile
10
+
11
+ from inference import infer_tool
12
+ from inference import slicer
13
+ from inference.infer_tool import Svc
14
+
15
+ logging.getLogger('numba').setLevel(logging.WARNING)
16
+ chunks_dict = infer_tool.read_temp("inference/chunks_temp.json")
17
+
18
+
19
+
20
+ def main():
21
+ import argparse
22
+
23
+ parser = argparse.ArgumentParser(description='sovits4 inference')
24
+
25
+ # 一定要设置的部分
26
+ parser.add_argument('-m', '--model_path', type=str, default="logs/44k/G_0.pth", help='模型路径')
27
+ parser.add_argument('-c', '--config_path', type=str, default="configs/config.json", help='配置文件路径')
28
+ parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["君の知らない物語-src.wav"], help='wav文件名列表,放在raw文件夹下')
29
+ parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], help='音高调整,支持正负(半音)')
30
+ parser.add_argument('-s', '--spk_list', type=str, nargs='+', default=['nen'], help='合成目标说话人名称')
31
+
32
+ # 可选项部分
33
+ parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=False,
34
+ help='语音转换自动预测音高,转换歌声时不要打开这个会严重跑调')
35
+ parser.add_argument('-cm', '--cluster_model_path', type=str, default="logs/44k/kmeans_10000.pt", help='聚类模型路径,如果没有训练聚类则随便填')
36
+ parser.add_argument('-cr', '--cluster_infer_ratio', type=float, default=0, help='聚类方案占比,范围0-1,若没有训练聚类模型则填0即可')
37
+
38
+ # 不用动的部分
39
+ parser.add_argument('-sd', '--slice_db', type=int, default=-40, help='默认-40,嘈杂的音频可以-30,干声保留呼吸可以-50')
40
+ parser.add_argument('-d', '--device', type=str, default=None, help='推理设备,None则为自动选择cpu和gpu')
41
+ parser.add_argument('-ns', '--noice_scale', type=float, default=0.4, help='噪音级别,会影响咬字和音质,较为玄学')
42
+ parser.add_argument('-p', '--pad_seconds', type=float, default=0.5, help='推理音频pad秒数,由于未知原因开头结尾会有异响,pad一小段静音段后就不会出现')
43
+ parser.add_argument('-wf', '--wav_format', type=str, default='flac', help='音频输出格式')
44
+
45
+ args = parser.parse_args()
46
+
47
+ svc_model = Svc(args.model_path, args.config_path, args.device, args.cluster_model_path)
48
+ infer_tool.mkdir(["raw", "results"])
49
+ clean_names = args.clean_names
50
+ trans = args.trans
51
+ spk_list = args.spk_list
52
+ slice_db = args.slice_db
53
+ wav_format = args.wav_format
54
+ auto_predict_f0 = args.auto_predict_f0
55
+ cluster_infer_ratio = args.cluster_infer_ratio
56
+ noice_scale = args.noice_scale
57
+ pad_seconds = args.pad_seconds
58
+
59
+ infer_tool.fill_a_to_b(trans, clean_names)
60
+ for clean_name, tran in zip(clean_names, trans):
61
+ raw_audio_path = f"raw/{clean_name}"
62
+ if "." not in raw_audio_path:
63
+ raw_audio_path += ".wav"
64
+ infer_tool.format_wav(raw_audio_path)
65
+ wav_path = Path(raw_audio_path).with_suffix('.wav')
66
+ chunks = slicer.cut(wav_path, db_thresh=slice_db)
67
+ audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
68
+
69
+ for spk in spk_list:
70
+ audio = []
71
+ for (slice_tag, data) in audio_data:
72
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
73
+
74
+ length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
75
+ if slice_tag:
76
+ print('jump empty segment')
77
+ _audio = np.zeros(length)
78
+ else:
79
+ # padd
80
+ pad_len = int(audio_sr * pad_seconds)
81
+ data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
82
+ raw_path = io.BytesIO()
83
+ soundfile.write(raw_path, data, audio_sr, format="wav")
84
+ raw_path.seek(0)
85
+ out_audio, out_sr = svc_model.infer(spk, tran, raw_path,
86
+ cluster_infer_ratio=cluster_infer_ratio,
87
+ auto_predict_f0=auto_predict_f0,
88
+ noice_scale=noice_scale
89
+ )
90
+ _audio = out_audio.cpu().numpy()
91
+ pad_len = int(svc_model.target_sample * pad_seconds)
92
+ _audio = _audio[pad_len:-pad_len]
93
+
94
+ audio.extend(list(infer_tool.pad_array(_audio, length)))
95
+ key = "auto" if auto_predict_f0 else f"{tran}key"
96
+ cluster_name = "" if cluster_infer_ratio == 0 else f"_{cluster_infer_ratio}"
97
+ res_path = f'./results/{clean_name}_{key}_{spk}{cluster_name}.{wav_format}'
98
+ soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format)
99
+
100
+ if __name__ == '__main__':
101
+ main()
models.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import modules.attentions as attentions
8
+ import modules.commons as commons
9
+ import modules.modules as modules
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ import utils
15
+ from modules.commons import init_weights, get_padding
16
+ from vdecoder.hifigan.models import Generator
17
+ from utils import f0_to_coarse
18
+
19
+ class ResidualCouplingBlock(nn.Module):
20
+ def __init__(self,
21
+ channels,
22
+ hidden_channels,
23
+ kernel_size,
24
+ dilation_rate,
25
+ n_layers,
26
+ n_flows=4,
27
+ gin_channels=0):
28
+ super().__init__()
29
+ self.channels = channels
30
+ self.hidden_channels = hidden_channels
31
+ self.kernel_size = kernel_size
32
+ self.dilation_rate = dilation_rate
33
+ self.n_layers = n_layers
34
+ self.n_flows = n_flows
35
+ self.gin_channels = gin_channels
36
+
37
+ self.flows = nn.ModuleList()
38
+ for i in range(n_flows):
39
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
40
+ self.flows.append(modules.Flip())
41
+
42
+ def forward(self, x, x_mask, g=None, reverse=False):
43
+ if not reverse:
44
+ for flow in self.flows:
45
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
46
+ else:
47
+ for flow in reversed(self.flows):
48
+ x = flow(x, x_mask, g=g, reverse=reverse)
49
+ return x
50
+
51
+
52
+ class Encoder(nn.Module):
53
+ def __init__(self,
54
+ in_channels,
55
+ out_channels,
56
+ hidden_channels,
57
+ kernel_size,
58
+ dilation_rate,
59
+ n_layers,
60
+ gin_channels=0):
61
+ super().__init__()
62
+ self.in_channels = in_channels
63
+ self.out_channels = out_channels
64
+ self.hidden_channels = hidden_channels
65
+ self.kernel_size = kernel_size
66
+ self.dilation_rate = dilation_rate
67
+ self.n_layers = n_layers
68
+ self.gin_channels = gin_channels
69
+
70
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
71
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
72
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
73
+
74
+ def forward(self, x, x_lengths, g=None):
75
+ # print(x.shape,x_lengths.shape)
76
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
77
+ x = self.pre(x) * x_mask
78
+ x = self.enc(x, x_mask, g=g)
79
+ stats = self.proj(x) * x_mask
80
+ m, logs = torch.split(stats, self.out_channels, dim=1)
81
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
82
+ return z, m, logs, x_mask
83
+
84
+
85
+ class TextEncoder(nn.Module):
86
+ def __init__(self,
87
+ out_channels,
88
+ hidden_channels,
89
+ kernel_size,
90
+ n_layers,
91
+ gin_channels=0,
92
+ filter_channels=None,
93
+ n_heads=None,
94
+ p_dropout=None):
95
+ super().__init__()
96
+ self.out_channels = out_channels
97
+ self.hidden_channels = hidden_channels
98
+ self.kernel_size = kernel_size
99
+ self.n_layers = n_layers
100
+ self.gin_channels = gin_channels
101
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
102
+ self.f0_emb = nn.Embedding(256, hidden_channels)
103
+
104
+ self.enc_ = attentions.Encoder(
105
+ hidden_channels,
106
+ filter_channels,
107
+ n_heads,
108
+ n_layers,
109
+ kernel_size,
110
+ p_dropout)
111
+
112
+ def forward(self, x, x_mask, f0=None, noice_scale=1):
113
+ x = x + self.f0_emb(f0).transpose(1,2)
114
+ x = self.enc_(x * x_mask, x_mask)
115
+ stats = self.proj(x) * x_mask
116
+ m, logs = torch.split(stats, self.out_channels, dim=1)
117
+ z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
118
+
119
+ return z, m, logs, x_mask
120
+
121
+
122
+
123
+ class DiscriminatorP(torch.nn.Module):
124
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
125
+ super(DiscriminatorP, self).__init__()
126
+ self.period = period
127
+ self.use_spectral_norm = use_spectral_norm
128
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
129
+ self.convs = nn.ModuleList([
130
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
131
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
132
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
133
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
134
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
135
+ ])
136
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
137
+
138
+ def forward(self, x):
139
+ fmap = []
140
+
141
+ # 1d to 2d
142
+ b, c, t = x.shape
143
+ if t % self.period != 0: # pad first
144
+ n_pad = self.period - (t % self.period)
145
+ x = F.pad(x, (0, n_pad), "reflect")
146
+ t = t + n_pad
147
+ x = x.view(b, c, t // self.period, self.period)
148
+
149
+ for l in self.convs:
150
+ x = l(x)
151
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
152
+ fmap.append(x)
153
+ x = self.conv_post(x)
154
+ fmap.append(x)
155
+ x = torch.flatten(x, 1, -1)
156
+
157
+ return x, fmap
158
+
159
+
160
+ class DiscriminatorS(torch.nn.Module):
161
+ def __init__(self, use_spectral_norm=False):
162
+ super(DiscriminatorS, self).__init__()
163
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
164
+ self.convs = nn.ModuleList([
165
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
166
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
167
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
168
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
169
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
170
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
171
+ ])
172
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
173
+
174
+ def forward(self, x):
175
+ fmap = []
176
+
177
+ for l in self.convs:
178
+ x = l(x)
179
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
180
+ fmap.append(x)
181
+ x = self.conv_post(x)
182
+ fmap.append(x)
183
+ x = torch.flatten(x, 1, -1)
184
+
185
+ return x, fmap
186
+
187
+
188
+ class MultiPeriodDiscriminator(torch.nn.Module):
189
+ def __init__(self, use_spectral_norm=False):
190
+ super(MultiPeriodDiscriminator, self).__init__()
191
+ periods = [2,3,5,7,11]
192
+
193
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
194
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
195
+ self.discriminators = nn.ModuleList(discs)
196
+
197
+ def forward(self, y, y_hat):
198
+ y_d_rs = []
199
+ y_d_gs = []
200
+ fmap_rs = []
201
+ fmap_gs = []
202
+ for i, d in enumerate(self.discriminators):
203
+ y_d_r, fmap_r = d(y)
204
+ y_d_g, fmap_g = d(y_hat)
205
+ y_d_rs.append(y_d_r)
206
+ y_d_gs.append(y_d_g)
207
+ fmap_rs.append(fmap_r)
208
+ fmap_gs.append(fmap_g)
209
+
210
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
211
+
212
+
213
+ class SpeakerEncoder(torch.nn.Module):
214
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
215
+ super(SpeakerEncoder, self).__init__()
216
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
217
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
218
+ self.relu = nn.ReLU()
219
+
220
+ def forward(self, mels):
221
+ self.lstm.flatten_parameters()
222
+ _, (hidden, _) = self.lstm(mels)
223
+ embeds_raw = self.relu(self.linear(hidden[-1]))
224
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
225
+
226
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
227
+ mel_slices = []
228
+ for i in range(0, total_frames-partial_frames, partial_hop):
229
+ mel_range = torch.arange(i, i+partial_frames)
230
+ mel_slices.append(mel_range)
231
+
232
+ return mel_slices
233
+
234
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
235
+ mel_len = mel.size(1)
236
+ last_mel = mel[:,-partial_frames:]
237
+
238
+ if mel_len > partial_frames:
239
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
240
+ mels = list(mel[:,s] for s in mel_slices)
241
+ mels.append(last_mel)
242
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
243
+
244
+ with torch.no_grad():
245
+ partial_embeds = self(mels)
246
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
247
+ #embed = embed / torch.linalg.norm(embed, 2)
248
+ else:
249
+ with torch.no_grad():
250
+ embed = self(last_mel)
251
+
252
+ return embed
253
+
254
+ class F0Decoder(nn.Module):
255
+ def __init__(self,
256
+ out_channels,
257
+ hidden_channels,
258
+ filter_channels,
259
+ n_heads,
260
+ n_layers,
261
+ kernel_size,
262
+ p_dropout,
263
+ spk_channels=0):
264
+ super().__init__()
265
+ self.out_channels = out_channels
266
+ self.hidden_channels = hidden_channels
267
+ self.filter_channels = filter_channels
268
+ self.n_heads = n_heads
269
+ self.n_layers = n_layers
270
+ self.kernel_size = kernel_size
271
+ self.p_dropout = p_dropout
272
+ self.spk_channels = spk_channels
273
+
274
+ self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
275
+ self.decoder = attentions.FFT(
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout)
282
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
283
+ self.f0_prenet = nn.Conv1d(1, hidden_channels , 3, padding=1)
284
+ self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
285
+
286
+ def forward(self, x, norm_f0, x_mask, spk_emb=None):
287
+ x = torch.detach(x)
288
+ if (spk_emb is not None):
289
+ x = x + self.cond(spk_emb)
290
+ x += self.f0_prenet(norm_f0)
291
+ x = self.prenet(x) * x_mask
292
+ x = self.decoder(x * x_mask, x_mask)
293
+ x = self.proj(x) * x_mask
294
+ return x
295
+
296
+
297
+ class SynthesizerTrn(nn.Module):
298
+ """
299
+ Synthesizer for Training
300
+ """
301
+
302
+ def __init__(self,
303
+ spec_channels,
304
+ segment_size,
305
+ inter_channels,
306
+ hidden_channels,
307
+ filter_channels,
308
+ n_heads,
309
+ n_layers,
310
+ kernel_size,
311
+ p_dropout,
312
+ resblock,
313
+ resblock_kernel_sizes,
314
+ resblock_dilation_sizes,
315
+ upsample_rates,
316
+ upsample_initial_channel,
317
+ upsample_kernel_sizes,
318
+ gin_channels,
319
+ ssl_dim,
320
+ n_speakers,
321
+ sampling_rate=44100,
322
+ **kwargs):
323
+
324
+ super().__init__()
325
+ self.spec_channels = spec_channels
326
+ self.inter_channels = inter_channels
327
+ self.hidden_channels = hidden_channels
328
+ self.filter_channels = filter_channels
329
+ self.n_heads = n_heads
330
+ self.n_layers = n_layers
331
+ self.kernel_size = kernel_size
332
+ self.p_dropout = p_dropout
333
+ self.resblock = resblock
334
+ self.resblock_kernel_sizes = resblock_kernel_sizes
335
+ self.resblock_dilation_sizes = resblock_dilation_sizes
336
+ self.upsample_rates = upsample_rates
337
+ self.upsample_initial_channel = upsample_initial_channel
338
+ self.upsample_kernel_sizes = upsample_kernel_sizes
339
+ self.segment_size = segment_size
340
+ self.gin_channels = gin_channels
341
+ self.ssl_dim = ssl_dim
342
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
343
+
344
+ self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
345
+
346
+ self.enc_p = TextEncoder(
347
+ inter_channels,
348
+ hidden_channels,
349
+ filter_channels=filter_channels,
350
+ n_heads=n_heads,
351
+ n_layers=n_layers,
352
+ kernel_size=kernel_size,
353
+ p_dropout=p_dropout
354
+ )
355
+ hps = {
356
+ "sampling_rate": sampling_rate,
357
+ "inter_channels": inter_channels,
358
+ "resblock": resblock,
359
+ "resblock_kernel_sizes": resblock_kernel_sizes,
360
+ "resblock_dilation_sizes": resblock_dilation_sizes,
361
+ "upsample_rates": upsample_rates,
362
+ "upsample_initial_channel": upsample_initial_channel,
363
+ "upsample_kernel_sizes": upsample_kernel_sizes,
364
+ "gin_channels": gin_channels,
365
+ }
366
+ self.dec = Generator(h=hps)
367
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
368
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
369
+ self.f0_decoder = F0Decoder(
370
+ 1,
371
+ hidden_channels,
372
+ filter_channels,
373
+ n_heads,
374
+ n_layers,
375
+ kernel_size,
376
+ p_dropout,
377
+ spk_channels=gin_channels
378
+ )
379
+ self.emb_uv = nn.Embedding(2, hidden_channels)
380
+
381
+ def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
382
+ g = self.emb_g(g).transpose(1,2)
383
+ # ssl prenet
384
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
385
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2)
386
+
387
+ # f0 predict
388
+ lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
389
+ norm_lf0 = utils.normalize_f0(lf0, x_mask, uv)
390
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
391
+
392
+ # encoder
393
+ z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
394
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
395
+
396
+ # flow
397
+ z_p = self.flow(z, spec_mask, g=g)
398
+ z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(z, f0, spec_lengths, self.segment_size)
399
+
400
+ # nsf decoder
401
+ o = self.dec(z_slice, g=g, f0=pitch_slice)
402
+
403
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0
404
+
405
+ def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
406
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
407
+ g = self.emb_g(g).transpose(1,2)
408
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
409
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1,2)
410
+
411
+ if predict_f0:
412
+ lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
413
+ norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
414
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
415
+ f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
416
+
417
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale)
418
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
419
+ o = self.dec(z * c_mask, g=g, f0=f0)
420
+ return o
onnx_export.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from onnxexport.model_onnx import SynthesizerTrn
3
+ import utils
4
+
5
+ def main(NetExport):
6
+ path = "SoVits4.0"
7
+ if NetExport:
8
+ device = torch.device("cpu")
9
+ hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
10
+ SVCVITS = SynthesizerTrn(
11
+ hps.data.filter_length // 2 + 1,
12
+ hps.train.segment_size // hps.data.hop_length,
13
+ **hps.model)
14
+ _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
15
+ _ = SVCVITS.eval().to(device)
16
+ for i in SVCVITS.parameters():
17
+ i.requires_grad = False
18
+
19
+ test_hidden_unit = torch.rand(1, 10, 256)
20
+ test_pitch = torch.rand(1, 10)
21
+ test_mel2ph = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
22
+ test_uv = torch.ones(1, 10, dtype=torch.float32)
23
+ test_noise = torch.randn(1, 192, 10)
24
+ test_sid = torch.LongTensor([0])
25
+ input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
26
+ output_names = ["audio", ]
27
+
28
+ torch.onnx.export(SVCVITS,
29
+ (
30
+ test_hidden_unit.to(device),
31
+ test_pitch.to(device),
32
+ test_mel2ph.to(device),
33
+ test_uv.to(device),
34
+ test_noise.to(device),
35
+ test_sid.to(device)
36
+ ),
37
+ f"checkpoints/{path}/model.onnx",
38
+ dynamic_axes={
39
+ "c": [0, 1],
40
+ "f0": [1],
41
+ "mel2ph": [1],
42
+ "uv": [1],
43
+ "noise": [2],
44
+ },
45
+ do_constant_folding=False,
46
+ opset_version=16,
47
+ verbose=False,
48
+ input_names=input_names,
49
+ output_names=output_names)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ main(True)
package.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ !svc infer "filename.wav" -m G_754.pth -c "config (2).json"
preprocess_flist_config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import re
4
+
5
+ from tqdm import tqdm
6
+ from random import shuffle
7
+ import json
8
+ import wave
9
+
10
+ config_template = json.load(open("configs_template/config_template.json"))
11
+
12
+ pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
13
+
14
+ def get_wav_duration(file_path):
15
+ with wave.open(file_path, 'rb') as wav_file:
16
+ # 获取音频帧数
17
+ n_frames = wav_file.getnframes()
18
+ # 获取采样率
19
+ framerate = wav_file.getframerate()
20
+ # 计算时长(秒)
21
+ duration = n_frames / float(framerate)
22
+ return duration
23
+
24
+ if __name__ == "__main__":
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
27
+ parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
28
+ parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list")
29
+ parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
30
+ args = parser.parse_args()
31
+
32
+ train = []
33
+ val = []
34
+ test = []
35
+ idx = 0
36
+ spk_dict = {}
37
+ spk_id = 0
38
+ for speaker in tqdm(os.listdir(args.source_dir)):
39
+ spk_dict[speaker] = spk_id
40
+ spk_id += 1
41
+ wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))]
42
+ new_wavs = []
43
+ for file in wavs:
44
+ if not file.endswith("wav"):
45
+ continue
46
+ if not pattern.match(file):
47
+ print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
48
+ if get_wav_duration(file) < 0.3:
49
+ print("skip too short audio:", file)
50
+ continue
51
+ new_wavs.append(file)
52
+ wavs = new_wavs
53
+ shuffle(wavs)
54
+ train += wavs[2:-2]
55
+ val += wavs[:2]
56
+ test += wavs[-2:]
57
+
58
+ shuffle(train)
59
+ shuffle(val)
60
+ shuffle(test)
61
+
62
+ print("Writing", args.train_list)
63
+ with open(args.train_list, "w") as f:
64
+ for fname in tqdm(train):
65
+ wavpath = fname
66
+ f.write(wavpath + "\n")
67
+
68
+ print("Writing", args.val_list)
69
+ with open(args.val_list, "w") as f:
70
+ for fname in tqdm(val):
71
+ wavpath = fname
72
+ f.write(wavpath + "\n")
73
+
74
+ print("Writing", args.test_list)
75
+ with open(args.test_list, "w") as f:
76
+ for fname in tqdm(test):
77
+ wavpath = fname
78
+ f.write(wavpath + "\n")
79
+
80
+ config_template["spk"] = spk_dict
81
+ print("Writing configs/config.json")
82
+ with open("configs/config.json", "w") as f:
83
+ json.dump(config_template, f, indent=2)
preprocess_hubert_f0.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import multiprocessing
3
+ import os
4
+ import argparse
5
+ from random import shuffle
6
+
7
+ import torch
8
+ from glob import glob
9
+ from tqdm import tqdm
10
+
11
+ import utils
12
+ import logging
13
+ logging.getLogger('numba').setLevel(logging.WARNING)
14
+ import librosa
15
+ import numpy as np
16
+
17
+ hps = utils.get_hparams_from_file("configs/config.json")
18
+ sampling_rate = hps.data.sampling_rate
19
+ hop_length = hps.data.hop_length
20
+
21
+
22
+ def process_one(filename, hmodel):
23
+ # print(filename)
24
+ wav, sr = librosa.load(filename, sr=sampling_rate)
25
+ soft_path = filename + ".soft.pt"
26
+ if not os.path.exists(soft_path):
27
+ devive = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)
29
+ wav16k = torch.from_numpy(wav16k).to(devive)
30
+ c = utils.get_hubert_content(hmodel, wav_16k_tensor=wav16k)
31
+ torch.save(c.cpu(), soft_path)
32
+ f0_path = filename + ".f0.npy"
33
+ if not os.path.exists(f0_path):
34
+ f0 = utils.compute_f0_dio(wav, sampling_rate=sampling_rate, hop_length=hop_length)
35
+ np.save(f0_path, f0)
36
+
37
+
38
+ def process_batch(filenames):
39
+ print("Loading hubert for content...")
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ hmodel = utils.get_hubert_model().to(device)
42
+ print("Loaded hubert.")
43
+ for filename in tqdm(filenames):
44
+ process_one(filename, hmodel)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--in_dir", type=str, default="dataset/44k", help="path to input dir")
50
+
51
+ args = parser.parse_args()
52
+ filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True) # [:10]
53
+ shuffle(filenames)
54
+ multiprocessing.set_start_method('spawn',force=True)
55
+
56
+ num_processes = 1
57
+ chunk_size = int(math.ceil(len(filenames) / num_processes))
58
+ chunks = [filenames[i:i + chunk_size] for i in range(0, len(filenames), chunk_size)]
59
+ print([len(c) for c in chunks])
60
+ processes = [multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks]
61
+ for p in processes:
62
+ p.start()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ Flask_Cors
3
+ gradio
4
+ numpy
5
+ pyworld==0.2.5
6
+ scipy==1.7.3
7
+ SoundFile==0.12.1
8
+ torch==1.13.1
9
+ torchaudio==0.13.1
10
+ tqdm
11
+ scikit-maad
12
+ praat-parselmouth
13
+ onnx
14
+ onnxsim
15
+ onnxoptimizer
16
+ fairseq==0.12.2
17
+ librosa==0.8.1
18
+ tensorboard
requirements_win.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.2
2
+ fairseq==0.12.2
3
+ Flask==2.1.2
4
+ Flask_Cors==3.0.10
5
+ gradio==3.4.1
6
+ numpy==1.20.0
7
+ playsound==1.3.0
8
+ PyAudio==0.2.12
9
+ pydub==0.25.1
10
+ pyworld==0.3.0
11
+ requests==2.28.1
12
+ scipy==1.7.3
13
+ sounddevice==0.4.5
14
+ SoundFile==0.10.3.post1
15
+ starlette==0.19.1
16
+ tqdm==4.63.0
17
+ scikit-maad
18
+ praat-parselmouth
19
+ onnx
20
+ onnxsim
21
+ onnxoptimizer
resample.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import librosa
4
+ import numpy as np
5
+ from multiprocessing import Pool, cpu_count
6
+ from scipy.io import wavfile
7
+ from tqdm import tqdm
8
+
9
+
10
+ def process(item):
11
+ spkdir, wav_name, args = item
12
+ # speaker 's5', 'p280', 'p315' are excluded,
13
+ speaker = spkdir.replace("\\", "/").split("/")[-1]
14
+ wav_path = os.path.join(args.in_dir, speaker, wav_name)
15
+ if os.path.exists(wav_path) and '.wav' in wav_path:
16
+ os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
17
+ wav, sr = librosa.load(wav_path, sr=None)
18
+ wav, _ = librosa.effects.trim(wav, top_db=20)
19
+ peak = np.abs(wav).max()
20
+ if peak > 1.0:
21
+ wav = 0.98 * wav / peak
22
+ wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2)
23
+ wav2 /= max(wav2.max(), -wav2.min())
24
+ save_name = wav_name
25
+ save_path2 = os.path.join(args.out_dir2, speaker, save_name)
26
+ wavfile.write(
27
+ save_path2,
28
+ args.sr2,
29
+ (wav2 * np.iinfo(np.int16).max).astype(np.int16)
30
+ )
31
+
32
+
33
+
34
+ if __name__ == "__main__":
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
37
+ parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
38
+ parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
39
+ args = parser.parse_args()
40
+ processs = cpu_count()-2 if cpu_count() >4 else 1
41
+ pool = Pool(processes=processs)
42
+
43
+ for speaker in os.listdir(args.in_dir):
44
+ spk_dir = os.path.join(args.in_dir, speaker)
45
+ if os.path.isdir(spk_dir):
46
+ print(spk_dir)
47
+ for _ in tqdm(pool.imap_unordered(process, [(spk_dir, i, args) for i in os.listdir(spk_dir) if i.endswith("wav")])):
48
+ pass
spec_gen.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_utils import TextAudioSpeakerLoader
2
+ import json
3
+ from tqdm import tqdm
4
+
5
+ from utils import HParams
6
+
7
+ config_path = 'configs/config.json'
8
+ with open(config_path, "r") as f:
9
+ data = f.read()
10
+ config = json.loads(data)
11
+ hps = HParams(**config)
12
+
13
+ train_dataset = TextAudioSpeakerLoader("filelists/train.txt", hps)
14
+ test_dataset = TextAudioSpeakerLoader("filelists/test.txt", hps)
15
+ eval_dataset = TextAudioSpeakerLoader("filelists/val.txt", hps)
16
+
17
+ for _ in tqdm(train_dataset):
18
+ pass
19
+ for _ in tqdm(eval_dataset):
20
+ pass
21
+ for _ in tqdm(test_dataset):
22
+ pass
train.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import multiprocessing
3
+ import time
4
+
5
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
6
+ import os
7
+ import json
8
+ import argparse
9
+ import itertools
10
+ import math
11
+ import torch
12
+ from torch import nn, optim
13
+ from torch.nn import functional as F
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ import torch.multiprocessing as mp
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ from torch.cuda.amp import autocast, GradScaler
20
+
21
+ import modules.commons as commons
22
+ import utils
23
+ from data_utils import TextAudioSpeakerLoader, TextAudioCollate
24
+ from models import (
25
+ SynthesizerTrn,
26
+ MultiPeriodDiscriminator,
27
+ )
28
+ from modules.losses import (
29
+ kl_loss,
30
+ generator_loss, discriminator_loss, feature_loss
31
+ )
32
+
33
+ from modules.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34
+
35
+ torch.backends.cudnn.benchmark = True
36
+ global_step = 0
37
+ start_time = time.time()
38
+
39
+ # os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
40
+
41
+
42
+ def main():
43
+ """Assume Single Node Multi GPUs Training Only"""
44
+ assert torch.cuda.is_available(), "CPU training is not allowed."
45
+ hps = utils.get_hparams()
46
+
47
+ n_gpus = torch.cuda.device_count()
48
+ os.environ['MASTER_ADDR'] = 'localhost'
49
+ os.environ['MASTER_PORT'] = hps.train.port
50
+
51
+ mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
52
+
53
+
54
+ def run(rank, n_gpus, hps):
55
+ global global_step
56
+ if rank == 0:
57
+ logger = utils.get_logger(hps.model_dir)
58
+ logger.info(hps)
59
+ utils.check_git_hash(hps.model_dir)
60
+ writer = SummaryWriter(log_dir=hps.model_dir)
61
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
62
+
63
+ # for pytorch on win, backend use gloo
64
+ dist.init_process_group(backend= 'gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, rank=rank)
65
+ torch.manual_seed(hps.train.seed)
66
+ torch.cuda.set_device(rank)
67
+ collate_fn = TextAudioCollate()
68
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
69
+ num_workers = 5 if multiprocessing.cpu_count() > 4 else multiprocessing.cpu_count()
70
+ train_loader = DataLoader(train_dataset, num_workers=num_workers, shuffle=False, pin_memory=True,
71
+ batch_size=hps.train.batch_size, collate_fn=collate_fn)
72
+ if rank == 0:
73
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
74
+ eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
75
+ batch_size=1, pin_memory=False,
76
+ drop_last=False, collate_fn=collate_fn)
77
+
78
+ net_g = SynthesizerTrn(
79
+ hps.data.filter_length // 2 + 1,
80
+ hps.train.segment_size // hps.data.hop_length,
81
+ **hps.model).cuda(rank)
82
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
83
+ optim_g = torch.optim.AdamW(
84
+ net_g.parameters(),
85
+ hps.train.learning_rate,
86
+ betas=hps.train.betas,
87
+ eps=hps.train.eps)
88
+ optim_d = torch.optim.AdamW(
89
+ net_d.parameters(),
90
+ hps.train.learning_rate,
91
+ betas=hps.train.betas,
92
+ eps=hps.train.eps)
93
+ net_g = DDP(net_g, device_ids=[rank]) # , find_unused_parameters=True)
94
+ net_d = DDP(net_d, device_ids=[rank])
95
+
96
+ skip_optimizer = False
97
+ try:
98
+ _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
99
+ optim_g, skip_optimizer)
100
+ _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
101
+ optim_d, skip_optimizer)
102
+ epoch_str = max(epoch_str, 1)
103
+ global_step = (epoch_str - 1) * len(train_loader)
104
+ except:
105
+ print("load old checkpoint failed...")
106
+ epoch_str = 1
107
+ global_step = 0
108
+ if skip_optimizer:
109
+ epoch_str = 1
110
+ global_step = 0
111
+
112
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
113
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
114
+
115
+ scaler = GradScaler(enabled=hps.train.fp16_run)
116
+
117
+ for epoch in range(epoch_str, hps.train.epochs + 1):
118
+ if rank == 0:
119
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
120
+ [train_loader, eval_loader], logger, [writer, writer_eval])
121
+ else:
122
+ train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
123
+ [train_loader, None], None, None)
124
+ scheduler_g.step()
125
+ scheduler_d.step()
126
+
127
+
128
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
129
+ net_g, net_d = nets
130
+ optim_g, optim_d = optims
131
+ scheduler_g, scheduler_d = schedulers
132
+ train_loader, eval_loader = loaders
133
+ if writers is not None:
134
+ writer, writer_eval = writers
135
+
136
+ # train_loader.batch_sampler.set_epoch(epoch)
137
+ global global_step
138
+
139
+ net_g.train()
140
+ net_d.train()
141
+ for batch_idx, items in enumerate(train_loader):
142
+ c, f0, spec, y, spk, lengths, uv = items
143
+ g = spk.cuda(rank, non_blocking=True)
144
+ spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
145
+ c = c.cuda(rank, non_blocking=True)
146
+ f0 = f0.cuda(rank, non_blocking=True)
147
+ uv = uv.cuda(rank, non_blocking=True)
148
+ lengths = lengths.cuda(rank, non_blocking=True)
149
+ mel = spec_to_mel_torch(
150
+ spec,
151
+ hps.data.filter_length,
152
+ hps.data.n_mel_channels,
153
+ hps.data.sampling_rate,
154
+ hps.data.mel_fmin,
155
+ hps.data.mel_fmax)
156
+
157
+ with autocast(enabled=hps.train.fp16_run):
158
+ y_hat, ids_slice, z_mask, \
159
+ (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths,
160
+ spec_lengths=lengths)
161
+
162
+ y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
163
+ y_hat_mel = mel_spectrogram_torch(
164
+ y_hat.squeeze(1),
165
+ hps.data.filter_length,
166
+ hps.data.n_mel_channels,
167
+ hps.data.sampling_rate,
168
+ hps.data.hop_length,
169
+ hps.data.win_length,
170
+ hps.data.mel_fmin,
171
+ hps.data.mel_fmax
172
+ )
173
+ y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
174
+
175
+ # Discriminator
176
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
177
+
178
+ with autocast(enabled=False):
179
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
180
+ loss_disc_all = loss_disc
181
+
182
+ optim_d.zero_grad()
183
+ scaler.scale(loss_disc_all).backward()
184
+ scaler.unscale_(optim_d)
185
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
186
+ scaler.step(optim_d)
187
+
188
+ with autocast(enabled=hps.train.fp16_run):
189
+ # Generator
190
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
191
+ with autocast(enabled=False):
192
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
193
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
194
+ loss_fm = feature_loss(fmap_r, fmap_g)
195
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
196
+ loss_lf0 = F.mse_loss(pred_lf0, lf0)
197
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
198
+ optim_g.zero_grad()
199
+ scaler.scale(loss_gen_all).backward()
200
+ scaler.unscale_(optim_g)
201
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
202
+ scaler.step(optim_g)
203
+ scaler.update()
204
+
205
+ if rank == 0:
206
+ if global_step % hps.train.log_interval == 0:
207
+ lr = optim_g.param_groups[0]['lr']
208
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
209
+ logger.info('Train Epoch: {} [{:.0f}%]'.format(
210
+ epoch,
211
+ 100. * batch_idx / len(train_loader)))
212
+ logger.info(f"Losses: {[x.item() for x in losses]}, step: {global_step}, lr: {lr}")
213
+
214
+ scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
215
+ "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
216
+ scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl,
217
+ "loss/g/lf0": loss_lf0})
218
+
219
+ # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
220
+ # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
221
+ # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
222
+ image_dict = {
223
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
224
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
225
+ "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
226
+ "all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
227
+ pred_lf0[0, 0, :].detach().cpu().numpy()),
228
+ "all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
229
+ norm_lf0[0, 0, :].detach().cpu().numpy())
230
+ }
231
+
232
+ utils.summarize(
233
+ writer=writer,
234
+ global_step=global_step,
235
+ images=image_dict,
236
+ scalars=scalar_dict
237
+ )
238
+
239
+ if global_step % hps.train.eval_interval == 0:
240
+ evaluate(hps, net_g, eval_loader, writer_eval)
241
+ utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,
242
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
243
+ utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,
244
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
245
+ keep_ckpts = getattr(hps.train, 'keep_ckpts', 0)
246
+ if keep_ckpts > 0:
247
+ utils.clean_checkpoints(path_to_models=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
248
+
249
+ global_step += 1
250
+
251
+ if rank == 0:
252
+ global start_time
253
+ now = time.time()
254
+ durtaion = format(now - start_time, '.2f')
255
+ logger.info(f'====> Epoch: {epoch}, cost {durtaion} s')
256
+ start_time = now
257
+
258
+
259
+ def evaluate(hps, generator, eval_loader, writer_eval):
260
+ generator.eval()
261
+ image_dict = {}
262
+ audio_dict = {}
263
+ with torch.no_grad():
264
+ for batch_idx, items in enumerate(eval_loader):
265
+ c, f0, spec, y, spk, _, uv = items
266
+ g = spk[:1].cuda(0)
267
+ spec, y = spec[:1].cuda(0), y[:1].cuda(0)
268
+ c = c[:1].cuda(0)
269
+ f0 = f0[:1].cuda(0)
270
+ uv= uv[:1].cuda(0)
271
+ mel = spec_to_mel_torch(
272
+ spec,
273
+ hps.data.filter_length,
274
+ hps.data.n_mel_channels,
275
+ hps.data.sampling_rate,
276
+ hps.data.mel_fmin,
277
+ hps.data.mel_fmax)
278
+ y_hat = generator.module.infer(c, f0, uv, g=g)
279
+
280
+ y_hat_mel = mel_spectrogram_torch(
281
+ y_hat.squeeze(1).float(),
282
+ hps.data.filter_length,
283
+ hps.data.n_mel_channels,
284
+ hps.data.sampling_rate,
285
+ hps.data.hop_length,
286
+ hps.data.win_length,
287
+ hps.data.mel_fmin,
288
+ hps.data.mel_fmax
289
+ )
290
+
291
+ audio_dict.update({
292
+ f"gen/audio_{batch_idx}": y_hat[0],
293
+ f"gt/audio_{batch_idx}": y[0]
294
+ })
295
+ image_dict.update({
296
+ f"gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()),
297
+ "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())
298
+ })
299
+ utils.summarize(
300
+ writer=writer_eval,
301
+ global_step=global_step,
302
+ images=image_dict,
303
+ audios=audio_dict,
304
+ audio_sampling_rate=hps.data.sampling_rate
305
+ )
306
+ generator.train()
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()
utils.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import sys
5
+ import argparse
6
+ import logging
7
+ import json
8
+ import subprocess
9
+ import random
10
+
11
+ import librosa
12
+ import numpy as np
13
+ from scipy.io.wavfile import read
14
+ import torch
15
+ from torch.nn import functional as F
16
+ from modules.commons import sequence_mask
17
+ from hubert import hubert_model
18
+ MATPLOTLIB_FLAG = False
19
+
20
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
21
+ logger = logging
22
+
23
+ f0_bin = 256
24
+ f0_max = 1100.0
25
+ f0_min = 50.0
26
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
27
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
28
+
29
+
30
+ # def normalize_f0(f0, random_scale=True):
31
+ # f0_norm = f0.clone() # create a copy of the input Tensor
32
+ # batch_size, _, frame_length = f0_norm.shape
33
+ # for i in range(batch_size):
34
+ # means = torch.mean(f0_norm[i, 0, :])
35
+ # if random_scale:
36
+ # factor = random.uniform(0.8, 1.2)
37
+ # else:
38
+ # factor = 1
39
+ # f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor
40
+ # return f0_norm
41
+ # def normalize_f0(f0, random_scale=True):
42
+ # means = torch.mean(f0[:, 0, :], dim=1, keepdim=True)
43
+ # if random_scale:
44
+ # factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device)
45
+ # else:
46
+ # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
47
+ # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
48
+ # return f0_norm
49
+ def normalize_f0(f0, x_mask, uv, random_scale=True):
50
+ # calculate means based on x_mask
51
+ uv_sum = torch.sum(uv, dim=1, keepdim=True)
52
+ uv_sum[uv_sum == 0] = 9999
53
+ means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
54
+
55
+ if random_scale:
56
+ factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
57
+ else:
58
+ factor = torch.ones(f0.shape[0], 1).to(f0.device)
59
+ # normalize f0 based on means and factor
60
+ f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
61
+ if torch.isnan(f0_norm).any():
62
+ exit(0)
63
+ return f0_norm * x_mask
64
+
65
+
66
+ def plot_data_to_numpy(x, y):
67
+ global MATPLOTLIB_FLAG
68
+ if not MATPLOTLIB_FLAG:
69
+ import matplotlib
70
+ matplotlib.use("Agg")
71
+ MATPLOTLIB_FLAG = True
72
+ mpl_logger = logging.getLogger('matplotlib')
73
+ mpl_logger.setLevel(logging.WARNING)
74
+ import matplotlib.pylab as plt
75
+ import numpy as np
76
+
77
+ fig, ax = plt.subplots(figsize=(10, 2))
78
+ plt.plot(x)
79
+ plt.plot(y)
80
+ plt.tight_layout()
81
+
82
+ fig.canvas.draw()
83
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
84
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
85
+ plt.close()
86
+ return data
87
+
88
+
89
+
90
+ def interpolate_f0(f0):
91
+ '''
92
+ 对F0进行插值处理
93
+ '''
94
+
95
+ data = np.reshape(f0, (f0.size, 1))
96
+
97
+ vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
98
+ vuv_vector[data > 0.0] = 1.0
99
+ vuv_vector[data <= 0.0] = 0.0
100
+
101
+ ip_data = data
102
+
103
+ frame_number = data.size
104
+ last_value = 0.0
105
+ for i in range(frame_number):
106
+ if data[i] <= 0.0:
107
+ j = i + 1
108
+ for j in range(i + 1, frame_number):
109
+ if data[j] > 0.0:
110
+ break
111
+ if j < frame_number - 1:
112
+ if last_value > 0.0:
113
+ step = (data[j] - data[i - 1]) / float(j - i)
114
+ for k in range(i, j):
115
+ ip_data[k] = data[i - 1] + step * (k - i + 1)
116
+ else:
117
+ for k in range(i, j):
118
+ ip_data[k] = data[j]
119
+ else:
120
+ for k in range(i, frame_number):
121
+ ip_data[k] = last_value
122
+ else:
123
+ ip_data[i] = data[i]
124
+ last_value = data[i]
125
+
126
+ return ip_data[:,0], vuv_vector[:,0]
127
+
128
+
129
+ def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
130
+ import parselmouth
131
+ x = wav_numpy
132
+ if p_len is None:
133
+ p_len = x.shape[0]//hop_length
134
+ else:
135
+ assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
136
+ time_step = hop_length / sampling_rate * 1000
137
+ f0_min = 50
138
+ f0_max = 1100
139
+ f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
140
+ time_step=time_step / 1000, voicing_threshold=0.6,
141
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
142
+
143
+ pad_size=(p_len - len(f0) + 1) // 2
144
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
145
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
146
+ return f0
147
+
148
+ def resize_f0(x, target_len):
149
+ source = np.array(x)
150
+ source[source<0.001] = np.nan
151
+ target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
152
+ res = np.nan_to_num(target)
153
+ return res
154
+
155
+ def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
156
+ import pyworld
157
+ if p_len is None:
158
+ p_len = wav_numpy.shape[0]//hop_length
159
+ f0, t = pyworld.dio(
160
+ wav_numpy.astype(np.double),
161
+ fs=sampling_rate,
162
+ f0_ceil=800,
163
+ frame_period=1000 * hop_length / sampling_rate,
164
+ )
165
+ f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
166
+ for index, pitch in enumerate(f0):
167
+ f0[index] = round(pitch, 1)
168
+ return resize_f0(f0, p_len)
169
+
170
+ def f0_to_coarse(f0):
171
+ is_torch = isinstance(f0, torch.Tensor)
172
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
173
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
174
+
175
+ f0_mel[f0_mel <= 1] = 1
176
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
177
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
178
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
179
+ return f0_coarse
180
+
181
+
182
+ def get_hubert_model():
183
+ vec_path = "hubert/checkpoint_best_legacy_500.pt"
184
+ print("load model(s) from {}".format(vec_path))
185
+ from fairseq import checkpoint_utils
186
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
187
+ [vec_path],
188
+ suffix="",
189
+ )
190
+ model = models[0]
191
+ model.eval()
192
+ return model
193
+
194
+ def get_hubert_content(hmodel, wav_16k_tensor):
195
+ feats = wav_16k_tensor
196
+ if feats.dim() == 2: # double channels
197
+ feats = feats.mean(-1)
198
+ assert feats.dim() == 1, feats.dim()
199
+ feats = feats.view(1, -1)
200
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
201
+ inputs = {
202
+ "source": feats.to(wav_16k_tensor.device),
203
+ "padding_mask": padding_mask.to(wav_16k_tensor.device),
204
+ "output_layer": 9, # layer 9
205
+ }
206
+ with torch.no_grad():
207
+ logits = hmodel.extract_features(**inputs)
208
+ feats = hmodel.final_proj(logits[0])
209
+ return feats.transpose(1, 2)
210
+
211
+
212
+ def get_content(cmodel, y):
213
+ with torch.no_grad():
214
+ c = cmodel.extract_features(y.squeeze(1))[0]
215
+ c = c.transpose(1, 2)
216
+ return c
217
+
218
+
219
+
220
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
221
+ assert os.path.isfile(checkpoint_path)
222
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
223
+ iteration = checkpoint_dict['iteration']
224
+ learning_rate = checkpoint_dict['learning_rate']
225
+ if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
226
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
227
+ saved_state_dict = checkpoint_dict['model']
228
+ if hasattr(model, 'module'):
229
+ state_dict = model.module.state_dict()
230
+ else:
231
+ state_dict = model.state_dict()
232
+ new_state_dict = {}
233
+ for k, v in state_dict.items():
234
+ try:
235
+ # assert "dec" in k or "disc" in k
236
+ # print("load", k)
237
+ new_state_dict[k] = saved_state_dict[k]
238
+ assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
239
+ except:
240
+ print("error, %s is not in the checkpoint" % k)
241
+ logger.info("%s is not in the checkpoint" % k)
242
+ new_state_dict[k] = v
243
+ if hasattr(model, 'module'):
244
+ model.module.load_state_dict(new_state_dict)
245
+ else:
246
+ model.load_state_dict(new_state_dict)
247
+ print("load ")
248
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
249
+ checkpoint_path, iteration))
250
+ return model, optimizer, learning_rate, iteration
251
+
252
+
253
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
254
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
255
+ iteration, checkpoint_path))
256
+ if hasattr(model, 'module'):
257
+ state_dict = model.module.state_dict()
258
+ else:
259
+ state_dict = model.state_dict()
260
+ torch.save({'model': state_dict,
261
+ 'iteration': iteration,
262
+ 'optimizer': optimizer.state_dict(),
263
+ 'learning_rate': learning_rate}, checkpoint_path)
264
+
265
+ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
266
+ """Freeing up space by deleting saved ckpts
267
+
268
+ Arguments:
269
+ path_to_models -- Path to the model directory
270
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
271
+ sort_by_time -- True -> chronologically delete ckpts
272
+ False -> lexicographically delete ckpts
273
+ """
274
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
275
+ name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
276
+ time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
277
+ sort_key = time_key if sort_by_time else name_key
278
+ x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
279
+ to_del = [os.path.join(path_to_models, fn) for fn in
280
+ (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
281
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
282
+ del_routine = lambda x: [os.remove(x), del_info(x)]
283
+ rs = [del_routine(fn) for fn in to_del]
284
+
285
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
286
+ for k, v in scalars.items():
287
+ writer.add_scalar(k, v, global_step)
288
+ for k, v in histograms.items():
289
+ writer.add_histogram(k, v, global_step)
290
+ for k, v in images.items():
291
+ writer.add_image(k, v, global_step, dataformats='HWC')
292
+ for k, v in audios.items():
293
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
294
+
295
+
296
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
297
+ f_list = glob.glob(os.path.join(dir_path, regex))
298
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
299
+ x = f_list[-1]
300
+ print(x)
301
+ return x
302
+
303
+
304
+ def plot_spectrogram_to_numpy(spectrogram):
305
+ global MATPLOTLIB_FLAG
306
+ if not MATPLOTLIB_FLAG:
307
+ import matplotlib
308
+ matplotlib.use("Agg")
309
+ MATPLOTLIB_FLAG = True
310
+ mpl_logger = logging.getLogger('matplotlib')
311
+ mpl_logger.setLevel(logging.WARNING)
312
+ import matplotlib.pylab as plt
313
+ import numpy as np
314
+
315
+ fig, ax = plt.subplots(figsize=(10,2))
316
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
317
+ interpolation='none')
318
+ plt.colorbar(im, ax=ax)
319
+ plt.xlabel("Frames")
320
+ plt.ylabel("Channels")
321
+ plt.tight_layout()
322
+
323
+ fig.canvas.draw()
324
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
325
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
326
+ plt.close()
327
+ return data
328
+
329
+
330
+ def plot_alignment_to_numpy(alignment, info=None):
331
+ global MATPLOTLIB_FLAG
332
+ if not MATPLOTLIB_FLAG:
333
+ import matplotlib
334
+ matplotlib.use("Agg")
335
+ MATPLOTLIB_FLAG = True
336
+ mpl_logger = logging.getLogger('matplotlib')
337
+ mpl_logger.setLevel(logging.WARNING)
338
+ import matplotlib.pylab as plt
339
+ import numpy as np
340
+
341
+ fig, ax = plt.subplots(figsize=(6, 4))
342
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
343
+ interpolation='none')
344
+ fig.colorbar(im, ax=ax)
345
+ xlabel = 'Decoder timestep'
346
+ if info is not None:
347
+ xlabel += '\n\n' + info
348
+ plt.xlabel(xlabel)
349
+ plt.ylabel('Encoder timestep')
350
+ plt.tight_layout()
351
+
352
+ fig.canvas.draw()
353
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
354
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
355
+ plt.close()
356
+ return data
357
+
358
+
359
+ def load_wav_to_torch(full_path):
360
+ sampling_rate, data = read(full_path)
361
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
362
+
363
+
364
+ def load_filepaths_and_text(filename, split="|"):
365
+ with open(filename, encoding='utf-8') as f:
366
+ filepaths_and_text = [line.strip().split(split) for line in f]
367
+ return filepaths_and_text
368
+
369
+
370
+ def get_hparams(init=True):
371
+ parser = argparse.ArgumentParser()
372
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
373
+ help='JSON file for configuration')
374
+ parser.add_argument('-m', '--model', type=str, required=True,
375
+ help='Model name')
376
+
377
+ args = parser.parse_args()
378
+ model_dir = os.path.join("./logs", args.model)
379
+
380
+ if not os.path.exists(model_dir):
381
+ os.makedirs(model_dir)
382
+
383
+ config_path = args.config
384
+ config_save_path = os.path.join(model_dir, "config.json")
385
+ if init:
386
+ with open(config_path, "r") as f:
387
+ data = f.read()
388
+ with open(config_save_path, "w") as f:
389
+ f.write(data)
390
+ else:
391
+ with open(config_save_path, "r") as f:
392
+ data = f.read()
393
+ config = json.loads(data)
394
+
395
+ hparams = HParams(**config)
396
+ hparams.model_dir = model_dir
397
+ return hparams
398
+
399
+
400
+ def get_hparams_from_dir(model_dir):
401
+ config_save_path = os.path.join(model_dir, "config.json")
402
+ with open(config_save_path, "r") as f:
403
+ data = f.read()
404
+ config = json.loads(data)
405
+
406
+ hparams =HParams(**config)
407
+ hparams.model_dir = model_dir
408
+ return hparams
409
+
410
+
411
+ def get_hparams_from_file(config_path):
412
+ with open(config_path, "r") as f:
413
+ data = f.read()
414
+ config = json.loads(data)
415
+
416
+ hparams =HParams(**config)
417
+ return hparams
418
+
419
+
420
+ def check_git_hash(model_dir):
421
+ source_dir = os.path.dirname(os.path.realpath(__file__))
422
+ if not os.path.exists(os.path.join(source_dir, ".git")):
423
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
424
+ source_dir
425
+ ))
426
+ return
427
+
428
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
429
+
430
+ path = os.path.join(model_dir, "githash")
431
+ if os.path.exists(path):
432
+ saved_hash = open(path).read()
433
+ if saved_hash != cur_hash:
434
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
435
+ saved_hash[:8], cur_hash[:8]))
436
+ else:
437
+ open(path, "w").write(cur_hash)
438
+
439
+
440
+ def get_logger(model_dir, filename="train.log"):
441
+ global logger
442
+ logger = logging.getLogger(os.path.basename(model_dir))
443
+ logger.setLevel(logging.DEBUG)
444
+
445
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
446
+ if not os.path.exists(model_dir):
447
+ os.makedirs(model_dir)
448
+ h = logging.FileHandler(os.path.join(model_dir, filename))
449
+ h.setLevel(logging.DEBUG)
450
+ h.setFormatter(formatter)
451
+ logger.addHandler(h)
452
+ return logger
453
+
454
+
455
+ def repeat_expand_2d(content, target_len):
456
+ # content : [h, t]
457
+
458
+ src_len = content.shape[-1]
459
+ target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
460
+ temp = torch.arange(src_len+1) * target_len / src_len
461
+ current_pos = 0
462
+ for i in range(target_len):
463
+ if i < temp[current_pos+1]:
464
+ target[:, i] = content[:, current_pos]
465
+ else:
466
+ current_pos += 1
467
+ target[:, i] = content[:, current_pos]
468
+
469
+ return target
470
+
471
+
472
+ class HParams():
473
+ def __init__(self, **kwargs):
474
+ for k, v in kwargs.items():
475
+ if type(v) == dict:
476
+ v = HParams(**v)
477
+ self[k] = v
478
+
479
+ def keys(self):
480
+ return self.__dict__.keys()
481
+
482
+ def items(self):
483
+ return self.__dict__.items()
484
+
485
+ def values(self):
486
+ return self.__dict__.values()
487
+
488
+ def __len__(self):
489
+ return len(self.__dict__)
490
+
491
+ def __getitem__(self, key):
492
+ return getattr(self, key)
493
+
494
+ def __setitem__(self, key, value):
495
+ return setattr(self, key, value)
496
+
497
+ def __contains__(self, key):
498
+ return key in self.__dict__
499
+
500
+ def __repr__(self):
501
+ return self.__dict__.__repr__()
502
+
wav_upload.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.colab import files
2
+ import shutil
3
+ import os
4
+ import argparse
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--type", type=str, required=True, help="type of file to upload")
8
+ args = parser.parse_args()
9
+ file_type = args.type
10
+
11
+ basepath = os.getcwd()
12
+ uploaded = files.upload() # 上传文件
13
+ assert(file_type in ['zip', 'audio'])
14
+ if file_type == "zip":
15
+ upload_path = "./upload/"
16
+ for filename in uploaded.keys():
17
+ #将上传的文件移动到指定的位置上
18
+ shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, "userzip.zip"))
19
+ elif file_type == "audio":
20
+ upload_path = "./raw/"
21
+ for filename in uploaded.keys():
22
+ #将上传的文件移动到指定的位置上
23
+ shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename))