Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from argparse import ArgumentParser | |
| from typing import Literal, List, Tuple | |
| import sys | |
| import importlib.util | |
| from datetime import datetime | |
| import spaces | |
| import torch | |
| import numpy as np | |
| import random | |
| import s3tokenizer | |
| from huggingface_hub import snapshot_download | |
| from soulxpodcast.models.soulxpodcast import SoulXPodcast | |
| from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams | |
| from soulxpodcast.utils.dataloader import ( | |
| PodcastInferHandler, | |
| SPK_DICT, TEXT_START, TEXT_END, AUDIO_START, TASK_PODCAST | |
| ) | |
| S1_PROMPT_WAV = "assets/audios/female_mandarin.wav" | |
| S2_PROMPT_WAV = "assets/audios/male_mandarin.wav" | |
| def load_dialect_prompt_data(): | |
| """ | |
| 加载方言提示文本文件并格式化为嵌套字典。 | |
| 返回结构: {dialect_key: {display_name: full_text, ...}, ...} | |
| """ | |
| dialect_data = {} | |
| dialect_files = [ | |
| ("sichuan", "assets/dialect_prompt/sichuan.txt", "<|Sichuan|>"), | |
| ("yueyu", "assets/dialect_prompt/yueyu.txt", "<|Yue|>"), | |
| ("henan", "assets/dialect_prompt/henan.txt", "<|Henan|>"), | |
| ] | |
| for key, file_path, prefix in dialect_files: | |
| dialect_data[key] = {"(无)": ""} | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| for i, line in enumerate(lines): | |
| line = line.strip() | |
| if line: | |
| full_text = f"{prefix}{line}" | |
| display_name = f"例{i+1}: {line[:20]}..." | |
| dialect_data[key][display_name] = full_text | |
| except FileNotFoundError: | |
| print(f"[WARNING] 方言文件未找到: {file_path}") | |
| except Exception as e: | |
| print(f"[WARNING] 读取方言文件失败 {file_path}: {e}") | |
| return dialect_data | |
| DIALECT_PROMPT_DATA = load_dialect_prompt_data() | |
| DIALECT_CHOICES = ["(无)", "sichuan", "yueyu", "henan"] | |
| EXAMPLES_LIST = [ | |
| [ | |
| None, "", "", None, "", "", "" | |
| ], | |
| [ | |
| S1_PROMPT_WAV, | |
| "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", | |
| "", | |
| S2_PROMPT_WAV, | |
| "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", | |
| "", | |
| "[S1] 哈喽,AI时代的冲浪先锋们!欢迎收听《AI生活进行时》。啊,一个充满了未来感,然后,还有一点点,<|laughter|>神经质的播客节目,我是主持人小希。\n[S2] 哎,大家好呀!我是能唠,爱唠,天天都想唠的唠嗑!\n[S1] 最近活得特别赛博朋克哈!以前老是觉得AI是科幻片儿里的,<|sigh|> 现在,现在连我妈都用AI写广场舞文案了。\n[S2] 这个例子很生动啊。是的,特别是生成式AI哈,感觉都要炸了! 诶,那我们今天就聊聊AI是怎么走进我们的生活的哈!", | |
| ], | |
| [ | |
| S1_PROMPT_WAV, | |
| "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", | |
| "<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷!", | |
| S2_PROMPT_WAV, | |
| "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", | |
| "<|Sichuan|>哎哟喂,这个搞反了噻!黑神话里头唱曲子的王二浪早八百年就在黄土高坡吼秦腔喽,游戏组专门跑切录的原汤原水,听得人汗毛儿都立起来!", | |
| "[S1] <|Sichuan|>各位《巴适得板》的听众些,大家好噻!我是你们主持人晶晶。今儿天气硬是巴适,不晓得大家是在赶路嘛,还是茶都泡起咯,准备跟我们好生摆一哈龙门阵喃?\n[S2] <|Sichuan|>晶晶好哦,大家安逸噻!我是李老倌。你刚开口就川味十足,摆龙门阵几个字一甩出来,我鼻子头都闻到茶香跟火锅香咯!\n[S1] <|Sichuan|>就是得嘛!李老倌,我前些天带个外地朋友切人民公园鹤鸣茶社坐了一哈。他硬是搞不醒豁,为啥子我们一堆人围到杯茶就可以吹一下午壳子,从隔壁子王嬢嬢娃儿耍朋友,扯到美国大选,中间还掺几盘斗地主。他说我们四川人简直是把摸鱼刻进骨子里头咯!\n[S2] <|Sichuan|>你那个朋友说得倒是有点儿趣,但他莫看到精髓噻。摆龙门阵哪是摸鱼嘛,这是我们川渝人特有的交际方式,更是一种活法。外省人天天说的松弛感,根根儿就在这龙门阵里头。今天我们就要好生摆一哈,为啥子四川人活得这么舒坦。就先从茶馆这个老窝子说起,看它咋个成了我们四川人的魂儿!", | |
| ], | |
| [ | |
| S1_PROMPT_WAV, | |
| "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", | |
| "<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯!", | |
| S2_PROMPT_WAV, | |
| "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", | |
| "<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,黑神话边有咁大面啊?你估佢哋抄游戏咩!", | |
| "[S1] <|Yue|>哈囉大家好啊,歡迎收聽我哋嘅節目。喂,我今日想問你樣嘢啊,你覺唔覺得,嗯,而家揸電動車,最煩,最煩嘅一樣嘢係咩啊?\n[S2] <|Yue|>梗係充電啦。大佬啊,搵個位都已經好煩,搵到個位仲要喺度等,你話快極都要半個鐘一個鐘,真係,有時諗起都覺得好冇癮。\n[S1] <|Yue|>係咪先。如果我而家同你講,充電可以快到同入油差唔多時間,你信唔信先?喂你平時喺油站入滿一缸油,要幾耐啊?五六分鐘?\n[S2] <|Yue|>差唔多啦,七八分鐘,點都走得啦。電車喎,可以做到咁快?你咪玩啦。", | |
| ], | |
| [ | |
| S1_PROMPT_WAV, | |
| "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。", | |
| "<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。", | |
| S2_PROMPT_WAV, | |
| "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?", | |
| "<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!", | |
| "[S1] <|Henan|>哎,大家好啊,欢迎收听咱这一期嘞《瞎聊呗,就这么说》,我是恁嘞老朋友,燕子。\n[S2] <|Henan|>大家好,我是老张。燕子啊,今儿瞅瞅你这个劲儿,咋着,是有啥可得劲嘞事儿想跟咱唠唠?\n[S1] <|Henan|>哎哟,老张,你咋恁懂我嘞!我跟你说啊,最近我刷手机,老是刷住些可逗嘞方言视频,特别是咱河南话,咦~我哩个乖乖,一听我都憋不住笑,咋说嘞,得劲儿哩很,跟回到家一样。\n[S2] <|Henan|>你这回可算说到根儿上了!河南话,咱往大处说说,中原官话,它真嘞是有一股劲儿搁里头。它可不光是说话,它脊梁骨后头藏嘞,是咱一整套、鲜鲜活活嘞过法儿,一种活人嘞道理。\n[S1] <|Henan|>活人嘞道理?哎,这你这一说,我嘞兴致“腾”一下就上来啦!觉住咱这嗑儿,一下儿从搞笑视频蹿到文化顶上了。那你赶紧给我白话白话,这里头到底有啥道道儿?我特别想知道——为啥一提起咱河南人,好些人脑子里“蹦”出来嘞头一个词儿,就是实在?这个实在,骨子里到底是啥嘞?", | |
| ], | |
| ] | |
| model: SoulXPodcast = None | |
| dataset: PodcastInferHandler = None | |
| def initiate_model(config: Config, enable_tn: bool=False): | |
| global model | |
| if model is None: | |
| model = SoulXPodcast(config) | |
| global dataset | |
| if dataset is None: | |
| dataset = PodcastInferHandler(model.llm.tokenizer, None, config) | |
| _i18n_key2lang_dict = dict( | |
| # Speaker1 Prompt | |
| spk1_prompt_audio_label=dict( | |
| en="Speaker 1 Prompt Audio", | |
| zh="说话人 1 参考语音", | |
| ), | |
| spk1_prompt_text_label=dict( | |
| en="Speaker 1 Prompt Text", | |
| zh="说话人 1 参考文本", | |
| ), | |
| spk1_prompt_text_placeholder=dict( | |
| en="text of speaker 1 Prompt audio.", | |
| zh="说话人 1 参考文本", | |
| ), | |
| spk1_dialect_prompt_text_label=dict( | |
| en="Speaker 1 Dialect Prompt Text", | |
| zh="说话人 1 方言提示文本", | |
| ), | |
| spk1_dialect_prompt_text_placeholder=dict( | |
| en="Dialect prompt text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ", | |
| zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!", | |
| ), | |
| # Speaker2 Prompt | |
| spk2_prompt_audio_label=dict( | |
| en="Speaker 2 Prompt Audio", | |
| zh="说话人 2 参考语音", | |
| ), | |
| spk2_prompt_text_label=dict( | |
| en="Speaker 2 Prompt Text", | |
| zh="说话人 2 参考文本", | |
| ), | |
| spk2_prompt_text_placeholder=dict( | |
| en="text of speaker 2 prompt audio.", | |
| zh="说话人 2 参考文本", | |
| ), | |
| spk2_dialect_prompt_text_label=dict( | |
| en="Speaker 2 Dialect Prompt Text", | |
| zh="说话人 2 方言提示文本", | |
| ), | |
| spk2_dialect_prompt_text_placeholder=dict( | |
| en="Dialect prompt text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ", | |
| zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!", | |
| ), | |
| # Dialogue input textbox | |
| dialogue_text_input_label=dict( | |
| en="Dialogue Text Input", | |
| zh="合成文本输入", | |
| ), | |
| dialogue_text_input_placeholder=dict( | |
| en="[S1]text[S2]text[S1]text...", | |
| zh="[S1]文本[S2]文本[S1]文本...", | |
| ), | |
| # Generate button | |
| generate_btn_label=dict( | |
| en="Generate Audio", | |
| zh="合成", | |
| ), | |
| # Generated audio | |
| generated_audio_label=dict( | |
| en="Generated Dialogue Audio", | |
| zh="合成的对话音频", | |
| ), | |
| # Warining1: invalid text for prompt | |
| warn_invalid_spk1_prompt_text=dict( | |
| en='Invalid speaker 1 prompt text, should not be empty and strictly follow: "xxx"', | |
| zh='说话人 1 参考文本不合规,不能为空,格式:"xxx"', | |
| ), | |
| warn_invalid_spk2_prompt_text=dict( | |
| en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"', | |
| zh='说话人 2 参考文本不合规,格式:"[S2]xxx"', | |
| ), | |
| warn_invalid_dialogue_text=dict( | |
| en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."', | |
| zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."', | |
| ), | |
| # Warining3: incomplete prompt info | |
| warn_incomplete_prompt=dict( | |
| en="Please provide prompt audio and text for both speaker 1 and speaker 2", | |
| zh="请提供说话人 1 与说话人 2 的参考语音与参考文本", | |
| ), | |
| ) | |
| global_lang: Literal["zh", "en"] = "zh" | |
| def i18n(key): | |
| global global_lang | |
| return _i18n_key2lang_dict[key][global_lang] | |
| def check_monologue_text(text: str, prefix: str = None) -> bool: | |
| text = text.strip() | |
| # Check speaker tags | |
| if prefix is not None and (not text.startswith(prefix)): | |
| return False | |
| # Remove prefix | |
| if prefix is not None: | |
| text = text.removeprefix(prefix) | |
| text = text.strip() | |
| # If empty? | |
| if len(text) == 0: | |
| return False | |
| return True | |
| def check_dialect_prompt_text(text: str, prefix: str = None) -> bool: | |
| text = text.strip() | |
| # Check Dialect Prompt prefix tags | |
| if prefix is not None and (not text.startswith(prefix)): | |
| return False | |
| text = text.strip() | |
| # If empty? | |
| if len(text) == 0: | |
| return False | |
| return True | |
| def check_dialogue_text(text_list: List[str]) -> bool: | |
| if len(text_list) == 0: | |
| return False | |
| for text in text_list: | |
| if not ( | |
| check_monologue_text(text, "[S1]") | |
| or check_monologue_text(text, "[S2]") | |
| or check_monologue_text(text, "[S3]") | |
| or check_monologue_text(text, "[S4]") | |
| ): | |
| return False | |
| return True | |
| def process_single(target_text_list, prompt_wav_list, prompt_text_list, use_dialect_prompt, dialect_prompt_text): | |
| spks, texts = [], [] | |
| for target_text in target_text_list: | |
| pattern = r'(\[S[1-9]\])(.+)' | |
| match = re.match(pattern, target_text) | |
| text, spk = match.group(2), int(match.group(1)[2])-1 | |
| spks.append(spk) | |
| texts.append(text) | |
| global dataset | |
| dataitem = {"key": "001", "prompt_text": prompt_text_list, "prompt_wav": prompt_wav_list, | |
| "text": texts, "spk": spks, } | |
| if use_dialect_prompt: | |
| dataitem.update({ | |
| "dialect_prompt_text": dialect_prompt_text | |
| }) | |
| dataset.update_datasource( | |
| [ | |
| dataitem | |
| ] | |
| ) | |
| # assert one data only; | |
| data = dataset[0] | |
| prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T] | |
| spk_emb_for_flow = torch.tensor(data["spk_emb"]) | |
| prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(data["mel"], batch_first=True, padding_value=0) # [B, T', num_mels=80] | |
| prompt_mels_lens_for_flow = torch.tensor(data['mel_len']) | |
| text_tokens_for_llm = data["text_tokens"] | |
| prompt_text_tokens_for_llm = data["prompt_text_tokens"] | |
| spk_ids = data["spks_list"] | |
| sampling_params = SamplingParams(use_ras=True,win_size=25,tau_r=0.2) | |
| infos = [data["info"]] | |
| processed_data = { | |
| "prompt_mels_for_llm": prompt_mels_for_llm, | |
| "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm, | |
| "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm, | |
| "text_tokens_for_llm": text_tokens_for_llm, | |
| "prompt_mels_for_flow_ori": prompt_mels_for_flow, | |
| "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow, | |
| "spk_emb_for_flow": spk_emb_for_flow, | |
| "sampling_params": sampling_params, | |
| "spk_ids": spk_ids, | |
| "infos": infos, | |
| "use_dialect_prompt": use_dialect_prompt, | |
| } | |
| if use_dialect_prompt: | |
| processed_data.update({ | |
| "dialect_prompt_text_tokens_for_llm": data["dialect_prompt_text_tokens"], | |
| "dialect_prefix": data["dialect_prefix"], | |
| }) | |
| return processed_data | |
| def dialogue_synthesis_function( | |
| target_text: str, | |
| spk1_prompt_text: str | None = "", | |
| spk1_prompt_audio: str | None = None, | |
| spk1_dialect_prompt_text: str | None = "", | |
| spk2_prompt_text: str | None = "", | |
| spk2_prompt_audio: str | None = None, | |
| spk2_dialect_prompt_text: str | None = "", | |
| seed: int = 1988, | |
| ): | |
| global config | |
| initiate_model(config) | |
| seed = int(seed) | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| # Check prompt info | |
| target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text) | |
| target_text_list = [text.strip() for text in target_text_list] | |
| if not check_dialogue_text(target_text_list): | |
| gr.Warning(message=i18n("warn_invalid_dialogue_text")) | |
| return None | |
| # Go synthesis | |
| progress_bar = gr.Progress(track_tqdm=True) | |
| prompt_wav_list = [spk1_prompt_audio, spk2_prompt_audio] | |
| prompt_text_list = [spk1_prompt_text, spk2_prompt_text] | |
| use_dialect_prompt = spk1_dialect_prompt_text.strip()!="" or spk2_dialect_prompt_text.strip()!="" | |
| dialect_prompt_text_list = [spk1_dialect_prompt_text, spk2_dialect_prompt_text] | |
| data = process_single( | |
| target_text_list, | |
| prompt_wav_list, | |
| prompt_text_list, | |
| use_dialect_prompt, | |
| dialect_prompt_text_list, | |
| ) | |
| results_dict = model.forward_longform( | |
| **data | |
| ) | |
| target_audio = None | |
| for i in range(len(results_dict['generated_wavs'])): | |
| if target_audio is None: | |
| target_audio = results_dict['generated_wavs'][i] | |
| else: | |
| target_audio = torch.concat([target_audio, results_dict['generated_wavs'][i]], axis=1) | |
| return (24000, target_audio.cpu().squeeze(0).numpy()) | |
| def update_example_choices(dialect_key: str): | |
| if dialect_key == "(无)": | |
| choices = ["(请先选择方言)"] | |
| return gr.update(choices=choices, value="(无)"), gr.update(choices=choices, value="(无)") | |
| choices = list(DIALECT_PROMPT_DATA.get(dialect_key, {}).keys()) | |
| return gr.update(choices=choices, value="(无)"), gr.update(choices=choices, value="(无)") | |
| def update_prompt_text(dialect_key: str, example_key: str): | |
| if dialect_key == "(无)" or example_key in ["(无)", "(请先选择方言)"]: | |
| return gr.update(value="") | |
| full_text = DIALECT_PROMPT_DATA.get(dialect_key, {}).get(example_key, "") | |
| return gr.update(value=full_text) | |
| def render_interface() -> gr.Blocks: | |
| with gr.Blocks(title="SoulX-Podcast", theme=gr.themes.Default()) as page: | |
| with gr.Row(): | |
| lang_choice = gr.Radio( | |
| choices=["中文", "English"], | |
| value="中文", | |
| label="Display Language/显示语言", | |
| type="index", | |
| interactive=True, | |
| scale=3, | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed (种子)", | |
| value=1988, | |
| step=1, | |
| interactive=True, | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(visible=True) as spk1_prompt_group: | |
| spk1_prompt_audio = gr.Audio( | |
| label=i18n("spk1_prompt_audio_label"), | |
| type="filepath", | |
| editable=False, | |
| interactive=True, | |
| ) | |
| spk1_prompt_text = gr.Textbox( | |
| label=i18n("spk1_prompt_text_label"), | |
| placeholder=i18n("spk1_prompt_text_placeholder"), | |
| lines=3, | |
| ) | |
| spk1_dialect_prompt_text = gr.Textbox( | |
| label=i18n("spk1_dialect_prompt_text_label"), | |
| placeholder=i18n("spk1_dialect_prompt_text_placeholder"), | |
| value="", | |
| lines=3, | |
| ) | |
| with gr.Column(scale=1, visible=True): | |
| with gr.Group(visible=True) as spk2_prompt_group: | |
| spk2_prompt_audio = gr.Audio( | |
| label=i18n("spk2_prompt_audio_label"), | |
| type="filepath", | |
| editable=False, | |
| interactive=True, | |
| ) | |
| spk2_prompt_text = gr.Textbox( | |
| label=i18n("spk2_prompt_text_label"), | |
| placeholder=i18n("spk2_prompt_text_placeholder"), | |
| lines=3, | |
| ) | |
| spk2_dialect_prompt_text = gr.Textbox( | |
| label=i18n("spk2_dialect_prompt_text_label"), | |
| placeholder=i18n("spk2_dialect_prompt_text_placeholder"), | |
| value="", | |
| lines=3, | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| dialogue_text_input = gr.Textbox( | |
| label=i18n("dialogue_text_input_label"), | |
| placeholder=i18n("dialogue_text_input_placeholder"), | |
| lines=18, | |
| ) | |
| # Generate button | |
| with gr.Row(): | |
| generate_btn = gr.Button( | |
| value=i18n("generate_btn_label"), | |
| variant="primary", | |
| scale=3, | |
| size="lg", | |
| ) | |
| # Long output audio | |
| generate_audio = gr.Audio( | |
| label=i18n("generated_audio_label"), | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| inputs_for_examples = [ | |
| spk1_prompt_audio, | |
| spk1_prompt_text, | |
| spk1_dialect_prompt_text, | |
| spk2_prompt_audio, | |
| spk2_prompt_text, | |
| spk2_dialect_prompt_text, | |
| dialogue_text_input, | |
| ] | |
| gr.Examples( | |
| examples=EXAMPLES_LIST, | |
| inputs=inputs_for_examples, | |
| label="播客模板示例 (点击加载)", | |
| examples_per_page=5, | |
| ) | |
| with gr.Accordion("方言提示文本 (Dialect Prompt) 选择器", open=False): | |
| gr.Markdown("选择方言后,请分别为 S1 和 S2 选择一个示例。") | |
| dialect_selector = gr.Dropdown( | |
| label="选择方言 (Select Dialect)", | |
| choices=DIALECT_CHOICES, | |
| value="(无)", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| s1_dialect_example_selector = gr.Dropdown( | |
| label="S1 方言示例 (S1 Dialect Example)", | |
| choices=["(请先选择方言)"], | |
| value="(无)", | |
| interactive=True, | |
| elem_classes="gradio-dropdown" | |
| ) | |
| s2_dialect_example_selector = gr.Dropdown( | |
| label="S2 方言示例 (S2 Dialect Example)", | |
| choices=["(请先选择方言)"], | |
| value="(无)", | |
| interactive=True, | |
| elem_classes="gradio-dropdown" | |
| ) | |
| dialect_selector.change( | |
| fn=update_example_choices, | |
| inputs=[dialect_selector], | |
| outputs=[s1_dialect_example_selector, s2_dialect_example_selector] | |
| ) | |
| s1_dialect_example_selector.change( | |
| fn=update_prompt_text, | |
| inputs=[dialect_selector, s1_dialect_example_selector], | |
| outputs=[spk1_dialect_prompt_text] | |
| ) | |
| s2_dialect_example_selector.change( | |
| fn=update_prompt_text, | |
| inputs=[dialect_selector, s2_dialect_example_selector], | |
| outputs=[spk2_dialect_prompt_text] | |
| ) | |
| def _change_component_language(lang): | |
| global global_lang | |
| global_lang = ["zh", "en"][lang] | |
| return [ | |
| # spk1_prompt_{audio,text,dialect_prompt_text} | |
| gr.update(label=i18n("spk1_prompt_audio_label")), | |
| gr.update( | |
| label=i18n("spk1_prompt_text_label"), | |
| placeholder=i18n("spk1_prompt_text_placeholder"), | |
| ), | |
| gr.update( | |
| label=i18n("spk1_dialect_prompt_text_label"), | |
| placeholder=i18n("spk1_dialect_prompt_text_placeholder"), | |
| ), | |
| # spk2_prompt_{audio,text} | |
| gr.update(label=i18n("spk2_prompt_audio_label")), | |
| gr.update( | |
| label=i18n("spk2_prompt_text_label"), | |
| placeholder=i18n("spk2_prompt_text_placeholder"), | |
| ), | |
| gr.update( | |
| label=i18n("spk2_dialect_prompt_text_label"), | |
| placeholder=i18n("spk2_dialect_prompt_text_placeholder"), | |
| ), | |
| # dialogue_text_input | |
| gr.update( | |
| label=i18n("dialogue_text_input_label"), | |
| placeholder=i18n("dialogue_text_input_placeholder"), | |
| ), | |
| # generate_btn | |
| gr.update(value=i18n("generate_btn_label")), | |
| # generate_audio | |
| gr.update(label=i18n("generated_audio_label")), | |
| ] | |
| lang_choice.change( | |
| fn=_change_component_language, | |
| inputs=[lang_choice], | |
| outputs=[ | |
| spk1_prompt_audio, | |
| spk1_prompt_text, | |
| spk1_dialect_prompt_text, | |
| spk2_prompt_audio, | |
| spk2_prompt_text, | |
| spk2_dialect_prompt_text, | |
| dialogue_text_input, | |
| generate_btn, | |
| generate_audio, | |
| ], | |
| ) | |
| generate_btn.click( | |
| fn=dialogue_synthesis_function, | |
| inputs=[ | |
| dialogue_text_input, | |
| spk1_prompt_text, | |
| spk1_prompt_audio, | |
| spk1_dialect_prompt_text, | |
| spk2_prompt_text, | |
| spk2_prompt_audio, | |
| spk2_dialect_prompt_text, | |
| seed_input, | |
| ], | |
| outputs=[generate_audio], | |
| ) | |
| return page | |
| def prepare_model(): | |
| ckpt_path = snapshot_download(repo_id="Soul-AILab/SoulX-Podcast-1.7B") | |
| return ckpt_path | |
| def get_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument('--model_path', | |
| type=str, | |
| default=None, | |
| help='model path') | |
| parser.add_argument('--llm_engine', | |
| type=str, | |
| default="hf", | |
| help='model execute engine') | |
| parser.add_argument('--fp16_flow', | |
| action='store_true', | |
| help='enable fp16 flow') | |
| parser.add_argument('--seed', | |
| type=int, | |
| default=1988, | |
| help='random seed for generation') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = get_args() | |
| model_path = prepare_model() | |
| print(f"current_download_modelpath = {model_path}") | |
| hf_config = SoulXPodcastLLMConfig.from_initial_and_json( | |
| initial_values={"fp16_flow": args.fp16_flow}, | |
| json_file=f"{model_path}/soulxpodcast_config.json") | |
| llm_engine = args.llm_engine | |
| if llm_engine == "vllm": | |
| if not importlib.util.find_spec("vllm"): | |
| llm_engine = "hf" | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.") | |
| config = Config(model=model_path, enforce_eager=True, llm_engine=llm_engine, | |
| hf_config=hf_config) | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| random.seed(args.seed) | |
| initiate_model(config) | |
| print("[INFO] SoulX-Podcast loaded") | |
| page = render_interface() | |
| page.queue() | |
| page.launch() | |