jack.li commited on
Commit
6340f23
·
1 Parent(s): dc02c06
Dockerfile CHANGED
@@ -25,7 +25,7 @@ ENV HOME=/home/user \
25
  # Set the working directory to the user's home directory
26
  WORKDIR $HOME/app
27
 
28
- RUN pip install tqdm
29
 
30
  COPY ./download.py .
31
 
 
25
  # Set the working directory to the user's home directory
26
  WORKDIR $HOME/app
27
 
28
+ RUN pip install tqdm nltk
29
 
30
  COPY ./download.py .
31
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (21 kB). View file
 
__pycache__/my_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/my_utils.cpython-310.pyc and b/__pycache__/my_utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -27,8 +27,8 @@ logging.getLogger("asyncio").setLevel(logging.ERROR)
27
  logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
28
  logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
- from download import *
31
- download()
32
 
33
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
34
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
 
27
  logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
28
  logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
  logging.getLogger("multipart").setLevel(logging.WARNING)
30
+ # from download import *
31
+ # download()
32
 
33
  if "_CUDA_VISIBLE_DEVICES" in os.environ:
34
  os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
download.py CHANGED
@@ -43,5 +43,12 @@ def download():
43
 
44
  os.remove(output)
45
 
 
 
 
 
 
 
 
46
  if __name__ == '__main__':
47
  download()
 
43
 
44
  os.remove(output)
45
 
46
+
47
+ def download_nltk_data():
48
+ import nltk
49
+ nltk.download('averaged_perceptron_tagger')
50
+ nltk.download('cmudict')
51
+
52
+
53
  if __name__ == '__main__':
54
  download()
main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Body, File, Form, UploadFile
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from model import clone_voice
5
+ import os
6
+ from enum import Enum
7
+ import uvicorn
8
+ app = FastAPI()
9
+
10
+ app.mount("/static", StaticFiles(directory="static"), name="static")
11
+
12
+
13
+ class Language(str, Enum):
14
+ en = "English"
15
+ zh = "中文"
16
+
17
+
18
+ class DefaultVoice(str, Enum):
19
+ zhs1 = "static/zh/s1.mp3"
20
+ zhs2 = "static/zh/s2.mp3"
21
+ ens1 = "static/en/s1.mp3"
22
+ ens2 = "static/en/s2.mp3"
23
+
24
+
25
+ @app.post("/tts")
26
+ async def tts(
27
+ custom_voice_file: UploadFile = File(None, description="用户自定义声音"),
28
+ language: Language = Form(..., description="语言选择"),
29
+ voice: DefaultVoice = Form(None),
30
+ text: str = Form(..., description="转换文本")
31
+ ):
32
+ os.makedirs("static/tmp", exist_ok=True)
33
+ if custom_voice_file is not None:
34
+ content = await file.read()
35
+ filename = f"static/tmp/{file.filename}"
36
+ with open(filename, "wb") as f:
37
+ f.write(content)
38
+ voice = filename
39
+ wav_path = clone_voice(
40
+ user_voice=voice, user_text=text, user_lang=language)
41
+ return FileResponse(wav_path)
42
+
43
+
44
+ if __name__ == '__main__':
45
+
46
+ uvicorn.run(app="main:app", port=int(7860), host="0.0.0.0")
model.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import soundfile as sf
4
+ from datetime import datetime
5
+ from time import time as ttime
6
+ from my_utils import load_audio
7
+ from transformers import pipeline
8
+ from text.cleaner import clean_text
9
+ from polyglot.detect import Detector
10
+ from feature_extractor import cnhubert
11
+ from timeit import default_timer as timer
12
+ from text import cleaned_text_to_sequence
13
+ from module.models import SynthesizerTrn
14
+ from module.mel_processing import spectrogram_torch
15
+ from transformers.pipelines.audio_utils import ffmpeg_read
16
+ import os,re,sys,LangSegment,librosa,pdb,torch,pytz,random
17
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
18
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
19
+
20
+
21
+ import logging
22
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
23
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
24
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
25
+ logging.getLogger("httpx").setLevel(logging.ERROR)
26
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
27
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
28
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
+ logging.getLogger("multipart").setLevel(logging.WARNING)
30
+ # from download import *
31
+ # download()
32
+
33
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
35
+ tz = pytz.timezone('Asia/Singapore')
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
+ def abs_path(dir):
39
+ global_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
40
+ return(os.path.join(global_dir, dir))
41
+ gpt_path = abs_path("MODELS/22/22.ckpt")
42
+ sovits_path=abs_path("MODELS/22/22.pth")
43
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
44
+ bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
45
+
46
+ if not os.path.exists(cnhubert_base_path):
47
+ cnhubert_base_path = "TencentGameMate/chinese-hubert-base"
48
+ if not os.path.exists(bert_path):
49
+ bert_path = "hfl/chinese-roberta-wwm-ext-large"
50
+ cnhubert.cnhubert_base_path = cnhubert_base_path
51
+
52
+ whisper_path = os.environ.get("whisper_path", "pretrained_models/whisper-tiny")
53
+ if not os.path.exists(whisper_path):
54
+ whisper_path = "openai/whisper-tiny"
55
+
56
+ pipe = pipeline(
57
+ task="automatic-speech-recognition",
58
+ model=whisper_path,
59
+ chunk_length_s=30,
60
+ device=device,)
61
+
62
+
63
+ is_half = eval(
64
+ os.environ.get("is_half", "True" if torch.cuda.is_available() else "False")
65
+ )
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
68
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
69
+ if is_half == True:
70
+ bert_model = bert_model.half().to(device)
71
+ else:
72
+ bert_model = bert_model.to(device)
73
+
74
+
75
+ def get_bert_feature(text, word2ph):
76
+ with torch.no_grad():
77
+ inputs = tokenizer(text, return_tensors="pt")
78
+ for i in inputs:
79
+ inputs[i] = inputs[i].to(device)
80
+ res = bert_model(**inputs, output_hidden_states=True)
81
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
82
+ assert len(word2ph) == len(text)
83
+ phone_level_feature = []
84
+ for i in range(len(word2ph)):
85
+ repeat_feature = res[i].repeat(word2ph[i], 1)
86
+ phone_level_feature.append(repeat_feature)
87
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
88
+ return phone_level_feature.T
89
+
90
+
91
+ class DictToAttrRecursive(dict):
92
+ def __init__(self, input_dict):
93
+ super().__init__(input_dict)
94
+ for key, value in input_dict.items():
95
+ if isinstance(value, dict):
96
+ value = DictToAttrRecursive(value)
97
+ self[key] = value
98
+ setattr(self, key, value)
99
+
100
+ def __getattr__(self, item):
101
+ try:
102
+ return self[item]
103
+ except KeyError:
104
+ raise AttributeError(f"Attribute {item} not found")
105
+
106
+ def __setattr__(self, key, value):
107
+ if isinstance(value, dict):
108
+ value = DictToAttrRecursive(value)
109
+ super(DictToAttrRecursive, self).__setitem__(key, value)
110
+ super().__setattr__(key, value)
111
+
112
+ def __delattr__(self, item):
113
+ try:
114
+ del self[item]
115
+ except KeyError:
116
+ raise AttributeError(f"Attribute {item} not found")
117
+
118
+
119
+ ssl_model = cnhubert.get_model()
120
+ if is_half == True:
121
+ ssl_model = ssl_model.half().to(device)
122
+ else:
123
+ ssl_model = ssl_model.to(device)
124
+
125
+
126
+ def change_sovits_weights(sovits_path):
127
+ global vq_model, hps
128
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
129
+ hps = dict_s2["config"]
130
+ hps = DictToAttrRecursive(hps)
131
+ hps.model.semantic_frame_rate = "25hz"
132
+ vq_model = SynthesizerTrn(
133
+ hps.data.filter_length // 2 + 1,
134
+ hps.train.segment_size // hps.data.hop_length,
135
+ n_speakers=hps.data.n_speakers,
136
+ **hps.model
137
+ )
138
+ if ("pretrained" not in sovits_path):
139
+ del vq_model.enc_q
140
+ if is_half == True:
141
+ vq_model = vq_model.half().to(device)
142
+ else:
143
+ vq_model = vq_model.to(device)
144
+ vq_model.eval()
145
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
146
+ with open("./sweight.txt", "w", encoding="utf-8") as f:
147
+ f.write(sovits_path)
148
+
149
+
150
+ change_sovits_weights(sovits_path)
151
+
152
+
153
+ def change_gpt_weights(gpt_path):
154
+ global hz, max_sec, t2s_model, config
155
+ hz = 50
156
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
157
+ config = dict_s1["config"]
158
+ max_sec = config["data"]["max_sec"]
159
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
160
+ t2s_model.load_state_dict(dict_s1["weight"])
161
+ if is_half == True:
162
+ t2s_model = t2s_model.half()
163
+ t2s_model = t2s_model.to(device)
164
+ t2s_model.eval()
165
+ total = sum([param.nelement() for param in t2s_model.parameters()])
166
+ print("Number of parameter: %.2fM" % (total / 1e6))
167
+ with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
168
+
169
+
170
+ change_gpt_weights(gpt_path)
171
+
172
+
173
+ def get_spepc(hps, filename):
174
+ audio = load_audio(filename, int(hps.data.sampling_rate))
175
+ audio = torch.FloatTensor(audio)
176
+ audio_norm = audio
177
+ audio_norm = audio_norm.unsqueeze(0)
178
+ spec = spectrogram_torch(
179
+ audio_norm,
180
+ hps.data.filter_length,
181
+ hps.data.sampling_rate,
182
+ hps.data.hop_length,
183
+ hps.data.win_length,
184
+ center=False,
185
+ )
186
+ return spec
187
+
188
+
189
+ dict_language = {
190
+ ("中文1"): "all_zh",#全部按中文识别
191
+ ("English"): "en",#全部按英文识别#######不变
192
+ ("日文1"): "all_ja",#全部按日文识别
193
+ ("中文"): "zh",#按中英混合识别####不变
194
+ ("日本語"): "ja",#按日英混合识别####不变
195
+ ("混合"): "auto",#多语种启动切分识别语种
196
+ }
197
+
198
+
199
+ def splite_en_inf(sentence, language):
200
+ pattern = re.compile(r'[a-zA-Z ]+')
201
+ textlist = []
202
+ langlist = []
203
+ pos = 0
204
+ for match in pattern.finditer(sentence):
205
+ start, end = match.span()
206
+ if start > pos:
207
+ textlist.append(sentence[pos:start])
208
+ langlist.append(language)
209
+ textlist.append(sentence[start:end])
210
+ langlist.append("en")
211
+ pos = end
212
+ if pos < len(sentence):
213
+ textlist.append(sentence[pos:])
214
+ langlist.append(language)
215
+ # Merge punctuation into previous word
216
+ for i in range(len(textlist)-1, 0, -1):
217
+ if re.match(r'^[\W_]+$', textlist[i]):
218
+ textlist[i-1] += textlist[i]
219
+ del textlist[i]
220
+ del langlist[i]
221
+ # Merge consecutive words with the same language tag
222
+ i = 0
223
+ while i < len(langlist) - 1:
224
+ if langlist[i] == langlist[i+1]:
225
+ textlist[i] += textlist[i+1]
226
+ del textlist[i+1]
227
+ del langlist[i+1]
228
+ else:
229
+ i += 1
230
+
231
+ return textlist, langlist
232
+
233
+
234
+ def clean_text_inf(text, language):
235
+ formattext = ""
236
+ language = language.replace("all_","")
237
+ for tmp in LangSegment.getTexts(text):
238
+ if language == "ja":
239
+ if tmp["lang"] == language or tmp["lang"] == "zh":
240
+ formattext += tmp["text"] + " "
241
+ continue
242
+ if tmp["lang"] == language:
243
+ formattext += tmp["text"] + " "
244
+ while " " in formattext:
245
+ formattext = formattext.replace(" ", " ")
246
+ phones, word2ph, norm_text = clean_text(formattext, language)
247
+ phones = cleaned_text_to_sequence(phones)
248
+ return phones, word2ph, norm_text
249
+
250
+ dtype=torch.float16 if is_half == True else torch.float32
251
+ def get_bert_inf(phones, word2ph, norm_text, language):
252
+ language=language.replace("all_","")
253
+ if language == "zh":
254
+ bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
255
+ else:
256
+ bert = torch.zeros(
257
+ (1024, len(phones)),
258
+ dtype=torch.float16 if is_half == True else torch.float32,
259
+ ).to(device)
260
+
261
+ return bert
262
+
263
+
264
+ def nonen_clean_text_inf(text, language):
265
+ if(language!="auto"):
266
+ textlist, langlist = splite_en_inf(text, language)
267
+ else:
268
+ textlist=[]
269
+ langlist=[]
270
+ for tmp in LangSegment.getTexts(text):
271
+ langlist.append(tmp["lang"])
272
+ textlist.append(tmp["text"])
273
+ print(textlist)
274
+ print(langlist)
275
+ phones_list = []
276
+ word2ph_list = []
277
+ norm_text_list = []
278
+ for i in range(len(textlist)):
279
+ lang = langlist[i]
280
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
281
+ phones_list.append(phones)
282
+ if lang == "zh":
283
+ word2ph_list.append(word2ph)
284
+ norm_text_list.append(norm_text)
285
+ print(word2ph_list)
286
+ phones = sum(phones_list, [])
287
+ word2ph = sum(word2ph_list, [])
288
+ norm_text = ' '.join(norm_text_list)
289
+
290
+ return phones, word2ph, norm_text
291
+
292
+
293
+ def nonen_get_bert_inf(text, language):
294
+ if(language!="auto"):
295
+ textlist, langlist = splite_en_inf(text, language)
296
+ else:
297
+ textlist=[]
298
+ langlist=[]
299
+ for tmp in LangSegment.getTexts(text):
300
+ langlist.append(tmp["lang"])
301
+ textlist.append(tmp["text"])
302
+ print(textlist)
303
+ print(langlist)
304
+ bert_list = []
305
+ for i in range(len(textlist)):
306
+ lang = langlist[i]
307
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
308
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
309
+ bert_list.append(bert)
310
+ bert = torch.cat(bert_list, dim=1)
311
+
312
+ return bert
313
+
314
+
315
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
316
+
317
+
318
+ def get_first(text):
319
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
320
+ text = re.split(pattern, text)[0].strip()
321
+ return text
322
+
323
+
324
+ def get_cleaned_text_final(text,language):
325
+ if language in {"en","all_zh","all_ja"}:
326
+ phones, word2ph, norm_text = clean_text_inf(text, language)
327
+ elif language in {"zh", "ja","auto"}:
328
+ phones, word2ph, norm_text = nonen_clean_text_inf(text, language)
329
+ return phones, word2ph, norm_text
330
+
331
+ def get_bert_final(phones, word2ph, text,language,device):
332
+ if language == "en":
333
+ bert = get_bert_inf(phones, word2ph, text, language)
334
+ elif language in {"zh", "ja","auto"}:
335
+ bert = nonen_get_bert_inf(text, language)
336
+ elif language == "all_zh":
337
+ bert = get_bert_feature(text, word2ph).to(device)
338
+ else:
339
+ bert = torch.zeros((1024, len(phones))).to(device)
340
+ return bert
341
+
342
+ def merge_short_text_in_array(texts, threshold):
343
+ if (len(texts)) < 2:
344
+ return texts
345
+ result = []
346
+ text = ""
347
+ for ele in texts:
348
+ text += ele
349
+ if len(text) >= threshold:
350
+ result.append(text)
351
+ text = ""
352
+ if (len(text) > 0):
353
+ if len(result) == 0:
354
+ result.append(text)
355
+ else:
356
+ result[len(result) - 1] += text
357
+ return result
358
+
359
+
360
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=("Do not split"), volume_scale=1.0):
361
+ if not duration(ref_wav_path):
362
+ return None
363
+ if text == '':
364
+ wprint("Please enter text to generate/请输入生成文字")
365
+ return None
366
+ t0 = ttime()
367
+ startTime=timer()
368
+ text=trim_text(text,text_language)
369
+ change_sovits_weights(sovits_path)
370
+ tprint(f'🏕️LOADED SoVITS Model: {sovits_path}')
371
+ change_gpt_weights(gpt_path)
372
+ tprint(f'🏕️LOADED GPT Model: {gpt_path}')
373
+
374
+ prompt_language = dict_language[prompt_language]
375
+ try:
376
+ text_language = dict_language[text_language]
377
+ except KeyError as e:
378
+ wprint(f"Unsupported language type: {e}")
379
+ return None
380
+
381
+ prompt_text = prompt_text.strip("\n")
382
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
383
+ text = text.strip("\n")
384
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
385
+ #print(("实际输入的参考文本:"), prompt_text)
386
+ #print(("📝实际输入的目标文本:"), text)
387
+ zero_wav = np.zeros(
388
+ int(hps.data.sampling_rate * 0.3),
389
+ dtype=np.float16 if is_half == True else np.float32,
390
+ )
391
+ with torch.no_grad():
392
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
393
+ if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
394
+ errinfo='参考音频在3~10秒范围外,请更换!'
395
+ raise OSError((errinfo))
396
+ wav16k = torch.from_numpy(wav16k)
397
+ zero_wav_torch = torch.from_numpy(zero_wav)
398
+ if is_half == True:
399
+ wav16k = wav16k.half().to(device)
400
+ zero_wav_torch = zero_wav_torch.half().to(device)
401
+ else:
402
+ wav16k = wav16k.to(device)
403
+ zero_wav_torch = zero_wav_torch.to(device)
404
+ wav16k = torch.cat([wav16k, zero_wav_torch])
405
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
406
+ "last_hidden_state"
407
+ ].transpose(
408
+ 1, 2
409
+ ) # .float()
410
+ codes = vq_model.extract_latent(ssl_content)
411
+ prompt_semantic = codes[0, 0]
412
+ t1 = ttime()
413
+
414
+ phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
415
+
416
+ if (how_to_cut == ("Split into groups of 4 sentences")):
417
+ text = cut1(text)
418
+ elif (how_to_cut == ("Split every 50 characters")):
419
+ text = cut2(text)
420
+ elif (how_to_cut == ("Split at CN/JP periods (。)")):
421
+ text = cut3(text)
422
+ elif (how_to_cut == ("Split at English periods (.)")):
423
+ text = cut4(text)
424
+ elif (how_to_cut == ("Split at punctuation marks")):
425
+ text = cut5(text)
426
+ while "\n\n" in text:
427
+ text = text.replace("\n\n", "\n")
428
+ print(f"🧨实际输入的目标文本(切句后):{text}\n")
429
+ texts = text.split("\n")
430
+ texts = merge_short_text_in_array(texts, 5)
431
+ audio_opt = []
432
+ bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
433
+
434
+ for text in texts:
435
+ if (len(text.strip()) == 0):
436
+ continue
437
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
438
+ print(("\n🎈实际输入的目标文本(每句):"), text)
439
+ phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
440
+ try:
441
+ bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
442
+ except RuntimeError as e:
443
+ wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
444
+ return None
445
+ bert = torch.cat([bert1, bert2], 1)
446
+
447
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
448
+ bert = bert.to(device).unsqueeze(0)
449
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
450
+ prompt = prompt_semantic.unsqueeze(0).to(device)
451
+ t2 = ttime()
452
+ with torch.no_grad():
453
+ # pred_semantic = t2s_model.model.infer(
454
+ pred_semantic, idx = t2s_model.model.infer_panel(
455
+ all_phoneme_ids,
456
+ all_phoneme_len,
457
+ prompt,
458
+ bert,
459
+ # prompt_phone_len=ph_offset,
460
+ top_k=config["inference"]["top_k"],
461
+ early_stop_num=hz * max_sec,
462
+ )
463
+ t3 = ttime()
464
+ # print(pred_semantic.shape,idx)
465
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
466
+ 0
467
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
468
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
469
+ if is_half == True:
470
+ refer = refer.half().to(device)
471
+ else:
472
+ refer = refer.to(device)
473
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
474
+ try:
475
+ audio = (
476
+ vq_model.decode(
477
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
478
+ )
479
+ .detach()
480
+ .cpu()
481
+ .numpy()[0, 0]
482
+ )
483
+ except RuntimeError as e:
484
+ wprint(f"The input text does not match the language/输入文本与语言不匹配: {e}")
485
+ return None
486
+
487
+ max_audio=np.abs(audio).max()
488
+ if max_audio>1:audio/=max_audio
489
+ audio_opt.append(audio)
490
+ audio_opt.append(zero_wav)
491
+ t4 = ttime()
492
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
493
+ #yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
494
+ audio_data = (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
495
+
496
+ audio_data = (audio_data.astype(np.float32) * volume_scale).astype(np.int16)
497
+ output_wav = "output_audio.wav"
498
+ sf.write(output_wav, audio_data, hps.data.sampling_rate)
499
+ endTime=timer()
500
+ tprint(f'🆗TTS COMPLETE,{round(endTime-startTime,4)}s')
501
+ return output_wav
502
+
503
+ def split(todo_text):
504
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
505
+ if todo_text[-1] not in splits:
506
+ todo_text += "。"
507
+ i_split_head = i_split_tail = 0
508
+ len_text = len(todo_text)
509
+ todo_texts = []
510
+ while 1:
511
+ if i_split_head >= len_text:
512
+ break
513
+ if todo_text[i_split_head] in splits:
514
+ i_split_head += 1
515
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
516
+ i_split_tail = i_split_head
517
+ else:
518
+ i_split_head += 1
519
+ return todo_texts
520
+
521
+
522
+ def cut1(inp):
523
+ inp = inp.strip("\n")
524
+ inps = split(inp)
525
+ split_idx = list(range(0, len(inps), 4))
526
+ split_idx[-1] = None
527
+ if len(split_idx) > 1:
528
+ opts = []
529
+ for idx in range(len(split_idx) - 1):
530
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
531
+ else:
532
+ opts = [inp]
533
+ return "\n".join(opts)
534
+
535
+
536
+ def cut2(inp):
537
+ inp = inp.strip("\n")
538
+ inps = split(inp)
539
+ if len(inps) < 2:
540
+ return inp
541
+ opts = []
542
+ summ = 0
543
+ tmp_str = ""
544
+ for i in range(len(inps)):
545
+ summ += len(inps[i])
546
+ tmp_str += inps[i]
547
+ if summ > 50:
548
+ summ = 0
549
+ opts.append(tmp_str)
550
+ tmp_str = ""
551
+ if tmp_str != "":
552
+ opts.append(tmp_str)
553
+ # print(opts)
554
+ if len(opts) > 1 and len(opts[-1]) < 50:
555
+ opts[-2] = opts[-2] + opts[-1]
556
+ opts = opts[:-1]
557
+ return "\n".join(opts)
558
+
559
+
560
+ def cut3(inp):
561
+ inp = inp.strip("\n")
562
+ return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
563
+
564
+
565
+ def cut4(inp):
566
+ inp = inp.strip("\n")
567
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
568
+
569
+
570
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
571
+ def cut5(inp):
572
+ # if not re.search(r'[^\w\s]', inp[-1]):
573
+ # inp += '。'
574
+ inp = inp.strip("\n")
575
+ punds = r'[,.;?!、,。?!;:…]'
576
+ items = re.split(f'({punds})', inp)
577
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
578
+ if len(items)%2 == 1:
579
+ mergeitems.append(items[-1])
580
+ opt = "\n".join(mergeitems)
581
+ return opt
582
+
583
+
584
+
585
+ def custom_sort_key(s):
586
+ # 使用正则表达式提取字符串中的数字部分和非��字部分
587
+ parts = re.split('(\d+)', s)
588
+ # 将数字部分转换为整数,非数字部分保持不变
589
+ parts = [int(part) if part.isdigit() else part for part in parts]
590
+ return parts
591
+
592
+ #==========custom functions============
593
+
594
+ def tprint(text):
595
+ now=datetime.now(tz).strftime('%H:%M:%S')
596
+ print(f'UTC+8 - {now} - {text}')
597
+
598
+ def wprint(text):
599
+ tprint(text)
600
+ gr.Warning(text)
601
+
602
+ def lang_detector(text):
603
+ min_chars = 5
604
+ if len(text) < min_chars:
605
+ return "Input text too short/输入文本太短"
606
+ try:
607
+ detector = Detector(text).language
608
+ lang_info = str(detector)
609
+ code = re.search(r"name: (\w+)", lang_info).group(1)
610
+ if code == 'Japanese':
611
+ return "日本語"
612
+ elif code == 'Chinese':
613
+ return "中文"
614
+ elif code == 'English':
615
+ return 'English'
616
+ else:
617
+ return code
618
+ except Exception as e:
619
+ return f"ERROR:{str(e)}"
620
+
621
+ def trim_text(text,language):
622
+ limit_cj = 120 #character
623
+ limit_en = 60 #words
624
+ search_limit_cj = limit_cj+30
625
+ search_limit_en = limit_en +30
626
+ text = text.replace('\n', '').strip()
627
+
628
+ if language =='English':
629
+ words = text.split()
630
+ if len(words) <= limit_en:
631
+ return text
632
+ # English
633
+ for i in range(limit_en, -1, -1):
634
+ if any(punct in words[i] for punct in splits):
635
+ return ' '.join(words[:i+1])
636
+ for i in range(limit_en, min(len(words), search_limit_en)):
637
+ if any(punct in words[i] for punct in splits):
638
+ return ' '.join(words[:i+1])
639
+ return ' '.join(words[:limit_en])
640
+
641
+ else:#中文日文
642
+ if len(text) <= limit_cj:
643
+ return text
644
+ for i in range(limit_cj, -1, -1):
645
+ if text[i] in splits:
646
+ return text[:i+1]
647
+ for i in range(limit_cj, min(len(text), search_limit_cj)):
648
+ if text[i] in splits:
649
+ return text[:i+1]
650
+ return text[:limit_cj]
651
+
652
+ def duration(audio_file_path):
653
+ if not audio_file_path:
654
+ wprint("Failed to obtain uploaded audio/未找到音频文件")
655
+ return False
656
+ try:
657
+ audio_duration = librosa.get_duration(filename=audio_file_path)
658
+ if not 3 < audio_duration < 10:
659
+ wprint("The audio length must be between 3~10 seconds/音频时长须在3~10秒之间")
660
+ return False
661
+ return True
662
+ except FileNotFoundError:
663
+ return False
664
+
665
+ def update_model(choice):
666
+ global gpt_path, sovits_path
667
+ model_info = models[choice]
668
+ gpt_path = abs_path(model_info["gpt_weight"])
669
+ sovits_path = abs_path(model_info["sovits_weight"])
670
+ model_name = choice
671
+ tone_info = model_info["tones"]["tone1"]
672
+ tone_sample_path = abs_path(tone_info["sample"])
673
+ tprint(f'✅SELECT MODEL:{choice}')
674
+ # 返回默认tone“tone1”
675
+ return (
676
+ tone_info["example_voice_wav"],
677
+ tone_info["example_voice_wav_words"],
678
+ model_info["default_language"],
679
+ model_info["default_language"],
680
+ model_name,
681
+ "tone1" ,
682
+ tone_sample_path
683
+ )
684
+
685
+ def update_tone(model_choice, tone_choice):
686
+ model_info = models[model_choice]
687
+ tone_info = model_info["tones"][tone_choice]
688
+ example_voice_wav = abs_path(tone_info["example_voice_wav"])
689
+ example_voice_wav_words = tone_info["example_voice_wav_words"]
690
+ tone_sample_path = abs_path(tone_info["sample"])
691
+ return example_voice_wav, example_voice_wav_words,tone_sample_path
692
+
693
+ def transcribe(voice):
694
+ time1=timer()
695
+ tprint('⚡Start Clone - transcribe')
696
+ task="transcribe"
697
+ if voice is None:
698
+ wprint("No audio file submitted! Please upload or record an audio file before submitting your request.")
699
+ R = pipe(voice, batch_size=8, generate_kwargs={"task": task}, return_timestamps=True,return_language=True)
700
+ text=R['text']
701
+ lang=R['chunks'][0]['language']
702
+ if lang=='english':
703
+ language='English'
704
+ elif lang =='chinese':
705
+ language='中文'
706
+ elif lang=='japanese':
707
+ language = '日本語'
708
+
709
+ time2=timer()
710
+ tprint(f'transcribe COMPLETE,{round(time2-time1,4)}s')
711
+ tprint(f'\nTRANSCRIBE RESULT:\n 🔣Language:{language} \n 🔣Text:{text}' )
712
+ return text,language
713
+
714
+ def clone_voice(user_voice,user_text,user_lang):
715
+ if not duration(user_voice):
716
+ return None
717
+ if user_text == '':
718
+ wprint("Please enter text to generate/请输入生成文字")
719
+ return None
720
+ user_text=trim_text(user_text,user_lang)
721
+ time1=timer()
722
+ global gpt_path, sovits_path
723
+ gpt_path = abs_path("pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
724
+ #tprint(f'Model loaded:{gpt_path}')
725
+ sovits_path = abs_path("pretrained_models/s2G488k.pth")
726
+ #tprint(f'Model loaded:{sovits_path}')
727
+ try:
728
+ prompt_text, prompt_language = transcribe(user_voice)
729
+ except UnboundLocalError as e:
730
+ wprint(f"The language in the audio cannot be recognized :{str(e)}")
731
+ return None
732
+
733
+ output_wav = get_tts_wav(
734
+ user_voice,
735
+ prompt_text,
736
+ prompt_language,
737
+ user_text,
738
+ user_lang,
739
+ how_to_cut="Do not split",
740
+ volume_scale=1.0)
741
+ time2=timer()
742
+ tprint(f'🆗CLONE COMPLETE,{round(time2-time1,4)}s')
743
+ return output_wav
744
+
745
+ with open('dummy') as f:
746
+ dummy_txt = f.read().strip().splitlines()
747
+
748
+ def dice():
749
+ return random.choice(dummy_txt), '🎲'
750
+
751
+ from info import models
752
+ models_by_language = {
753
+ "English": [],
754
+ "中文": [],
755
+ "日本語": []
756
+ }
757
+ for model_name, model_info in models.items():
758
+ language = model_info["default_language"]
759
+ models_by_language[language].append((model_name, model_info))
760
+
requirements.txt CHANGED
@@ -28,4 +28,6 @@ pyicu
28
  morfessor
29
  pycld2
30
  polyglot
31
- wordsegment
 
 
 
28
  morfessor
29
  pycld2
30
  polyglot
31
+ wordsegment
32
+ fastapi
33
+ uvicorn
static/en/s1.mp3 ADDED
Binary file (44.1 kB). View file
 
static/en/s2.mp3 ADDED
Binary file (53.8 kB). View file
 
static/zh/s1.mp3 ADDED
Binary file (32.4 kB). View file
 
static/zh/s2.mp3 ADDED
Binary file (32.4 kB). View file
 
text/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/text/__pycache__/__init__.cpython-310.pyc and b/text/__pycache__/__init__.cpython-310.pyc differ
 
text/__pycache__/chinese.cpython-310.pyc CHANGED
Binary files a/text/__pycache__/chinese.cpython-310.pyc and b/text/__pycache__/chinese.cpython-310.pyc differ
 
text/__pycache__/cleaner.cpython-310.pyc CHANGED
Binary files a/text/__pycache__/cleaner.cpython-310.pyc and b/text/__pycache__/cleaner.cpython-310.pyc differ
 
text/__pycache__/symbols.cpython-310.pyc CHANGED
Binary files a/text/__pycache__/symbols.cpython-310.pyc and b/text/__pycache__/symbols.cpython-310.pyc differ