tiamojames commited on
Commit
33331a1
·
verified ·
1 Parent(s): 2076d61

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +567 -0
  2. gitattributes +35 -0
  3. requirements.txt +14 -0
  4. soulxpodcast/__init__.py +0 -0
  5. soulxpodcast/__pycache__/__init__.cpython-311.pyc +0 -0
  6. soulxpodcast/__pycache__/__init__.cpython-312.pyc +0 -0
  7. soulxpodcast/__pycache__/config.cpython-311.pyc +0 -0
  8. soulxpodcast/__pycache__/config.cpython-312.pyc +0 -0
  9. soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc +0 -0
  10. soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc +0 -0
  11. soulxpodcast/cli/engine_test.py +74 -0
  12. soulxpodcast/cli/soulxpodcast.py +273 -0
  13. soulxpodcast/config.py +141 -0
  14. soulxpodcast/engine/__init__.py +0 -0
  15. soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc +0 -0
  16. soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc +0 -0
  17. soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc +0 -0
  18. soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc +0 -0
  19. soulxpodcast/engine/llm_engine.py +116 -0
  20. soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc +0 -0
  21. soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc +0 -0
  22. soulxpodcast/models/modules/__init__.py +0 -0
  23. soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  24. soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  25. soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc +0 -0
  26. soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc +0 -0
  27. soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc +0 -0
  28. soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc +0 -0
  29. soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc +0 -0
  30. soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc +0 -0
  31. soulxpodcast/models/modules/flow.py +197 -0
  32. soulxpodcast/models/modules/flow_components/__init__.py +0 -0
  33. soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc +0 -0
  34. soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc +0 -0
  35. soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc +0 -0
  36. soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc +0 -0
  37. soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
  38. soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc +0 -0
  39. soulxpodcast/models/modules/flow_components/estimator.py +974 -0
  40. soulxpodcast/models/modules/flow_components/upsample_encoder.py +998 -0
  41. soulxpodcast/models/modules/hifigan.py +249 -0
  42. soulxpodcast/models/modules/hifigan_components/__init__.py +0 -0
  43. soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc +0 -0
  44. soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc +0 -0
  45. soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc +0 -0
  46. soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc +0 -0
  47. soulxpodcast/models/modules/hifigan_components/layers.py +433 -0
  48. soulxpodcast/models/modules/sampler.py +221 -0
  49. soulxpodcast/models/soulxpodcast.py +192 -0
  50. soulxpodcast/utils/__init__.py +0 -0
app.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from tqdm import tqdm
4
+ from argparse import ArgumentParser
5
+ from typing import Literal, List, Tuple
6
+ import sys
7
+ import importlib.util
8
+ from datetime import datetime
9
+ import spaces
10
+ import torch
11
+ import numpy as np # 确保导入 numpy
12
+ import random # 确保导入 random
13
+ import s3tokenizer
14
+
15
+ from soulxpodcast.models.soulxpodcast import SoulXPodcast
16
+ from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams
17
+ from soulxpodcast.utils.dataloader import (
18
+ PodcastInferHandler,
19
+ SPK_DICT, TEXT_START, TEXT_END, AUDIO_START, TASK_PODCAST
20
+ )
21
+
22
+ # ================================================
23
+ # 示例音频路径
24
+ # ================================================
25
+ S1_PROMPT_WAV = "assets/audios/female_mandarin.wav" # 示例路径
26
+ S2_PROMPT_WAV = "assets/audios/male_mandarin.wav" # 示例路径
27
+
28
+
29
+ # ================================================
30
+ # 示例数据 (gr.Examples)
31
+ # ================================================
32
+ EXAMPLES_LIST = [
33
+ # 示例 1:清空所有
34
+ [
35
+ None, "", "", None, "", "", ""
36
+ ],
37
+ # 示例 2:普通播客
38
+ [
39
+ S1_PROMPT_WAV,
40
+ "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
41
+ "",
42
+ S2_PROMPT_WAV,
43
+ "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
44
+ "",
45
+ "[S1] 哈喽,AI时代的冲浪先锋们!欢迎收听《AI生活进行时》。啊,一个充满了未来感,然后,还有一点点,<|laughter|>神经质的播客节目,我是主持人小希。\n[S2] 哎,大家好呀!我是能唠,爱唠,天天都想唠的唠嗑!\n[S1] 最近活得特别赛博朋克哈!以前老是觉得AI是科幻片儿里的,<|sigh|> 现在,现在连我妈都用AI写广场舞文案了。\n[S2] 这个例子很生动啊。是的,特别是生成式AI哈,感觉都要炸了! 诶,那我们今天就聊聊AI是怎么走进我们的生活的哈!",
46
+ ],
47
+ # 示例 3:四川播客
48
+ [
49
+ S1_PROMPT_WAV,
50
+ "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
51
+ "<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷!",
52
+ S2_PROMPT_WAV,
53
+ "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
54
+ "<|Sichuan|>哎哟喂,这个搞反了噻!黑神话里头唱曲子的王二浪早八百年就在黄土高坡吼秦腔喽,游戏组专门跑切录的原汤原水,听得人汗毛儿都立起来!",
55
+ "[S1] <|Sichuan|>各位《巴适得板》的听众些,大家好噻!我是你们主持人晶晶。今儿天气硬是巴适,不晓得大家是在赶路嘛,还是茶都泡起咯,准备跟我们好生摆一哈龙门阵喃?\n[S2] <|Sichuan|>晶晶好哦,大家安逸噻!我是李老倌。你刚开口就川味十足,摆龙门阵几个字一甩出来,我鼻子头都闻到茶香跟火锅香咯!\n[S1] <|Sichuan|>就是得嘛!李老倌,我前些天带个外地朋友切人民公园鹤鸣茶社坐了一哈。他硬是搞不醒豁,为啥子我们一堆人围到杯茶就可以吹一下午壳子,从隔壁子王嬢嬢娃儿耍朋友,扯到美国大选,中间还掺几盘斗地主。他说我们四川人简直是把摸鱼刻进骨子里头咯!\n[S2] <|Sichuan|>你那个朋友说得倒是有点儿趣,但他莫看到精髓噻。摆龙门阵哪是摸鱼嘛,这是我们川渝人特有的交际方式,更是一种活法。外省人天天说的松弛感,根根儿就在这龙门阵里头。今天我们就要好生摆一哈,为啥子四川人活得这么舒坦。就先从茶馆这个老窝子说起,看它咋个成了我们四川人的魂儿!",
56
+ ],
57
+ # 示例 4:粤语播客
58
+ [
59
+ S1_PROMPT_WAV,
60
+ "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
61
+ "<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯!",
62
+ S2_PROMPT_WAV,
63
+ "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
64
+ "<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,��神话边有咁大面啊?你估佢哋抄游戏咩!",
65
+ "[S1] <|Yue|>哈囉大家好啊,歡迎收聽我哋嘅節目。喂,我今日想問你樣嘢啊,你覺唔覺得,嗯,而家揸電動車,最煩,最煩嘅一樣嘢係咩啊?\n[S2] <|Yue|>梗係充電啦。大佬啊,搵個位都已經好煩,搵到個位仲要喺度等,你話快極都要半個鐘一個鐘,真係,有時諗起都覺得好冇癮。\n[S1] <|Yue|>係咪先。如果我而家同你講,充電可以快到同入油差唔多時間,你信唔信先?喂你平時喺油站入滿一缸油,要幾耐啊?五六分鐘?\n[S2] <|Yue|>差唔多啦,七八分鐘,點都走得啦。電車喎,可以做到咁快?你咪玩啦。",
66
+ ],
67
+ # 示例 5:河南播客
68
+ [
69
+ S1_PROMPT_WAV,
70
+ "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
71
+ "<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。",
72
+ S2_PROMPT_WAV,
73
+ "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
74
+ "<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!",
75
+ "[S1] <|Henan|>哎,大家好啊,欢迎收听咱这一期嘞《瞎聊呗,就这么说》,我是恁嘞老朋友,燕子。\n[S2] <|Henan|>大家好,我是老张。燕子啊,今儿瞅瞅你这个劲儿,咋着,是有啥可得劲嘞事儿想跟咱唠唠?\n[S1] <|Henan|>哎哟,老张,你咋恁懂我嘞!我跟你说啊,最近我刷手机,老是刷住些可逗嘞方言视频,特别是咱河南话,咦~我哩个乖乖,一听我都憋不住笑,咋说嘞,得劲儿哩很,跟回到家一样。\n[S2] <|Henan|>你这回可算说到根儿上了!河南话,咱往大处说说,中原官话,它真嘞是有一股劲儿搁里头。它可不光是说话,它脊梁骨后头藏嘞,是咱一整套、鲜鲜活活嘞过法儿,一种活人嘞道理。\n[S1] <|Henan|>活人嘞道理?哎,这你这一说,我嘞兴致“腾”一下就上来啦!觉住咱这嗑儿,一下儿从搞笑视频蹿到文化顶上了。那你赶紧给我白话白话,这里头到底有啥道道儿?我特别想知道——为啥一提起咱河南人,好些人脑子里“蹦”出来嘞头一个词儿,就是实在?这个实在,骨子里到底是啥嘞?",
76
+ ],
77
+ ]
78
+
79
+
80
+ # ================================================
81
+ # SoulX-Podcast Model
82
+ # ================================================
83
+ model: SoulXPodcast = None
84
+ dataset: PodcastInferHandler = None
85
+ def initiate_model(config: Config, enable_tn: bool=False):
86
+ global model
87
+ if model is None:
88
+ model = SoulXPodcast(config)
89
+
90
+ global dataset
91
+ if dataset is None:
92
+ dataset = PodcastInferHandler(model.llm.tokenizer, None, config)
93
+
94
+ # ================================================
95
+ # Gradio
96
+ # ================================================
97
+
98
+ _i18n_key2lang_dict = dict(
99
+ # Speaker1 Prompt
100
+ spk1_prompt_audio_label=dict(
101
+ en="Speaker 1 Prompt Audio",
102
+ zh="说话人 1 参考语音",
103
+ ),
104
+ spk1_prompt_text_label=dict(
105
+ en="Speaker 1 Prompt Text",
106
+ zh="说话人 1 参考文本",
107
+ ),
108
+ spk1_prompt_text_placeholder=dict(
109
+ en="text of speaker 1 Prompt audio.",
110
+ zh="说话人 1 参考文本",
111
+ ),
112
+ spk1_prompt_cot_text_label=dict(
113
+ en="Speaker 1 Prompt COT Text",
114
+ zh="说话人 1 参考推理链文本",
115
+ ),
116
+ spk1_prompt_cot_text_placeholder=dict(
117
+ en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ",
118
+ zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!",
119
+ ),
120
+ # Speaker2 Prompt
121
+ spk2_prompt_audio_label=dict(
122
+ en="Speaker 2 Prompt Audio",
123
+ zh="说话人 2 参考语音",
124
+ ),
125
+ spk2_prompt_text_label=dict(
126
+ en="Speaker 2 Prompt Text",
127
+ zh="说话人 2 参考文本",
128
+ ),
129
+ spk2_prompt_text_placeholder=dict(
130
+ en="[S2] text of speaker 2 prompt audio.",
131
+ zh="[S2] 说话人 2 参考文本",
132
+ ),
133
+ spk2_prompt_cot_text_label=dict(
134
+ en="Speaker 2 Prompt COT Text",
135
+ zh="说话人 2 参考推理链文本",
136
+ ),
137
+ spk2_prompt_cot_text_placeholder=dict(
138
+ en="Dialect prompt cot text with prefix: <|Sichuan|>/<|Yue|>/<|Henan|> ",
139
+ zh="带前缀方言提示词思维链文本,前缀如下:<|Sichuan|>/<|Yue|>/<|Henan|>,如:<|Sichuan|>走嘛,切吃那家新开的麻辣烫,听别个说味道硬是霸道得很,好吃到不摆了,去晚了还得排队!",
140
+ ),
141
+ # Dialogue input textbox
142
+ dialogue_text_input_label=dict(
143
+ en="Dialogue Text Input",
144
+ zh="合成文本输入",
145
+ ),
146
+ dialogue_text_input_placeholder=dict(
147
+ en="[S1]text[S2]text[S1]text...",
148
+ zh="[S1]文本[S2]文本[S1]文本...",
149
+ ),
150
+ # Generate button
151
+ generate_btn_label=dict(
152
+ en="Generate Audio",
153
+ zh="合成",
154
+ ),
155
+ # Generated audio
156
+ generated_audio_label=dict(
157
+ en="Generated Dialogue Audio",
158
+ zh="合成的对话音频",
159
+ ),
160
+ # Warining1: invalid text for prompt
161
+ warn_invalid_spk1_prompt_text=dict(
162
+ en='Invalid speaker 1 prompt text, should not be empty and strictly follow: "xxx"',
163
+ zh='说话人 1 参考文本不合规,不能为空,格式:"xxx"',
164
+ ),
165
+ # warn_invalid_spk1_prompt_cot_text=dict(
166
+ # en='Invalid speaker 1 prompt cot text, should not be empty and strictly follow: "[S1]xxx"',
167
+ # zh='说话人 1 参考文本不合规,格式:"[S1]xxx"',
168
+ # ),
169
+ warn_invalid_spk2_prompt_text=dict(
170
+ en='Invalid speaker 2 prompt text, should strictly follow: "[S2]xxx"',
171
+ zh='说话人 2 参考文本不合规,格式:"[S2]xxx"',
172
+ ),
173
+ # Warining2: invalid text for dialogue input
174
+ warn_invalid_dialogue_text=dict(
175
+ en='Invalid dialogue input text, should strictly follow: "[S1]xxx[S2]xxx..."',
176
+ zh='对话文本输入不合规,格式:"[S1]xxx[S2]xxx..."',
177
+ ),
178
+ # Warining3: incomplete prompt info
179
+ warn_incomplete_prompt=dict(
180
+ en="Please provide prompt audio and text for both speaker 1 and speaker 2",
181
+ zh="请提供说话人 1 与说话人 2 的参考语音与参考文本",
182
+ ),
183
+ )
184
+
185
+
186
+ global_lang: Literal["zh", "en"] = "zh"
187
+
188
+ def i18n(key):
189
+ # (保持不变)
190
+ global global_lang
191
+ return _i18n_key2lang_dict[key][global_lang]
192
+
193
+ def check_monologue_text(text: str, prefix: str = None) -> bool:
194
+ text = text.strip()
195
+ # Check speaker tags
196
+ if prefix is not None and (not text.startswith(prefix)):
197
+ return False
198
+ # Remove prefix
199
+ if prefix is not None:
200
+ text = text.removeprefix(prefix)
201
+ text = text.strip()
202
+ # If empty?
203
+ if len(text) == 0:
204
+ return False
205
+ return True
206
+
207
+ def check_dialect_prompt_cot_text(text: str, prefix: str = None) -> bool:
208
+ text = text.strip()
209
+ # Check COT prefix tags
210
+ if prefix is not None and (not text.startswith(prefix)):
211
+ return False
212
+ text = text.strip()
213
+ # If empty?
214
+ if len(text) == 0:
215
+ return False
216
+ return True
217
+
218
+ def check_dialogue_text(text_list: List[str]) -> bool:
219
+ if len(text_list) == 0:
220
+ return False
221
+ for text in text_list:
222
+ if not (
223
+ check_monologue_text(text, "[S1]")
224
+ or check_monologue_text(text, "[S2]")
225
+ or check_monologue_text(text, "[S3]")
226
+ or check_monologue_text(text, "[S4]")
227
+ ):
228
+ return False
229
+ return True
230
+
231
+ def process_single(target_text_list, prompt_wav_list, prompt_text_list, use_prompt_cot, prompt_cot_text_list):
232
+ spks, texts = [], []
233
+ for target_text in target_text_list:
234
+ pattern = r'(\[S[1-9]\])(.+)'
235
+ match = re.match(pattern, target_text)
236
+ text, spk = match.group(2), int(match.group(1)[2])-1
237
+ spks.append(spk)
238
+ texts.append(text)
239
+
240
+ global dataset
241
+ dataitem = {"key": "001", "prompt_text": prompt_text_list, "prompt_wav": prompt_wav_list,
242
+ "text": texts, "spk": spks, }
243
+ if use_prompt_cot:
244
+ dataitem.update({
245
+ "prompt_cot_text": prompt_cot_text_list
246
+ })
247
+ dataset.update_datasource(
248
+ [
249
+ dataitem
250
+ ]
251
+ )
252
+
253
+ # assert one data only;
254
+ data = dataset[0]
255
+ prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T]
256
+ spk_emb_for_flow = torch.tensor(data["spk_emb"])
257
+ prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(data["mel"], batch_first=True, padding_value=0) # [B, T', num_mels=80]
258
+ prompt_mels_lens_for_flow = torch.tensor(data['mel_len'])
259
+ text_tokens_for_llm = data["text_tokens"]
260
+ prompt_text_tokens_for_llm = data["prompt_text_tokens"]
261
+ spk_ids = data["spks_list"]
262
+ sampling_params = SamplingParams(use_ras=True,win_size=25,tau_r=0.2)
263
+ infos = [data["info"]]
264
+ processed_data = {
265
+ "prompt_mels_for_llm": prompt_mels_for_llm,
266
+ "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
267
+ "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
268
+ "text_tokens_for_llm": text_tokens_for_llm,
269
+ "prompt_mels_for_flow_ori": prompt_mels_for_flow,
270
+ "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
271
+ "spk_emb_for_flow": spk_emb_for_flow,
272
+ "sampling_params": sampling_params,
273
+ "spk_ids": spk_ids,
274
+ "infos": infos,
275
+ "use_prompt_cot": use_prompt_cot,
276
+ }
277
+ if use_prompt_cot:
278
+ processed_data.update({
279
+ "prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"],
280
+ "prompt_cot_prefix": data["prompt_cot_prefix"],
281
+ })
282
+ return processed_data
283
+
284
+ @spaces.GPU
285
+ def dialogue_synthesis_function(
286
+ target_text: str,
287
+ spk1_prompt_text: str | None = "",
288
+ spk1_prompt_audio: str | None = None,
289
+ spk1_prompt_cot_text: str | None = "",
290
+ spk2_prompt_text: str | None = "",
291
+ spk2_prompt_audio: str | None = None,
292
+ spk2_prompt_cot_text: str | None = "",
293
+ seed: int = 1988, # <-- seed 参数保留
294
+ ):
295
+ # ================== 设置随机种子 ==================
296
+ seed = int(seed)
297
+ torch.manual_seed(seed)
298
+ np.random.seed(seed)
299
+ random.seed(seed)
300
+ # ================================================
301
+
302
+ # Check prompt info
303
+ target_text_list: List[str] = re.findall(r"(\[S[0-9]\][^\[\]]*)", target_text)
304
+ target_text_list = [text.strip() for text in target_text_list]
305
+ if not check_dialogue_text(target_text_list):
306
+ gr.Warning(message=i18n("warn_invalid_dialogue_text"))
307
+ return None
308
+
309
+ # Go synthesis
310
+ progress_bar = gr.Progress(track_tqdm=True)
311
+ prompt_wav_list = [spk1_prompt_audio, spk2_prompt_audio]
312
+ prompt_text_list = [spk1_prompt_text, spk2_prompt_text]
313
+ use_prompt_cot = spk1_prompt_cot_text.strip()!="" or spk2_prompt_cot_text.strip()!=""
314
+ prompt_cot_text_list = [spk1_prompt_cot_text, spk2_prompt_cot_text]
315
+ data = process_single(
316
+ target_text_list,
317
+ prompt_wav_list,
318
+ prompt_text_list,
319
+ use_prompt_cot,
320
+ prompt_cot_text_list,
321
+ )
322
+
323
+ results_dict = model.forward_longform(
324
+ **data
325
+ )
326
+ target_audio = None
327
+ for i in range(len(results_dict['generated_wavs'])):
328
+ if target_audio is None:
329
+ target_audio = results_dict['generated_wavs'][i]
330
+ else:
331
+ target_audio = torch.concat([target_audio, results_dict['generated_wavs'][i]], axis=1)
332
+ return (24000, target_audio.cpu().squeeze(0).numpy())
333
+
334
+
335
+ def render_interface() -> gr.Blocks:
336
+ with gr.Blocks(title="SoulX-Podcast", theme=gr.themes.Default()) as page:
337
+ # ======================== UI ========================
338
+ with gr.Row():
339
+ lang_choice = gr.Radio(
340
+ choices=["中文", "English"],
341
+ value="中文",
342
+ label="Display Language/显示语言",
343
+ type="index",
344
+ interactive=True,
345
+ scale=3,
346
+ )
347
+ seed_input = gr.Number(
348
+ label="Seed (种子)",
349
+ value=1988,
350
+ step=1,
351
+ interactive=True,
352
+ scale=1,
353
+ )
354
+
355
+ with gr.Row():
356
+ # ==== Speaker1 Prompt ====
357
+ with gr.Column(scale=1):
358
+ with gr.Group(visible=True) as spk1_prompt_group:
359
+ spk1_prompt_audio = gr.Audio(
360
+ label=i18n("spk1_prompt_audio_label"),
361
+ type="filepath",
362
+ editable=False,
363
+ interactive=True,
364
+ )
365
+ spk1_prompt_text = gr.Textbox(
366
+ label=i18n("spk1_prompt_text_label"),
367
+ placeholder=i18n("spk1_prompt_text_placeholder"),
368
+ lines=3,
369
+ )
370
+ spk1_prompt_cot_text = gr.Textbox(
371
+ label=i18n("spk1_prompt_cot_text_label"),
372
+ placeholder=i18n("spk1_prompt_cot_text_placeholder"),
373
+ value="",
374
+ lines=3,
375
+ )
376
+ # ==== Speaker2 Prompt ====
377
+ with gr.Column(scale=1, visible=True):
378
+ with gr.Group(visible=True) as spk2_prompt_group:
379
+ spk2_prompt_audio = gr.Audio(
380
+ label=i18n("spk2_prompt_audio_label"),
381
+ type="filepath",
382
+ editable=False,
383
+ interactive=True,
384
+ )
385
+ spk2_prompt_text = gr.Textbox(
386
+ label=i18n("spk2_prompt_text_label"),
387
+ placeholder=i18n("spk2_prompt_text_placeholder"),
388
+ lines=3,
389
+ )
390
+ spk2_prompt_cot_text = gr.Textbox(
391
+ label=i18n("spk2_prompt_cot_text_label"),
392
+ placeholder=i18n("spk2_prompt_cot_text_placeholder"),
393
+ value="",
394
+ lines=3,
395
+ )
396
+ # ==== Text input ====
397
+ with gr.Column(scale=2):
398
+ with gr.Row():
399
+ dialogue_text_input = gr.Textbox(
400
+ label=i18n("dialogue_text_input_label"),
401
+ placeholder=i18n("dialogue_text_input_placeholder"),
402
+ lines=18,
403
+ )
404
+
405
+ # Generate button
406
+ with gr.Row():
407
+ generate_btn = gr.Button(
408
+ value=i18n("generate_btn_label"),
409
+ variant="primary",
410
+ scale=3,
411
+ size="lg",
412
+ )
413
+
414
+ # Long output audio
415
+ generate_audio = gr.Audio(
416
+ label=i18n("generated_audio_label"),
417
+ interactive=False,
418
+ )
419
+
420
+ with gr.Row():
421
+ inputs_for_examples = [
422
+ spk1_prompt_audio,
423
+ spk1_prompt_text,
424
+ spk1_prompt_cot_text,
425
+ spk2_prompt_audio,
426
+ spk2_prompt_text,
427
+ spk2_prompt_cot_text,
428
+ dialogue_text_input,
429
+ ]
430
+
431
+ example_headers = [
432
+ "S1 音频", "S1 文本", "S1 COT",
433
+ "S2 音频", "S2 文本", "S2 COT",
434
+ "对话内容"
435
+ ]
436
+
437
+ gr.Examples(
438
+ examples=EXAMPLES_LIST,
439
+ inputs=inputs_for_examples,
440
+ label="播客模板示例 (点击加载)",
441
+ examples_per_page=5,
442
+ )
443
+
444
+ # ======================== Action ========================
445
+ def _change_component_language(lang):
446
+ global global_lang
447
+ global_lang = ["zh", "en"][lang]
448
+ return [
449
+
450
+ # spk1_prompt_{audio,text,prompt_cot_text}
451
+ gr.update(label=i18n("spk1_prompt_audio_label")),
452
+ gr.update(
453
+ label=i18n("spk1_prompt_text_label"),
454
+ placeholder=i18n("spk1_prompt_text_placeholder"),
455
+ ),
456
+ gr.update(
457
+ label=i18n("spk1_prompt_cot_text_label"),
458
+ placeholder=i18n("spk1_prompt_cot_text_placeholder"),
459
+ ),
460
+ # spk2_prompt_{audio,text}
461
+ gr.update(label=i18n("spk2_prompt_audio_label")),
462
+ gr.update(
463
+ label=i18n("spk2_prompt_text_label"),
464
+ placeholder=i18n("spk2_prompt_text_placeholder"),
465
+ ),
466
+ gr.update(
467
+ label=i18n("spk2_prompt_cot_text_label"),
468
+ placeholder=i18n("spk2_prompt_cot_text_placeholder"),
469
+ ),
470
+ # dialogue_text_input
471
+ gr.update(
472
+ label=i18n("dialogue_text_input_label"),
473
+ placeholder=i18n("dialogue_text_input_placeholder"),
474
+ ),
475
+ # generate_btn
476
+ gr.update(value=i18n("generate_btn_label")),
477
+ # generate_audio
478
+ gr.update(label=i18n("generated_audio_label")),
479
+ ]
480
+
481
+ lang_choice.change(
482
+ fn=_change_component_language,
483
+ inputs=[lang_choice],
484
+ outputs=[
485
+ spk1_prompt_audio,
486
+ spk1_prompt_text,
487
+ spk1_prompt_cot_text,
488
+ spk2_prompt_audio,
489
+ spk2_prompt_text,
490
+ spk2_prompt_cot_text,
491
+ dialogue_text_input,
492
+ generate_btn,
493
+ generate_audio,
494
+ ],
495
+ )
496
+
497
+ # Generate button click Action
498
+ generate_btn.click(
499
+ fn=dialogue_synthesis_function,
500
+ inputs=[
501
+ dialogue_text_input,
502
+ spk1_prompt_text,
503
+ spk1_prompt_audio,
504
+ spk1_prompt_cot_text,
505
+ spk2_prompt_text,
506
+ spk2_prompt_audio,
507
+ spk2_prompt_cot_text,
508
+ seed_input,
509
+ ],
510
+ outputs=[generate_audio],
511
+ )
512
+ return page
513
+
514
+
515
+ # ================================================
516
+ # Options
517
+ # ================================================
518
+ def get_args():
519
+ parser = ArgumentParser()
520
+ parser.add_argument('--model_path',
521
+ required=True,
522
+ type=str,
523
+ help='model path')
524
+ parser.add_argument('--llm_engine',
525
+ type=str,
526
+ default="hf",
527
+ help='model execute engine')
528
+ parser.add_argument('--fp16_flow',
529
+ action='store_true',
530
+ help='enable fp16 flow')
531
+ parser.add_argument('--seed',
532
+ type=int,
533
+ default=1988,
534
+ help='random seed for generation')
535
+ args = parser.parse_args()
536
+ return args
537
+
538
+
539
+ if __name__ == "__main__":
540
+ args = get_args()
541
+
542
+ # Initiate model
543
+ hf_config = SoulXPodcastLLMConfig.from_initial_and_json(
544
+ initial_values={"fp16_flow": args.fp16_flow},
545
+ json_file=f"{args.model_path}/soulxpodcast_config.json")
546
+
547
+ # Compatible with the absence of a VLLM installation
548
+ llm_engine = args.llm_engine
549
+ if llm_engine == "vllm":
550
+ if not importlib.util.find_spec("vllm"):
551
+ llm_engine = "hf"
552
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
553
+ tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.")
554
+ config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine,
555
+ hf_config=hf_config)
556
+
557
+ torch.manual_seed(args.seed)
558
+ np.random.seed(args.seed)
559
+ random.seed(args.seed)
560
+
561
+ initiate_model(config)
562
+ print("[INFO] SoulX-Podcast loaded")
563
+ # UI
564
+ page = render_interface()
565
+ page.queue()
566
+ page.launch()
567
+ # page.launch(share=False)
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ numpy
3
+ scipy
4
+ s3tokenizer
5
+ diffusers
6
+ torch==2.7.1
7
+ torchaudio==2.7.1
8
+ triton>=3.0.0
9
+ transformers==4.57.1
10
+ accelerate==1.10.1
11
+ onnxruntime
12
+ onnxruntime-gpu
13
+ einops
14
+ gradio
soulxpodcast/__init__.py ADDED
File without changes
soulxpodcast/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
soulxpodcast/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
soulxpodcast/__pycache__/config.cpython-311.pyc ADDED
Binary file (9.22 kB). View file
 
soulxpodcast/__pycache__/config.cpython-312.pyc ADDED
Binary file (7.47 kB). View file
 
soulxpodcast/cli/__pycache__/soulxpodcast.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
soulxpodcast/cli/__pycache__/soulxpodcast.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
soulxpodcast/cli/engine_test.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ import sys
6
+ from glob import glob
7
+ from copy import deepcopy
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ from dataclasses import fields, asdict
12
+
13
+ from vllm import LLM
14
+ from vllm.inputs import TokensPrompt as TokensPrompt
15
+ from vllm import SamplingParams
16
+
17
+ def set_all_random_seed(seed):
18
+ import random
19
+ import numpy as np
20
+ import os
21
+ random.seed(seed)
22
+ os.environ["PYTHONHASHSEED"] = str(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed)
26
+ torch.backends.cudnn.deterministic = True
27
+ torch.backends.cudnn.benchmark = False
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='FlashCosyVoice')
31
+ parser.add_argument('--model_path',
32
+ required=True,
33
+ type=str,
34
+ help='model path')
35
+ parser.add_argument('--seed',
36
+ type=int,
37
+ default=1986,
38
+ help='random seed for generation')
39
+ args = parser.parse_args()
40
+ return args
41
+
42
+ def main():
43
+ args = get_args()
44
+
45
+ os.environ["VLLM_USE_V1"] = "0"
46
+ sampling_params = SamplingParams(temperature=0.6, repetition_penalty=1.25, top_k=100, top_p=0.9, max_tokens=3000, stop_token_ids=[153478], use_ras=True, win_size=25, tau_r=0.2)
47
+
48
+ llm = LLM(model=args.model_path, enforce_eager=True, dtype="bfloat16", )
49
+ input = [152477, 151674, 151712, 152596, 120432, 102115, 101417, 121407, 3837, 113562, 108896, 103554, 34187, 99744, 104283, 3837, 104074, 18830, 101417, 63109, 29826, 103924, 11319, 151713, 153477, 157342, 157987, 158071, 157264, 153912, 154998, 159371, 159391, 159562, 158917, 158870, 159680, 159672, 157485, 155058, 155207, 153630, 153846, 158058, 153814, 153841, 158204, 154543, 158928, 154055, 155761, 155026, 155298, 155061, 155884, 159856, 155815, 158542, 156292, 154332, 155055, 153606, 159474, 157971, 158055, 158607, 158193, 155035, 155038, 159420, 159643, 157699, 155778, 154224, 158652, 158461, 158551, 156465, 157728, 155541, 155307, 155297, 157894, 157894, 157813, 157570, 159511, 159616, 158914, 158833, 158935, 155045, 155770, 157962, 159491, 157384, 157398, 157077, 158859, 158693, 159419, 159565, 158796, 157484, 157248, 158029, 158348, 158271, 156337, 154890, 154638, 155268, 159637, 158827, 159646, 158179, 158159, 155244, 155280, 155283, 154488, 156999, 156279, 157007, 157891, 160081, 160000, 153478, 151675, 151712, 152596, 18830, 101417, 106909, 56568, 104138, 36587, 3837, 100132, 107133, 117320, 3837, 105146, 53153, 102155, 117927, 1773, 151713, 153477, 155303, 157049, 159236, 159344, 158936, 158693, 159425, 159424, 158609, 158266, 159972, 159969, 160052, 160055, 159971, 158375, 159670, 159427, 159428, 155783, 159321, 160053, 158284, 158132, 156714, 157470, 155283, 155286, 155167, 157324, 159511, 154330, 155635, 159613, 158481, 156348, 154160, 155572, 158516, 158326, 154645, 155619, 155457, 157343, 158882, 158155, 154777, 154249, 158695, 158675, 155055, 154332, 155079, 159004, 158510, 158591, 158593, 157151, 158697, 158688, 158595, 158577, 158576, 157814, 157894, 157894, 157813, 157813, 157813, 157813, 157732, 157240, 157483, 155296, 155626, 155761, 153659, 154194, 155052, 157968, 159421, 158032, 159320, 159247, 159015, 157789, 159974, 159977, 159986, 154641, 154392, 156328, 156094, 154413, 157728, 159509, 153733, 155680, 155860, 154889, 155618, 155609, 156581, 158191, 156949, 157482, 158691, 158694, 160100, 159125, 159772, 155627, 155626, 157489, 159514, 158785, 156598, 154411, 155626, 160054, 160130, 160051, 160066, 159337, 158365, 158194, 154783, 154867, 158372, 158615, 158616, 155943, 157263, 157587, 155425, 155380, 157321, 155080, 155734, 155029, 158690, 154426, 157259, 155847, 158961, 159988, 157069, 156513, 155541, 154569, 155316, 155403, 157427, 158238, 158588, 158668, 155935, 159970, 160050, 159725, 159778, 159967, 155369, 155370, 156828, 158446, 155005, 154246, 156433, 158539, 156455, 154308, 153738, 153990, 153909, 154152, 156429, 158426, 158452, 158659, 157085, 157813, 157813, 157813, 153478, 151674, 151712, 152596, 50930, 100014, 36587, 112491, 54039, 100391, 107512, 113524, 9370, 115639, 100930, 34718, 104038, 26940, 110602, 17447, 99469, 28029, 25067, 69249, 109624, 116542, 3837, 108243, 82166, 99172, 58364, 52129, 104207, 9370, 22697, 104097, 99882, 99615, 70074, 3837, 120432, 18830, 101417, 110721, 11319, 151713, 153477, 154327, 153757, 159436, 157245, 153603, 156051, 158027, 158273, 154133, 153918, 157908, 159372, 158681, 158163, 157428, 159572, 159886, 155049, 154305, 158538, 153973, 153595, 153676, 159430, 155060, 154575, 156006, 160129, 160138, 158666, 160017, 155156, 155437, 157459, 154713, 154962, 157239, 156016, 158272, 156688, 158384, 158541, 154218, 156186, 159262, 158614, 158535, 155203, 159574, 157316, 159669, 159657, 155284, 157300, 157456, 157159, 154906, 155781, 154720, 153652, 157183, 159422, 155206, 158190, 158082, 157996, 159239, 158513, 156169, 154314, 155727, 159124, 155689, 158078, 157247, 155304, 153900, 154390, 159975, 159726, 159968, 158257, 158203, 155512, 158056, 158479, 158085, 156519, 154329, 154364, 155059, 159433, 159433, 158704, 155140, 155707, 153649, 156516, 155057, 154656, 154079, 155191, 155775, 157239, 156669, 154951, 158517, 158030, 158245, 158299, 154123, 154728, 154857, 157830, 159853, 158885, 158425, 158072, 157749, 155157, 159535, 159619, 158836, 156597, 155540, 155550, 154101, 156277, 156349, 156124, 158349, 153651, 153898, 156835, 156445, 159397, 157232, 155128, 157158, 156023, 159500, 155871, 154557, 157319, 159577, 158938, 158158, 158629, 159287, 157833, 157860, 157887, 159426, 156355, 158623, 154287, 155544, 157727, 159512, 158038, 158514, 158533, 155790, 154812, 154830, 155542, 155056, 157246, 157243, 155053, 155620, 154813, 154000, 158455, 158621, 158678, 155051, 154177, 160154, 158689, 153976, 158366, 156021, 154233, 154161, 158002, 158270, 153890, 158781, 158639, 158642, 160132, 160133, 160144, 157967, 155694, 157644, 156024, 158135, 155784, 155078, 155491, 155599, 154326, 155543, 154821, 156774, 159533, 158480, 158417, 158054, 157258, 155787, 153792, 159538, 158890, 158809, 158809, 158107, 156577, 155133, 154404, 155446, 158515, 157076, 158453, 159040, 158392, 157669, 159245, 158524, 153708, 155546, 154899, 158660, 160096, 158627, 159431, 157490, 154197, 156096, 158887, 156733, 154434, 154119, 160013, 159043, 155573, 155545, 157327, 159433, 156517, 155383, 155401, 155072, 153639, 156294, 156888, 158664, 158663, 159562, 155920, 157242, 157726, 157241, 154517, 155356, 157708, 160151, 160153, 159418, 159101, 159263, 158534, 158364, 158004, 156524, 157567, 157894, 157894, 153478, 151675, 151712, 152596, 110931, 45629, 101454, 17447, 40814, 99164, 100753, 41683, 9370, 2073, 119176, 103162, 102781, 854, 100074, 99615, 47872, 6313, 115639, 112491, 99615, 81668, 104462, 99319, 100307, 102561, 99964, 101182, 99319, 108375, 100074, 3837, 99557, 99601, 104631, 108439, 100372, 100369, 99375, 99261, 6313, 151713, 153477, 155328, 159429, 159432, 153612, 157265, 155322, 155653, 159246, 156346, 153981, 155055, 155539, 155541, 154270, 155677, 160105, 160075, 157207, 160139, 158897, 158267, 158518, 158052, 156756, 154092, 155559, 155318, 155554, 155299, 155623, 155302, 159433, 159541, 157354, 155411, 154843, 158351, 158362, 159343, 156105, 154397, 158521, 154965, 154221, 156089, 157840, 159325, 159319, 160067, 160070, 159340, 159094, 159244, 158428, 156352, 158458, 159032, 158134, 158000, 157261, 155158, 153668, 153597, 153867, 154110, 154838, 155569, 156997, 157000, 154898, 155636, 157941, 155672, 155673, 155754, 160127, 160126, 158692, 158858, 156319, 157048, 157075, 154242, 154359, 153653, 155077, 155071, 158700, 159104, 158374, 158383, 158232, 157017, 157503, 157506, 155319, 155399, 155644, 155545, 155053, 157243, 159457, 157270, 155083, 157786, 159243, 155782, 158941, 159194, 159752, 153849, 155562, 155643, 155722, 154991, 159915, 155058, 154440, 156501, 158209, 155518, 153661, 158696, 158200, 158861, 157566, 155622, 155706, 155733, 155760, 155624, 154814, 157810, 157813, 157813, 157813, 157813, 157246, 159508, 159454, 157267, 155326, 158239, 158590, 159322, 159403, 158599, 158680, 156482, 155646, 157509, 155482, 154135, 159510, 159435, 153954, 154070, 153598, 155863, 159670, 157250, 154332, 154143, 158454, 157861, 160135, 156503, 158600, 159935, 159773, 159693, 159774, 157909, 159267, 155055, 154602, 154062, 158207, 156364, 156436, 156481, 156510, 154839, 157394, 157557, 159023, 159103, 156036, 155640, 155400, 155321, 155563, 155626, 155545, 159433, 159460, 158731, 154357, 155464, 155515, 157481, 159264, 153999, 153711, 153747, 158561, 159290, 156310, 158210, 158617, 158543, 158679, 154309, 155758, 154323, 153975, 159581, 158214, 159671, 154578, 155565, 155645, 154168, 154264, 160155, 160074, 160073, 159423, 155047, 155707, 155269, 157157, 156347, 153657, 154356, 153648, 159209, 157021, 158506, 158587, 158416, 155965, 159532, 157346, 159983, 155809, 158212, 159193, 159753, 157008, 154587, 155561, 155551, 157729, 157813, 157813, 157813, 157813, 157489, 159511, 159457, 156541, 155461, 154408, 153806, 153744, 156590, 159046, 158254, 158108, 158354, 158353, 156177, 155061, 153626, 159448, 159127, 159976, 158079, 158706, 155844, 154624, 159240, 156511, 159913, 154815, 154818, 156492, 158686, 158275, 159077, 160049, 159256, 157800, 157798, 158596, 159157, 159346, 158641, 155780, 155697, 158759, 158110, 154159, 154437, 155547, 154089, 156384, 158687, 156500, 157084, 159184, 159187, 154817, 155726, 157940, 157939, 158684, 158131, 158057, 153644, 157507, 155564, 158434, 157408, 156676, 158156, 155895, 155126, 157395, 153671, 156420, 155868, 156368, 160019, 153850, 156171, 158685, 157886, 156509, 159132, 155274, 155277, 155196, 155168, 155141, 155138, 155626, 153478, 151674, 151712, 152596, 102177, 73670, 99877, 2073, 102274, 108243, 854, 99257, 100896, 3837, 100345, 26940, 107513, 99815, 85361, 23656, 25067, 105636, 3837, 58364, 52129, 2073, 100917, 108471, 101187, 99243, 100859, 854, 2073, 106957, 99476, 23305, 44729, 854, 100001, 100175, 107600, 9370, 119305, 100399, 70074, 6313, 151713, 153477, 157543, 155855, 158784, 158780, 158703, 155796, 156531, 155402, 153838, 156187, 156287, 158109, 158373, 156321, 154403, 157057, 156088, 160152, 159421, 155290, 157309, 157315, 157401, 158940, 156753, 159485, 159643, 158920, 154474, 154650, 154910, 159770, 158318, 158507, 154134, 153882, 153618, 159449, 156940, 154084, 156106, 159196, 159350, 158624, 154305, 154322, 154978, 154267, 160047, 159398, 155050, 154897, 154291, 155043, 154917, 159854, 156859, 158278, 157236, 154116, 154842, 160048, 157942, 157731, 159186, 158555, 159286, 157822, 160081, 157894, 157894, 157894, 157894, 157813, 157813, 157813, 154570, 153802, 153624, 158075, 160099, 159091, 159911, 159912, 156996, 153674, 155410, 154141, 153903, 154278, 157185, 159121, 158884, 157066, 158281, 154971, 153669, 159204, 159367, 159850, 156449, 154086, 154833, 154188, 156365, 159352, 158633, 159833, 159832, 159589, 157447, 154451, 157238, 157965, 153835, 154870, 158274, 154888, 155376, 155605, 156817, 153627, 159513, 158868, 156141, 155331, 155384, 156760, 159433, 159433, 159433, 159433, 155137, 157837, 160102, 160129, 157954, 155700, 154968, 159588, 159183, 158189, 158783, 155799, 153864, 156068, 158188, 159163, 156967, 158192, 157976, 156536, 155320, 159253, 154647, 153873, 153603, 158229, 156320, 157039, 158444, 158860, 158546, 157104, 155725, 154298, 159593, 156114, 153819, 154384, 157405, 159437, 159995, 154104, 155724, 155716, 155755, 154646, 154863, 154374, 157746, 159045, 158291, 159650, 157444, 159181, 158202, 153600, 155117, 157313, 157393, 155811, 159284, 160016, 159804, 159910, 158197, 158137, 155795, 157262, 155347, 159980, 157556, 156585, 153663, 155114, 156518, 158704, 155788, 155221, 155113, 156739, 153789, 155852, 159000, 154132, 156087, 158081, 155194, 154621, 156025, 154081, 155613, 154137, 158186, 158996, 159220, 158286, 153894, 153654, 158026, 158597, 156184, 158619, 158651, 159409, 155164, 159643, 156703, 155210, 157314, 157977, 156339, 154862, 154861, 154727, 155568, 155574, 155007, 158688, 156280, 158536, 158581, 158402, 156651, 159643, 154471, 154677, 156288, 159044, 155555, 157894, 157894, 157813, 153478, 151675, 151712, 152596, 99936, 30534, 104609, 106594, 99882, 99615, 70074, 9370, 101485, 99499, 57566, 6313, 102245, 101325, 99164, 45861, 102504, 9370, 2073, 100268, 93, 102634, 44729, 93, 121407, 93, 119921, 93, 99800, 101953, 93, 33590, 99601, 49187, 36407, 100132, 104666, 101941, 6313, 151713, 153477, 154258, 156489, 155054, 154331, 154349, 159775, 157831, 158516, 156148, 158443, 158165, 155817, 153636, 155074, 155419, 155329, 159433, 156517, 154816, 159235, 156015, 154896, 154230, 154948, 158515, 157222, 154275, 155540, 155567, 159914, 155971, 158515, 158608, 160071, 157884, 157155, 154320, 155039, 157807, 156754, 155323, 157030, 158347, 156504, 154296, 157914, 157590, 157617, 157724, 159668, 158198, 158162, 158001, 156533, 159453, 157266, 155105, 155330, 157246, 155086, 154870, 158111, 156427, 155976, 157001, 154098, 154206, 158669, 159370, 157906, 159266, 157244, 153927, 153675, 158753, 159969, 157060, 153660, 155315, 159776, 154633, 158025, 157998, 156054, 156027, 153840, 154083, 154595, 155299, 157240, 154412, 154826, 157642, 157480, 159664, 158206, 155940, 155180, 155103, 155102, 155183, 155200, 159665, 157725, 155295, 155441, 155479, 155477, 155898, 158445, 158427, 158319, 159047, 157823, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155398, 159445, 157970, 153623, 158512, 156342, 158670, 158643, 160101, 159369, 155011, 157078, 159751, 157591, 155407, 154627, 156133, 158542, 154178, 154302, 153982, 158269, 158682, 158321, 159973, 158511, 158698, 159679, 159103, 158311, 159695, 155483, 158516, 155869, 156526, 157494, 154821, 154911, 155314, 155838, 158322, 158241, 158223, 158303, 158222, 155324, 155570, 155300, 155545, 155296, 157483, 157483, 155539, 155707, 157912, 157666, 156452, 158650, 158647, 158648, 158568, 158571, 156357, 154170, 154179, 154845, 154844, 155654, 155545, 155626, 157813, 157486, 159427, 157246, 155383, 154843, 158271, 156175, 158696, 155979, 156600, 153635, 159451, 155428, 159982, 159985, 159744, 159096, 158043, 158034, 158115, 158087, 158087, 158978, 157495, 157733, 157813, 157813, 157813, 157813, 157813, 157813, 155302, 159430, 159430, 157243, 155059, 155707, 153892, 158195, 159653, 159654, 157467, 157476, 157397, 159582, 154398, 158139, 158166, 158112, 155925, 156168, 156249, 154071, 154719, 155439, 155358, 155087, 155383, 157813, 157732, 157732, 157813, 157813, 157813, 157813, 157243, 159430, 156514, 155377, 157894, 154297, 158618, 159104, 158943, 158457, 156270, 156351, 154167, 154893, 155595, 155325, 155570, 155545, 155626, 157813, 157813, 157813, 157813, 157813, 157813, 157813, 155626, 159973, 159248, 158295, 159269, 158544, 156393, 158661, 158604, 158995, 160131, 156340, 158294, 159024, 159105, 158862, 159618, 159621, 159645, 159630, 159629, 159548, 157579, 159766, 157498, 157741, 157813, 157813, 157813, 157813, 157894, 155707, 157786, 158997, 157969, 158950, 159590, 157329, 153684, 154035, 155862, 155434, 153901, 156837, 156273, 156354, 156381, 158649, 158676, 158657, 158656, 159388, 160010, 157742, 157732, 157732, 157732, 157813, 157813, 157813, 157813, 157813, 157726, 159670, 159697, 157510, 157510, 157483, 157483, 157726, 157483, 157483, 157726, 157483, 157483, 157486, 157732, 160000, 160000, 160000, 160000, 160000, 159919, 159919, 159919, 159919, 159919, 159919, 157732, 157732, 159217, 158243, 158489, 153889, 154386, 154353, 156529, 157588, 156097, 158051, 157974, 153752, 155365, 157975, 156112, 153729, 155076, 157752, 157881, 156099, 158133, 155865, 153681, 153735, 154464, 155084, 155327, 155056, 157243, 157240, 157240, 157240, 157813, 160000, 160000, 157813, 157813, 157813, 157813, 157813, 157894, 155593, 158105, 158592, 156426, 154239, 156507, 157233, 158725, 158671, 158509, 157799, 157071, 155856, 153750, 153993, 153741, 153984, 154389, 155351, 155573, 155545, 155545, 155644, 157426, 156695, 158371, 158560, 158264, 157508, 157589, 155509, 156085, 156103, 155934, 154410, 155922, 158265, 159237, 157805, 157806, 156834, 156807, 158994, 159830, 159829, 158939, 156750, 153825, 153816, 153705, 154461, 155112, 155354, 155327, 155545, 157813, 157813, 153478, 151674, 151712, 152596, 110602, 17447, 99469, 99354, 75882, 29635, 18947, 2073, 70074, 85254, 101904, 23836, 33590, 99258, 104090, 111077, 18800, 101044, 106063, 18397, 115639, 119754, 46553, 3837, 45181, 99391, 22697, 9370, 101780, 17523, 99882, 99615, 104188, 99530, 22697, 9370, 120728, 121909, 99293, 115807, 101182, 6313, 151713, 153477, 156815, 158031, 156031, 154620, 154128, 159365, 157423, 158399, 158173, 158960, 157260, 159535, 156730, 157323, 155541, 154824, 156290, 156268, 156367, 158268, 153732, 159182, 158447, 159131, 159591, 157404, 155217, 154542, 155488, 153760, 155139, 155061, 158409, 158201, 158914, 158836, 155993, 156081, 154883, 154126, 156414, 153678, 156542, 159643, 158839, 153743, 154191, 155558, 159534, 157777, 159010, 159345, 158096, 159933, 155481, 155353, 158380, 156283, 159316, 158668, 158606, 158220, 155061, 154341, 155159, 156301, 159164, 159622, 155176, 155538, 154566, 158211, 155115, 155627, 158947, 159676, 159433, 159433, 159433, 155222, 155392, 155294, 155136, 157002, 155802, 156374, 157156, 158932, 158683, 158613, 154903, 153919, 156034, 158529, 156324, 156297, 155919, 155275, 153895, 159516, 157491, 154556, 155530, 155046, 155052, 157882, 157951, 158187, 158435, 158138, 159753, 157809, 159834, 158891, 158806, 158915, 155902, 156162, 156405, 155767, 156009, 158355, 156411, 154080, 155500, 159562, 157915, 154174, 155306, 155058, 156123, 159155, 159884, 153828, 153948, 156392, 158620, 158448, 156267, 158039, 158672, 158356, 156498, 155025, 159368, 155443, 158024, 159515, 156762, 153823, 155524, 158616, 156060, 153621, 155617, 156823, 154881, 153972, 158798, 159041, 155582, 155626, 157894, 157894, 157813, 159757, 159430, 159430, 159430, 159430, 155383, 157759, 158350, 156358, 156483, 157101, 155480, 157795, 159242, 158106, 159270, 158625, 158674, 159167, 159643, 158914, 156638, 157475, 155127, 155814, 153742, 156082, 159265, 158665, 159424, 159425, 156274, 156454, 160154, 159176, 154150, 158290, 154407, 157734, 158630, 160093, 158385, 158787, 156033, 153736, 153628, 154096, 155349, 158718, 157585, 155837, 159996, 157083, 156462, 155937, 156428, 155590, 155591, 154161, 154629, 158865, 158930, 158914, 156655, 159662, 153642, 155892, 154957, 154243, 156844, 158184, 156014, 156584, 158436, 158696, 158282, 159081, 158488, 156348, 155261, 154722, 156492, 158565, 156506, 154987, 154294, 160155, 159424, 155691, 155708, 157813, 153478, 151675, 151712, 152596, 102177, 99360, 91777, 100569, 17340, 101376, 102073, 22697, 70074, 104387, 12857, 74393, 41505, 120965, 101240, 120965, 102565, 97907, 102138, 34718, 70074, 91956, 99662, 99318, 121657, 44729, 97907, 99318, 100893, 70074, 3837, 100132, 75606, 110261, 104754, 6313, 151713, 153477]
50
+ outputs = llm.generate(TokensPrompt(prompt_token_ids=input),
51
+ sampling_params,
52
+ use_tqdm=False)[0].outputs[0].token_ids
53
+ print(outputs)
54
+ # files = glob(f"{args.data_list}/*_result.json")
55
+ # files.sort()
56
+ # for file in files:
57
+ # with open(file) as fin:
58
+ # test_sets = json.load(fin)
59
+ # for test_set in test_sets:
60
+ # input = test_set["input"]
61
+ # set_all_random_seed(args.seed)
62
+ # llm_outputs = model.llm.generate(input, sampling_params)['token_ids']
63
+ # set_all_random_seed(args.seed)
64
+ # import pdb;pdb.set_trace()
65
+ # outputs = llm.generate(TokensPrompt(prompt_token_ids=input),
66
+ # VllmSamplingParams(**asdict(sampling_params)),
67
+ # use_tqdm=False)[0].outputs[0].token_ids
68
+ # print(llm_outputs)
69
+ # print(outputs)
70
+ # print("=========")
71
+ # import pdb;pdb.set_trace()
72
+
73
+ if __name__ == "__main__":
74
+ main()
soulxpodcast/cli/soulxpodcast.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from datetime import datetime
9
+ import importlib.util
10
+
11
+ import numpy as np
12
+ import onnxruntime
13
+ import s3tokenizer
14
+ import torch
15
+ import torchaudio
16
+ import torchaudio.compliance.kaldi as kaldi
17
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
18
+ from tqdm import tqdm
19
+
20
+ from soulxpodcast.config import Config, SoulXPodcastLLMConfig, SamplingParams
21
+ from soulxpodcast.models.soulxpodcast import SoulXPodcast
22
+ from soulxpodcast.utils.dataloader import PodcastDataset
23
+ from soulxpodcast.utils.audio import mel_spectrogram
24
+
25
+
26
+ def set_all_random_seed(seed):
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed_all(seed)
31
+
32
+
33
+ def save_file_async(
34
+ wav, uttid,
35
+ info,
36
+ is_sub,
37
+ ):
38
+ """Save audio asynchronously."""
39
+ try:
40
+ parentdir = f"{os.path.dirname(info['wav'])}"
41
+ basename = os.path.basename(info["wav"]).split(".")[0]
42
+ if is_sub:
43
+ parentdir = f"{parentdir}/individual_clips"
44
+ os.makedirs(parentdir, exist_ok=True)
45
+ if wav is not None:
46
+ wav = wav.cpu()
47
+ torchaudio.save(f'{parentdir}/{uttid}.wav', wav, 24000)
48
+ duration = wav.shape[-1] / 24000.0
49
+ else:
50
+ duration = 0.0
51
+ if not is_sub:
52
+ with open(f"{parentdir}/{basename}.json", "w") as f:
53
+ json.dump(info, f, ensure_ascii=False, indent=4)
54
+ return duration
55
+ except Exception as e:
56
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
57
+ tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
58
+ return 0.0
59
+
60
+ def collate_fn(batch):
61
+ assert len(batch) == 1
62
+ data = batch[0]
63
+
64
+ # prepare prompt mels for llm, spk_emb + prompt mel for flow;
65
+ prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(data["log_mel"]) # [B, num_mels=128, T]
66
+ spk_emb_for_flow = torch.tensor(data["spk_emb"])
67
+ prompt_mels_for_flow = data["mel"]
68
+
69
+ # prepare text + spk for llm;
70
+ text_tokens_for_llm = data["text_tokens"]
71
+ prompt_text_tokens_for_llm = data["prompt_text_tokens"]
72
+ spk_ids = data["spks_list"]
73
+ sampling_params = SamplingParams()
74
+ infos = [data["info"]]
75
+
76
+ processed_data = {
77
+ "prompt_mels_for_llm": prompt_mels_for_llm,
78
+ "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
79
+ "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
80
+ "text_tokens_for_llm": text_tokens_for_llm,
81
+ "prompt_mels_for_flow_ori": prompt_mels_for_flow,
82
+ "spk_emb_for_flow": spk_emb_for_flow,
83
+ "sampling_params": sampling_params,
84
+ "spk_ids": spk_ids,
85
+ "infos": infos,
86
+ }
87
+
88
+ if data.get("use_prompt_cot", False):
89
+ processed_data.update({
90
+ "use_prompt_cot": True,
91
+ "prompt_cot_text_tokens_for_llm": data["prompt_cot_text_tokens"],
92
+ "prompt_cot_prefix": data["prompt_cot_prefix"],
93
+ })
94
+ return processed_data
95
+
96
+
97
+ def get_args():
98
+ parser = argparse.ArgumentParser(description='FlashCosyVoice')
99
+ parser.add_argument('--model_path',
100
+ required=True,
101
+ type=str,
102
+ help='model path')
103
+ parser.add_argument('--data_list',
104
+ required=True,
105
+ type=str,
106
+ help='data list')
107
+ parser.add_argument('--num_workers',
108
+ type=int,
109
+ default=4,
110
+ help='workers for dataloader')
111
+ parser.add_argument('--prefetch',
112
+ type=int,
113
+ default=5,
114
+ help='prefetch for dataloader')
115
+ parser.add_argument('--llm_engine',
116
+ type=str,
117
+ default="hf",
118
+ help='model execute engine')
119
+ parser.add_argument('--fp16_flow',
120
+ action='store_true',
121
+ help='enable fp16 flow')
122
+ parser.add_argument('--seed',
123
+ type=int,
124
+ default=1986,
125
+ help='random seed for generation')
126
+ parser.add_argument('--save_intermediate',
127
+ action='store_true',
128
+ help='enable save intermediate result in long form.')
129
+ args = parser.parse_args()
130
+ return args
131
+
132
+
133
+ def main():
134
+ args = get_args()
135
+
136
+ assert (torch.cuda.is_available())
137
+ hf_config = SoulXPodcastLLMConfig.from_initial_and_json(
138
+ initial_values={"fp16_flow": args.fp16_flow},
139
+ json_file=f"{args.model_path}/soulxpodcast_config.json")
140
+
141
+ # Compatible with the absence of a VLLM installation
142
+ llm_engine = args.llm_engine
143
+ if llm_engine == "vllm":
144
+ if not importlib.util.find_spec("vllm"):
145
+ llm_engine = "hf"
146
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
147
+ tqdm.write(f"[{timestamp}] - [WARNING]: No install VLLM, switch to hf engine.")
148
+
149
+ config = Config(model=args.model_path, enforce_eager=True, llm_engine=llm_engine,
150
+ hf_config=hf_config)
151
+ model = SoulXPodcast(config)
152
+
153
+ set_all_random_seed(args.seed)
154
+
155
+ dataset = PodcastDataset(model.llm.tokenizer, args.data_list, config)
156
+ sampler = SequentialSampler(dataset,)
157
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True,
158
+ sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
159
+ total_steps = len(dataset)
160
+
161
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
162
+ tqdm.write(f"[{timestamp}] - [INFO] - {args}")
163
+ progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
164
+ position=0, leave=True, dynamic_ncols=True)
165
+
166
+ cpu_counts = os.cpu_count()
167
+ executor = ThreadPoolExecutor(max_workers=min(args.num_workers, cpu_counts // 2))
168
+
169
+ pending_futures = []
170
+ dataloader_iter = iter(dataloader)
171
+ succeed_duration = 0.01 # avoid division by zero
172
+ start_time = time.time()
173
+ estimated_total_wavs = 0
174
+ succeed_wavs = 0
175
+ failed_wavs = 0
176
+ last_print_time = start_time
177
+
178
+ while True:
179
+ try:
180
+ batch = next(dataloader_iter)
181
+
182
+ if len(batch['infos']) == 0:
183
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
184
+ tqdm.write(f"[{timestamp}] - [WARNING]: No valid batch found, skipping this batch...")
185
+ continue
186
+
187
+ results_dict = model.forward_longform(**batch)
188
+
189
+ estimated_total_wavs += len(results_dict['generated_wavs'])
190
+
191
+ uttid = batch['infos'][0]['key']
192
+ result = None
193
+ for i in range(len(results_dict['generated_wavs'])):
194
+ is_sub = True
195
+ if args.save_intermediate:
196
+ future = executor.submit(
197
+ save_file_async, results_dict['generated_wavs'][i],
198
+ f"{uttid}_turn_{str(i).zfill(2)}", batch['infos'][0].copy(), is_sub
199
+ )
200
+ pending_futures.append(future)
201
+ if result is None:
202
+ result = results_dict['generated_wavs'][i]
203
+ else:
204
+ result = torch.concat([result, results_dict['generated_wavs'][i]], axis=1)
205
+ future = executor.submit(
206
+ save_file_async, result,
207
+ f"{uttid}", batch['infos'][0].copy(), False
208
+ )
209
+ pending_futures.append(future)
210
+ completed_futures = []
211
+ for future in pending_futures:
212
+ if future.done():
213
+ try:
214
+ duration = future.result()
215
+ succeed_duration += duration
216
+ succeed_wavs += 1
217
+ except Exception as e:
218
+ failed_wavs += 1
219
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
220
+ tqdm.write(f"[{timestamp}] - [ERROR]: Error in async save task: {e}")
221
+ completed_futures.append(future)
222
+
223
+ for future in completed_futures:
224
+ pending_futures.remove(future)
225
+
226
+ update_n = 1
227
+ if progress_bar.n + update_n > progress_bar.total:
228
+ progress_bar.update(progress_bar.total - progress_bar.n)
229
+ else:
230
+ progress_bar.update(update_n)
231
+
232
+ current_time = time.time()
233
+ if current_time - last_print_time >= 120:
234
+ elapsed_time = current_time - start_time
235
+ avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
236
+ estimated_total_duration = avg_duration * estimated_total_wavs
237
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
238
+ tqdm.write(f"[{timestamp}] - [INFO]: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Elapsed time: {elapsed_time:.2f}s") # noqa
239
+ last_print_time = current_time
240
+ except StopIteration:
241
+ break
242
+ except Exception as e:
243
+ failed_wavs += 1
244
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
245
+ tqdm.write(f"[{timestamp}] - [ERROR]: Error in main loop: {e}")
246
+ continue
247
+
248
+ total_time = time.time() - start_time
249
+
250
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
251
+ tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
252
+
253
+ for future in pending_futures:
254
+ try:
255
+ duration = future.result(timeout=60)
256
+ succeed_duration += duration
257
+ succeed_wavs += 1
258
+ except Exception as e:
259
+ failed_wavs += 1
260
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
261
+ tqdm.write(f"[{timestamp}] - [ERROR]: Error in final async save task: {e}")
262
+ executor.shutdown(wait=True)
263
+
264
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
265
+ tqdm.write(f"[{timestamp}] - [INFO]: All async save tasks completed.")
266
+ progress_bar.close()
267
+
268
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
269
+ tqdm.write(f"[{timestamp}] - [INFO]: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h).") # noqa
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
soulxpodcast/config.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field, fields, is_dataclass, asdict
3
+ from typing import Any, Dict, List, Optional
4
+ from pathlib import Path
5
+ import json
6
+
7
+ import torch
8
+ from transformers import AutoConfig
9
+ from transformers import PretrainedConfig
10
+
11
+ @dataclass
12
+ class SoulXPodcastLLMConfig:
13
+ architectures: list[str] = field(default_factory=lambda: ["Qwen3ForCausalLM"])
14
+ attention_dropout: float = 0.0
15
+ bos_token_id: int = 151643
16
+ eos_token_id: int = 151675 # speech eos
17
+ hidden_act: str = "silu"
18
+ hidden_size: int = 2048
19
+ initializer_range: float = 0.02
20
+ intermediate_size: int = 6144
21
+ max_position_embeddings: int = 40960
22
+ max_window_layers: int = 28
23
+ model_type: str = "qwen3"
24
+ num_attention_heads: int = 16
25
+ num_hidden_layers: int = 28
26
+ num_key_value_heads: int = 8
27
+ head_dim: int = 128
28
+ rms_norm_eps: float = 1e-06
29
+ rope_scaling: dict | None = None
30
+ rope_theta: float = 1000000.0
31
+ sliding_window: int = 32768
32
+ tie_word_embeddings: bool = True
33
+ torch_dtype: str = "bfloat16"
34
+ transformers_version: str = "4.52.3"
35
+ use_cache: bool = True
36
+ use_sliding_window: bool = False
37
+ vocab_size: int = 159488 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
38
+ lm_head_bias: bool = False
39
+ qkv_bias: bool = False
40
+ fp16_flow: bool = False
41
+ speech_token_offset: int = 152927
42
+
43
+ @classmethod
44
+ def from_initial_and_json(
45
+ cls,
46
+ initial_values: Dict[str, Any] = None,
47
+ json_file: Optional[str] = None
48
+ ):
49
+ """
50
+ Create instance from initial values and JSON data
51
+
52
+ Args:
53
+ initial_values: Initial key-value dict, which will overrides all other configurations
54
+ json_file: JSON file path
55
+
56
+ Returns:
57
+ SoulXPodcastLLMConfig instance
58
+ """
59
+ # Merge all data sources
60
+ merged_data = {}
61
+
62
+ # 1. Load from JSON file first (lowest priority)
63
+ if json_file and os.path.exists(json_file):
64
+ file_data = cls._load_json_file(json_file)
65
+ merged_data.update(file_data)
66
+
67
+ # 2. Override with initial values last (highest priority)
68
+ if initial_values:
69
+ merged_data.update(initial_values)
70
+
71
+ # 3. Extract dataclass fields
72
+ valid_fields = {f.name for f in fields(cls)}
73
+ init_data = {k: v for k, v in merged_data.items() if k in valid_fields}
74
+
75
+ return cls(**init_data)
76
+
77
+ @staticmethod
78
+ def _load_json_file(file_path: str) -> Dict[str, Any]:
79
+ """从JSON文件加载数据"""
80
+ path = Path(file_path)
81
+ if not path.exists():
82
+ return {}
83
+ with open(path, 'r', encoding='utf-8') as f:
84
+ return json.load(f)
85
+
86
+ class AutoPretrainedConfig(PretrainedConfig):
87
+ model_type = "qwen3"
88
+
89
+ def __init__(self, **kwargs):
90
+ # Remove non-configuration parameters
91
+ config_kwargs = {k: v for k, v in kwargs.items()
92
+ if not k.startswith('_') and k != 'self'}
93
+ super().__init__(**config_kwargs)
94
+
95
+ @classmethod
96
+ def from_dataclass(cls, dataclass_config):
97
+ """Dynamically generate config from dataclass"""
98
+ if not is_dataclass(dataclass_config):
99
+ raise ValueError("Input must be a dataclass instance")
100
+
101
+ dataclass_dict = asdict(dataclass_config)
102
+ return cls(**dataclass_dict)
103
+
104
+
105
+ @dataclass
106
+ class SamplingParams:
107
+ temperature: float = 0.6
108
+ repetition_penalty: float = 1.25
109
+ top_k: int = 100
110
+ top_p: float = 0.9
111
+ max_tokens: int = 3000
112
+ min_tokens: int = 8
113
+ stop_token_ids: list[int] = field(default_factory=lambda: [151675])
114
+ # RasSampler parameters
115
+ use_ras: bool = True
116
+ win_size: int = 25
117
+ tau_r: float = 0.2
118
+
119
+
120
+ @dataclass
121
+ class Config:
122
+ model: str
123
+ max_model_len: int = 8192 # 15s prompt + 30s generated audio for 25hz audio tokenizer
124
+ gpu_memory_utilization: float = 0.9
125
+ tensor_parallel_size: int = 1
126
+ enforce_eager: bool = False
127
+ hf_config: SoulXPodcastLLMConfig | AutoConfig = field(default_factory=SoulXPodcastLLMConfig)
128
+ eos: int = -1
129
+ llm_engine: str = "hf" # support hf, nano-vllm
130
+ max_turn_size: int = 14
131
+ turn_tokens_threshold: int = 6192
132
+
133
+ prompt_context: int = 2 # default to 2 for two-speaker podcast;
134
+ history_context: int = 4
135
+ history_text_context: int = 4
136
+
137
+ def __post_init__(self):
138
+ assert os.path.isdir(self.model)
139
+
140
+ max_pos = getattr(self.hf_config, "max_position_embeddings", 8192)
141
+ self.max_model_len = min(self.max_model_len, max_pos)
soulxpodcast/engine/__init__.py ADDED
File without changes
soulxpodcast/engine/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
soulxpodcast/engine/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (172 Bytes). View file
 
soulxpodcast/engine/__pycache__/llm_engine.cpython-311.pyc ADDED
Binary file (7.81 kB). View file
 
soulxpodcast/engine/__pycache__/llm_engine.cpython-312.pyc ADDED
Binary file (6.51 kB). View file
 
soulxpodcast/engine/llm_engine.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import atexit
3
+ from dataclasses import fields, asdict
4
+ from time import perf_counter
5
+ import os
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.multiprocessing as mp
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
11
+ from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
12
+ try:
13
+ from vllm import LLM
14
+ from vllm import SamplingParams as VllmSamplingParams
15
+ from vllm.inputs import TokensPrompt as TokensPrompt
16
+ SUPPORT_VLLM = True
17
+ except ImportError:
18
+ SUPPORT_VLLM = False
19
+
20
+ from soulxpodcast.models.modules.sampler import _ras_sample_hf_engine
21
+ from soulxpodcast.config import Config, SamplingParams
22
+
23
+ class HFLLMEngine:
24
+
25
+ def __init__(self, model, **kwargs):
26
+ config_fields = {field.name for field in fields(Config)}
27
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
28
+ config = Config(model, **config_kwargs)
29
+
30
+ self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
31
+ config.eos = config.hf_config.eos_token_id # speech eos token;
32
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
33
+ self.model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16, device_map=self.device)
34
+ self.config = config
35
+ self.pad_token_id = self.tokenizer.pad_token_id
36
+
37
+ def generate(
38
+ self,
39
+ prompt: list[str],
40
+ sampling_param: SamplingParams,
41
+ past_key_values=None,
42
+ ) -> dict:
43
+
44
+ # Recreate stopping_criteria per request for thread safety
45
+ stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.config.hf_config.eos_token_id)])
46
+ if sampling_param.use_ras:
47
+ sample_hf_engine_handler = partial(_ras_sample_hf_engine,
48
+ use_ras=sampling_param.use_ras,
49
+ win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
50
+ else:
51
+ sample_hf_engine_handler = None
52
+ rep_pen_processor = RepetitionPenaltyLogitsProcessor(
53
+ penalty=sampling_param.repetition_penalty,
54
+ prompt_ignore_length=len(prompt)
55
+ ) # exclude the input prompt, consistent with vLLM implementation;
56
+ with torch.no_grad(): # Avoids gradient computation with no_grad
57
+ input_len = len(prompt)
58
+ generated_ids = self.model.generate(
59
+ input_ids = torch.tensor([prompt], dtype=torch.int64).to(self.device),
60
+ do_sample=True,
61
+ top_k=sampling_param.top_k,
62
+ top_p=sampling_param.top_p,
63
+ min_new_tokens=sampling_param.min_tokens,
64
+ max_new_tokens=sampling_param.max_tokens,
65
+ temperature=sampling_param.temperature,
66
+ repetition_penalty=sampling_param.repetition_penalty,
67
+ stopping_criteria=stopping_criteria,
68
+ past_key_values=past_key_values,
69
+ custom_generate=sample_hf_engine_handler,
70
+ use_cache=True,
71
+ logits_processor=[rep_pen_processor]
72
+ )
73
+ generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
74
+ output = {
75
+ "text": self.tokenizer.decode(generated_ids),
76
+ "token_ids": generated_ids,
77
+ }
78
+ return output
79
+
80
+ class VLLMEngine:
81
+
82
+ def __init__(self, model, **kwargs):
83
+
84
+ config_fields = {field.name for field in fields(Config)}
85
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
86
+ config = Config(model, **config_kwargs)
87
+
88
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
89
+ config.eos = config.hf_config.eos_token_id # speech eos token;
90
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
91
+ os.environ["VLLM_USE_V1"] = "0"
92
+ if SUPPORT_VLLM:
93
+ self.model = LLM(model=model, enforce_eager=True, dtype="bfloat16", max_model_len=8192, enable_prefix_caching=True)
94
+ else:
95
+ raise ImportError("Not Support VLLM now!!!")
96
+ self.config = config
97
+ self.pad_token_id = self.tokenizer.pad_token_id
98
+
99
+ def generate(
100
+ self,
101
+ prompt: list[str],
102
+ sampling_param: SamplingParams,
103
+ past_key_values=None,
104
+ ) -> dict:
105
+ sampling_param.stop_token_ids = [self.config.hf_config.eos_token_id]
106
+ with torch.no_grad(): # Avoids gradient computation with no_grad
107
+ generated_ids = self.model.generate(
108
+ TokensPrompt(prompt_token_ids=prompt),
109
+ VllmSamplingParams(**asdict(sampling_param)),
110
+ use_tqdm=False,
111
+ )[0].outputs[0].token_ids
112
+ output = {
113
+ "text": self.tokenizer.decode(generated_ids),
114
+ "token_ids": list(generated_ids),
115
+ }
116
+ return output
soulxpodcast/models/__pycache__/soulxpodcast.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
soulxpodcast/models/__pycache__/soulxpodcast.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
soulxpodcast/models/modules/__init__.py ADDED
File without changes
soulxpodcast/models/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (192 Bytes). View file
 
soulxpodcast/models/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (180 Bytes). View file
 
soulxpodcast/models/modules/__pycache__/flow.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
soulxpodcast/models/modules/__pycache__/flow.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
soulxpodcast/models/modules/__pycache__/hifigan.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
soulxpodcast/models/modules/__pycache__/hifigan.cpython-312.pyc ADDED
Binary file (13.8 kB). View file
 
soulxpodcast/models/modules/__pycache__/sampler.cpython-311.pyc ADDED
Binary file (9.64 kB). View file
 
soulxpodcast/models/modules/__pycache__/sampler.cpython-312.pyc ADDED
Binary file (9.25 kB). View file
 
soulxpodcast/models/modules/flow.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from soulxpodcast.models.modules.flow_components.estimator import \
8
+ CausalConditionalDecoder
9
+ from soulxpodcast.models.modules.flow_components.upsample_encoder import (
10
+ UpsampleConformerEncoder, make_pad_mask)
11
+
12
+
13
+ @dataclass
14
+ class CfmParams:
15
+ sigma_min: float = 1e-6
16
+ solver: str = "euler"
17
+ t_scheduler: str = "cosine"
18
+ training_cfg_rate: float = 0.2
19
+ inference_cfg_rate: float = 0.7
20
+
21
+
22
+ class CausalConditionalCFM(torch.nn.Module):
23
+ def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
24
+ super().__init__()
25
+ self.n_feats = in_channels
26
+ self.n_spks = n_spks
27
+ self.spk_emb_dim = spk_emb_dim
28
+ self.solver = cfm_params.solver
29
+ if hasattr(cfm_params, "sigma_min"):
30
+ self.sigma_min = cfm_params.sigma_min
31
+ else:
32
+ self.sigma_min = 1e-4
33
+ self.t_scheduler = cfm_params.t_scheduler
34
+ self.training_cfg_rate = cfm_params.training_cfg_rate
35
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
36
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
37
+ # Just change the architecture of the estimator here
38
+ self.estimator = CausalConditionalDecoder() if estimator is None else estimator
39
+
40
+ @torch.inference_mode()
41
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
42
+ """Forward diffusion
43
+
44
+ Args:
45
+ mu (torch.Tensor): output of encoder
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ mask (torch.Tensor): output_mask
48
+ shape: (batch_size, 1, mel_timesteps)
49
+ n_timesteps (int): number of diffusion steps
50
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
51
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
52
+ shape: (batch_size, spk_emb_dim)
53
+ cond: Not used but kept for future purposes
54
+
55
+ Returns:
56
+ sample: generated mel-spectrogram
57
+ shape: (batch_size, n_feats, mel_timesteps)
58
+ """
59
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
60
+ # fix prompt and overlap part mu and z
61
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
62
+ if self.t_scheduler == 'cosine':
63
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
64
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
65
+
66
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
67
+ """
68
+ Fixed euler solver for ODEs.
69
+ Args:
70
+ x (torch.Tensor): random noise
71
+ t_span (torch.Tensor): n_timesteps interpolated
72
+ shape: (n_timesteps + 1,)
73
+ mu (torch.Tensor): output of encoder
74
+ shape: (batch_size, n_feats, mel_timesteps)
75
+ mask (torch.Tensor): output_mask
76
+ shape: (batch_size, 1, mel_timesteps)
77
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
78
+ shape: (batch_size, spk_emb_dim)
79
+ cond: Not used but kept for future purposes
80
+ """
81
+ batch_size = x.size(0)
82
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
83
+
84
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
85
+ # Or in future might add like a return_all_steps flag
86
+ sol = []
87
+
88
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
89
+ # Create tensors with double batch size for CFG (conditional + unconditional)
90
+ x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
91
+ mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
92
+ mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
93
+ t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
94
+ spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
95
+ cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
96
+
97
+ for step in range(1, len(t_span)):
98
+ # Classifier-Free Guidance inference introduced in VoiceBox
99
+ # Copy conditional and unconditional input
100
+ x_in[:batch_size] = x
101
+ x_in[batch_size:] = x
102
+ mask_in[:batch_size] = mask
103
+ mask_in[batch_size:] = mask
104
+ mu_in[:batch_size] = mu
105
+ # Unconditional part remains 0
106
+ t_in.fill_(t)
107
+ spks_in[:batch_size] = spks
108
+ cond_in[:batch_size] = cond
109
+
110
+ dphi_dt = self.estimator(
111
+ x_in, mask_in,
112
+ mu_in, t_in,
113
+ spks_in,
114
+ cond_in,
115
+ streaming
116
+ )
117
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
118
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
119
+ x = x + dt * dphi_dt
120
+ t = t + dt
121
+ sol.append(x)
122
+ if step < len(t_span) - 1:
123
+ dt = t_span[step + 1] - t
124
+
125
+ return sol[-1].float()
126
+
127
+
128
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
129
+ def __init__(
130
+ self,
131
+ input_size: int = 512,
132
+ output_size: int = 80,
133
+ spk_embed_dim: int = 192,
134
+ output_type: str = "mel",
135
+ vocab_size: int = 6561,
136
+ input_frame_rate: int = 25,
137
+ token_mel_ratio: int = 2,
138
+ pre_lookahead_len: int = 3,
139
+ encoder: torch.nn.Module = None,
140
+ decoder: torch.nn.Module = None,
141
+ ):
142
+ super().__init__()
143
+ self.input_size = input_size
144
+ self.output_size = output_size
145
+ self.vocab_size = vocab_size
146
+ self.output_type = output_type
147
+ self.input_frame_rate = input_frame_rate
148
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
149
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
150
+ self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
151
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
152
+ self.decoder = CausalConditionalCFM() if decoder is None else decoder
153
+ self.token_mel_ratio = token_mel_ratio
154
+ self.pre_lookahead_len = pre_lookahead_len
155
+
156
+ @torch.inference_mode()
157
+ def forward(self,
158
+ token,
159
+ token_len,
160
+ prompt_feat,
161
+ prompt_feat_len,
162
+ embedding,
163
+ streaming,
164
+ finalize):
165
+ # xvec projection
166
+ embedding = F.normalize(embedding, dim=1)
167
+ embedding = self.spk_embed_affine_layer(embedding)
168
+
169
+ # concat text and prompt_text
170
+ mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
171
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
172
+
173
+ # text encode
174
+ if finalize is True:
175
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
176
+ else:
177
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
178
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
179
+ h = self.encoder_proj(h)
180
+
181
+ # get conditions
182
+ conds = torch.zeros_like(h, device=token.device)
183
+ for i, j in enumerate(prompt_feat_len):
184
+ conds[i, :j] = prompt_feat[i, :j]
185
+ conds = conds.transpose(1, 2)
186
+
187
+ h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
188
+ mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
189
+ feat, _ = self.decoder(
190
+ mu=h.transpose(1, 2).contiguous(),
191
+ mask=mask.unsqueeze(1),
192
+ spks=embedding,
193
+ cond=conds,
194
+ n_timesteps=15,
195
+ streaming=streaming
196
+ ) # [B, num_mels, T]
197
+ return feat.float(), h_lengths
soulxpodcast/models/modules/flow_components/__init__.py ADDED
File without changes
soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (208 Bytes). View file
 
soulxpodcast/models/modules/flow_components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (196 Bytes). View file
 
soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-311.pyc ADDED
Binary file (49.4 kB). View file
 
soulxpodcast/models/modules/flow_components/__pycache__/estimator.cpython-312.pyc ADDED
Binary file (43.6 kB). View file
 
soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-311.pyc ADDED
Binary file (52 kB). View file
 
soulxpodcast/models/modules/flow_components/__pycache__/upsample_encoder.cpython-312.pyc ADDED
Binary file (49.6 kB). View file
 
soulxpodcast/models/modules/flow_components/estimator.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
8
+ AdaLayerNormZero, ApproximateGELU)
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.lora import LoRACompatibleLinear
11
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
12
+ from einops import pack, rearrange, repeat
13
+
14
+ from soulxpodcast.models.modules.flow_components.upsample_encoder import \
15
+ add_optional_chunk_mask
16
+
17
+
18
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
19
+ assert mask.dtype == torch.bool
20
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
21
+ mask = mask.to(dtype)
22
+ # attention mask bias
23
+ # NOTE(Mddct): torch.finfo jit issues
24
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
25
+ mask = (1.0 - mask) * -1.0e+10
26
+ return mask
27
+
28
+
29
+ class SnakeBeta(nn.Module):
30
+ """
31
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
32
+ Shape:
33
+ - Input: (B, C, T)
34
+ - Output: (B, C, T), same shape as the input
35
+ Parameters:
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ References:
39
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
40
+ https://arxiv.org/abs/2006.08195
41
+ Examples:
42
+ >>> a1 = snakebeta(256)
43
+ >>> x = torch.randn(256)
44
+ >>> x = a1(x)
45
+
46
+ Args:
47
+ in_features: shape of the input
48
+ out_features: shape of the output
49
+ alpha: trainable parameter that controls frequency
50
+ alpha_trainable: whether alpha is trainable
51
+ alpha_logscale: whether to use log scale for alpha
52
+ alpha is initialized to 1 by default, higher values = higher-frequency.
53
+ beta is initialized to 1 by default, higher values = higher-magnitude.
54
+ alpha will be trained along with the rest of your model.
55
+ """
56
+
57
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
58
+ super().__init__()
59
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
60
+ self.proj = LoRACompatibleLinear(in_features, out_features)
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
66
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
67
+ else: # linear scale alphas initialized to ones
68
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
69
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
70
+
71
+ self.alpha.requires_grad = alpha_trainable
72
+ self.beta.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ """
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
81
+ """
82
+ x = self.proj(x)
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(self.alpha)
85
+ beta = torch.exp(self.beta)
86
+ else:
87
+ alpha = self.alpha
88
+ beta = self.beta
89
+
90
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
91
+
92
+ return x
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ r"""
97
+ A feed-forward layer.
98
+
99
+ Parameters:
100
+ dim (`int`): The number of channels in the input.
101
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
102
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
103
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
104
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
105
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ dim_out: Optional[int] = None,
112
+ mult: int = 4,
113
+ dropout: float = 0.0,
114
+ activation_fn: str = "geglu",
115
+ final_dropout: bool = False,
116
+ ):
117
+ super().__init__()
118
+ inner_dim = int(dim * mult)
119
+ dim_out = dim_out if dim_out is not None else dim
120
+
121
+ if activation_fn == "gelu":
122
+ act_fn = GELU(dim, inner_dim)
123
+ if activation_fn == "gelu-approximate":
124
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
125
+ elif activation_fn == "geglu":
126
+ act_fn = GEGLU(dim, inner_dim)
127
+ elif activation_fn == "geglu-approximate":
128
+ act_fn = ApproximateGELU(dim, inner_dim)
129
+ elif activation_fn == "snakebeta":
130
+ act_fn = SnakeBeta(dim, inner_dim)
131
+
132
+ self.net = nn.ModuleList([])
133
+ # project in
134
+ self.net.append(act_fn)
135
+ # project dropout
136
+ self.net.append(nn.Dropout(dropout))
137
+ # project out
138
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
139
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
140
+ if final_dropout:
141
+ self.net.append(nn.Dropout(dropout))
142
+
143
+ def forward(self, hidden_states):
144
+ for module in self.net:
145
+ hidden_states = module(hidden_states)
146
+ return hidden_states
147
+
148
+
149
+ @maybe_allow_in_graph
150
+ class BasicTransformerBlock(nn.Module):
151
+ r"""
152
+ A basic Transformer block.
153
+
154
+ Parameters:
155
+ dim (`int`): The number of channels in the input and output.
156
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
157
+ attention_head_dim (`int`): The number of channels in each head.
158
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
159
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
160
+ only_cross_attention (`bool`, *optional*):
161
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
162
+ double_self_attention (`bool`, *optional*):
163
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
164
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
165
+ num_embeds_ada_norm (:
166
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
167
+ attention_bias (:
168
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ num_attention_heads: int,
175
+ attention_head_dim: int,
176
+ dropout=0.0,
177
+ cross_attention_dim: Optional[int] = None,
178
+ activation_fn: str = "geglu",
179
+ num_embeds_ada_norm: Optional[int] = None,
180
+ attention_bias: bool = False,
181
+ only_cross_attention: bool = False,
182
+ double_self_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+ norm_elementwise_affine: bool = True,
185
+ norm_type: str = "layer_norm",
186
+ final_dropout: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.only_cross_attention = only_cross_attention
190
+
191
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
192
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
193
+
194
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
195
+ raise ValueError(
196
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
197
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
198
+ )
199
+
200
+ # Define 3 blocks. Each block has its own normalization layer.
201
+ # 1. Self-Attn
202
+ if self.use_ada_layer_norm:
203
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
204
+ elif self.use_ada_layer_norm_zero:
205
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
206
+ else:
207
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
208
+ self.attn1 = Attention(
209
+ query_dim=dim,
210
+ heads=num_attention_heads,
211
+ dim_head=attention_head_dim,
212
+ dropout=dropout,
213
+ bias=attention_bias,
214
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
215
+ upcast_attention=upcast_attention,
216
+ )
217
+
218
+ # 2. Cross-Attn
219
+ if cross_attention_dim is not None or double_self_attention:
220
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
221
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
222
+ # the second cross attention block.
223
+ self.norm2 = (
224
+ AdaLayerNorm(dim, num_embeds_ada_norm)
225
+ if self.use_ada_layer_norm
226
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
227
+ )
228
+ self.attn2 = Attention(
229
+ query_dim=dim,
230
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
231
+ heads=num_attention_heads,
232
+ dim_head=attention_head_dim,
233
+ dropout=dropout,
234
+ bias=attention_bias,
235
+ upcast_attention=upcast_attention,
236
+ # scale_qk=False, # uncomment this to not to use flash attention
237
+ ) # is self-attn if encoder_hidden_states is none
238
+ else:
239
+ self.norm2 = None
240
+ self.attn2 = None
241
+
242
+ # 3. Feed-forward
243
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
244
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
245
+
246
+ # let chunk size default to None
247
+ self._chunk_size = None
248
+ self._chunk_dim = 0
249
+
250
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
251
+ # Sets chunk feed-forward
252
+ self._chunk_size = chunk_size
253
+ self._chunk_dim = dim
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.FloatTensor,
258
+ attention_mask: Optional[torch.FloatTensor] = None,
259
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
260
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
261
+ timestep: Optional[torch.LongTensor] = None,
262
+ cross_attention_kwargs: Dict[str, Any] = None,
263
+ class_labels: Optional[torch.LongTensor] = None,
264
+ ):
265
+ # Notice that normalization is always applied before the real computation in the following blocks.
266
+ # 1. Self-Attention
267
+ if self.use_ada_layer_norm:
268
+ norm_hidden_states = self.norm1(hidden_states, timestep)
269
+ elif self.use_ada_layer_norm_zero:
270
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
271
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
272
+ )
273
+ else:
274
+ norm_hidden_states = self.norm1(hidden_states)
275
+
276
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
277
+
278
+ attn_output = self.attn1(
279
+ norm_hidden_states,
280
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
281
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
282
+ **cross_attention_kwargs,
283
+ )
284
+ if self.use_ada_layer_norm_zero:
285
+ attn_output = gate_msa.unsqueeze(1) * attn_output
286
+ hidden_states = attn_output + hidden_states
287
+
288
+ # 2. Cross-Attention
289
+ if self.attn2 is not None:
290
+ norm_hidden_states = (
291
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
292
+ )
293
+
294
+ attn_output = self.attn2(
295
+ norm_hidden_states,
296
+ encoder_hidden_states=encoder_hidden_states,
297
+ attention_mask=encoder_attention_mask,
298
+ **cross_attention_kwargs,
299
+ )
300
+ hidden_states = attn_output + hidden_states
301
+
302
+ # 3. Feed-forward
303
+ norm_hidden_states = self.norm3(hidden_states)
304
+
305
+ if self.use_ada_layer_norm_zero:
306
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
307
+
308
+ if self._chunk_size is not None:
309
+ # "feed_forward_chunk_size" can be used to save memory
310
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
311
+ raise ValueError(
312
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
313
+ )
314
+
315
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
316
+ ff_output = torch.cat(
317
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
318
+ dim=self._chunk_dim,
319
+ )
320
+ else:
321
+ ff_output = self.ff(norm_hidden_states)
322
+
323
+ if self.use_ada_layer_norm_zero:
324
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
325
+
326
+ hidden_states = ff_output + hidden_states
327
+
328
+ return hidden_states
329
+
330
+
331
+ class SinusoidalPosEmb(torch.nn.Module):
332
+ def __init__(self, dim):
333
+ super().__init__()
334
+ self.dim = dim
335
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
336
+
337
+ def forward(self, x, scale=1000):
338
+ if x.ndim < 1:
339
+ x = x.unsqueeze(0)
340
+ device = x.device
341
+ half_dim = self.dim // 2
342
+ emb = math.log(10000) / (half_dim - 1)
343
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
344
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
345
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
346
+ return emb
347
+
348
+
349
+ class Block1D(torch.nn.Module):
350
+ def __init__(self, dim, dim_out, groups=8):
351
+ super().__init__()
352
+ self.block = torch.nn.Sequential(
353
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
354
+ torch.nn.GroupNorm(groups, dim_out),
355
+ nn.Mish(),
356
+ )
357
+
358
+ def forward(self, x, mask):
359
+ output = self.block(x * mask)
360
+ return output * mask
361
+
362
+
363
+ class ResnetBlock1D(torch.nn.Module):
364
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
365
+ super().__init__()
366
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
367
+
368
+ self.block1 = Block1D(dim, dim_out, groups=groups)
369
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
370
+
371
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
372
+
373
+ def forward(self, x, mask, time_emb):
374
+ h = self.block1(x, mask)
375
+ h += self.mlp(time_emb).unsqueeze(-1)
376
+ h = self.block2(h, mask)
377
+ output = h + self.res_conv(x * mask)
378
+ return output
379
+
380
+
381
+ class Downsample1D(nn.Module):
382
+ def __init__(self, dim):
383
+ super().__init__()
384
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
385
+
386
+ def forward(self, x):
387
+ return self.conv(x)
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(
392
+ self,
393
+ in_channels: int,
394
+ time_embed_dim: int,
395
+ act_fn: str = "silu",
396
+ out_dim: int = None,
397
+ post_act_fn: Optional[str] = None,
398
+ cond_proj_dim=None,
399
+ ):
400
+ super().__init__()
401
+ assert act_fn == "silu", "act_fn must be silu"
402
+
403
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
404
+
405
+ if cond_proj_dim is not None:
406
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
407
+ else:
408
+ self.cond_proj = None
409
+
410
+ self.act = nn.SiLU()
411
+
412
+ if out_dim is not None:
413
+ time_embed_dim_out = out_dim
414
+ else:
415
+ time_embed_dim_out = time_embed_dim
416
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
417
+
418
+ if post_act_fn is None:
419
+ self.post_act = None
420
+ else:
421
+ self.post_act = nn.SiLU()
422
+
423
+ def forward(self, sample, condition=None):
424
+ if condition is not None:
425
+ sample = sample + self.cond_proj(condition)
426
+ sample = self.linear_1(sample)
427
+
428
+ if self.act is not None:
429
+ sample = self.act(sample)
430
+
431
+ sample = self.linear_2(sample)
432
+
433
+ if self.post_act is not None:
434
+ sample = self.post_act(sample)
435
+ return sample
436
+
437
+
438
+ class Upsample1D(nn.Module):
439
+ """A 1D upsampling layer with an optional convolution.
440
+
441
+ Parameters:
442
+ channels (`int`):
443
+ number of channels in the inputs and outputs.
444
+ use_conv (`bool`, default `False`):
445
+ option to use a convolution.
446
+ use_conv_transpose (`bool`, default `False`):
447
+ option to use a convolution transpose.
448
+ out_channels (`int`, optional):
449
+ number of output channels. Defaults to `channels`.
450
+ """
451
+
452
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
453
+ super().__init__()
454
+ self.channels = channels
455
+ self.out_channels = out_channels or channels
456
+ self.use_conv = use_conv
457
+ self.use_conv_transpose = use_conv_transpose
458
+ self.name = name
459
+
460
+ self.conv = None
461
+ if use_conv_transpose:
462
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
463
+ elif use_conv:
464
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
465
+
466
+ def forward(self, inputs):
467
+ assert inputs.shape[1] == self.channels
468
+ if self.use_conv_transpose:
469
+ return self.conv(inputs)
470
+
471
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
472
+
473
+ if self.use_conv:
474
+ outputs = self.conv(outputs)
475
+
476
+ return outputs
477
+
478
+
479
+ class Transpose(torch.nn.Module):
480
+ def __init__(self, dim0: int, dim1: int):
481
+ super().__init__()
482
+ self.dim0 = dim0
483
+ self.dim1 = dim1
484
+
485
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
486
+ x = torch.transpose(x, self.dim0, self.dim1)
487
+ return x
488
+
489
+
490
+ class CausalConv1d(torch.nn.Conv1d):
491
+ def __init__(
492
+ self,
493
+ in_channels: int,
494
+ out_channels: int,
495
+ kernel_size: int,
496
+ stride: int = 1,
497
+ dilation: int = 1,
498
+ groups: int = 1,
499
+ bias: bool = True,
500
+ padding_mode: str = 'zeros',
501
+ device=None,
502
+ dtype=None
503
+ ) -> None:
504
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
505
+ kernel_size, stride,
506
+ padding=0, dilation=dilation,
507
+ groups=groups, bias=bias,
508
+ padding_mode=padding_mode,
509
+ device=device, dtype=dtype)
510
+ assert stride == 1
511
+ self.causal_padding = kernel_size - 1
512
+
513
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
514
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
515
+ x = super(CausalConv1d, self).forward(x)
516
+ return x
517
+
518
+
519
+ class CausalBlock1D(Block1D):
520
+ def __init__(self, dim: int, dim_out: int):
521
+ super(CausalBlock1D, self).__init__(dim, dim_out)
522
+ self.block = torch.nn.Sequential(
523
+ CausalConv1d(dim, dim_out, 3),
524
+ Transpose(1, 2),
525
+ nn.LayerNorm(dim_out),
526
+ Transpose(1, 2),
527
+ nn.Mish(),
528
+ )
529
+
530
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ output = self.block(x * mask)
532
+ return output * mask
533
+
534
+
535
+ class CausalResnetBlock1D(ResnetBlock1D):
536
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
537
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
538
+ self.block1 = CausalBlock1D(dim, dim_out)
539
+ self.block2 = CausalBlock1D(dim_out, dim_out)
540
+
541
+
542
+ class ConditionalDecoder(nn.Module):
543
+ """
544
+ This decoder requires an input with the same shape of the target. So, if your text content
545
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
546
+
547
+ Args:
548
+ in_channels: number of input channels
549
+ out_channels: number of output channels
550
+ channels: tuple of channel dimensions
551
+ dropout: dropout rate
552
+ attention_head_dim: dimension of attention heads
553
+ n_blocks: number of transformer blocks
554
+ num_mid_blocks: number of middle blocks
555
+ num_heads: number of attention heads
556
+ act_fn: activation function name
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ in_channels,
562
+ out_channels,
563
+ channels=(256, 256),
564
+ dropout=0.05,
565
+ attention_head_dim=64,
566
+ n_blocks=1,
567
+ num_mid_blocks=2,
568
+ num_heads=4,
569
+ act_fn="snake",
570
+ ):
571
+ super().__init__()
572
+ channels = tuple(channels)
573
+ self.in_channels = in_channels
574
+ self.out_channels = out_channels
575
+
576
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
577
+ time_embed_dim = channels[0] * 4
578
+ self.time_mlp = TimestepEmbedding(
579
+ in_channels=in_channels,
580
+ time_embed_dim=time_embed_dim,
581
+ act_fn="silu",
582
+ )
583
+ self.down_blocks = nn.ModuleList([])
584
+ self.mid_blocks = nn.ModuleList([])
585
+ self.up_blocks = nn.ModuleList([])
586
+
587
+ output_channel = in_channels
588
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
589
+ input_channel = output_channel
590
+ output_channel = channels[i]
591
+ is_last = i == len(channels) - 1
592
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
593
+ transformer_blocks = nn.ModuleList(
594
+ [
595
+ BasicTransformerBlock(
596
+ dim=output_channel,
597
+ num_attention_heads=num_heads,
598
+ attention_head_dim=attention_head_dim,
599
+ dropout=dropout,
600
+ activation_fn=act_fn,
601
+ )
602
+ for _ in range(n_blocks)
603
+ ]
604
+ )
605
+ downsample = (
606
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
607
+ )
608
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
609
+
610
+ for _ in range(num_mid_blocks):
611
+ input_channel = channels[-1]
612
+ out_channels = channels[-1]
613
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
614
+
615
+ transformer_blocks = nn.ModuleList(
616
+ [
617
+ BasicTransformerBlock(
618
+ dim=output_channel,
619
+ num_attention_heads=num_heads,
620
+ attention_head_dim=attention_head_dim,
621
+ dropout=dropout,
622
+ activation_fn=act_fn,
623
+ )
624
+ for _ in range(n_blocks)
625
+ ]
626
+ )
627
+
628
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
629
+
630
+ channels = channels[::-1] + (channels[0],)
631
+ for i in range(len(channels) - 1):
632
+ input_channel = channels[i] * 2
633
+ output_channel = channels[i + 1]
634
+ is_last = i == len(channels) - 2
635
+ resnet = ResnetBlock1D(
636
+ dim=input_channel,
637
+ dim_out=output_channel,
638
+ time_emb_dim=time_embed_dim,
639
+ )
640
+ transformer_blocks = nn.ModuleList(
641
+ [
642
+ BasicTransformerBlock(
643
+ dim=output_channel,
644
+ num_attention_heads=num_heads,
645
+ attention_head_dim=attention_head_dim,
646
+ dropout=dropout,
647
+ activation_fn=act_fn,
648
+ )
649
+ for _ in range(n_blocks)
650
+ ]
651
+ )
652
+ upsample = (
653
+ Upsample1D(output_channel, use_conv_transpose=True)
654
+ if not is_last
655
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
656
+ )
657
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
658
+ self.final_block = Block1D(channels[-1], channels[-1])
659
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
660
+ self.initialize_weights()
661
+
662
+ def initialize_weights(self):
663
+ for m in self.modules():
664
+ if isinstance(m, nn.Conv1d):
665
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
666
+ if m.bias is not None:
667
+ nn.init.constant_(m.bias, 0)
668
+ elif isinstance(m, nn.GroupNorm):
669
+ nn.init.constant_(m.weight, 1)
670
+ nn.init.constant_(m.bias, 0)
671
+ elif isinstance(m, nn.Linear):
672
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
673
+ if m.bias is not None:
674
+ nn.init.constant_(m.bias, 0)
675
+
676
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
677
+ """Forward pass of the UNet1DConditional model.
678
+
679
+ Args:
680
+ x (torch.Tensor): shape (batch_size, in_channels, time)
681
+ mask (_type_): shape (batch_size, 1, time)
682
+ t (_type_): shape (batch_size)
683
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
684
+ cond (_type_, optional): placeholder for future use. Defaults to None.
685
+
686
+ Raises:
687
+ ValueError: _description_
688
+ ValueError: _description_
689
+
690
+ Returns:
691
+ _type_: _description_
692
+ """
693
+
694
+ t = self.time_embeddings(t).to(t.dtype)
695
+ t = self.time_mlp(t)
696
+
697
+ x = pack([x, mu], "b * t")[0]
698
+
699
+ if spks is not None:
700
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
701
+ x = pack([x, spks], "b * t")[0]
702
+ if cond is not None:
703
+ x = pack([x, cond], "b * t")[0]
704
+
705
+ hiddens = []
706
+ masks = [mask]
707
+ for resnet, transformer_blocks, downsample in self.down_blocks:
708
+ mask_down = masks[-1]
709
+ x = resnet(x, mask_down, t)
710
+ x = rearrange(x, "b c t -> b t c").contiguous()
711
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
712
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
713
+ for transformer_block in transformer_blocks:
714
+ x = transformer_block(
715
+ hidden_states=x,
716
+ attention_mask=attn_mask,
717
+ timestep=t,
718
+ )
719
+ x = rearrange(x, "b t c -> b c t").contiguous()
720
+ hiddens.append(x) # Save hidden states for skip connections
721
+ x = downsample(x * mask_down)
722
+ masks.append(mask_down[:, :, ::2])
723
+ masks = masks[:-1]
724
+ mask_mid = masks[-1]
725
+
726
+ for resnet, transformer_blocks in self.mid_blocks:
727
+ x = resnet(x, mask_mid, t)
728
+ x = rearrange(x, "b c t -> b t c").contiguous()
729
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
730
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
731
+ for transformer_block in transformer_blocks:
732
+ x = transformer_block(
733
+ hidden_states=x,
734
+ attention_mask=attn_mask,
735
+ timestep=t,
736
+ )
737
+ x = rearrange(x, "b t c -> b c t").contiguous()
738
+
739
+ for resnet, transformer_blocks, upsample in self.up_blocks:
740
+ mask_up = masks.pop()
741
+ skip = hiddens.pop()
742
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
743
+ x = resnet(x, mask_up, t)
744
+ x = rearrange(x, "b c t -> b t c").contiguous()
745
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
746
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
747
+ for transformer_block in transformer_blocks:
748
+ x = transformer_block(
749
+ hidden_states=x,
750
+ attention_mask=attn_mask,
751
+ timestep=t,
752
+ )
753
+ x = rearrange(x, "b t c -> b c t").contiguous()
754
+ x = upsample(x * mask_up)
755
+ x = self.final_block(x, mask_up)
756
+ output = self.final_proj(x * mask_up)
757
+ return output * mask
758
+
759
+
760
+ class CausalConditionalDecoder(ConditionalDecoder):
761
+ """
762
+ This decoder requires an input with the same shape of the target. So, if your text content
763
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
764
+
765
+ Args:
766
+ in_channels: number of input channels
767
+ out_channels: number of output channels
768
+ channels: list of channel dimensions
769
+ dropout: dropout rate
770
+ attention_head_dim: dimension of attention heads
771
+ n_blocks: number of transformer blocks
772
+ num_mid_blocks: number of middle blocks
773
+ num_heads: number of attention heads
774
+ act_fn: activation function name
775
+ static_chunk_size: size of static chunks
776
+ num_decoding_left_chunks: number of left chunks for decoding
777
+ """
778
+
779
+ def __init__(
780
+ self,
781
+ in_channels=320,
782
+ out_channels=80,
783
+ channels=[256], # noqa
784
+ dropout=0.0,
785
+ attention_head_dim=64,
786
+ n_blocks=4,
787
+ num_mid_blocks=12,
788
+ num_heads=8,
789
+ act_fn="gelu",
790
+ static_chunk_size=50,
791
+ num_decoding_left_chunks=-1,
792
+ ):
793
+ torch.nn.Module.__init__(self)
794
+ channels = tuple(channels)
795
+ self.in_channels = in_channels
796
+ self.out_channels = out_channels
797
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
798
+ time_embed_dim = channels[0] * 4
799
+ self.time_mlp = TimestepEmbedding(
800
+ in_channels=in_channels,
801
+ time_embed_dim=time_embed_dim,
802
+ act_fn="silu",
803
+ )
804
+ self.static_chunk_size = static_chunk_size
805
+ self.num_decoding_left_chunks = num_decoding_left_chunks
806
+ self.down_blocks = nn.ModuleList([])
807
+ self.mid_blocks = nn.ModuleList([])
808
+ self.up_blocks = nn.ModuleList([])
809
+
810
+ output_channel = in_channels
811
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
812
+ input_channel = output_channel
813
+ output_channel = channels[i]
814
+ is_last = i == len(channels) - 1
815
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
816
+ transformer_blocks = nn.ModuleList(
817
+ [
818
+ BasicTransformerBlock(
819
+ dim=output_channel,
820
+ num_attention_heads=num_heads,
821
+ attention_head_dim=attention_head_dim,
822
+ dropout=dropout,
823
+ activation_fn=act_fn,
824
+ )
825
+ for _ in range(n_blocks)
826
+ ]
827
+ )
828
+ downsample = (
829
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
830
+ )
831
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
832
+
833
+ for _ in range(num_mid_blocks):
834
+ input_channel = channels[-1]
835
+ out_channels = channels[-1]
836
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
837
+
838
+ transformer_blocks = nn.ModuleList(
839
+ [
840
+ BasicTransformerBlock(
841
+ dim=output_channel,
842
+ num_attention_heads=num_heads,
843
+ attention_head_dim=attention_head_dim,
844
+ dropout=dropout,
845
+ activation_fn=act_fn,
846
+ )
847
+ for _ in range(n_blocks)
848
+ ]
849
+ )
850
+
851
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
852
+
853
+ channels = channels[::-1] + (channels[0],)
854
+ for i in range(len(channels) - 1):
855
+ input_channel = channels[i] * 2
856
+ output_channel = channels[i + 1]
857
+ is_last = i == len(channels) - 2
858
+ resnet = CausalResnetBlock1D(
859
+ dim=input_channel,
860
+ dim_out=output_channel,
861
+ time_emb_dim=time_embed_dim,
862
+ )
863
+ transformer_blocks = nn.ModuleList(
864
+ [
865
+ BasicTransformerBlock(
866
+ dim=output_channel,
867
+ num_attention_heads=num_heads,
868
+ attention_head_dim=attention_head_dim,
869
+ dropout=dropout,
870
+ activation_fn=act_fn,
871
+ )
872
+ for _ in range(n_blocks)
873
+ ]
874
+ )
875
+ upsample = (
876
+ Upsample1D(output_channel, use_conv_transpose=True)
877
+ if not is_last
878
+ else CausalConv1d(output_channel, output_channel, 3)
879
+ )
880
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
881
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
882
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
883
+ self.initialize_weights()
884
+
885
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
886
+ """Forward pass of the UNet1DConditional model.
887
+
888
+ Args:
889
+ x (torch.Tensor): shape (batch_size, in_channels, time)
890
+ mask (_type_): shape (batch_size, 1, time)
891
+ t (_type_): shape (batch_size)
892
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
893
+ cond (_type_, optional): placeholder for future use. Defaults to None.
894
+
895
+ Raises:
896
+ ValueError: _description_
897
+ ValueError: _description_
898
+
899
+ Returns:
900
+ _type_: _description_
901
+ """
902
+ t = self.time_embeddings(t).to(t.dtype)
903
+ t = self.time_mlp(t)
904
+
905
+ x = pack([x, mu], "b * t")[0]
906
+
907
+ if spks is not None:
908
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
909
+ x = pack([x, spks], "b * t")[0]
910
+ if cond is not None:
911
+ x = pack([x, cond], "b * t")[0]
912
+
913
+ hiddens = []
914
+ masks = [mask]
915
+ for resnet, transformer_blocks, downsample in self.down_blocks:
916
+ mask_down = masks[-1]
917
+ x = resnet(x, mask_down, t)
918
+ x = rearrange(x, "b c t -> b t c").contiguous()
919
+ if streaming is True:
920
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
921
+ else:
922
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
923
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
924
+ for transformer_block in transformer_blocks:
925
+ x = transformer_block(
926
+ hidden_states=x,
927
+ attention_mask=attn_mask,
928
+ timestep=t,
929
+ )
930
+ x = rearrange(x, "b t c -> b c t").contiguous()
931
+ hiddens.append(x) # Save hidden states for skip connections
932
+ x = downsample(x * mask_down)
933
+ masks.append(mask_down[:, :, ::2])
934
+ masks = masks[:-1]
935
+ mask_mid = masks[-1]
936
+
937
+ for resnet, transformer_blocks in self.mid_blocks:
938
+ x = resnet(x, mask_mid, t)
939
+ x = rearrange(x, "b c t -> b t c").contiguous()
940
+ if streaming is True:
941
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
942
+ else:
943
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
944
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
945
+ for transformer_block in transformer_blocks:
946
+ x = transformer_block(
947
+ hidden_states=x,
948
+ attention_mask=attn_mask,
949
+ timestep=t,
950
+ )
951
+ x = rearrange(x, "b t c -> b c t").contiguous()
952
+
953
+ for resnet, transformer_blocks, upsample in self.up_blocks:
954
+ mask_up = masks.pop()
955
+ skip = hiddens.pop()
956
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
957
+ x = resnet(x, mask_up, t)
958
+ x = rearrange(x, "b c t -> b t c").contiguous()
959
+ if streaming is True:
960
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
961
+ else:
962
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
963
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
964
+ for transformer_block in transformer_blocks:
965
+ x = transformer_block(
966
+ hidden_states=x,
967
+ attention_mask=attn_mask,
968
+ timestep=t,
969
+ )
970
+ x = rearrange(x, "b t c -> b c t").contiguous()
971
+ x = upsample(x * mask_up)
972
+ x = self.final_block(x, mask_up)
973
+ output = self.final_proj(x * mask_up)
974
+ return output * mask
soulxpodcast/models/modules/flow_components/upsample_encoder.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def subsequent_chunk_mask(
10
+ size: int,
11
+ chunk_size: int,
12
+ num_left_chunks: int = -1,
13
+ device: torch.device = torch.device("cpu"),
14
+ ) -> torch.Tensor:
15
+ """Create mask for subsequent steps (size, size) with chunk size,
16
+ this is for streaming encoder
17
+
18
+ Args:
19
+ size (int): size of mask
20
+ chunk_size (int): size of chunk
21
+ num_left_chunks (int): number of left chunks
22
+ <0: use full chunk
23
+ >=0: use num_left_chunks
24
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
25
+
26
+ Returns:
27
+ torch.Tensor: mask
28
+
29
+ Examples:
30
+ >>> subsequent_chunk_mask(4, 2)
31
+ [[1, 1, 0, 0],
32
+ [1, 1, 0, 0],
33
+ [1, 1, 1, 1],
34
+ [1, 1, 1, 1]]
35
+ """
36
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
37
+ pos_idx = torch.arange(size, device=device)
38
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
39
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
40
+ return ret
41
+
42
+
43
+ def add_optional_chunk_mask(xs: torch.Tensor,
44
+ masks: torch.Tensor,
45
+ use_dynamic_chunk: bool,
46
+ use_dynamic_left_chunk: bool,
47
+ decoding_chunk_size: int,
48
+ static_chunk_size: int,
49
+ num_decoding_left_chunks: int,
50
+ enable_full_context: bool = True):
51
+ """ Apply optional mask for encoder.
52
+
53
+ Args:
54
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
55
+ mask (torch.Tensor): mask for xs, (B, 1, L)
56
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
57
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
58
+ training.
59
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
60
+ 0: default for training, use random dynamic chunk.
61
+ <0: for decoding, use full chunk.
62
+ >0: for decoding, use fixed chunk size as set.
63
+ static_chunk_size (int): chunk size for static chunk training/decoding
64
+ if it's greater than 0, if use_dynamic_chunk is true,
65
+ this parameter will be ignored
66
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
67
+ the chunk size is decoding_chunk_size.
68
+ >=0: use num_decoding_left_chunks
69
+ <0: use all left chunks
70
+ enable_full_context (bool):
71
+ True: chunk size is either [1, 25] or full context(max_len)
72
+ False: chunk size ~ U[1, 25]
73
+
74
+ Returns:
75
+ torch.Tensor: chunk mask of the input xs.
76
+ """
77
+ # Whether to use chunk mask or not
78
+ if use_dynamic_chunk:
79
+ max_len = xs.size(1)
80
+ if decoding_chunk_size < 0:
81
+ chunk_size = max_len
82
+ num_left_chunks = -1
83
+ elif decoding_chunk_size > 0:
84
+ chunk_size = decoding_chunk_size
85
+ num_left_chunks = num_decoding_left_chunks
86
+ else:
87
+ # chunk size is either [1, 25] or full context(max_len).
88
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
89
+ # delay, the maximum frame is 100 / 4 = 25.
90
+ chunk_size = torch.randint(1, max_len, (1, )).item()
91
+ num_left_chunks = -1
92
+ if chunk_size > max_len // 2 and enable_full_context:
93
+ chunk_size = max_len
94
+ else:
95
+ chunk_size = chunk_size % 25 + 1
96
+ if use_dynamic_left_chunk:
97
+ max_left_chunks = (max_len - 1) // chunk_size
98
+ num_left_chunks = torch.randint(0, max_left_chunks,
99
+ (1, )).item()
100
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
101
+ num_left_chunks,
102
+ xs.device) # (L, L)
103
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
104
+ chunk_masks = masks & chunk_masks # (B, L, L)
105
+ elif static_chunk_size > 0:
106
+ num_left_chunks = num_decoding_left_chunks
107
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
108
+ num_left_chunks,
109
+ xs.device) # (L, L)
110
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
111
+ chunk_masks = masks & chunk_masks # (B, L, L)
112
+ else:
113
+ chunk_masks = masks
114
+ assert chunk_masks.dtype == torch.bool
115
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
116
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
117
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
118
+ return chunk_masks
119
+
120
+
121
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
122
+ """Make mask tensor containing indices of padded part.
123
+
124
+ See description of make_non_pad_mask.
125
+
126
+ Args:
127
+ lengths (torch.Tensor): Batch of lengths (B,).
128
+ Returns:
129
+ torch.Tensor: Mask tensor containing indices of padded part.
130
+
131
+ Examples:
132
+ >>> lengths = [5, 3, 2]
133
+ >>> make_pad_mask(lengths)
134
+ masks = [[0, 0, 0, 0 ,0],
135
+ [0, 0, 0, 1, 1],
136
+ [0, 0, 1, 1, 1]]
137
+ """
138
+ batch_size = lengths.size(0)
139
+ max_len = max_len if max_len > 0 else lengths.max().item()
140
+ seq_range = torch.arange(0,
141
+ max_len,
142
+ dtype=torch.int64,
143
+ device=lengths.device)
144
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
145
+ seq_length_expand = lengths.unsqueeze(-1)
146
+ mask = seq_range_expand >= seq_length_expand
147
+ return mask
148
+
149
+
150
+ class EspnetRelPositionalEncoding(torch.nn.Module):
151
+ """Relative positional encoding module (new implementation).
152
+
153
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
154
+
155
+ See : Appendix B in https://arxiv.org/abs/1901.02860
156
+
157
+ Args:
158
+ d_model (int): Embedding dimension.
159
+ max_len (int): Maximum input length.
160
+
161
+ """
162
+
163
+ def __init__(self, d_model: int, max_len: int = 5000):
164
+ super(EspnetRelPositionalEncoding, self).__init__()
165
+ self.d_model = d_model
166
+ self.xscale = math.sqrt(self.d_model)
167
+ self.pe = None
168
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
169
+
170
+ def extend_pe(self, x: torch.Tensor):
171
+ """Reset the positional encodings."""
172
+ if self.pe is not None:
173
+ # self.pe contains both positive and negative parts
174
+ # the length of self.pe is 2 * input_len - 1
175
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
176
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
177
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
178
+ return
179
+ # Suppose `i` means to the position of query vecotr and `j` means the
180
+ # position of key vector. We use position relative positions when keys
181
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
182
+ pe_positive = torch.zeros(x.size(1), self.d_model)
183
+ pe_negative = torch.zeros(x.size(1), self.d_model)
184
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
185
+ div_term = torch.exp(
186
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
187
+ * -(math.log(10000.0) / self.d_model)
188
+ )
189
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
190
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
191
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
192
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
193
+
194
+ # Reserve the order of positive indices and concat both positive and
195
+ # negative indices. This is used to support the shifting trick
196
+ # as in https://arxiv.org/abs/1901.02860
197
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
198
+ pe_negative = pe_negative[1:].unsqueeze(0)
199
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
200
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
201
+
202
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
203
+ -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """Add positional encoding.
205
+
206
+ Args:
207
+ x (torch.Tensor): Input tensor (batch, time, `*`).
208
+
209
+ Returns:
210
+ torch.Tensor: Encoded tensor (batch, time, `*`).
211
+
212
+ """
213
+ self.extend_pe(x)
214
+ x = x * self.xscale
215
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
216
+ return x, pos_emb
217
+
218
+ def position_encoding(self,
219
+ offset: Union[int, torch.Tensor],
220
+ size: int) -> torch.Tensor:
221
+ """ For getting encoding in a streaming fashion
222
+
223
+ Attention!!!!!
224
+ we apply dropout only once at the whole utterance level in a none
225
+ streaming way, but will call this function several times with
226
+ increasing input size in a streaming scenario, so the dropout will
227
+ be applied several times.
228
+
229
+ Args:
230
+ offset (int or torch.tensor): start offset
231
+ size (int): required size of position encoding
232
+
233
+ Returns:
234
+ torch.Tensor: Corresponding encoding
235
+ """
236
+ # How to subscript a Union type:
237
+ # https://github.com/pytorch/pytorch/issues/69434
238
+ if isinstance(offset, int):
239
+ pos_emb = self.pe[
240
+ :,
241
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
242
+ ]
243
+ elif isinstance(offset, torch.Tensor):
244
+ pos_emb = self.pe[
245
+ :,
246
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
247
+ ]
248
+ return pos_emb
249
+
250
+
251
+ class LinearNoSubsampling(torch.nn.Module):
252
+ """Linear transform the input without subsampling
253
+
254
+ Args:
255
+ idim (int): Input dimension.
256
+ odim (int): Output dimension.
257
+ pos_enc_class (torch.nn.Module): Positional encoding class.
258
+
259
+ """
260
+
261
+ def __init__(self, idim: int, odim: int,
262
+ pos_enc_class: torch.nn.Module):
263
+ super().__init__()
264
+ self.out = torch.nn.Sequential(
265
+ torch.nn.Linear(idim, odim),
266
+ torch.nn.LayerNorm(odim, eps=1e-5),
267
+ )
268
+ self.pos_enc = pos_enc_class
269
+ self.right_context = 0
270
+ self.subsampling_rate = 1
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ x_mask: torch.Tensor,
276
+ offset: Union[int, torch.Tensor] = 0
277
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ """Input x.
279
+
280
+ Args:
281
+ x (torch.Tensor): Input tensor (#batch, time, idim).
282
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
283
+
284
+ Returns:
285
+ torch.Tensor: linear input tensor (#batch, time', odim),
286
+ where time' = time .
287
+ torch.Tensor: linear input mask (#batch, 1, time'),
288
+ where time' = time .
289
+
290
+ """
291
+ x = self.out(x)
292
+ x, pos_emb = self.pos_enc(x, offset)
293
+ return x, pos_emb, x_mask
294
+
295
+ def position_encoding(self, offset: Union[int, torch.Tensor],
296
+ size: int) -> torch.Tensor:
297
+ return self.pos_enc.position_encoding(offset, size)
298
+
299
+
300
+ class Upsample1D(nn.Module):
301
+ """A 1D upsampling layer with an optional convolution.
302
+
303
+ Parameters:
304
+ channels (`int`):
305
+ number of channels in the inputs and outputs.
306
+ use_conv (`bool`, default `False`):
307
+ option to use a convolution.
308
+ use_conv_transpose (`bool`, default `False`):
309
+ option to use a convolution transpose.
310
+ out_channels (`int`, optional):
311
+ number of output channels. Defaults to `channels`.
312
+ """
313
+
314
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
315
+ super().__init__()
316
+ self.channels = channels
317
+ self.out_channels = out_channels
318
+ self.stride = stride
319
+ # In this mode, first repeat interpolate, than conv with stride=1
320
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
321
+
322
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
324
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
325
+ outputs = self.conv(outputs)
326
+ return outputs, input_lengths * self.stride
327
+
328
+
329
+ class PreLookaheadLayer(nn.Module):
330
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
331
+ super().__init__()
332
+ self.channels = channels
333
+ self.pre_lookahead_len = pre_lookahead_len
334
+ self.conv1 = nn.Conv1d(
335
+ channels, channels,
336
+ kernel_size=pre_lookahead_len + 1,
337
+ stride=1, padding=0,
338
+ )
339
+ self.conv2 = nn.Conv1d(
340
+ channels, channels,
341
+ kernel_size=3, stride=1, padding=0,
342
+ )
343
+
344
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
345
+ """
346
+ inputs: (batch_size, seq_len, channels)
347
+ """
348
+ outputs = inputs.transpose(1, 2).contiguous()
349
+ context = context.transpose(1, 2).contiguous()
350
+ # look ahead
351
+ if context.size(2) == 0:
352
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
353
+ else:
354
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
355
+ assert context.size(2) == self.pre_lookahead_len
356
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
357
+ outputs = F.leaky_relu(self.conv1(outputs))
358
+ # outputs
359
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
360
+ outputs = self.conv2(outputs)
361
+ outputs = outputs.transpose(1, 2).contiguous()
362
+
363
+ # residual connection
364
+ outputs = outputs + inputs
365
+ return outputs
366
+
367
+
368
+ class MultiHeadedAttention(nn.Module):
369
+ """Multi-Head Attention layer.
370
+
371
+ Args:
372
+ n_head (int): The number of heads.
373
+ n_feat (int): The number of features.
374
+ dropout_rate (float): Dropout rate.
375
+ key_bias (bool): Whether to use bias in key linear layer.
376
+
377
+ """
378
+
379
+ def __init__(self,
380
+ n_head: int,
381
+ n_feat: int,
382
+ dropout_rate: float,
383
+ key_bias: bool = True):
384
+ super().__init__()
385
+ assert n_feat % n_head == 0
386
+ # We assume d_v always equals d_k
387
+ self.d_k = n_feat // n_head
388
+ self.h = n_head
389
+ self.linear_q = nn.Linear(n_feat, n_feat)
390
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
391
+ self.linear_v = nn.Linear(n_feat, n_feat)
392
+ self.linear_out = nn.Linear(n_feat, n_feat)
393
+ self.dropout = nn.Dropout(p=dropout_rate)
394
+
395
+ def forward_qkv(
396
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
397
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
398
+ """Transform query, key and value.
399
+
400
+ Args:
401
+ query (torch.Tensor): Query tensor (#batch, time1, size).
402
+ key (torch.Tensor): Key tensor (#batch, time2, size).
403
+ value (torch.Tensor): Value tensor (#batch, time2, size).
404
+
405
+ Returns:
406
+ torch.Tensor: Transformed query tensor, size
407
+ (#batch, n_head, time1, d_k).
408
+ torch.Tensor: Transformed key tensor, size
409
+ (#batch, n_head, time2, d_k).
410
+ torch.Tensor: Transformed value tensor, size
411
+ (#batch, n_head, time2, d_k).
412
+
413
+ """
414
+ n_batch = query.size(0)
415
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
416
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
417
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
418
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
419
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
420
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
421
+
422
+ return q, k, v
423
+
424
+ def forward_attention(
425
+ self,
426
+ value: torch.Tensor,
427
+ scores: torch.Tensor,
428
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
429
+ ) -> torch.Tensor:
430
+ """Compute attention context vector.
431
+
432
+ Args:
433
+ value (torch.Tensor): Transformed value, size
434
+ (#batch, n_head, time2, d_k).
435
+ scores (torch.Tensor): Attention score, size
436
+ (#batch, n_head, time1, time2).
437
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
438
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
439
+
440
+ Returns:
441
+ torch.Tensor: Transformed value (#batch, time1, d_model)
442
+ weighted by the attention score (#batch, time1, time2).
443
+
444
+ """
445
+ n_batch = value.size(0)
446
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
447
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
448
+ # 1st chunk to ease the onnx export.]
449
+ # 2. pytorch training
450
+ if mask.size(2) > 0: # time2 > 0
451
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
452
+ # For last chunk, time2 might be larger than scores.size(-1)
453
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
454
+ scores = scores.masked_fill(mask, -float('inf'))
455
+ attn = torch.softmax(scores, dim=-1).masked_fill(
456
+ mask, 0.0) # (batch, head, time1, time2)
457
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
458
+ # 1. onnx(16/-1, -1/-1, 16/0)
459
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
460
+ else:
461
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
462
+
463
+ p_attn = self.dropout(attn)
464
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
465
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
466
+ self.h * self.d_k)
467
+ ) # (batch, time1, d_model)
468
+
469
+ return self.linear_out(x) # (batch, time1, d_model)
470
+
471
+ def forward(
472
+ self,
473
+ query: torch.Tensor,
474
+ key: torch.Tensor,
475
+ value: torch.Tensor,
476
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
477
+ pos_emb: torch.Tensor = torch.empty(0),
478
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
479
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ """Compute scaled dot product attention.
481
+
482
+ Args:
483
+ query (torch.Tensor): Query tensor (#batch, time1, size).
484
+ key (torch.Tensor): Key tensor (#batch, time2, size).
485
+ value (torch.Tensor): Value tensor (#batch, time2, size).
486
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
487
+ (#batch, time1, time2).
488
+ 1.When applying cross attention between decoder and encoder,
489
+ the batch padding mask for input is in (#batch, 1, T) shape.
490
+ 2.When applying self attention of encoder,
491
+ the mask is in (#batch, T, T) shape.
492
+ 3.When applying self attention of decoder,
493
+ the mask is in (#batch, L, L) shape.
494
+ 4.If the different position in decoder see different block
495
+ of the encoder, such as Mocha, the passed in mask could be
496
+ in (#batch, L, T) shape. But there is no such case in current
497
+ CosyVoice.
498
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
499
+ where `cache_t == chunk_size * num_decoding_left_chunks`
500
+ and `head * d_k == size`
501
+
502
+
503
+ Returns:
504
+ torch.Tensor: Output tensor (#batch, time1, d_model).
505
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
506
+ where `cache_t == chunk_size * num_decoding_left_chunks`
507
+ and `head * d_k == size`
508
+
509
+ """
510
+ q, k, v = self.forward_qkv(query, key, value)
511
+
512
+ # NOTE(xcsong):
513
+ # when export onnx model, for 1st chunk, we feed
514
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
515
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
516
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
517
+ # and we will always do splitting and
518
+ # concatnation(this will simplify onnx export). Note that
519
+ # it's OK to concat & split zero-shaped tensors(see code below).
520
+ # when export jit model, for 1st chunk, we always feed
521
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
522
+ # >>> a = torch.ones((1, 2, 0, 4))
523
+ # >>> b = torch.ones((1, 2, 3, 4))
524
+ # >>> c = torch.cat((a, b), dim=2)
525
+ # >>> torch.equal(b, c) # True
526
+ # >>> d = torch.split(a, 2, dim=-1)
527
+ # >>> torch.equal(d[0], d[1]) # True
528
+ if cache.size(0) > 0:
529
+ key_cache, value_cache = torch.split(cache,
530
+ cache.size(-1) // 2,
531
+ dim=-1)
532
+ k = torch.cat([key_cache, k], dim=2)
533
+ v = torch.cat([value_cache, v], dim=2)
534
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
535
+ # non-trivial to calculate `next_cache_start` here.
536
+ new_cache = torch.cat((k, v), dim=-1)
537
+
538
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
539
+ return self.forward_attention(v, scores, mask), new_cache
540
+
541
+
542
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
543
+ """Multi-Head Attention layer with relative position encoding.
544
+ Paper: https://arxiv.org/abs/1901.02860
545
+ Args:
546
+ n_head (int): The number of heads.
547
+ n_feat (int): The number of features.
548
+ dropout_rate (float): Dropout rate.
549
+ key_bias (bool): Whether to use bias in key linear layer.
550
+ """
551
+
552
+ def __init__(self,
553
+ n_head: int,
554
+ n_feat: int,
555
+ dropout_rate: float,
556
+ key_bias: bool = True):
557
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
558
+ # linear transformation for positional encoding
559
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
560
+ # these two learnable bias are used in matrix c and matrix d
561
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
562
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
563
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
564
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
565
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
566
+
567
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
568
+ """Compute relative positional encoding.
569
+
570
+ Args:
571
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
572
+ time1 means the length of query vector.
573
+
574
+ Returns:
575
+ torch.Tensor: Output tensor.
576
+
577
+ """
578
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
579
+ device=x.device,
580
+ dtype=x.dtype)
581
+ x_padded = torch.cat([zero_pad, x], dim=-1)
582
+
583
+ x_padded = x_padded.view(x.size()[0],
584
+ x.size()[1],
585
+ x.size(3) + 1, x.size(2))
586
+ x = x_padded[:, :, 1:].view_as(x)[
587
+ :, :, :, : x.size(-1) // 2 + 1
588
+ ] # only keep the positions from 0 to time2
589
+ return x
590
+
591
+ def forward(
592
+ self,
593
+ query: torch.Tensor,
594
+ key: torch.Tensor,
595
+ value: torch.Tensor,
596
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
597
+ pos_emb: torch.Tensor = torch.empty(0),
598
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
599
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
600
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
601
+ Args:
602
+ query (torch.Tensor): Query tensor (#batch, time1, size).
603
+ key (torch.Tensor): Key tensor (#batch, time2, size).
604
+ value (torch.Tensor): Value tensor (#batch, time2, size).
605
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
606
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
607
+ pos_emb (torch.Tensor): Positional embedding tensor
608
+ (#batch, time2, size).
609
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
610
+ where `cache_t == chunk_size * num_decoding_left_chunks`
611
+ and `head * d_k == size`
612
+ Returns:
613
+ torch.Tensor: Output tensor (#batch, time1, d_model).
614
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
615
+ where `cache_t == chunk_size * num_decoding_left_chunks`
616
+ and `head * d_k == size`
617
+ """
618
+ q, k, v = self.forward_qkv(query, key, value)
619
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
620
+
621
+ # NOTE(xcsong):
622
+ # when export onnx model, for 1st chunk, we feed
623
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
624
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
625
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
626
+ # and we will always do splitting and
627
+ # concatnation(this will simplify onnx export). Note that
628
+ # it's OK to concat & split zero-shaped tensors(see code below).
629
+ # when export jit model, for 1st chunk, we always feed
630
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
631
+ # >>> a = torch.ones((1, 2, 0, 4))
632
+ # >>> b = torch.ones((1, 2, 3, 4))
633
+ # >>> c = torch.cat((a, b), dim=2)
634
+ # >>> torch.equal(b, c) # True
635
+ # >>> d = torch.split(a, 2, dim=-1)
636
+ # >>> torch.equal(d[0], d[1]) # True
637
+ if cache.size(0) > 0:
638
+ key_cache, value_cache = torch.split(cache,
639
+ cache.size(-1) // 2,
640
+ dim=-1)
641
+ k = torch.cat([key_cache, k], dim=2)
642
+ v = torch.cat([value_cache, v], dim=2)
643
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
644
+ # non-trivial to calculate `next_cache_start` here.
645
+ new_cache = torch.cat((k, v), dim=-1)
646
+
647
+ n_batch_pos = pos_emb.size(0)
648
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
649
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
650
+
651
+ # (batch, head, time1, d_k)
652
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
653
+ # (batch, head, time1, d_k)
654
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
655
+
656
+ # compute attention score
657
+ # first compute matrix a and matrix c
658
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
659
+ # (batch, head, time1, time2)
660
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
661
+
662
+ # compute matrix b and matrix d
663
+ # (batch, head, time1, time2)
664
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
665
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
666
+ if matrix_ac.shape != matrix_bd.shape:
667
+ matrix_bd = self.rel_shift(matrix_bd)
668
+
669
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
670
+ self.d_k) # (batch, head, time1, time2)
671
+
672
+ return self.forward_attention(v, scores, mask), new_cache
673
+
674
+
675
+ class PositionwiseFeedForward(torch.nn.Module):
676
+ """Positionwise feed forward layer.
677
+
678
+ FeedForward are appied on each position of the sequence.
679
+ The output dim is same with the input dim.
680
+
681
+ Args:
682
+ idim (int): Input dimenstion.
683
+ hidden_units (int): The number of hidden units.
684
+ dropout_rate (float): Dropout rate.
685
+ activation (torch.nn.Module): Activation function
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ idim: int,
691
+ hidden_units: int,
692
+ dropout_rate: float,
693
+ activation: torch.nn.Module = torch.nn.ReLU(),
694
+ ):
695
+ super(PositionwiseFeedForward, self).__init__()
696
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
697
+ self.activation = activation
698
+ self.dropout = torch.nn.Dropout(dropout_rate)
699
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
700
+
701
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
702
+ """Forward function.
703
+
704
+ Args:
705
+ xs: input tensor (B, L, D)
706
+ Returns:
707
+ output tensor, (B, L, D)
708
+ """
709
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
710
+
711
+
712
+ class ConformerEncoderLayer(nn.Module):
713
+ """Encoder layer module.
714
+ Args:
715
+ size (int): Input dimension.
716
+ self_attn (torch.nn.Module): Self-attention module instance.
717
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
718
+ instance can be used as the argument.
719
+ feed_forward (torch.nn.Module): Feed-forward module instance.
720
+ `PositionwiseFeedForward` instance can be used as the argument.
721
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
722
+ instance.
723
+ `PositionwiseFeedForward` instance can be used as the argument.
724
+ conv_module (torch.nn.Module): Convolution module instance.
725
+ `ConvlutionModule` instance can be used as the argument.
726
+ dropout_rate (float): Dropout rate.
727
+ normalize_before (bool):
728
+ True: use layer_norm before each sub-block.
729
+ False: use layer_norm after each sub-block.
730
+ """
731
+
732
+ def __init__(
733
+ self,
734
+ size: int,
735
+ self_attn: torch.nn.Module,
736
+ feed_forward: Optional[nn.Module] = None,
737
+ feed_forward_macaron: Optional[nn.Module] = None,
738
+ conv_module: Optional[nn.Module] = None,
739
+ dropout_rate: float = 0.0,
740
+ normalize_before: bool = True,
741
+ ):
742
+ super().__init__()
743
+ self.self_attn = self_attn
744
+ self.feed_forward = feed_forward
745
+ self.feed_forward_macaron = feed_forward_macaron
746
+ self.conv_module = conv_module
747
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
748
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
749
+ if feed_forward_macaron is not None:
750
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
751
+ self.ff_scale = 0.5
752
+ else:
753
+ self.ff_scale = 1.0
754
+ if self.conv_module is not None:
755
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
756
+ self.norm_final = nn.LayerNorm(
757
+ size, eps=1e-12) # for the final output of the block
758
+ self.dropout = nn.Dropout(dropout_rate)
759
+ self.size = size
760
+ self.normalize_before = normalize_before
761
+
762
+ def forward(
763
+ self,
764
+ x: torch.Tensor,
765
+ mask: torch.Tensor,
766
+ pos_emb: torch.Tensor,
767
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
768
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
769
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
770
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
771
+ """Compute encoded features.
772
+
773
+ Args:
774
+ x (torch.Tensor): (#batch, time, size)
775
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
776
+ (0, 0, 0) means fake mask.
777
+ pos_emb (torch.Tensor): positional encoding, must not be None
778
+ for ConformerEncoderLayer.
779
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
780
+ (#batch, 1,time), (0, 0, 0) means fake mask.
781
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
782
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
783
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
784
+ (#batch=1, size, cache_t2)
785
+ Returns:
786
+ torch.Tensor: Output tensor (#batch, time, size).
787
+ torch.Tensor: Mask tensor (#batch, time, time).
788
+ torch.Tensor: att_cache tensor,
789
+ (#batch=1, head, cache_t1 + time, d_k * 2).
790
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
791
+ """
792
+
793
+ # whether to use macaron style
794
+ if self.feed_forward_macaron is not None:
795
+ residual = x
796
+ if self.normalize_before:
797
+ x = self.norm_ff_macaron(x)
798
+ x = residual + self.ff_scale * self.dropout(
799
+ self.feed_forward_macaron(x))
800
+ if not self.normalize_before:
801
+ x = self.norm_ff_macaron(x)
802
+
803
+ # multi-headed self-attention module
804
+ residual = x
805
+ if self.normalize_before:
806
+ x = self.norm_mha(x)
807
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
808
+ att_cache)
809
+ x = residual + self.dropout(x_att)
810
+ if not self.normalize_before:
811
+ x = self.norm_mha(x)
812
+
813
+ # convolution module
814
+ # Fake new cnn cache here, and then change it in conv_module
815
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
816
+ if self.conv_module is not None:
817
+ residual = x
818
+ if self.normalize_before:
819
+ x = self.norm_conv(x)
820
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
821
+ x = residual + self.dropout(x)
822
+
823
+ if not self.normalize_before:
824
+ x = self.norm_conv(x)
825
+
826
+ # feed forward module
827
+ residual = x
828
+ if self.normalize_before:
829
+ x = self.norm_ff(x)
830
+
831
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
832
+ if not self.normalize_before:
833
+ x = self.norm_ff(x)
834
+
835
+ if self.conv_module is not None:
836
+ x = self.norm_final(x)
837
+
838
+ return x, mask, new_att_cache, new_cnn_cache
839
+
840
+
841
+ class UpsampleConformerEncoder(torch.nn.Module):
842
+ """
843
+ Args:
844
+ input_size (int): input dim
845
+ output_size (int): dimension of attention
846
+ attention_heads (int): the number of heads of multi head attention
847
+ linear_units (int): the hidden units number of position-wise feed
848
+ forward
849
+ num_blocks (int): the number of decoder blocks
850
+ static_chunk_size (int): chunk size for static chunk training and
851
+ decoding
852
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
853
+ training or not, You can only use fixed chunk(chunk_size > 0)
854
+ or dyanmic chunk size(use_dynamic_chunk = True)
855
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
856
+ dynamic chunk training
857
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ input_size: int = 512,
863
+ output_size: int = 512,
864
+ attention_heads: int = 8,
865
+ linear_units: int = 2048,
866
+ num_blocks: int = 6,
867
+ static_chunk_size: int = 25,
868
+ use_dynamic_chunk: bool = False,
869
+ use_dynamic_left_chunk: bool = False,
870
+ key_bias: bool = True,
871
+ ):
872
+ super().__init__()
873
+ self._output_size = output_size
874
+
875
+ self.embed = LinearNoSubsampling(
876
+ input_size, output_size,
877
+ EspnetRelPositionalEncoding(output_size),
878
+ )
879
+
880
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
881
+ self.static_chunk_size = static_chunk_size
882
+ self.use_dynamic_chunk = use_dynamic_chunk
883
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
884
+ activation = torch.nn.SiLU()
885
+ # self-attention module definition
886
+ encoder_selfattn_layer_args = (
887
+ attention_heads,
888
+ output_size,
889
+ 0.0,
890
+ key_bias,
891
+ )
892
+ # feed-forward module definition
893
+ positionwise_layer_args = (
894
+ output_size,
895
+ linear_units,
896
+ 0.0,
897
+ activation,
898
+ )
899
+ # convolution module definition
900
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
901
+ self.encoders = torch.nn.ModuleList([
902
+ ConformerEncoderLayer(
903
+ output_size,
904
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
905
+ PositionwiseFeedForward(*positionwise_layer_args),
906
+ ) for _ in range(num_blocks)
907
+ ])
908
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
909
+ self.up_embed = LinearNoSubsampling(
910
+ input_size, output_size,
911
+ EspnetRelPositionalEncoding(output_size),
912
+ )
913
+ self.up_encoders = torch.nn.ModuleList([
914
+ ConformerEncoderLayer(
915
+ output_size,
916
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
917
+ PositionwiseFeedForward(*positionwise_layer_args),
918
+ ) for _ in range(4)
919
+ ])
920
+
921
+ def output_size(self) -> int:
922
+ return self._output_size
923
+
924
+ def forward(
925
+ self,
926
+ xs: torch.Tensor,
927
+ xs_lens: torch.Tensor,
928
+ context: torch.Tensor = torch.zeros(0, 0, 0),
929
+ decoding_chunk_size: int = 0,
930
+ num_decoding_left_chunks: int = -1,
931
+ streaming: bool = False,
932
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
933
+ """Embed positions in tensor.
934
+
935
+ Args:
936
+ xs: padded input tensor (B, T, D)
937
+ xs_lens: input length (B)
938
+ decoding_chunk_size: decoding chunk size for dynamic chunk
939
+ 0: default for training, use random dynamic chunk.
940
+ <0: for decoding, use full chunk.
941
+ >0: for decoding, use fixed chunk size as set.
942
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
943
+ the chunk size is decoding_chunk_size.
944
+ >=0: use num_decoding_left_chunks
945
+ <0: use all left chunks
946
+ Returns:
947
+ encoder output tensor xs, and subsampled masks
948
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
949
+ masks: torch.Tensor batch padding mask after subsample
950
+ (B, 1, T' ~= T/subsample_rate)
951
+ NOTE(xcsong):
952
+ We pass the `__call__` method of the modules instead of `forward` to the
953
+ checkpointing API because `__call__` attaches all the hooks of the module.
954
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
955
+ """
956
+ T = xs.size(1)
957
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
958
+ xs, pos_emb, masks = self.embed(xs, masks)
959
+ if context.size(1) != 0:
960
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
961
+ context_masks = torch.ones(1, 1, context.size(1)).to(masks)
962
+ context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
963
+ mask_pad = masks # (B, 1, T/subsample_rate)
964
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
965
+ # lookahead + conformer encoder
966
+ xs = self.pre_lookahead_layer(xs, context=context)
967
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
968
+
969
+ # upsample + conformer encoder
970
+ xs = xs.transpose(1, 2).contiguous()
971
+ xs, xs_lens = self.up_layer(xs, xs_lens)
972
+ xs = xs.transpose(1, 2).contiguous()
973
+ T = xs.size(1)
974
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
975
+ xs, pos_emb, masks = self.up_embed(xs, masks)
976
+ mask_pad = masks # (B, 1, T/subsample_rate)
977
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
978
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
979
+
980
+ xs = self.after_norm(xs)
981
+ # Here we assume the mask is not changed in encoder layers, so just
982
+ # return the masks before encoder layers, and the masks will be used
983
+ # for cross attention with decoder later
984
+ return xs, masks
985
+
986
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
987
+ pos_emb: torch.Tensor,
988
+ mask_pad: torch.Tensor) -> torch.Tensor:
989
+ for layer in self.encoders:
990
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
991
+ return xs
992
+
993
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
994
+ pos_emb: torch.Tensor,
995
+ mask_pad: torch.Tensor) -> torch.Tensor:
996
+ for layer in self.up_encoders:
997
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
998
+ return xs
soulxpodcast/models/modules/hifigan.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, List
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from scipy.signal import get_window
24
+ from torch.nn import Conv1d, ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+
27
+ try:
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ except ImportError:
30
+ from torch.nn.utils import weight_norm # noqa
31
+
32
+ from soulxpodcast.models.modules.hifigan_components.layers import (
33
+ ResBlock, SourceModuleHnNSF, SourceModuleHnNSF2, init_weights)
34
+
35
+
36
+ class ConvRNNF0Predictor(nn.Module):
37
+ def __init__(self,
38
+ num_class: int = 1,
39
+ in_channels: int = 80,
40
+ cond_channels: int = 512
41
+ ):
42
+ super().__init__()
43
+
44
+ self.num_class = num_class
45
+ self.condnet = nn.Sequential(
46
+ weight_norm( # noqa
47
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
48
+ ),
49
+ nn.ELU(),
50
+ weight_norm( # noqa
51
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
52
+ ),
53
+ nn.ELU(),
54
+ weight_norm( # noqa
55
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
56
+ ),
57
+ nn.ELU(),
58
+ weight_norm( # noqa
59
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
60
+ ),
61
+ nn.ELU(),
62
+ weight_norm( # noqa
63
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
64
+ ),
65
+ nn.ELU(),
66
+ )
67
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ x = self.condnet(x)
71
+ x = x.transpose(1, 2)
72
+ return torch.abs(self.classifier(x).squeeze(-1))
73
+
74
+
75
+ class HiFTGenerator(nn.Module):
76
+ """
77
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
78
+ https://arxiv.org/abs/2309.09493
79
+ """
80
+ def __init__(
81
+ self,
82
+ in_channels: int = 80,
83
+ base_channels: int = 512,
84
+ nb_harmonics: int = 8,
85
+ sampling_rate: int = 24000,
86
+ nsf_alpha: float = 0.1,
87
+ nsf_sigma: float = 0.003,
88
+ nsf_voiced_threshold: float = 10,
89
+ upsample_rates: List[int] = [8, 5, 3], # noqa
90
+ upsample_kernel_sizes: List[int] = [16, 11, 7], # noqa
91
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, # noqa
92
+ resblock_kernel_sizes: List[int] = [3, 7, 11], # noqa
93
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
94
+ source_resblock_kernel_sizes: List[int] = [7, 7, 11], # noqa
95
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
96
+ lrelu_slope: float = 0.1,
97
+ audio_limit: float = 0.99,
98
+ f0_predictor: torch.nn.Module = None,
99
+ ):
100
+ super(HiFTGenerator, self).__init__()
101
+
102
+ self.out_channels = 1
103
+ self.nb_harmonics = nb_harmonics
104
+ self.sampling_rate = sampling_rate
105
+ self.istft_params = istft_params
106
+ self.lrelu_slope = lrelu_slope
107
+ self.audio_limit = audio_limit
108
+
109
+ self.num_kernels = len(resblock_kernel_sizes)
110
+ self.num_upsamples = len(upsample_rates)
111
+ # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
112
+ this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
113
+ self.m_source = this_SourceModuleHnNSF(
114
+ sampling_rate=sampling_rate,
115
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
116
+ harmonic_num=nb_harmonics,
117
+ sine_amp=nsf_alpha,
118
+ add_noise_std=nsf_sigma,
119
+ voiced_threshod=nsf_voiced_threshold)
120
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
121
+
122
+ self.conv_pre = weight_norm( # noqa
123
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
124
+ )
125
+
126
+ # Up
127
+ self.ups = nn.ModuleList()
128
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
129
+ self.ups.append(
130
+ weight_norm( # noqa
131
+ ConvTranspose1d(
132
+ base_channels // (2**i),
133
+ base_channels // (2**(i + 1)),
134
+ k,
135
+ u,
136
+ padding=(k - u) // 2,
137
+ )
138
+ )
139
+ )
140
+
141
+ # Down
142
+ self.source_downs = nn.ModuleList()
143
+ self.source_resblocks = nn.ModuleList()
144
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
145
+ downsample_cum_rates = np.cumprod(downsample_rates)
146
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
147
+ if u == 1:
148
+ self.source_downs.append(
149
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
150
+ )
151
+ else:
152
+ self.source_downs.append(
153
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
154
+ )
155
+
156
+ self.source_resblocks.append(
157
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
158
+ )
159
+
160
+ self.resblocks = nn.ModuleList()
161
+ for i in range(len(self.ups)):
162
+ ch = base_channels // (2**(i + 1))
163
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
164
+ self.resblocks.append(ResBlock(ch, k, d))
165
+
166
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) # noqa
167
+ self.ups.apply(init_weights)
168
+ self.conv_post.apply(init_weights)
169
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
170
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
171
+ self.f0_predictor = ConvRNNF0Predictor() if f0_predictor is None else f0_predictor
172
+
173
+ def remove_weight_norm(self):
174
+ print('Removing weight norm...')
175
+ for up in self.ups:
176
+ remove_weight_norm(up)
177
+ for resblock in self.resblocks:
178
+ resblock.remove_weight_norm()
179
+ remove_weight_norm(self.conv_pre)
180
+ remove_weight_norm(self.conv_post)
181
+ self.m_source.remove_weight_norm()
182
+ for source_down in self.source_downs:
183
+ remove_weight_norm(source_down)
184
+ for source_resblock in self.source_resblocks:
185
+ source_resblock.remove_weight_norm()
186
+
187
+ def _stft(self, x):
188
+ spec = torch.stft(
189
+ x,
190
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
191
+ return_complex=True)
192
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
193
+ return spec[..., 0], spec[..., 1]
194
+
195
+ def _istft(self, magnitude, phase):
196
+ magnitude = torch.clip(magnitude, max=1e2)
197
+ real = magnitude * torch.cos(phase)
198
+ img = magnitude * torch.sin(phase)
199
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
200
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
201
+ return inverse_transform
202
+
203
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
204
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
205
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
206
+
207
+ x = self.conv_pre(x)
208
+ for i in range(self.num_upsamples):
209
+ x = F.leaky_relu(x, self.lrelu_slope)
210
+ x = self.ups[i](x)
211
+
212
+ if i == self.num_upsamples - 1:
213
+ x = self.reflection_pad(x)
214
+
215
+ # fusion
216
+ si = self.source_downs[i](s_stft)
217
+ si = self.source_resblocks[i](si)
218
+ x = x + si
219
+
220
+ xs = None
221
+ for j in range(self.num_kernels):
222
+ if xs is None:
223
+ xs = self.resblocks[i * self.num_kernels + j](x)
224
+ else:
225
+ xs += self.resblocks[i * self.num_kernels + j](x)
226
+ x = xs / self.num_kernels
227
+
228
+ x = F.leaky_relu(x)
229
+ x = self.conv_post(x)
230
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
231
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
232
+
233
+ x = self._istft(magnitude, phase)
234
+ x = torch.clamp(x*0.98, -self.audio_limit, self.audio_limit)
235
+ return x
236
+
237
+ @torch.inference_mode()
238
+ def forward(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
239
+ # mel->f0
240
+ f0 = self.f0_predictor(speech_feat)
241
+ # f0->source
242
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
243
+ s, _, _ = self.m_source(s)
244
+ s = s.transpose(1, 2)
245
+ # use cache_source to avoid glitch
246
+ if cache_source.shape[2] != 0:
247
+ s[:, :, :cache_source.shape[2]] = cache_source
248
+ generated_speech = self.decode(x=speech_feat, s=s)
249
+ return generated_speech, s
soulxpodcast/models/modules/hifigan_components/__init__.py ADDED
File without changes
soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (211 Bytes). View file
 
soulxpodcast/models/modules/hifigan_components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (199 Bytes). View file
 
soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-311.pyc ADDED
Binary file (20.5 kB). View file
 
soulxpodcast/models/modules/hifigan_components/__pycache__/layers.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
soulxpodcast/models/modules/hifigan_components/layers.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.distributions.uniform import Uniform
7
+ from torch.nn import Conv1d
8
+ from torch.nn.utils import remove_weight_norm
9
+
10
+ try:
11
+ from torch.nn.utils.parametrizations import weight_norm
12
+ except ImportError:
13
+ from torch.nn.utils import weight_norm # noqa
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ def init_weights(m, mean=0.0, std=0.01):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ m.weight.data.normal_(mean, std)
24
+
25
+
26
+ """hifigan based generator implementation.
27
+
28
+ This code is modified from https://github.com/jik876/hifi-gan
29
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
30
+ https://github.com/NVIDIA/BigVGAN
31
+
32
+ """
33
+
34
+
35
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
36
+ # LICENSE is in incl_licenses directory.
37
+ class Snake(nn.Module):
38
+ '''
39
+ Implementation of a sine-based periodic activation function
40
+ Shape:
41
+ - Input: (B, C, T)
42
+ - Output: (B, C, T), same shape as the input
43
+ Parameters:
44
+ - alpha - trainable parameter
45
+ References:
46
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
47
+ https://arxiv.org/abs/2006.08195
48
+ Examples:
49
+ >>> a1 = snake(256)
50
+ >>> x = torch.randn(256)
51
+ >>> x = a1(x)
52
+
53
+ Args:
54
+ in_features: shape of the input
55
+ alpha: trainable parameter
56
+ alpha_trainable: whether alpha is trainable
57
+ alpha_logscale: whether to use log scale for alpha
58
+ alpha is initialized to 1 by default, higher values = higher-frequency.
59
+ alpha will be trained along with the rest of your model.
60
+ '''
61
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
62
+ super(Snake, self).__init__()
63
+ self.in_features = in_features
64
+
65
+ # initialize alpha
66
+ self.alpha_logscale = alpha_logscale
67
+ if self.alpha_logscale: # log scale alphas initialized to zeros
68
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
69
+ else: # linear scale alphas initialized to ones
70
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
71
+
72
+ self.alpha.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ '''
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ Snake ∶= x + 1/a * sin^2 (xa)
81
+ '''
82
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(alpha)
85
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
86
+
87
+ return x
88
+
89
+
90
+ class ResBlock(torch.nn.Module):
91
+ """Residual block module in HiFiGAN/BigVGAN."""
92
+ def __init__(
93
+ self,
94
+ channels: int = 512,
95
+ kernel_size: int = 3,
96
+ dilations: List[int] = [1, 3, 5], # noqa
97
+ ):
98
+ super(ResBlock, self).__init__()
99
+ self.convs1 = nn.ModuleList()
100
+ self.convs2 = nn.ModuleList()
101
+
102
+ for dilation in dilations:
103
+ self.convs1.append(
104
+ weight_norm( # noqa
105
+ Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ 1,
110
+ dilation=dilation,
111
+ padding=get_padding(kernel_size, dilation)
112
+ )
113
+ )
114
+ )
115
+ self.convs2.append(
116
+ weight_norm( # noqa
117
+ Conv1d(
118
+ channels,
119
+ channels,
120
+ kernel_size,
121
+ 1,
122
+ dilation=1,
123
+ padding=get_padding(kernel_size, 1)
124
+ )
125
+ )
126
+ )
127
+ self.convs1.apply(init_weights)
128
+ self.convs2.apply(init_weights)
129
+ self.activations1 = nn.ModuleList([
130
+ Snake(channels, alpha_logscale=False)
131
+ for _ in range(len(self.convs1))
132
+ ])
133
+ self.activations2 = nn.ModuleList([
134
+ Snake(channels, alpha_logscale=False)
135
+ for _ in range(len(self.convs2))
136
+ ])
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ for idx in range(len(self.convs1)):
140
+ xt = self.activations1[idx](x)
141
+ xt = self.convs1[idx](xt)
142
+ xt = self.activations2[idx](xt)
143
+ xt = self.convs2[idx](xt)
144
+ x = xt + x
145
+ return x
146
+
147
+ def remove_weight_norm(self):
148
+ for idx in range(len(self.convs1)):
149
+ remove_weight_norm(self.convs1[idx])
150
+ remove_weight_norm(self.convs2[idx])
151
+
152
+
153
+ class SineGen(torch.nn.Module):
154
+ """ Definition of sine generator
155
+ SineGen(samp_rate, harmonic_num = 0,
156
+ sine_amp = 0.1, noise_std = 0.003,
157
+ voiced_threshold = 0,
158
+ flag_for_pulse=False)
159
+ samp_rate: sampling rate in Hz
160
+ harmonic_num: number of harmonic overtones (default 0)
161
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
162
+ noise_std: std of Gaussian noise (default 0.003)
163
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
164
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
165
+ Note: when flag_for_pulse is True, the first time step of a voiced
166
+ segment is always sin(np.pi) or cos(0)
167
+ """
168
+
169
+ def __init__(self, samp_rate, harmonic_num=0,
170
+ sine_amp=0.1, noise_std=0.003,
171
+ voiced_threshold=0):
172
+ super(SineGen, self).__init__()
173
+ self.sine_amp = sine_amp
174
+ self.noise_std = noise_std
175
+ self.harmonic_num = harmonic_num
176
+ self.sampling_rate = samp_rate
177
+ self.voiced_threshold = voiced_threshold
178
+
179
+ def _f02uv(self, f0):
180
+ # generate uv signal
181
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
182
+ return uv
183
+
184
+ @torch.no_grad()
185
+ def forward(self, f0):
186
+ """
187
+ :param f0: [B, 1, sample_len], Hz
188
+ :return: [B, 1, sample_len]
189
+ """
190
+
191
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
192
+ for i in range(self.harmonic_num + 1):
193
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
194
+
195
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
196
+ u_dist = Uniform(low=-np.pi, high=np.pi)
197
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
198
+ phase_vec[:, 0, :] = 0
199
+
200
+ # generate sine waveforms
201
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
202
+
203
+ # generate uv signal
204
+ uv = self._f02uv(f0)
205
+
206
+ # noise: for unvoiced should be similar to sine_amp
207
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
208
+ # . for voiced regions is self.noise_std
209
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
210
+ noise = noise_amp * torch.randn_like(sine_waves)
211
+
212
+ # first: set the unvoiced part to 0 by uv
213
+ # then: additive noise
214
+ sine_waves = sine_waves * uv + noise
215
+ return sine_waves, uv, noise
216
+
217
+
218
+ class SourceModuleHnNSF(torch.nn.Module):
219
+ """ SourceModule for hn-nsf
220
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
221
+ add_noise_std=0.003, voiced_threshod=0)
222
+ sampling_rate: sampling_rate in Hz
223
+ harmonic_num: number of harmonic above F0 (default: 0)
224
+ sine_amp: amplitude of sine source signal (default: 0.1)
225
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
226
+ note that amplitude of noise in unvoiced is decided
227
+ by sine_amp
228
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
229
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
230
+ F0_sampled (batchsize, length, 1)
231
+ Sine_source (batchsize, length, 1)
232
+ noise_source (batchsize, length 1)
233
+ uv (batchsize, length, 1)
234
+ """
235
+
236
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
237
+ add_noise_std=0.003, voiced_threshod=0):
238
+ super(SourceModuleHnNSF, self).__init__()
239
+
240
+ self.sine_amp = sine_amp
241
+ self.noise_std = add_noise_std
242
+
243
+ # to produce sine waveforms
244
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
245
+ sine_amp, add_noise_std, voiced_threshod)
246
+
247
+ # to merge source harmonics into a single excitation
248
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
249
+ self.l_tanh = torch.nn.Tanh()
250
+
251
+ def forward(self, x):
252
+ """
253
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
254
+ F0_sampled (batchsize, length, 1)
255
+ Sine_source (batchsize, length, 1)
256
+ noise_source (batchsize, length 1)
257
+ """
258
+ # source for harmonic branch
259
+ with torch.no_grad():
260
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
261
+ sine_wavs = sine_wavs.transpose(1, 2)
262
+ uv = uv.transpose(1, 2)
263
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
264
+
265
+ # source for noise branch, in the same shape as uv
266
+ noise = torch.randn_like(uv) * self.sine_amp / 3
267
+ return sine_merge, noise, uv
268
+
269
+
270
+ class SineGen2(torch.nn.Module):
271
+ """ Definition of sine generator
272
+ SineGen(samp_rate, harmonic_num = 0,
273
+ sine_amp = 0.1, noise_std = 0.003,
274
+ voiced_threshold = 0,
275
+ flag_for_pulse=False)
276
+ samp_rate: sampling rate in Hz
277
+ harmonic_num: number of harmonic overtones (default 0)
278
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
279
+ noise_std: std of Gaussian noise (default 0.003)
280
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
281
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
282
+ Note: when flag_for_pulse is True, the first time step of a voiced
283
+ segment is always sin(np.pi) or cos(0)
284
+ """
285
+
286
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
287
+ sine_amp=0.1, noise_std=0.003,
288
+ voiced_threshold=0,
289
+ flag_for_pulse=False):
290
+ super(SineGen2, self).__init__()
291
+ self.sine_amp = sine_amp
292
+ self.noise_std = noise_std
293
+ self.harmonic_num = harmonic_num
294
+ self.dim = self.harmonic_num + 1
295
+ self.sampling_rate = samp_rate
296
+ self.voiced_threshold = voiced_threshold
297
+ self.flag_for_pulse = flag_for_pulse
298
+ self.upsample_scale = upsample_scale
299
+
300
+ def _f02uv(self, f0):
301
+ # generate uv signal
302
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
303
+ return uv
304
+
305
+ def _f02sine(self, f0_values):
306
+ """ f0_values: (batchsize, length, dim)
307
+ where dim indicates fundamental tone and overtones
308
+ """
309
+ # convert to F0 in rad. The interger part n can be ignored
310
+ # because 2 * np.pi * n doesn't affect phase
311
+ rad_values = (f0_values / self.sampling_rate) % 1
312
+
313
+ # initial phase noise (no noise for fundamental component)
314
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
315
+ rand_ini[:, 0] = 0
316
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
317
+
318
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
319
+ if not self.flag_for_pulse:
320
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
321
+ scale_factor=1 / self.upsample_scale,
322
+ mode="linear").transpose(1, 2)
323
+
324
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
325
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
326
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
327
+ sines = torch.sin(phase)
328
+ else:
329
+ # If necessary, make sure that the first time step of every
330
+ # voiced segments is sin(pi) or cos(0)
331
+ # This is used for pulse-train generation
332
+
333
+ # identify the last time step in unvoiced segments
334
+ uv = self._f02uv(f0_values)
335
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
336
+ uv_1[:, -1, :] = 1
337
+ u_loc = (uv < 1) * (uv_1 > 0)
338
+
339
+ # get the instantanouse phase
340
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
341
+ # different batch needs to be processed differently
342
+ for idx in range(f0_values.shape[0]):
343
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
344
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
345
+ # stores the accumulation of i.phase within
346
+ # each voiced segments
347
+ tmp_cumsum[idx, :, :] = 0
348
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
349
+
350
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
351
+ # within the previous voiced segment.
352
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
353
+
354
+ # get the sines
355
+ sines = torch.cos(i_phase * 2 * np.pi)
356
+ return sines
357
+
358
+ def forward(self, f0):
359
+ """ sine_tensor, uv = forward(f0)
360
+ input F0: tensor(batchsize=1, length, dim=1)
361
+ f0 for unvoiced steps should be 0
362
+ output sine_tensor: tensor(batchsize=1, length, dim)
363
+ output uv: tensor(batchsize=1, length, 1)
364
+ """
365
+ # fundamental component
366
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
367
+
368
+ # generate sine waveforms
369
+ sine_waves = self._f02sine(fn) * self.sine_amp
370
+
371
+ # generate uv signal
372
+ uv = self._f02uv(f0)
373
+
374
+ # noise: for unvoiced should be similar to sine_amp
375
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
376
+ # . for voiced regions is self.noise_std
377
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
378
+ noise = noise_amp * torch.randn_like(sine_waves)
379
+
380
+ # first: set the unvoiced part to 0 by uv
381
+ # then: additive noise
382
+ sine_waves = sine_waves * uv + noise
383
+ return sine_waves, uv, noise
384
+
385
+
386
+ class SourceModuleHnNSF2(torch.nn.Module):
387
+ """ SourceModule for hn-nsf
388
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
389
+ add_noise_std=0.003, voiced_threshod=0)
390
+ sampling_rate: sampling_rate in Hz
391
+ harmonic_num: number of harmonic above F0 (default: 0)
392
+ sine_amp: amplitude of sine source signal (default: 0.1)
393
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
394
+ note that amplitude of noise in unvoiced is decided
395
+ by sine_amp
396
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
397
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
398
+ F0_sampled (batchsize, length, 1)
399
+ Sine_source (batchsize, length, 1)
400
+ noise_source (batchsize, length 1)
401
+ uv (batchsize, length, 1)
402
+ """
403
+
404
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
405
+ add_noise_std=0.003, voiced_threshod=0):
406
+ super(SourceModuleHnNSF2, self).__init__()
407
+
408
+ self.sine_amp = sine_amp
409
+ self.noise_std = add_noise_std
410
+
411
+ # to produce sine waveforms
412
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
413
+ sine_amp, add_noise_std, voiced_threshod)
414
+
415
+ # to merge source harmonics into a single excitation
416
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
417
+ self.l_tanh = torch.nn.Tanh()
418
+
419
+ def forward(self, x):
420
+ """
421
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
422
+ F0_sampled (batchsize, length, 1)
423
+ Sine_source (batchsize, length, 1)
424
+ noise_source (batchsize, length 1)
425
+ """
426
+ # source for harmonic branch
427
+ with torch.no_grad():
428
+ sine_wavs, uv, _ = self.l_sin_gen(x)
429
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
430
+
431
+ # source for noise branch, in the same shape as uv
432
+ noise = torch.randn_like(uv) * self.sine_amp / 3
433
+ return sine_merge, noise, uv
soulxpodcast/models/modules/sampler.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Callable, Optional, Union
4
+ import torch
5
+ from torch import nn
6
+ from transformers.generation.logits_process import (
7
+ LogitsProcessorList
8
+ )
9
+ from transformers.generation.stopping_criteria import (
10
+ StoppingCriteriaList
11
+ )
12
+ from transformers.generation.configuration_utils import (
13
+ GenerationConfig
14
+ )
15
+ from transformers.generation.streamers import BaseStreamer
16
+ from transformers.generation.utils import (
17
+ GenerateNonBeamOutput,
18
+ GenerateEncoderDecoderOutput,
19
+ GenerateDecoderOnlyOutput,
20
+ )
21
+ from transformers import StoppingCriteria
22
+
23
+
24
+ def _ras_sample_hf_engine(
25
+ self,
26
+ input_ids: torch.LongTensor,
27
+ logits_processor: LogitsProcessorList,
28
+ stopping_criteria: StoppingCriteriaList,
29
+ generation_config: GenerationConfig,
30
+ synced_gpus: bool = False,
31
+ streamer: Optional["BaseStreamer"] = None,
32
+ use_ras=False,
33
+ win_size=25,
34
+ tau_r=0.2,
35
+ **model_kwargs,
36
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
37
+ r"""
38
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
39
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
40
+
41
+ Parameters:
42
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
43
+ The sequence used as a prompt for the generation.
44
+ logits_processor (`LogitsProcessorList`):
45
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
46
+ used to modify the prediction scores of the language modeling head applied at each generation step.
47
+ stopping_criteria (`StoppingCriteriaList`):
48
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
49
+ used to tell if the generation loop should stop.
50
+ generation_config ([`~generation.GenerationConfig`]):
51
+ The generation configuration to be used as parametrization of the decoding method.
52
+ synced_gpus (`bool`):
53
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
54
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
55
+ streamer (`BaseStreamer`, *optional*):
56
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
57
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
58
+ model_kwargs:
59
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
60
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
61
+
62
+ Return:
63
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
64
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
65
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
66
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
67
+ `model.config.is_encoder_decoder=True`.
68
+ """
69
+ # init values
70
+ pad_token_id = generation_config._pad_token_tensor
71
+ output_attentions = generation_config.output_attentions
72
+ output_hidden_states = generation_config.output_hidden_states
73
+ output_scores = generation_config.output_scores
74
+ output_logits = generation_config.output_logits
75
+ return_dict_in_generate = generation_config.return_dict_in_generate
76
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
77
+ do_sample = generation_config.do_sample
78
+
79
+ # init attention / hidden states / scores tuples
80
+ scores = () if (return_dict_in_generate and output_scores) else None
81
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
82
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
83
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
84
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
85
+
86
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
87
+ if return_dict_in_generate and self.config.is_encoder_decoder:
88
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
89
+ encoder_hidden_states = (
90
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
91
+ )
92
+
93
+ # keep track of which sequences are already finished
94
+ batch_size, cur_len = input_ids.shape[:2]
95
+ this_peer_finished = False
96
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
97
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
98
+
99
+ model_forward = self.__call__
100
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
101
+ if compile_forward:
102
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
103
+ model_forward = self.get_compiled_call(generation_config.compile_config)
104
+
105
+ if generation_config.prefill_chunk_size is not None:
106
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
107
+ is_prefill = False
108
+ else:
109
+ is_prefill = True
110
+
111
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
112
+ # prepare model inputs
113
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
114
+
115
+ # prepare variable output controls (note: some models won't accept all output controls)
116
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
117
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
118
+
119
+ if is_prefill:
120
+ outputs = self(**model_inputs, return_dict=True)
121
+ is_prefill = False
122
+ else:
123
+ outputs = model_forward(**model_inputs, return_dict=True)
124
+
125
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
126
+ model_kwargs = self._update_model_kwargs_for_generation(
127
+ outputs,
128
+ model_kwargs,
129
+ is_encoder_decoder=self.config.is_encoder_decoder,
130
+ )
131
+ if synced_gpus and this_peer_finished:
132
+ continue
133
+
134
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
135
+ # (the clone itself is always small)
136
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
137
+
138
+
139
+ # pre-process distribution
140
+ next_token_scores = logits_processor(input_ids, next_token_logits)
141
+
142
+ # Repetition Aware Sampling in VALL-E 2
143
+ if use_ras:
144
+ probs_candidate = nn.functional.softmax(next_token_scores, dim=-1)
145
+ next_tokens_candidate = torch.multinomial(probs_candidate, num_samples=1).squeeze(1)
146
+ rep_num = (input_ids[:,-win_size:] == next_tokens_candidate).sum().item() + 1
147
+ if rep_num >= win_size * tau_r:
148
+ next_token_scores = next_token_logits
149
+
150
+ # Store scores, attentions and hidden_states when required
151
+ if return_dict_in_generate:
152
+ if output_scores:
153
+ scores += (next_token_scores,)
154
+ if output_logits:
155
+ raw_logits += (next_token_logits,)
156
+ if output_attentions:
157
+ decoder_attentions += (
158
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
159
+ )
160
+ if self.config.is_encoder_decoder:
161
+ cross_attentions += (outputs.cross_attentions,)
162
+
163
+ if output_hidden_states:
164
+ decoder_hidden_states += (
165
+ (outputs.decoder_hidden_states,)
166
+ if self.config.is_encoder_decoder
167
+ else (outputs.hidden_states,)
168
+ )
169
+
170
+ # token selection
171
+ if do_sample:
172
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
173
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
174
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
175
+ else:
176
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
177
+
178
+ # finished sentences should have their next token be a padding token
179
+ if has_eos_stopping_criteria:
180
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
181
+
182
+ # update generated ids, model inputs, and length for next step
183
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
184
+ if streamer is not None:
185
+ streamer.put(next_tokens.cpu())
186
+
187
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
188
+ this_peer_finished = unfinished_sequences.max() == 0
189
+ cur_len += 1
190
+
191
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
192
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
193
+ del outputs
194
+
195
+ if streamer is not None:
196
+ streamer.end()
197
+
198
+ if return_dict_in_generate:
199
+ if self.config.is_encoder_decoder:
200
+ return GenerateEncoderDecoderOutput(
201
+ sequences=input_ids,
202
+ scores=scores,
203
+ logits=raw_logits,
204
+ encoder_attentions=encoder_attentions,
205
+ encoder_hidden_states=encoder_hidden_states,
206
+ decoder_attentions=decoder_attentions,
207
+ cross_attentions=cross_attentions,
208
+ decoder_hidden_states=decoder_hidden_states,
209
+ past_key_values=model_kwargs.get("past_key_values"),
210
+ )
211
+ else:
212
+ return GenerateDecoderOnlyOutput(
213
+ sequences=input_ids,
214
+ scores=scores,
215
+ logits=raw_logits,
216
+ attentions=decoder_attentions,
217
+ hidden_states=decoder_hidden_states,
218
+ past_key_values=model_kwargs.get("past_key_values"),
219
+ )
220
+ else:
221
+ return input_ids
soulxpodcast/models/soulxpodcast.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from datetime import datetime
3
+ from itertools import chain
4
+ from tqdm import tqdm
5
+ from copy import deepcopy
6
+
7
+ import numpy as np
8
+ import s3tokenizer
9
+ import torch
10
+
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
12
+ from soulxpodcast.config import Config, SamplingParams, AutoPretrainedConfig
13
+ from soulxpodcast.engine.llm_engine import (
14
+ HFLLMEngine, VLLMEngine
15
+ )
16
+ from soulxpodcast.models.modules.flow import CausalMaskedDiffWithXvec
17
+ from soulxpodcast.models.modules.hifigan import HiFTGenerator
18
+
19
+ class SoulXPodcast(torch.nn.Module):
20
+ def __init__(self, config: Config = None):
21
+ super().__init__()
22
+ self.config = Config() if config is None else config
23
+
24
+ self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()
25
+ if self.config.llm_engine == "hf":
26
+ self.llm = HFLLMEngine(**self.config.__dict__)
27
+ elif self.config.llm_engine == "vllm":
28
+ self.llm = VLLMEngine(**self.config.__dict__)
29
+ else:
30
+ raise NotImplementedError
31
+
32
+ self.use_tqdm = True
33
+
34
+ self.flow = CausalMaskedDiffWithXvec()
35
+ if self.config.hf_config.fp16_flow:
36
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
37
+ tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
38
+ self.flow.half()
39
+ self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
40
+ self.flow.cuda().eval()
41
+
42
+ self.hift = HiFTGenerator()
43
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
44
+ self.hift.load_state_dict(hift_state_dict, strict=True)
45
+ self.hift.cuda().eval()
46
+
47
+
48
+ @torch.inference_mode()
49
+ def forward_longform(
50
+ self, prompt_mels_for_llm,
51
+ prompt_mels_lens_for_llm: torch.Tensor,
52
+ prompt_text_tokens_for_llm: list[list[int]],
53
+ text_tokens_for_llm: list[list[int]],
54
+ prompt_mels_for_flow_ori,
55
+ spk_emb_for_flow: torch.Tensor,
56
+ sampling_params: SamplingParams | list[SamplingParams],
57
+ spk_ids: list[list[int]],
58
+ use_prompt_cot: bool = False,
59
+ prompt_cot_text_tokens_for_llm: list[list[int]] = None,
60
+ prompt_cot_prefix: list[list[int]] = None,
61
+ **kwargs, # for compatibility
62
+ ):
63
+ prompt_size, turn_size = len(prompt_mels_for_llm), len(text_tokens_for_llm)
64
+
65
+ # Audio tokenization
66
+ prompt_speech_tokens_ori, prompt_speech_tokens_lens_ori = self.audio_tokenizer.quantize(
67
+ prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
68
+ )
69
+
70
+ # align speech token with speech feat as to reduce
71
+ # the noise ratio during the generation process.
72
+ prompt_speech_tokens = []
73
+ prompt_mels_for_flow, prompt_mels_lens_for_flow = [], []
74
+
75
+ for prompt_index in range(prompt_size):
76
+ prompt_speech_token_len = prompt_speech_tokens_lens_ori[prompt_index].item()
77
+ prompt_speech_token = prompt_speech_tokens_ori[prompt_index, :prompt_speech_token_len]
78
+ prompt_mel = prompt_mels_for_flow_ori[prompt_index]
79
+ prompt_mel_len = prompt_mel.shape[0]
80
+ if prompt_speech_token_len * 2 > prompt_mel_len:
81
+ prompt_speech_token = prompt_speech_token[:int(prompt_mel_len/2)]
82
+ prompt_mel_len = torch.tensor([prompt_mel_len]).cuda()
83
+ else:
84
+ prompt_mel = prompt_mel.detach().clone()[:prompt_speech_token_len * 2].cuda()
85
+ prompt_mel_len = torch.tensor([prompt_speech_token_len * 2]).cuda()
86
+ prompt_speech_tokens.append(prompt_speech_token)
87
+ prompt_mels_for_flow.append(prompt_mel)
88
+ prompt_mels_lens_for_flow.append(prompt_mel_len)
89
+
90
+ # Prepare LLM inputs
91
+ prompt_inputs = []
92
+ history_inputs = []
93
+
94
+ # for i in range(prompt_size):
95
+ # prompt_mels = prompt_mels_for_flow[i][None]
96
+ # prompt_mels_lens = prompt_mels_lens_for_flow[i]
97
+ # spk_emb = spk_emb_for_flow[i:i+1]
98
+
99
+ # # Flow generation
100
+ # with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
101
+ # flow_input = torch.concat([prompt_speech_tokens[i].detach().clone(), prompt_speech_tokens[1].detach().clone()], axis=0)[None]
102
+ # flow_inputs_len = torch.tensor(flow_input.shape[1])[None]
103
+ # generated_mels, generated_mels_lens = self.flow(
104
+ # flow_input.cuda(), flow_inputs_len.cuda(),
105
+ # prompt_mels, prompt_mels_lens, spk_emb.cuda(),
106
+ # streaming=False, finalize=True
107
+ # )
108
+
109
+ # # HiFi-GAN generation
110
+ # mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()]
111
+ # wav, _ = self.hift(speech_feat=mel)
112
+ # import soundfile as sf
113
+ # sf.write(f"{str(i).zfill(2)}.wav", wav.cpu().squeeze(0).numpy(), 24000)
114
+
115
+ for i in range(prompt_size):
116
+ speech_tokens_i = [token+self.config.hf_config.speech_token_offset for token in prompt_speech_tokens[i].tolist()]
117
+ speech_tokens_i += [self.config.hf_config.eos_token_id]
118
+ if use_prompt_cot and len(prompt_cot_text_tokens_for_llm[i])>0:
119
+ prompt_cot_input = prompt_text_tokens_for_llm[i] + speech_tokens_i + prompt_cot_text_tokens_for_llm[i]
120
+ if i>0:
121
+ prompt_cot_input = prompt_cot_prefix[0] + prompt_cot_input
122
+ cot_input = self.llm.generate(prompt_cot_input, sampling_params, past_key_values=None)['token_ids']
123
+ prompt_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input)
124
+ history_inputs.append(prompt_cot_prefix[i+1]+prompt_cot_text_tokens_for_llm[i] + cot_input)
125
+ else:
126
+ prompt_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i )
127
+ history_inputs.append(prompt_text_tokens_for_llm[i] + speech_tokens_i )
128
+
129
+ generated_wavs, results_dict = [], {}
130
+
131
+ # LLM generation
132
+ inputs = list(chain.from_iterable(prompt_inputs))
133
+ cache_config = AutoPretrainedConfig().from_dataclass(self.llm.config.hf_config)
134
+ past_key_values = DynamicCache(config=cache_config)
135
+ valid_turn_size = prompt_size
136
+ for i in range(turn_size):
137
+
138
+ # # set ratio: reach the reset cache ratio;
139
+ if valid_turn_size > self.config.max_turn_size or len(inputs)>self.config.turn_tokens_threshold:
140
+ assert self.config.max_turn_size >= self.config.prompt_context + self.config.history_context, "Invalid Long history size setting, "
141
+ prompt_text_bound = max(self.config.prompt_context, len(history_inputs)-self.config.history_text_context-self.config.history_context)
142
+ inputs = list(chain.from_iterable(
143
+ history_inputs[:self.config.prompt_context]+ \
144
+ history_inputs[prompt_text_bound:-self.config.history_context]+ \
145
+ prompt_inputs[-self.config.history_context:]
146
+ ))
147
+ valid_turn_size = self.config.prompt_context + len(history_inputs) - prompt_text_bound
148
+ past_key_values = DynamicCache(config=cache_config)
149
+ valid_turn_size += 1
150
+
151
+ inputs.extend(text_tokens_for_llm[i])
152
+ start_time = time.time()
153
+ llm_outputs = self.llm.generate(inputs, sampling_params, past_key_values=past_key_values)
154
+
155
+ inputs.extend(llm_outputs['token_ids'])
156
+ prompt_inputs.append(text_tokens_for_llm[i]+llm_outputs['token_ids'])
157
+ history_inputs.append(text_tokens_for_llm[i][:-1]) # remove the <|audio_start|>
158
+
159
+ # Prepare Flow inputs
160
+ turn_spk = spk_ids[i]
161
+ generated_speech_tokens = [token - self.config.hf_config.speech_token_offset for token in llm_outputs['token_ids'][:-1]] # ignore last eos
162
+ prompt_speech_token = prompt_speech_tokens[turn_spk].tolist()
163
+ flow_input = torch.tensor([prompt_speech_token + generated_speech_tokens])
164
+ flow_inputs_len = torch.tensor([len(prompt_speech_token) + len(generated_speech_tokens)])
165
+
166
+
167
+ # Flow generation and HiFi-GAN generation
168
+ start_idx = spk_ids[i]
169
+ prompt_mels = prompt_mels_for_flow[start_idx][None]
170
+ prompt_mels_lens = prompt_mels_lens_for_flow[start_idx][None]
171
+ spk_emb = spk_emb_for_flow[start_idx:start_idx+1]
172
+
173
+ # Flow generation
174
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
175
+ generated_mels, generated_mels_lens = self.flow(
176
+ flow_input.cuda(), flow_inputs_len.cuda(),
177
+ prompt_mels, prompt_mels_lens, spk_emb.cuda(),
178
+ streaming=False, finalize=True
179
+ )
180
+
181
+ # HiFi-GAN generation
182
+ mel = generated_mels[:, :, prompt_mels_lens[0].item():generated_mels_lens[0].item()]
183
+ try:
184
+ wav, _ = self.hift(speech_feat=mel)
185
+ except Exception as e:
186
+ import pdb;pdb.set_trace()
187
+ print(e)
188
+ generated_wavs.append(wav)
189
+
190
+ # Save the generated wav;
191
+ results_dict['generated_wavs'] = generated_wavs
192
+ return results_dict
soulxpodcast/utils/__init__.py ADDED
File without changes