Sayoyo commited on
Commit
5488167
·
1 Parent(s): ba6d996

[feat] v1 commit

Browse files
Files changed (46) hide show
  1. .gitignore +164 -0
  2. apg_guidance.py +90 -0
  3. app.py +39 -0
  4. data_sampler.py +23 -0
  5. examples/input_params/output_20250426071706_0_input_params.json +25 -0
  6. examples/input_params/output_20250426071812_0_input_params.json +25 -0
  7. examples/input_params/output_20250426072346_0_input_params.json +25 -0
  8. examples/input_params/output_20250426072508_0_input_params.json +25 -0
  9. examples/input_params/output_20250426073829_0_input_params.json +25 -0
  10. examples/input_params/output_20250426074037_0_input_params.json +25 -0
  11. examples/input_params/output_20250426074214_0_input_params.json +25 -0
  12. examples/input_params/output_20250426074413_0_input_params.json +25 -0
  13. examples/input_params/output_20250426075107_0_input_params.json +25 -0
  14. examples/input_params/output_20250426075537_0_input_params.json +25 -0
  15. examples/input_params/output_20250426075843_0_input_params.json +25 -0
  16. examples/input_params/output_20250426080234_0_input_params.json +25 -0
  17. examples/input_params/output_20250426080407_0_input_params.json +25 -0
  18. examples/input_params/output_20250426080601_0_input_params.json +25 -0
  19. examples/input_params/output_20250426081134_0_input_params.json +25 -0
  20. examples/input_params/output_20250426091716_0_input_params.json +25 -0
  21. examples/input_params/output_20250426092025_0_input_params.json +25 -0
  22. examples/input_params/output_20250426093007_0_input_params.json +25 -0
  23. examples/input_params/output_20250426093146_0_input_params.json +25 -0
  24. language_segmentation/LangSegment.py +866 -0
  25. language_segmentation/__init__.py +9 -0
  26. language_segmentation/utils/__init__.py +0 -0
  27. language_segmentation/utils/num.py +327 -0
  28. models/ace_step_transformer.py +475 -0
  29. models/attention.py +319 -0
  30. models/config.json +23 -0
  31. models/customer_attention_processor.py +339 -0
  32. models/lyrics_utils/lyric_encoder.py +1070 -0
  33. models/lyrics_utils/lyric_normalizer.py +66 -0
  34. models/lyrics_utils/lyric_tokenizer.py +883 -0
  35. models/lyrics_utils/vocab.json +0 -0
  36. models/lyrics_utils/zh_num2words.py +1209 -0
  37. music_dcae/__init__.py +0 -0
  38. music_dcae/music_dcae_pipeline.py +150 -0
  39. music_dcae/music_log_mel.py +107 -0
  40. music_dcae/music_vocoder.py +576 -0
  41. packages.txt +1 -0
  42. pipeline_ace_step.py +735 -0
  43. requirements.txt +22 -0
  44. schedulers/scheduling_flow_match_euler_discrete.py +394 -0
  45. schedulers/scheduling_flow_match_heun_discrete.py +348 -0
  46. ui/components.py +244 -0
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ checkpoints/
apg_guidance.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class MomentumBuffer:
5
+ def __init__(self, momentum: float = -0.75):
6
+ self.momentum = momentum
7
+ self.running_average = 0
8
+
9
+ def update(self, update_value: torch.Tensor):
10
+ new_average = self.momentum * self.running_average
11
+ self.running_average = update_value + new_average
12
+
13
+
14
+ def project(
15
+ v0: torch.Tensor, # [B, C, H, W]
16
+ v1: torch.Tensor, # [B, C, H, W]
17
+ dims=[-1, -2],
18
+ ):
19
+ dtype = v0.dtype
20
+ v0, v1 = v0.double(), v1.double()
21
+ v1 = torch.nn.functional.normalize(v1, dim=dims)
22
+ v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
23
+ v0_orthogonal = v0 - v0_parallel
24
+ return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
25
+
26
+
27
+ def apg_forward(
28
+ pred_cond: torch.Tensor, # [B, C, H, W]
29
+ pred_uncond: torch.Tensor, # [B, C, H, W]
30
+ guidance_scale: float,
31
+ momentum_buffer: MomentumBuffer = None,
32
+ eta: float = 0.0,
33
+ norm_threshold: float = 2.5,
34
+ dims=[-1, -2],
35
+ ):
36
+ diff = pred_cond - pred_uncond
37
+ if momentum_buffer is not None:
38
+ momentum_buffer.update(diff)
39
+ diff = momentum_buffer.running_average
40
+
41
+ if norm_threshold > 0:
42
+ ones = torch.ones_like(diff)
43
+ diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
44
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
45
+ diff = diff * scale_factor
46
+
47
+ diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
48
+ normalized_update = diff_orthogonal + eta * diff_parallel
49
+ pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
50
+ return pred_guided
51
+
52
+
53
+ def cfg_forward(cond_output, uncond_output, cfg_strength):
54
+ return uncond_output + cfg_strength * (cond_output - uncond_output)
55
+
56
+ def cfg_double_condition_forward(
57
+ cond_output,
58
+ uncond_output,
59
+ only_text_cond_output,
60
+ guidance_scale_text,
61
+ guidance_scale_lyric,
62
+ ):
63
+ return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output
64
+
65
+
66
+ def optimized_scale(positive_flat, negative_flat):
67
+
68
+ # Calculate dot production
69
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
70
+
71
+ # Squared norm of uncondition
72
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
73
+
74
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
75
+ st_star = dot_product / squared_norm
76
+
77
+ return st_star
78
+
79
+
80
+ def cfg_zero_star(noise_pred_with_cond, noise_pred_uncond, guidance_scale, i, zero_steps=1, use_zero_init=True):
81
+ bsz = noise_pred_with_cond.shape[0]
82
+ positive_flat = noise_pred_with_cond.view(bsz, -1)
83
+ negative_flat = noise_pred_uncond.view(bsz, -1)
84
+ alpha = optimized_scale(positive_flat, negative_flat)
85
+ alpha = alpha.view(bsz, 1, 1, 1)
86
+ if (i <= zero_steps) and use_zero_init:
87
+ noise_pred = noise_pred_with_cond * 0.
88
+ else:
89
+ noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
90
+ return noise_pred
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from ui.components import create_main_demo_ui
3
+ from pipeline_ace_step import ACEStepPipeline
4
+ from data_sampler import DataSampler
5
+ import os
6
+
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--checkpoint_path", type=str, default=None)
10
+ parser.add_argument("--port", type=int, default=7860)
11
+ parser.add_argument("--device_id", type=int, default=0)
12
+ parser.add_argument("--share", action='store_true', default=False)
13
+ parser.add_argument("--bf16", action='store_true', default=True)
14
+
15
+ args = parser.parse_args()
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
17
+
18
+
19
+ persistent_storage_path = "/data"
20
+
21
+
22
+ def main(args):
23
+
24
+ model_demo = ACEStepPipeline(
25
+ checkpoint_dir=args.checkpoint_path,
26
+ dtype="bfloat16" if args.bf16 else "float32",
27
+ persistent_storage_path=persistent_storage_path
28
+ )
29
+ data_sampler = DataSampler()
30
+
31
+ demo = create_main_demo_ui(
32
+ text2music_process_func=model_demo.__call__,
33
+ sample_data_func=data_sampler.sample,
34
+ )
35
+ demo.launch()
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main(args)
data_sampler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import random
4
+
5
+
6
+ DEFAULT_ROOT_DIR = "examples/input_params"
7
+
8
+
9
+ class DataSampler:
10
+ def __init__(self, root_dir=DEFAULT_ROOT_DIR):
11
+ self.root_dir = root_dir
12
+
13
+ # glob
14
+ self.input_params_files = list(Path(self.root_dir).glob("*.json"))
15
+
16
+ def load_json(self, file_path):
17
+ with open(file_path, "r", encoding="utf-8") as f:
18
+ return json.load(f)
19
+
20
+ def sample(self):
21
+ json_path = random.choice(self.input_params_files)
22
+ json_data = self.load_json(json_path)
23
+ return json_data
examples/input_params/output_20250426071706_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "pop, rap, electronic, blues, hip-house, rhythm and blues",
3
+ "lyrics": "[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意",
4
+ "audio_duration": 170.63997916666668,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 3.191075086593628,
19
+ "diffusion": 17.459356784820557,
20
+ "latent2audio": 1.7095518112182617
21
+ },
22
+ "actual_seeds": [
23
+ 3299954530
24
+ ]
25
+ }
examples/input_params/output_20250426071812_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "country rock, folk rock, southern rock, bluegrass, country pop",
3
+ "lyrics": "[verse]\nWoke up to the sunrise glow\nTook my heart and hit the road\nWheels hummin' the only tune I know\nStraight to where the wildflowers grow\n\n[verse]\nGot that old map all wrinkled and torn\nDestination unknown but I'm reborn\nWith a smile that the wind has worn\nChasin' dreams that can't be sworn\n\n[chorus]\nRidin' on a highway to sunshine\nGot my shades and my radio on fine\nLeave the shadows in the rearview rhyme\nHeart's racing as we chase the time\n\n[verse]\nMet a girl with a heart of gold\nTold stories that never get old\nHer laugh like a tale that's been told\nA melody so bold yet uncontrolled\n\n[bridge]\nClouds roll by like silent ghosts\nAs we drive along the coast\nWe toast to the days we love the most\nFreedom's song is what we post\n\n[chorus]\nRidin' on a highway to sunshine\nGot my shades and my radio on fine\nLeave the shadows in the rearview rhyme\nHeart's racing as we chase the time",
4
+ "audio_duration": 224.23997916666667,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 4.262240648269653,
19
+ "diffusion": 15.380569219589233,
20
+ "latent2audio": 2.3227272033691406
21
+ },
22
+ "actual_seeds": [
23
+ 401640
24
+ ]
25
+ }
examples/input_params/output_20250426072346_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "hip-house, funk",
3
+ "lyrics": "[verse]\n哎呀跳起来,脚尖踩节拍 (oo-yeah!)\n灯光闪烁像星星盛开 (uh-huh!)\n人人都醒来,把烦恼踹开 (get it!)\n热血沸腾,汗水自己安排\n\n[chorus]\n嘿,你还等啥?快抓住节拍 (come on!)\n光芒指引,让心都不存在 (whoa!)\n点燃热火,我们一起飙high (let’s go!)\n跳入午夜的狂欢时代\n\n[bridge]\n咚咚鼓声啊,让你的灵魂起飞 (woo!)\n手心拍一拍,能量翻倍 (ah-hah!)\n键盘响起来,如宇宙的交汇 (oh yeah!)\n就是这感觉,兄弟姐妹都陶醉\n\n[verse]\n灵魂从不睡,只想继续燃烧 (woo!)\n节奏像热浪,席卷这街道 (ow!)\n大伙儿涌上楼台,满面微笑 (yeah!)\n这一刻属于我们,无可替代\n\n[chorus]\n嘿,你还等啥?快抓住节拍 (come on!)\n光芒指引,让心都不存在 (whoa!)\n点燃热火,我们一起飙high (let’s go!)\n跳入午夜的狂欢时代\n\n[verse]\n世界多精彩,握紧把它打开 (alright!)\n每一步都像星球在摇摆 (uh-huh!)\n无边无际的律动像大海 (oo-yeah!)\n跟着光芒之舞,一起澎湃",
4
+ "audio_duration": 204.19997916666668,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.05196118354797363,
19
+ "diffusion": 15.530808210372925,
20
+ "latent2audio": 2.5604095458984375
21
+ },
22
+ "actual_seeds": [
23
+ 401640
24
+ ]
25
+ }
examples/input_params/output_20250426072508_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
3
+ "lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin' in my chest\nHeartbeats match the city's zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song\n\n[verse]\nGuitar strings they start to weep\nWake the soul from silent sleep\nEvery note a story told\nIn this night we’re bold and gold\n\n[bridge]\nVoices blend in harmony\nLost in pure cacophony\nTimeless echoes timeless cries\nSoulful shouts beneath the skies\n\n[verse]\nKeyboard dances on the keys\nMelodies on evening breeze\nCatch the tune and hold it tight\nIn this moment we take flight",
4
+ "audio_duration": 178.87997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.02882218360900879,
19
+ "diffusion": 16.91233205795288,
20
+ "latent2audio": 1.7794082164764404
21
+ },
22
+ "actual_seeds": [
23
+ 401640
24
+ ]
25
+ }
examples/input_params/output_20250426073829_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "electronic rap",
3
+ "lyrics": "[verse]\nWaves on the bass, pulsing in the speakers,\nTurn the dial up, we chasing six-figure features,\nGrinding on the beats, codes in the creases,\nDigital hustler, midnight in sneakers.\n\n[chorus]\nElectro vibes, hearts beat with the hum,\nUrban legends ride, we ain't ever numb,\nCircuits sparking live, tapping on the drum,\nLiving on the edge, never succumb.\n\n[verse]\nSynthesizers blaze, city lights a glow,\nRhythm in the haze, moving with the flow,\nSwagger on stage, energy to blow,\nFrom the blocks to the booth, you already know.\n\n[bridge]\nNight's electric, streets full of dreams,\nBass hits collective, bursting at seams,\nHustle perspective, all in the schemes,\nRise and reflective, ain't no in-betweens.\n\n[verse]\nVibin' with the crew, sync in the wire,\nGot the dance moves, fire in the attire,\nRhythm and blues, soul's our supplier,\nRun the digital zoo, higher and higher.\n\n[chorus]\nElectro vibes, hearts beat with the hum,\nUrban legends ride, we ain't ever numb,\nCircuits sparking live, tapping on the drum,\nLiving on the edge, never succumb.",
4
+ "audio_duration": 221.42547916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.024875164031982422,
19
+ "diffusion": 20.566852569580078,
20
+ "latent2audio": 2.2281734943389893
21
+ },
22
+ "actual_seeds": [
23
+ 401640
24
+ ]
25
+ }
examples/input_params/output_20250426074037_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "electronic, house, electro house, synthesizer, drums, bass, percussion, fast, energetic, uplifting, exciting",
3
+ "lyrics": "[verse]\n霓虹灯下我们追逐\n人群跃动像潮水满布\n热浪袭来吹散孤独\n跳进节奏不如停下脚步\n\n[pre-chorus]\n脚尖触电快点感受\n迎着风声释放自由\n心跳节拍配合节奏\n一切烦恼请靠边游\n\n[chorus]\n夏夜狂奔没有尽头\n星光闪烁舞池不朽\n尽情挥洒所有节奏\n无边热情把你包裹哦\n\n[verse]\n天空翻滚黑云入夜\n每颗星星像音乐律贴\n耳边回响那低音线\n环绕耳际如梦境般甜\n\n[pre-chorus]\n脚尖触电快点感受\n迎着风声释放自由\n心跳节拍配合节奏\n一切烦恼请靠边游\n\n[chorus]\n夏夜狂奔没有尽头\n星光闪烁舞池不朽\n尽情挥洒所有节奏\n无边热情把你包裹哦",
4
+ "audio_duration": 221.47997916666668,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.028400182723999023,
19
+ "diffusion": 13.195815324783325,
20
+ "latent2audio": 2.1679723262786865
21
+ },
22
+ "actual_seeds": [
23
+ 3440445703
24
+ ]
25
+ }
examples/input_params/output_20250426074214_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "synth-pop, electronic, pop, synthesizer, drums, bass, piano, 128 BPM, energetic, uplifting, modern",
3
+ "lyrics": "[verse]\nWoke up in a city that's always alive\nNeon lights they shimmer they thrive\nElectric pulses beat they drive\nMy heart races just to survive\n\n[chorus]\nOh electric dreams they keep me high\nThrough the wires I soar and fly\nMidnight rhythms in the sky\nElectric dreams together we’ll defy\n\n[verse]\nLost in the labyrinth of screens\nVirtual love or so it seems\nIn the night the city gleams\nDigital faces haunted by memes\n\n[chorus]\nOh electric dreams they keep me high\nThrough the wires I soar and fly\nMidnight rhythms in the sky\nElectric dreams together we’ll defy\n\n[bridge]\nSilent whispers in my ear\nPixelated love serene and clear\nThrough the chaos find you near\nIn electric dreams no fear\n\n[verse]\nBound by circuits intertwined\nLove like ours is hard to find\nIn this world we’re truly blind\nBut electric dreams free the mind",
4
+ "audio_duration": 221.27997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.025463581085205078,
19
+ "diffusion": 15.243804454803467,
20
+ "latent2audio": 2.170398473739624
21
+ },
22
+ "actual_seeds": [
23
+ 3400270027
24
+ ]
25
+ }
examples/input_params/output_20250426074413_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Cuban music, salsa, son, Afro-Cuban, traditional Cuban",
3
+ "lyrics": "[verse]\nSun dips low the night ignites\nBassline hums with gleaming lights\nElectric guitar singing tales so fine\nIn the rhythm we all intertwine\n\n[verse]\nDrums beat steady calling out\nPercussion guides no room for doubt\nElectric pulse through every vein\nDance away every ounce of pain\n\n[chorus]\nFeel the rhythm feel the flow\nLet the music take control\nBassline deep electric hum\nIn this night we're never numb\n\n[bridge]\nStars above they start to glow\nEchoes of the night's soft glow\nElectric strings weave through the air\nIn this moment none compare\n\n[verse]\nHeartbeats sync with every tone\nLost in music never alone\nElectric tales of love and peace\nIn this groove we find release\n\n[chorus]\nFeel the rhythm feel the flow\nLet the music take control\nBassline deep electric hum\nIn this night we're never numb",
4
+ "audio_duration": 208.27997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.026132583618164062,
19
+ "diffusion": 15.139378070831299,
20
+ "latent2audio": 2.2071540355682373
21
+ },
22
+ "actual_seeds": [
23
+ 3358899399
24
+ ]
25
+ }
examples/input_params/output_20250426075107_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "pop, piano, rap, dark, atmospheric",
3
+ "lyrics": "[verse]\n月光爬上窗 染白冷的床\n心跳的方向 带我入迷惘\n黑夜吞噬光 命运的纸张\n爱是血色霜 邪恶又芬芳\n\n[chorus]\n你是猎人的欲望 我是迷途的小羊\n深陷你眼眸的荒 唐突献出心脏\n我在夜里回荡 是谁给我希望\n黑暗风中飘荡 假装不再受伤\n\n[verse]\n心锁在门外 谁会解开关怀\n温柔的手拍 藏着冷酷杀害\n思绪如尘埃 撞击爱的霹雳\n灵魂的独白 为你沾满血迹\n\n[bridge]\n你是噩梦的歌唱 是灵魂的捆绑\n绝望中带着光 悬崖边的渴望\n心跳被你鼓掌 恶魔也痴痴想\n渐渐没了抵抗 古老诡计流淌\n\n[chorus]\n你是猎人的欲望 我是迷途的小羊\n深陷你眼眸的荒 唐突献出心脏\n我在夜里回荡 是谁给我希望\n黑暗风中飘荡 假装不再受伤\n\n[outro]\n爱如月黑无光 渗进梦的战场\n逃入无声的场 放手或心嚷嚷\n隐秘的极端 爱是极致风浪\n灵魂彻底交偿 你是终极虚妄",
4
+ "audio_duration": 146.91997916666668,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.03876018524169922,
19
+ "diffusion": 15.962624549865723,
20
+ "latent2audio": 1.4594337940216064
21
+ },
22
+ "actual_seeds": [
23
+ 2065110378
24
+ ]
25
+ }
examples/input_params/output_20250426075537_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "surf music",
3
+ "lyrics": "[verse]\nSunshine on the boulevard the beach is calling loud\nWaves are dancing golden sand under a cotton cloud\nElectric heartbeat pounding fast the tide is on our side\nCatch a wave and feel alive we’ll take it for a ride\n\n[verse]\nPalm trees swaying left to right they know where we belong\nFeel the rhythm of the night it keeps us moving strong\nSea spray kisses salty air we’re flying with the breeze\nChampagne states of mind we ride we do just as we please\n\n[chorus]\nWe’re riding waves of life together hand in hand\nWith every beat we chase the beat it’s our own wonderland\nFeel the music take you higher as the shorelines blur\nThis is our world our endless summer as we live and learn\n\n[bridge]\nMoonlight paints the ocean blue reflections in our eyes\nStars align to light our path we’re surfing through the skies\nEvery moment like a song we sing it loud and clear\nEvery day’s a new adventure with you always near\n\n[verse]\nNeon lights and city sounds they blend with ocean views\nWe’re unstoppable tonight no way that we can lose\nDreams are written in the sand they sparkle in the sun\nTogether we’re a masterpiece our story’s just begun\n\n[chorus]\nWe’re riding waves of life together hand in hand\nWith every beat we chase the beat it’s our own wonderland\nFeel the music take you higher as the shorelines blur\nThis is our world our endless summer as we live and learn",
4
+ "audio_duration": 236.55997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.033666133880615234,
19
+ "diffusion": 16.291455507278442,
20
+ "latent2audio": 2.3726775646209717
21
+ },
22
+ "actual_seeds": [
23
+ 508630535
24
+ ]
25
+ }
examples/input_params/output_20250426075843_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "alternative rock, pop, rock",
3
+ "lyrics": "[verse]\nBright lights flashing in the city sky\nRunning fast and we don't know why\nElectric nights got our hearts on fire\nChasing dreams we'll never tire\n\n[verse]\nGrit in our eyes wind in our hair\nBreaking rules we don't even care\nShouting loud above the crowd\nLiving life like we're unbowed\n\n[chorus]\nRunning wild in the night so free\nFeel the beat pumping endlessly\nHearts collide in the midnight air\nWe belong we don't have a care\n\n[verse]\nPiercing through like a lightning strike\nEvery moment feels like a hike\nDaring bold never backing down\nKings and queens without a crown\n\n[chorus]\nRunning wild in the night so free\nFeel the beat pumping endlessly\nHearts collide in the midnight air\nWe belong we don't have a care\n\n[bridge]\nClose your eyes let your spirit soar\nWe are the ones who wanted more\nBreaking chains of the mundane\nIn this world we'll make our claim",
4
+ "audio_duration": 202.19997916666668,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.02512216567993164,
19
+ "diffusion": 18.860822677612305,
20
+ "latent2audio": 2.0361969470977783
21
+ },
22
+ "actual_seeds": [
23
+ 1255121549
24
+ ]
25
+ }
examples/input_params/output_20250426080234_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "rock, hip - hop, orchestral, bass, drums, electric guitar, piano, synthesizer, violin, viola, cello, fast, energetic, motivational, inspirational, empowering",
3
+ "lyrics": "### **[Intro – Spoken]** \n*\"The streets whisper, their echoes never fade. \nEvery step I take leaves a mark—this ain't just a game.\"* \n\n### **[Hook/Chorus]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Verse 1]** \nCold nights, empty pockets, dreams laced with fight, \nEvery loss made me sharper, cut deep like a knife. \nThey said I wouldn’t make it, now they watch in despair, \nFrom the curb to the throne, took the pain, made it rare. \nEvery siren’s a melody, every alley holds a tale, \nRose from the shadows, left my name on the trail. \nStreetlights flicker like warnings in the haze, \nBut I move like a phantom, unfazed by the blaze. \n\n### **[Hook/Chorus]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Verse 2]** \nBarbed wire fences couldn't lock in my mind, \nEvery cage they designed, I left broken behind. \nThey want control, but I’m destined to roam, \nWhere the lost find their voice, where the heart sets the tone. \nSteel and concrete, where the lessons run deep, \nEvery crack in the pavement tells a story of heat. \nBut I rise, undefeated, like a king with no throne, \nWriting scripts in the struggle, my legacy’s stone. \n\n### **[Bridge]** \nFeel the rhythm of the underground roar, \nEvery wound tells a story of the battles before. \nBlood, sweat, and echoes fill the cold midnight, \nBut we move with the fire—unshaken, upright. \n\n### **[Verse 3]** \nNo regrets, no retreat, this game has no pause, \nEvery step that I take is a win for the lost. \nI took lessons from hustlers, wisdom from pain, \nNow the echoes of struggle carve power in my name. \nThey built walls, but I walk through the cracks, \nTurned dirt into gold, never looked back. \nThrough the struggle we rise, through the fire we claim, \nThis is more than just music—it's life in the frame. \n\n### **[Hook/Chorus – Reprise]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Outro – Spoken]** \n*\"The scars, the struggle, the grind—it’s all part of the rhythm. \nWe never break, we never fold. We rise.\"*",
4
+ "audio_duration": 153.95997916666667,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.04368758201599121,
19
+ "diffusion": 17.16369390487671,
20
+ "latent2audio": 1.5405471324920654
21
+ },
22
+ "actual_seeds": [
23
+ 2659225017
24
+ ]
25
+ }
examples/input_params/output_20250426080407_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "tango finlandés, campanas, disco, dark pop, electro, guitarra clásica, corridos tumba",
3
+ "lyrics": "[inst]",
4
+ "audio_duration": 162.79997916666667,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.011058568954467773,
19
+ "diffusion": 9.924944400787354,
20
+ "latent2audio": 1.6034839153289795
21
+ },
22
+ "actual_seeds": [
23
+ 780297686
24
+ ]
25
+ }
examples/input_params/output_20250426080601_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Nightclubs, dance parties, workout playlists, radio broadcasts",
3
+ "lyrics": "Burning in motion, set me alight!\nEvery heartbeat turns into a fight!\nCaged in rhythm, chained in time!\nLove’s a battle— You're Mine! You're Mine!",
4
+ "audio_duration": 221.83997916666667,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.012485980987548828,
19
+ "diffusion": 14.345409154891968,
20
+ "latent2audio": 2.174558639526367
21
+ },
22
+ "actual_seeds": [
23
+ 1318394052
24
+ ]
25
+ }
examples/input_params/output_20250426081134_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "melancholic, world, sad, medieval, soulful",
3
+ "lyrics": "[Verse]\nIn a world so grand he roams the skies alone\nHis heart a heavy stone a tale untold\nWhispers of his past echo through the night\nA lonely dragon searching for the light\n\n[Verse 2]\nOnce a mighty force now he drifts in pain\nHis scales once shimmered now they're dark with shame\nCast out by his kin in shadows he does hide\nA haunting sorrow burns deep inside\n\n[Chorus]\nRoaming endless fields with no friend in sight\nHis roar a mournful cry beneath the moon's pale light\nTears fall like stars as he flies on his way\nA lonely dragon yearning for the break of day\n\n[Bridge]\nThe world turns cold the nights grow long\nIn his heart he carries an ancient song\nOf battles fought and love long gone\nA legend now but his soul is torn\n\n[Verse 3]\nHoping for a day he'll find a kindred soul\nTo share his pain and make him whole\nTill then he drifts a shadow in the sky\nA lonely dragon with tears in his eye\n\n[Chorus]\nRoaming endless fields with no friend in sight\nHis roar a mournful cry beneath the moon's pale light\nTears fall like stars as he flies on his way\nA lonely dragon yearning for the break of day",
4
+ "audio_duration": 239.99997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.029100656509399414,
19
+ "diffusion": 22.503791570663452,
20
+ "latent2audio": 2.3603708744049072
21
+ },
22
+ "actual_seeds": [
23
+ 2166832218
24
+ ]
25
+ }
examples/input_params/output_20250426091716_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "anime, cute female vocals, kawaii pop, j-pop, childish, piano, guitar, synthesizer, fast, happy, cheerful, lighthearted",
3
+ "lyrics": "[Chorus]\nねぇ、顔が赤いよ?\nどうしたの? 熱があるの?\nそれとも怒ってるの?\nねぇ、言ってよ!\n\nどうしてそんな目で見るの?\n私、悪いことした?\n何か間違えたの?\nお願い、やめて… 怖いから…\nだから、やめてよ…\n\n[Bridge]\n目を閉じて、くるっと背を向けて、\n何も見なかったフリするから、\n怒らないで… 許してよ…\n\n[Chorus]\nねぇ、顔が赤いよ?\nどうしたの? 熱があるの?\nそれとも怒ってるの?\nねぇ、言ってよ!\n\nどうしてそんな目で見るの?\n私、悪いことした?\n何か間違えたの?\nお願い、やめて… 怖いから…\nだから、やめてよ…\n\n[Bridge 2]\n待って、もし私が悪いなら、\nごめんなさいって言うから、\nアイスクリームあげるから、\nもう怒らないで?\n\nOoooh… 言ってよ!",
4
+ "audio_duration": 160,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.0282442569732666,
19
+ "diffusion": 12.104875326156616,
20
+ "latent2audio": 1.587641954421997
21
+ },
22
+ "actual_seeds": [
23
+ 4028738662
24
+ ]
25
+ }
examples/input_params/output_20250426092025_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "dark, death rock, metal, hardcore, electric guitar, powerful, bass, drums, 110 bpm, G major",
3
+ "lyrics": "[Verse]\nMy lovers betray me\nThe snake in my garden is hissing\nIn the air is the sweetness of roses\nAnd under my skin\nThere's a thorn\n\n[Verse 2]\nI should have known\nThat God sends his angel in shadows\nWith blood in his veins\nI watch the enemy\nGivin' me the hand of my savior\n\n[Chorus]\nAnd I can't love again\nWith the echo of your name in my head\nWith the demons in my bed\nWith the memories\nYour ghost\nI see it\n'Cause it comes to haunt me\nJust to taunt me\nIt comes to haunt me\nJust to taunt me\n\n[Verse 3]\nWith sugar and spice\nIt's hard to ignore the nostalgia\nWith the men on their knees\nAt the gates of my heart\nHow they beg me\n\n[Verse 4]\nThey say\n\"No one will ever love you\nThe way that I do\nNo one will ever touch you\nThe way that I do\"\n\n[Chorus]\nAnd I can't love again\nWith the echo of your name in my head\nWith the demons in my bed\nWith the memories\nYour ghost\nI see it\n'Cause it comes to haunt me\nJust to taunt me\nIt comes to haunt me\nJust to taunt me",
4
+ "audio_duration": 174.27997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 3.8372838497161865,
19
+ "diffusion": 13.039669275283813,
20
+ "latent2audio": 1.7923030853271484
21
+ },
22
+ "actual_seeds": [
23
+ 4064916393
24
+ ]
25
+ }
examples/input_params/output_20250426093007_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "aggressive, Heavy Riffs, Blast Beats, Satanic Black Metal",
3
+ "lyrics": "[verse]\nFloating through the galaxy on a midnight ride\nStars are dancing all around in cosmic tides\nFeel the pulse of space and time beneath our feet\nEvery beat a heartbeat in this endless suite\n\n[chorus]\nGalactic dreams under neon lights\nSailing through the velvet nights\nWe are echoes in a cosmic sea\nIn a universe where we are free\n\n[verse]\nPlanetary whispers in the sky tonight\nEvery constellation's got a secret sight\nDistant worlds and moons we have yet to see\nIn the void of space where we can just be\n\n[bridge]\nAsteroids and comets in a ballet they spin\nLost in the rhythm of where our dreams begin\nClose your eyes and let the synths take flight\nWe're voyagers on an electric night\n\n[verse]\nLet the piano keys unlock the stars above\nEvery chord a memory every note is love\nIn this synth symphony we find our grace\nDrifting forever in this boundless space\n\n[chorus]\nGalactic dreams under neon lights\nSailing through the velvet nights\nWe are echoes in a cosmic sea\nIn a universe where we are free",
4
+ "audio_duration": 181.99997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.025065898895263672,
19
+ "diffusion": 17.176705837249756,
20
+ "latent2audio": 1.8225171566009521
21
+ },
22
+ "actual_seeds": [
23
+ 1132623236
24
+ ]
25
+ }
examples/input_params/output_20250426093146_0_input_params.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "r&b, soul, funk/soul",
3
+ "lyrics": "[verse]\nDancing through electric fires\nHeart is buzzing like live wires\nIn your arms I find desire\nFeel the beat as we get higher\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why\n\n[verse]\nWhisper secrets that make me blush\nUnder the neon city hush\nYour touch gives me such a rush\nTurn it up we're feeling lush\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why\n\n[bridge]\nThrough the lights and the smoky haze\nI see you in a thousand ways\nLove's a script and we’re the play\nTurn the page stay till we sway\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why",
4
+ "audio_duration": 195.15997916666666,
5
+ "infer_step": 60,
6
+ "guidance_scale": 15,
7
+ "scheduler_type": "euler",
8
+ "cfg_type": "apg",
9
+ "omega_scale": 10,
10
+ "guidance_interval": 0.5,
11
+ "guidance_interval_decay": 0,
12
+ "min_guidance_scale": 3,
13
+ "use_erg_tag": true,
14
+ "use_erg_lyric": true,
15
+ "use_erg_diffusion": true,
16
+ "oss_steps": [],
17
+ "timecosts": {
18
+ "preprocess": 0.025553464889526367,
19
+ "diffusion": 18.250118494033813,
20
+ "latent2audio": 1.9400627613067627
21
+ },
22
+ "actual_seeds": [
23
+ 2853131993
24
+ ]
25
+ }
language_segmentation/LangSegment.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file bundles language identification functions.
3
+
4
+ Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
5
+
6
+ Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
7
+ Based on research by Marco Lui and Tim Baldwin.
8
+
9
+ See LICENSE file for more info.
10
+ https://github.com/adbar/py3langid
11
+
12
+ Projects:
13
+ https://github.com/juntaosun/LangSegment
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import sys
19
+ import numpy as np
20
+ from collections import Counter
21
+ from collections import defaultdict
22
+
23
+ # import langid
24
+ # import py3langid as langid
25
+ # pip install py3langid==0.2.2
26
+
27
+ # 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
28
+ # langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
29
+ # For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
30
+ from py3langid.langid import LanguageIdentifier, MODEL_FILE
31
+
32
+ # Digital processing
33
+ try:from .utils.num import num2str
34
+ except ImportError:
35
+ try:from utils.num import num2str
36
+ except ImportError as e:
37
+ raise e
38
+
39
+ # -----------------------------------
40
+ # 更新日志:新版本分词更加精准。
41
+ # Changelog: The new version of the word segmentation is more accurate.
42
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
43
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
44
+ # -----------------------------------
45
+
46
+
47
+ # Word segmentation function:
48
+ # automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
49
+ # making it more suitable for TTS processing.
50
+ # This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
51
+ # This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
52
+
53
+ #===========================================================================================================
54
+ #分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
55
+ #このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
56
+ #===========================================================================================================
57
+ #(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
58
+ #(2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
59
+ #この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
60
+ #===========================================================================================================
61
+
62
+ #===========================================================================================================
63
+ # 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
64
+ # 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
65
+ #===========================================================================================================
66
+ # (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
67
+ # (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
68
+ # 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
69
+ #===========================================================================================================
70
+
71
+ # ===========================================================================================================
72
+ # 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
73
+ # 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
74
+ # ===========================================================================================================
75
+ # (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
76
+ # (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
77
+ # 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
78
+ # ===========================================================================================================
79
+
80
+
81
+ # 手动分词标签规范:<语言标签>文本内容</语言标签>
82
+ # 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
83
+ # Manual word segmentation tag specification: <language tags> text content </language tags>
84
+ # 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
85
+ # ===========================================================================================================
86
+ # For manual word segmentation, labels need to appear in pairs, such as:
87
+ # 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
88
+ # 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
89
+ # Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
90
+ # ===========================================================================================================
91
+
92
+
93
+ # ===========================================================================================================
94
+ # 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
95
+ # 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
96
+ # 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
97
+ # Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
98
+ # ===========================================================================================================
99
+ # 中文实现:Chinese implementation:
100
+ # 【SSML】<number>=中文大写数字读法(单字)
101
+ # 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
102
+ # 【SSML】<currency>=按金额发音。
103
+ # 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
104
+ # ===========================================================================================================
105
+ class LangSSML:
106
+
107
+ def __init__(self):
108
+ # 纯数字
109
+ self._zh_numerals_number = {
110
+ '0': '零',
111
+ '1': '一',
112
+ '2': '二',
113
+ '3': '三',
114
+ '4': '四',
115
+ '5': '五',
116
+ '6': '六',
117
+ '7': '七',
118
+ '8': '八',
119
+ '9': '九'
120
+ }
121
+
122
+ # 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
123
+ # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
124
+ def _format_chinese_data(self, date_str:str):
125
+ # 处理日期格式
126
+ input_date = date_str
127
+ if date_str is None or date_str.strip() == "":return ""
128
+ date_str = re.sub(r"[\/\._|年|月]","-",date_str)
129
+ date_str = re.sub(r"日",r"",date_str)
130
+ date_arrs = date_str.split(' ')
131
+ if len(date_arrs) == 1 and ":" in date_arrs[0]:
132
+ time_str = date_arrs[0]
133
+ date_arrs = []
134
+ else:
135
+ time_str = date_arrs[1] if len(date_arrs) >=2 else ""
136
+ def nonZero(num,cn,func=None):
137
+ if func is not None:num=func(num)
138
+ return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
139
+ f_number = self.to_chinese_number
140
+ f_currency = self.to_chinese_currency
141
+ # year, month, day
142
+ year_month_day = ""
143
+ if len(date_arrs) > 0:
144
+ year, month, day = "","",""
145
+ parts = date_arrs[0].split('-')
146
+ if len(parts) == 3: # 格式为 YYYY-MM-DD
147
+ year, month, day = parts
148
+ elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
149
+ if len(parts[0]) == 4: # 年-月
150
+ year, month = parts
151
+ else:month, day = parts # 月-日
152
+ elif len(parts[0]) > 0: # 仅有月-日或年
153
+ if len(parts[0]) == 4:
154
+ year = parts[0]
155
+ else:day = parts[0]
156
+ year,month,day = nonZero(year,"年",f_number),nonZero(month,"月",f_currency),nonZero(day,"日",f_currency)
157
+ year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
158
+ # hours, minutes, seconds
159
+ time_str = re.sub(r"[\/\.\-:_]",":",time_str)
160
+ time_arrs = time_str.split(":")
161
+ hours, minutes, seconds = "","",""
162
+ if len(time_arrs) == 3: # H/M/S
163
+ hours, minutes, seconds = time_arrs
164
+ elif len(time_arrs) == 2:# H/M
165
+ hours, minutes = time_arrs
166
+ elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}点' # H
167
+ if len(time_arrs) > 1:
168
+ hours, minutes, seconds = nonZero(hours,"点",f_currency),nonZero(minutes,"分",f_currency),nonZero(seconds,"秒",f_currency)
169
+ hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
170
+ output_date = f"{year_month_day}{hours_minutes_seconds}"
171
+ return output_date
172
+
173
+ # 【SSML】number=中文大写数字读法(单字)
174
+ # Chinese Numbers(single word)
175
+ def to_chinese_number(self, num:str):
176
+ pattern = r'(\d+)'
177
+ zh_numerals = self._zh_numerals_number
178
+ arrs = re.split(pattern, num)
179
+ output = ""
180
+ for item in arrs:
181
+ if re.match(pattern,item):
182
+ output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
183
+ else:output += item
184
+ output = output.replace(".","点")
185
+ return output
186
+
187
+ # 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
188
+ # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
189
+ def to_chinese_telephone(self, num:str):
190
+ output = self.to_chinese_number(num.replace("+86","")) # zh +86
191
+ output = output.replace("一","幺")
192
+ return output
193
+
194
+ # 【SSML】currency=按金额发音。
195
+ # Digital processing from GPT_SoVITS num.py (thanks)
196
+ def to_chinese_currency(self, num:str):
197
+ pattern = r'(\d+)'
198
+ arrs = re.split(pattern, num)
199
+ output = ""
200
+ for item in arrs:
201
+ if re.match(pattern,item):
202
+ output += num2str(item)
203
+ else:output += item
204
+ output = output.replace(".","点")
205
+ return output
206
+
207
+ # 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
208
+ def to_chinese_date(self, num:str):
209
+ chinese_date = self._format_chinese_data(num)
210
+ return chinese_date
211
+
212
+
213
+ class LangSegment:
214
+
215
+ def __init__(self):
216
+
217
+ self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
218
+
219
+ self._text_cache = None
220
+ self._text_lasts = None
221
+ self._text_langs = None
222
+ self._lang_count = None
223
+ self._lang_eos = None
224
+
225
+ # 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
226
+ # Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
227
+ # <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
228
+ self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
229
+
230
+ # 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
231
+ # 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
232
+ # 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
233
+ # The language filter group function allows you to specify reserved languages.
234
+ # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
235
+ # 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
236
+
237
+ # 系统默认过滤器。System default filter。(ISO 639-1 codes given)
238
+ # ----------------------------------------------------------------------------------------------------------------------------------
239
+ # "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
240
+ # "th"泰语=Thai
241
+ # ----------------------------------------------------------------------------------------------------------------------------------
242
+ self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
243
+
244
+ # 用户可自定义过滤器。User-defined filters
245
+ self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本
246
+
247
+ # 合并文本
248
+ self.isLangMerge = True
249
+
250
+ # 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
251
+ # 请使用API启用:self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
252
+
253
+ # 预览版功能,自动启用或禁用,无需设置
254
+ # Preview feature, automatically enabled or disabled, no settings required
255
+ self.EnablePreview = False
256
+
257
+ # 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
258
+ # In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
259
+ # 示例:您可以任意指定多种组合,进行过滤
260
+ # Example: You can specify any combination to filter
261
+
262
+ # 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
263
+ # 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
264
+ # 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
265
+ # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
266
+ # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
267
+ self.LangPriorityThreshold = 0.89
268
+
269
+ # Langfilters = ["zh"] # 按中文识别
270
+ # Langfilters = ["en"] # 按英文识别
271
+ # Langfilters = ["ja"] # 按日文识别
272
+ # Langfilters = ["ko"] # 按韩文识别
273
+ # Langfilters = ["zh_ja"] # 中日混合识别
274
+ # Langfilters = ["zh_en"] # 中英混合识别
275
+ # Langfilters = ["ja_en"] # 日英混合识别
276
+ # Langfilters = ["zh_ko"] # 中韩混合识别
277
+ # Langfilters = ["ja_ko"] # 日韩混合识别
278
+ # Langfilters = ["en_ko"] # 英韩混合识别
279
+ # Langfilters = ["zh_ja_en"] # 中日英混合识别
280
+ # Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
281
+
282
+ # 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
283
+ # より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
284
+
285
+ # 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
286
+ # 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
287
+ self.keepPinyin = False
288
+
289
+ # DEFINITION
290
+ self.PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
291
+
292
+ self.LangSSML = LangSSML()
293
+
294
+ def _clears(self):
295
+ self._text_cache = None
296
+ self._text_lasts = None
297
+ self._text_langs = None
298
+ self._text_waits = None
299
+ self._lang_count = None
300
+ self._lang_eos = None
301
+
302
+ def _is_english_word(self, word):
303
+ return bool(re.match(r'^[a-zA-Z]+$', word))
304
+
305
+ def _is_chinese(self, word):
306
+ for char in word:
307
+ if '\u4e00' <= char <= '\u9fff':
308
+ return True
309
+ return False
310
+
311
+ def _is_japanese_kana(self, word):
312
+ pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
313
+ matches = pattern.findall(word)
314
+ return len(matches) > 0
315
+
316
+ def _insert_english_uppercase(self, word):
317
+ modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
318
+ modified_text = modified_text.strip('-')
319
+ return modified_text + " "
320
+
321
+ def _split_camel_case(self, word):
322
+ return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
323
+
324
+ def _statistics(self, language, text):
325
+ # Language word statistics:
326
+ # Chinese characters usually occupy double bytes
327
+ if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
328
+ self._lang_count = defaultdict(int)
329
+ lang_count = self._lang_count
330
+ if not "|" in language:
331
+ lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
332
+ self._lang_count = lang_count
333
+
334
+ def _clear_text_number(self, text):
335
+ if text == "\n":return text,False # Keep Line Breaks
336
+ clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
337
+ is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
338
+ return clear_text,is_number
339
+
340
+ def _saveData(self, words,language:str,text:str,score:float,symbol=None):
341
+ # Pre-detection
342
+ clear_text , is_number = self._clear_text_number(text)
343
+ # Merge the same language and save the results
344
+ preData = words[-1] if len(words) > 0 else None
345
+ if symbol is not None:pass
346
+ elif preData is not None and preData["symbol"] is None:
347
+ if len(clear_text) == 0:language = preData["lang"]
348
+ elif is_number == True:language = preData["lang"]
349
+ _ , pre_is_number = self._clear_text_number(preData["text"])
350
+ if (preData["lang"] == language):
351
+ self._statistics(preData["lang"],text)
352
+ text = preData["text"] + text
353
+ preData["text"] = text
354
+ return preData
355
+ elif pre_is_number == True:
356
+ text = f'{preData["text"]}{text}'
357
+ words.pop()
358
+ elif is_number == True:
359
+ priority_language = self._get_filters_string()[:2]
360
+ if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
361
+ data = {"lang":language,"text": text,"score":score,"symbol":symbol}
362
+ filters = self.Langfilters
363
+ if filters is None or len(filters) == 0 or "?" in language or \
364
+ language in filters or language in filters[0] or \
365
+ filters[0] == "*" or filters[0] in "alls-mixs-autos":
366
+ words.append(data)
367
+ self._statistics(data["lang"],data["text"])
368
+ return data
369
+
370
+ def _addwords(self, words,language,text,score,symbol=None):
371
+ if text == "\n":pass # Keep Line Breaks
372
+ elif text is None or len(text.strip()) == 0:return True
373
+ if language is None:language = ""
374
+ language = language.lower()
375
+ if language == 'en':text = self._insert_english_uppercase(text)
376
+ # text = re.sub(r'[(())]', ',' , text) # Keep it.
377
+ text_waits = self._text_waits
378
+ ispre_waits = len(text_waits)>0
379
+ preResult = text_waits.pop() if ispre_waits else None
380
+ if preResult is None:preResult = words[-1] if len(words) > 0 else None
381
+ if preResult and ("|" in preResult["lang"]):
382
+ pre_lang = preResult["lang"]
383
+ if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
384
+ else:preResult["lang"]=pre_lang.split("|")[0]
385
+ if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
386
+ pre_lang = preResult["lang"] if preResult else None
387
+ if ("|" in language) and (pre_lang and not pre_lang in language and not "…" in language):language = language.split("|")[0]
388
+ if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
389
+ else:self._saveData(words,language,text,score,symbol)
390
+ return False
391
+
392
+ def _get_prev_data(self, words):
393
+ data = words[-1] if words and len(words) > 0 else None
394
+ if data:return (data["lang"] , data["text"])
395
+ return (None,"")
396
+
397
+ def _match_ending(self, input , index):
398
+ if input is None or len(input) == 0:return False,None
399
+ input = re.sub(r'\s+', '', input)
400
+ if len(input) == 0 or abs(index) > len(input):return False,None
401
+ ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
402
+ return ending_pattern.match(input[index]),input[index]
403
+
404
+ def _cleans_text(self, cleans_text):
405
+ cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
406
+ cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
407
+ return cleans_text.strip()
408
+
409
+ def _mean_processing(self, text:str):
410
+ if text is None or (text.strip()) == "":return None , 0.0
411
+ arrs = self._split_camel_case(text).split(" ")
412
+ langs = []
413
+ for t in arrs:
414
+ if len(t.strip()) <= 3:continue
415
+ language, score = self.langid.classify(t)
416
+ langs.append({"lang":language})
417
+ if len(langs) == 0:return None , 0.0
418
+ return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
419
+
420
+ def _lang_classify(self, cleans_text):
421
+ language, score = self.langid.classify(cleans_text)
422
+ # fix: Huggingface is np.float32
423
+ if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
424
+ score = score.item()
425
+ score = round(score , 3)
426
+ return language, score
427
+
428
+ def _get_filters_string(self):
429
+ filters = self.Langfilters
430
+ return "-".join(filters).lower().strip() if filters is not None else ""
431
+
432
+ def _parse_language(self, words , segment):
433
+ LANG_JA = "ja"
434
+ LANG_ZH = "zh"
435
+ LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
436
+ LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
437
+ language = LANG_ZH
438
+ regex_pattern = re.compile(r'([^\w\s]+)')
439
+ lines = regex_pattern.split(segment)
440
+ lines_max = len(lines)
441
+ LANG_EOS =self._lang_eos
442
+ for index, text in enumerate(lines):
443
+ if len(text) == 0:continue
444
+ EOS = index >= (lines_max - 1)
445
+ nextId = index + 1
446
+ nextText = lines[nextId] if not EOS else ""
447
+ nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
448
+ textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
449
+ if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
450
+ lines[nextId] = f'{text}{nextText}'
451
+ continue
452
+ number_tags = re.compile(r'(⑥\d{6,}⑥)')
453
+ cleans_text = re.sub(number_tags, '' ,text)
454
+ cleans_text = re.sub(r'\d+', '' ,cleans_text)
455
+ cleans_text = self._cleans_text(cleans_text)
456
+ # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
457
+ if not EOS and len(cleans_text) <= 2:
458
+ lines[nextId] = f'{text}{nextText}'
459
+ continue
460
+ language,score = self._lang_classify(cleans_text)
461
+ prev_language , prev_text = self._get_prev_data(words)
462
+ if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
463
+ if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
464
+ filters_string = self._get_filters_string()
465
+ if score < self.LangPriorityThreshold and len(filters_string) > 0:
466
+ index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
467
+ if index_ja != -1 and index_ja < index_zh:language = LANG_JA
468
+ elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
469
+ if self._is_japanese_kana(cleans_text):language = LANG_JA
470
+ elif len(cleans_text) > 2 and score > 0.90:pass
471
+ elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
472
+ else:
473
+ LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
474
+ match_end,match_char = self._match_ending(text, -1)
475
+ referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
476
+ if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
477
+ else:language = f"{LANG_UNKNOWN}|…"
478
+ text,*_ = re.subn(number_tags , self._restore_number , text )
479
+ self._addwords(words,language,text,score)
480
+
481
+ # ----------------------------------------------------------
482
+ # 【SSML】中文数字处理:Chinese Number Processing (SSML support)
483
+ # 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
484
+ # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
485
+ # 中文电话号码:<telephone>1234567</telephone>
486
+ # 中文数字号码:<number>1234567</number>
487
+ def _process_symbol_SSML(self, words,data):
488
+ tag , match = data
489
+ language = SSML = match[1]
490
+ text = match[2]
491
+ score = 1.0
492
+ if SSML == "telephone":
493
+ # 中文-电话号码
494
+ language = "zh"
495
+ text = self.LangSSML.to_chinese_telephone(text)
496
+ elif SSML == "number":
497
+ # 中文-数字读法
498
+ language = "zh"
499
+ text = self.LangSSML.to_chinese_number(text)
500
+ elif SSML == "currency":
501
+ # 中文-按金额发音
502
+ language = "zh"
503
+ text = self.LangSSML.to_chinese_currency(text)
504
+ elif SSML == "date":
505
+ # 中文-按金额发音
506
+ language = "zh"
507
+ text = self.LangSSML.to_chinese_date(text)
508
+ self._addwords(words,language,text,score,SSML)
509
+
510
+ # ----------------------------------------------------------
511
+ def _restore_number(self, matche):
512
+ value = matche.group(0)
513
+ text_cache = self._text_cache
514
+ if value in text_cache:
515
+ process , data = text_cache[value]
516
+ tag , match = data
517
+ value = match
518
+ return value
519
+
520
+ def _pattern_symbols(self, item , text):
521
+ if text is None:return text
522
+ tag , pattern , process = item
523
+ matches = pattern.findall(text)
524
+ if len(matches) == 1 and "".join(matches[0]) == text:
525
+ return text
526
+ for i , match in enumerate(matches):
527
+ key = f"⑥{tag}{i:06d}⑥"
528
+ text = re.sub(pattern , key , text , count=1)
529
+ self._text_cache[key] = (process , (tag , match))
530
+ return text
531
+
532
+ def _process_symbol(self, words,data):
533
+ tag , match = data
534
+ language = match[1]
535
+ text = match[2]
536
+ score = 1.0
537
+ filters = self._get_filters_string()
538
+ if language not in filters:
539
+ self._process_symbol_SSML(words,data)
540
+ else:
541
+ self._addwords(words,language,text,score,True)
542
+
543
+ def _process_english(self, words,data):
544
+ tag , match = data
545
+ text = match[0]
546
+ filters = self._get_filters_string()
547
+ priority_language = filters[:2]
548
+ # Preview feature, other language segmentation processing
549
+ enablePreview = self.EnablePreview
550
+ if enablePreview == True:
551
+ # Experimental: Other language support
552
+ regex_pattern = re.compile(r'(.*?[。.??!!]+[\n]{,1})')
553
+ lines = regex_pattern.split(text)
554
+ for index , text in enumerate(lines):
555
+ if len(text.strip()) == 0:continue
556
+ cleans_text = self._cleans_text(text)
557
+ language,score = self._lang_classify(cleans_text)
558
+ if language not in filters:
559
+ language,score = self._mean_processing(cleans_text)
560
+ if language is None or score <= 0.0:continue
561
+ elif language in filters:pass # pass
562
+ elif score >= 0.95:continue # High score, but not in the filter, excluded.
563
+ elif score <= 0.15 and filters[:2] == "fr":language = priority_language
564
+ else:language = "en"
565
+ self._addwords(words,language,text,score)
566
+ else:
567
+ # Default is English
568
+ language, score = "en", 1.0
569
+ self._addwords(words,language,text,score)
570
+
571
+ def _process_Russian(self, words,data):
572
+ tag , match = data
573
+ text = match[0]
574
+ language = "ru"
575
+ score = 1.0
576
+ self._addwords(words,language,text,score)
577
+
578
+ def _process_Thai(self, words,data):
579
+ tag , match = data
580
+ text = match[0]
581
+ language = "th"
582
+ score = 1.0
583
+ self._addwords(words,language,text,score)
584
+
585
+ def _process_korean(self, words,data):
586
+ tag , match = data
587
+ text = match[0]
588
+ language = "ko"
589
+ score = 1.0
590
+ self._addwords(words,language,text,score)
591
+
592
+ def _process_quotes(self, words,data):
593
+ tag , match = data
594
+ text = "".join(match)
595
+ childs = self.PARSE_TAG.findall(text)
596
+ if len(childs) > 0:
597
+ self._process_tags(words , text , False)
598
+ else:
599
+ cleans_text = self._cleans_text(match[1])
600
+ if len(cleans_text) <= 5:
601
+ self._parse_language(words,text)
602
+ else:
603
+ language,score = self._lang_classify(cleans_text)
604
+ self._addwords(words,language,text,score)
605
+
606
+ def _process_pinyin(self, words,data):
607
+ tag , match = data
608
+ text = match
609
+ language = "zh"
610
+ score = 1.0
611
+ self._addwords(words,language,text,score)
612
+
613
+ def _process_number(self, words,data): # "$0" process only
614
+ """
615
+ Numbers alone cannot accurately identify language.
616
+ Because numbers are universal in all languages.
617
+ So it won't be executed here, just for testing.
618
+ """
619
+ tag , match = data
620
+ language = words[0]["lang"] if len(words) > 0 else "zh"
621
+ text = match
622
+ score = 0.0
623
+ self._addwords(words,language,text,score)
624
+
625
+ def _process_tags(self, words , text , root_tag):
626
+ text_cache = self._text_cache
627
+ segments = re.split(self.PARSE_TAG, text)
628
+ segments_len = len(segments) - 1
629
+ for index , text in enumerate(segments):
630
+ if root_tag:self._lang_eos = index >= segments_len
631
+ if self.PARSE_TAG.match(text):
632
+ process , data = text_cache[text]
633
+ if process:process(words , data)
634
+ else:
635
+ self._parse_language(words , text)
636
+ return words
637
+
638
+ def _merge_results(self, words):
639
+ new_word = []
640
+ for index , cur_data in enumerate(words):
641
+ if "symbol" in cur_data:del cur_data["symbol"]
642
+ if index == 0:new_word.append(cur_data)
643
+ else:
644
+ pre_data = new_word[-1]
645
+ if cur_data["lang"] == pre_data["lang"]:
646
+ pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
647
+ else:new_word.append(cur_data)
648
+ return new_word
649
+
650
+ def _parse_symbols(self, text):
651
+ TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
652
+ TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
653
+ TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
654
+ # Get custom language filter
655
+ filters = self.Langfilters
656
+ filters = filters if filters is not None else ""
657
+ # =======================================================================================================
658
+ # Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
659
+ # 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
660
+ # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
661
+ # S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
662
+ # Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
663
+ # -------------------------------------------------------------------------------------------------------
664
+ # Preview feature, other language support
665
+ enablePreview = self.EnablePreview
666
+ if "fr" in filters or \
667
+ "vi" in filters:enablePreview = True
668
+ self.EnablePreview = enablePreview
669
+ # 实验性:法语字符支持。Prise en charge des caractères français
670
+ RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
671
+ # 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
672
+ RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
673
+ # -------------------------------------------------------------------------------------------------------
674
+ # Basic options:
675
+ process_list = [
676
+ ( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
677
+ ( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
678
+ ( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
679
+ ( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
680
+ ( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
681
+ ( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
682
+ ( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
683
+ ( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
684
+ ]
685
+ # Extended options: Default False
686
+ if self.keepPinyin == True:process_list.insert(1 ,
687
+ ( TAG_S2 , re.compile(r'([\(({](?:\s*\w*\d\w*\s*)+[})\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
688
+ )
689
+ # -------------------------------------------------------------------------------------------------------
690
+ words = []
691
+ lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
692
+ for index , text in enumerate(lines):
693
+ if len(text.strip()) == 0:continue
694
+ self._lang_eos = False
695
+ self._text_cache = {}
696
+ for item in process_list:
697
+ text = self._pattern_symbols(item , text)
698
+ cur_word = self._process_tags([] , text , True)
699
+ if len(cur_word) == 0:continue
700
+ cur_data = cur_word[0] if len(cur_word) > 0 else None
701
+ pre_data = words[-1] if len(words) > 0 else None
702
+ if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] \
703
+ and cur_data["symbol"] == False and pre_data["symbol"] :
704
+ cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
705
+ words.pop()
706
+ words += cur_word
707
+ if self.isLangMerge == True:words = self._merge_results(words)
708
+ lang_count = self._lang_count
709
+ if lang_count and len(lang_count) > 0:
710
+ lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
711
+ lang_count = list(lang_count.items())
712
+ self._lang_count = lang_count
713
+ return words
714
+
715
+ def setfilters(self, filters):
716
+ # 当过滤器更改时,清除缓存
717
+ # 필터가 변경되면 캐시를 지웁니다.
718
+ # フィルタが変更されると、キャッシュがクリアされます
719
+ # When the filter changes, clear the cache
720
+ if self.Langfilters != filters:
721
+ self._clears()
722
+ self.Langfilters = filters
723
+
724
+ def getfilters(self):
725
+ return self.Langfilters
726
+
727
+ def setPriorityThreshold(self, threshold:float):
728
+ self.LangPriorityThreshold = threshold
729
+
730
+ def getPriorityThreshold(self):
731
+ return self.LangPriorityThreshold
732
+
733
+ def getCounts(self):
734
+ lang_count = self._lang_count
735
+ if lang_count is not None:return lang_count
736
+ text_langs = self._text_langs
737
+ if text_langs is None or len(text_langs) == 0:return [("zh",0)]
738
+ lang_counts = defaultdict(int)
739
+ for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
740
+ lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
741
+ lang_counts = list(lang_counts.items())
742
+ self._lang_count = lang_counts
743
+ return lang_counts
744
+
745
+ def getTexts(self, text:str):
746
+ if text is None or len(text.strip()) == 0:
747
+ self._clears()
748
+ return []
749
+ # lasts
750
+ text_langs = self._text_langs
751
+ if self._text_lasts == text and text_langs is not None:return text_langs
752
+ # parse
753
+ self._text_waits = []
754
+ self._lang_count = None
755
+ self._text_lasts = text
756
+ text = self._parse_symbols(text)
757
+ self._text_langs = text
758
+ return text
759
+
760
+ def classify(self, text:str):
761
+ return self.getTexts(text)
762
+
763
+ def printList(langlist):
764
+ """
765
+ 功能:打印数组结果
766
+ 기능: 어레이 결과 인쇄
767
+ 機能:配列結果を印刷
768
+ Function: Print array results
769
+ """
770
+ print("\n===================【打印结果】===================")
771
+ if langlist is None or len(langlist) == 0:
772
+ print("无内容结果,No content result")
773
+ return
774
+ for line in langlist:
775
+ print(line)
776
+ pass
777
+
778
+
779
+
780
+ def main():
781
+
782
+ # -----------------------------------
783
+ # 更新日志:新版本分词更加精准。
784
+ # Changelog: The new version of the word segmentation is more accurate.
785
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
786
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
787
+ # -----------------------------------
788
+
789
+ # 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
790
+ # text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
791
+
792
+ # 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
793
+ # text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
794
+
795
+ # 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
796
+ # text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
797
+
798
+
799
+ # 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
800
+ # text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
801
+
802
+
803
+ # 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
804
+ langsegment = LangSegment()
805
+ langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
806
+ text = """
807
+ 我喜欢在雨天里听音乐。
808
+ I enjoy listening to music on rainy days.
809
+ 雨の日に音楽を聴くのが好きです。
810
+ 비 오는 날에 음악을 듣는 것을 즐깁니다。
811
+ J'aime écouter de la musique les jours de pluie.
812
+ Tôi thích nghe nhạc vào những ngày mưa.
813
+ Мне нравится слушать музыку в дождливую погоду.
814
+ ฉันชอบฟังเพลงในวันที่ฝนตก
815
+ """
816
+
817
+
818
+
819
+ # 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
820
+ langlist = langsegment.getTexts(text)
821
+ printList(langlist)
822
+
823
+
824
+ # 语种统计:Language statistics:
825
+ print("\n===================【语种统计】===================")
826
+ # 获取所有语种数组结果,根据内容字数降序排列
827
+ # Get the array results in all languages, sorted in descending order according to the number of content words
828
+ langCounts = langsegment.getCounts()
829
+ print(langCounts , "\n")
830
+
831
+ # 根据结果获取内容的主要语种 (语言,字数含标点)
832
+ # Get the main language of content based on the results (language, word count including punctuation)
833
+ lang , count = langCounts[0]
834
+ print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
835
+ print("==================================================\n")
836
+
837
+
838
+ # 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
839
+ # ===================【打印结果】===================
840
+ # {'lang': 'zh', 'text': '你的名字叫'}
841
+ # {'lang': 'ja', 'text': '佐々木?'}
842
+ # {'lang': 'zh', 'text': '吗?韩语中的'}
843
+ # {'lang': 'ko', 'text': '안녕 오빠'}
844
+ # {'lang': 'zh', 'text': '读什么呢?'}
845
+ # {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
846
+ # {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
847
+ # {'lang': 'en', 'text': 'i Phone '}
848
+ # {'lang': 'zh', 'text': '15系列机型和三款'}
849
+ # {'lang': 'en', 'text': 'Apple Watch '}
850
+ # {'lang': 'zh', 'text': '等一系列新品,这次的'}
851
+ # {'lang': 'en', 'text': 'i Pad Air '}
852
+ # {'lang': 'zh', 'text': '采用了'}
853
+ # {'lang': 'en', 'text': 'L C D '}
854
+ # {'lang': 'zh', 'text': '屏幕'}
855
+ # ===================【语种统计】===================
856
+
857
+ # ===================【语种统计】===================
858
+ # [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
859
+
860
+ # 输入内容的主要语言为 = zh ,字数 = 51
861
+ # ==================================================
862
+ # The main language of the input content is = zh, word count = 51
863
+
864
+
865
+ if __name__ == "__main__":
866
+ main()
language_segmentation/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .LangSegment import LangSegment
2
+
3
+
4
+ # release
5
+ __version__ = '0.3.5'
6
+
7
+
8
+ # develop
9
+ __develop__ = 'dev-0.0.1'
language_segmentation/utils/__init__.py ADDED
File without changes
language_segmentation/utils/num.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Digital processing from GPT_SoVITS num.py (thanks)
15
+ """
16
+ Rules to verbalize numbers into Chinese characters.
17
+ https://zh.wikipedia.org/wiki/中文数字#現代中文
18
+ """
19
+
20
+ import re
21
+ from collections import OrderedDict
22
+ from typing import List
23
+
24
+ DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
25
+ UNITS = OrderedDict({
26
+ 1: '十',
27
+ 2: '百',
28
+ 3: '千',
29
+ 4: '万',
30
+ 8: '亿',
31
+ })
32
+
33
+ COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
34
+
35
+ # 分数表达式
36
+ RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
37
+
38
+
39
+ def replace_frac(match) -> str:
40
+ """
41
+ Args:
42
+ match (re.Match)
43
+ Returns:
44
+ str
45
+ """
46
+ sign = match.group(1)
47
+ nominator = match.group(2)
48
+ denominator = match.group(3)
49
+ sign: str = "负" if sign else ""
50
+ nominator: str = num2str(nominator)
51
+ denominator: str = num2str(denominator)
52
+ result = f"{sign}{denominator}分之{nominator}"
53
+ return result
54
+
55
+
56
+ # 百分数表达式
57
+ RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
58
+
59
+
60
+ def replace_percentage(match) -> str:
61
+ """
62
+ Args:
63
+ match (re.Match)
64
+ Returns:
65
+ str
66
+ """
67
+ sign = match.group(1)
68
+ percent = match.group(2)
69
+ sign: str = "负" if sign else ""
70
+ percent: str = num2str(percent)
71
+ result = f"{sign}百分之{percent}"
72
+ return result
73
+
74
+
75
+ # 整数表达式
76
+ # 带负号的整数 -10
77
+ RE_INTEGER = re.compile(r'(-)' r'(\d+)')
78
+
79
+
80
+ def replace_negative_num(match) -> str:
81
+ """
82
+ Args:
83
+ match (re.Match)
84
+ Returns:
85
+ str
86
+ """
87
+ sign = match.group(1)
88
+ number = match.group(2)
89
+ sign: str = "负" if sign else ""
90
+ number: str = num2str(number)
91
+ result = f"{sign}{number}"
92
+ return result
93
+
94
+
95
+ # 编号-无符号整形
96
+ # 00078
97
+ RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
98
+
99
+
100
+ def replace_default_num(match):
101
+ """
102
+ Args:
103
+ match (re.Match)
104
+ Returns:
105
+ str
106
+ """
107
+ number = match.group(0)
108
+ return verbalize_digit(number, alt_one=True)
109
+
110
+
111
+ # 加减乘除
112
+ # RE_ASMD = re.compile(
113
+ # r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
114
+ RE_ASMD = re.compile(
115
+ r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
116
+
117
+ asmd_map = {
118
+ '+': '加',
119
+ '-': '减',
120
+ '×': '乘',
121
+ '÷': '除',
122
+ '=': '等于'
123
+ }
124
+
125
+ def replace_asmd(match) -> str:
126
+ """
127
+ Args:
128
+ match (re.Match)
129
+ Returns:
130
+ str
131
+ """
132
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
133
+ return result
134
+
135
+
136
+ # 次方专项
137
+ RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
138
+
139
+ power_map = {
140
+ '⁰': '0',
141
+ '¹': '1',
142
+ '²': '2',
143
+ '³': '3',
144
+ '⁴': '4',
145
+ '⁵': '5',
146
+ '⁶': '6',
147
+ '⁷': '7',
148
+ '⁸': '8',
149
+ '⁹': '9',
150
+ 'ˣ': 'x',
151
+ 'ʸ': 'y',
152
+ 'ⁿ': 'n'
153
+ }
154
+
155
+ def replace_power(match) -> str:
156
+ """
157
+ Args:
158
+ match (re.Match)
159
+ Returns:
160
+ str
161
+ """
162
+ power_num = ""
163
+ for m in match.group(0):
164
+ power_num += power_map[m]
165
+ result = "的" + power_num + "次方"
166
+ return result
167
+
168
+
169
+ # 数字表达式
170
+ # 纯小数
171
+ RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
172
+ # 正整数 + 量词
173
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
174
+ RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
175
+
176
+
177
+ def replace_positive_quantifier(match) -> str:
178
+ """
179
+ Args:
180
+ match (re.Match)
181
+ Returns:
182
+ str
183
+ """
184
+ number = match.group(1)
185
+ match_2 = match.group(2)
186
+ if match_2 == "+":
187
+ match_2 = "多"
188
+ match_2: str = match_2 if match_2 else ""
189
+ quantifiers: str = match.group(3)
190
+ number: str = num2str(number)
191
+ result = f"{number}{match_2}{quantifiers}"
192
+ return result
193
+
194
+
195
+ def replace_number(match) -> str:
196
+ """
197
+ Args:
198
+ match (re.Match)
199
+ Returns:
200
+ str
201
+ """
202
+ sign = match.group(1)
203
+ number = match.group(2)
204
+ pure_decimal = match.group(5)
205
+ if pure_decimal:
206
+ result = num2str(pure_decimal)
207
+ else:
208
+ sign: str = "负" if sign else ""
209
+ number: str = num2str(number)
210
+ result = f"{sign}{number}"
211
+ return result
212
+
213
+
214
+ # 范围表达式
215
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
216
+
217
+ RE_RANGE = re.compile(
218
+ r"""
219
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
220
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
221
+ [-~] # 匹配范围分隔符
222
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
223
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
224
+ """, re.VERBOSE)
225
+
226
+
227
+ def replace_range(match) -> str:
228
+ """
229
+ Args:
230
+ match (re.Match)
231
+ Returns:
232
+ str
233
+ """
234
+ first, second = match.group(1), match.group(6)
235
+ first = RE_NUMBER.sub(replace_number, first)
236
+ second = RE_NUMBER.sub(replace_number, second)
237
+ result = f"{first}到{second}"
238
+ return result
239
+
240
+
241
+ # ~至表达式
242
+ RE_TO_RANGE = re.compile(
243
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
244
+
245
+ def replace_to_range(match) -> str:
246
+ """
247
+ Args:
248
+ match (re.Match)
249
+ Returns:
250
+ str
251
+ """
252
+ result = match.group(0).replace('~', '至')
253
+ return result
254
+
255
+
256
+ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
257
+ stripped = value_string.lstrip('0')
258
+ if len(stripped) == 0:
259
+ return []
260
+ elif len(stripped) == 1:
261
+ if use_zero and len(stripped) < len(value_string):
262
+ return [DIGITS['0'], DIGITS[stripped]]
263
+ else:
264
+ return [DIGITS[stripped]]
265
+ else:
266
+ largest_unit = next(
267
+ power for power in reversed(UNITS.keys()) if power < len(stripped))
268
+ first_part = value_string[:-largest_unit]
269
+ second_part = value_string[-largest_unit:]
270
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
271
+ second_part)
272
+
273
+
274
+ def verbalize_cardinal(value_string: str) -> str:
275
+ if not value_string:
276
+ return ''
277
+
278
+ # 000 -> '零' , 0 -> '零'
279
+ value_string = value_string.lstrip('0')
280
+ if len(value_string) == 0:
281
+ return DIGITS['0']
282
+
283
+ result_symbols = _get_value(value_string)
284
+ # verbalized number starting with '一十*' is abbreviated as `十*`
285
+ if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
286
+ '1'] and result_symbols[1] == UNITS[1]:
287
+ result_symbols = result_symbols[1:]
288
+ return ''.join(result_symbols)
289
+
290
+
291
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
292
+ result_symbols = [DIGITS[digit] for digit in value_string]
293
+ result = ''.join(result_symbols)
294
+ if alt_one:
295
+ result = result.replace("一", "幺")
296
+ return result
297
+
298
+
299
+ def num2str(value_string: str) -> str:
300
+ integer_decimal = value_string.split('.')
301
+ if len(integer_decimal) == 1:
302
+ integer = integer_decimal[0]
303
+ decimal = ''
304
+ elif len(integer_decimal) == 2:
305
+ integer, decimal = integer_decimal
306
+ else:
307
+ raise ValueError(
308
+ f"The value string: '${value_string}' has more than one point in it."
309
+ )
310
+
311
+ result = verbalize_cardinal(integer)
312
+
313
+ decimal = decimal.rstrip('0')
314
+ if decimal:
315
+ # '.22' is verbalized as '零点二二'
316
+ # '3.20' is verbalized as '三点二
317
+ result = result if result else "零"
318
+ result += '点' + verbalize_digit(decimal)
319
+ return result
320
+
321
+
322
+ if __name__ == "__main__":
323
+
324
+ text = ""
325
+ text = num2str(text)
326
+ print(text)
327
+ pass
models/ace_step_transformer.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional, Tuple, List, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import BaseOutput, is_torch_version
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+
27
+
28
+ from .attention import LinearTransformerBlock, t2i_modulate
29
+ from .lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
30
+
31
+
32
+ def cross_norm(hidden_states, controlnet_input):
33
+ # input N x T x c
34
+ mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
35
+ mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
36
+ controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
37
+ return controlnet_input
38
+
39
+
40
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
41
+ class Qwen2RotaryEmbedding(nn.Module):
42
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
43
+ super().__init__()
44
+
45
+ self.dim = dim
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.base = base
48
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
49
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
50
+
51
+ # Build here to make `torch.jit.trace` work.
52
+ self._set_cos_sin_cache(
53
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
54
+ )
55
+
56
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
57
+ self.max_seq_len_cached = seq_len
58
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
59
+
60
+ freqs = torch.outer(t, self.inv_freq)
61
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
62
+ emb = torch.cat((freqs, freqs), dim=-1)
63
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
64
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
65
+
66
+ def forward(self, x, seq_len=None):
67
+ # x: [bs, num_attention_heads, seq_len, head_size]
68
+ if seq_len > self.max_seq_len_cached:
69
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
70
+
71
+ return (
72
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
73
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
74
+ )
75
+
76
+
77
+ class T2IFinalLayer(nn.Module):
78
+ """
79
+ The final layer of Sana.
80
+ """
81
+
82
+ def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
83
+ super().__init__()
84
+ self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
85
+ self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
86
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
87
+ self.out_channels = out_channels
88
+ self.patch_size = patch_size
89
+
90
+ def unpatchfy(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ width: int,
94
+ ):
95
+ # 4 unpatchify
96
+ new_height, new_width = 1, hidden_states.size(1)
97
+ hidden_states = hidden_states.reshape(
98
+ shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
99
+ ).contiguous()
100
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
101
+ output = hidden_states.reshape(
102
+ shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
103
+ ).contiguous()
104
+ if width > new_width:
105
+ output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
106
+ elif width < new_width:
107
+ output = output[:, :, :, :width]
108
+ return output
109
+
110
+ def forward(self, x, t, output_length):
111
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
112
+ x = t2i_modulate(self.norm_final(x), shift, scale)
113
+ x = self.linear(x)
114
+ # unpatchify
115
+ output = self.unpatchfy(x, output_length)
116
+ return output
117
+
118
+
119
+ class PatchEmbed(nn.Module):
120
+ """2D Image to Patch Embedding"""
121
+
122
+ def __init__(
123
+ self,
124
+ height=16,
125
+ width=4096,
126
+ patch_size=(16, 1),
127
+ in_channels=8,
128
+ embed_dim=1152,
129
+ bias=True,
130
+ ):
131
+ super().__init__()
132
+ patch_size_h, patch_size_w = patch_size
133
+ self.early_conv_layers = nn.Sequential(
134
+ nn.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias),
135
+ torch.nn.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True),
136
+ nn.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
137
+ )
138
+ self.patch_size = patch_size
139
+ self.height, self.width = height // patch_size_h, width // patch_size_w
140
+ self.base_size = self.width
141
+
142
+ def forward(self, latent):
143
+ # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
144
+ latent = self.early_conv_layers(latent)
145
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
146
+ return latent
147
+
148
+
149
+ @dataclass
150
+ class Transformer2DModelOutput(BaseOutput):
151
+
152
+ sample: torch.FloatTensor
153
+ proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
154
+
155
+
156
+ class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ in_channels: Optional[int] = 8,
163
+ num_layers: int = 28,
164
+ inner_dim: int = 1536,
165
+ attention_head_dim: int = 64,
166
+ num_attention_heads: int = 24,
167
+ mlp_ratio: float = 4.0,
168
+ out_channels: int = 8,
169
+ max_position: int = 32768,
170
+ rope_theta: float = 1000000.0,
171
+ speaker_embedding_dim: int = 512,
172
+ text_embedding_dim: int = 768,
173
+ ssl_encoder_depths: List[int] = [9, 9],
174
+ ssl_names: List[str] = ["mert", "m-hubert"],
175
+ ssl_latent_dims: List[int] = [1024, 768],
176
+ lyric_encoder_vocab_size: int = 6681,
177
+ lyric_hidden_size: int = 1024,
178
+ patch_size: List[int] = [16, 1],
179
+ max_height: int = 16,
180
+ max_width: int = 4096,
181
+ **kwargs,
182
+ ):
183
+ super().__init__()
184
+
185
+ self.num_attention_heads = num_attention_heads
186
+ self.attention_head_dim = attention_head_dim
187
+ inner_dim = num_attention_heads * attention_head_dim
188
+ self.inner_dim = inner_dim
189
+ self.out_channels = out_channels
190
+ self.max_position = max_position
191
+ self.patch_size = patch_size
192
+
193
+ self.rope_theta = rope_theta
194
+
195
+ self.rotary_emb = Qwen2RotaryEmbedding(
196
+ dim=self.attention_head_dim,
197
+ max_position_embeddings=self.max_position,
198
+ base=self.rope_theta,
199
+ )
200
+
201
+ # 2. Define input layers
202
+ self.in_channels = in_channels
203
+
204
+ # 3. Define transformers blocks
205
+ self.transformer_blocks = nn.ModuleList(
206
+ [
207
+ LinearTransformerBlock(
208
+ dim=self.inner_dim,
209
+ num_attention_heads=self.num_attention_heads,
210
+ attention_head_dim=attention_head_dim,
211
+ mlp_ratio=mlp_ratio,
212
+ add_cross_attention=True,
213
+ add_cross_attention_dim=self.inner_dim,
214
+ )
215
+ for i in range(self.config.num_layers)
216
+ ]
217
+ )
218
+ self.num_layers = num_layers
219
+
220
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
221
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
222
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
223
+
224
+ # speaker
225
+ self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
226
+
227
+ # genre
228
+ self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
229
+
230
+ # lyric
231
+ self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
232
+ self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
233
+ self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
234
+
235
+ projector_dim = 2 * self.inner_dim
236
+
237
+ self.projectors = nn.ModuleList([
238
+ nn.Sequential(
239
+ nn.Linear(self.inner_dim, projector_dim),
240
+ nn.SiLU(),
241
+ nn.Linear(projector_dim, projector_dim),
242
+ nn.SiLU(),
243
+ nn.Linear(projector_dim, ssl_dim),
244
+ ) for ssl_dim in ssl_latent_dims
245
+ ])
246
+
247
+ self.ssl_latent_dims = ssl_latent_dims
248
+ self.ssl_encoder_depths = ssl_encoder_depths
249
+
250
+ self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
251
+ self.ssl_names = ssl_names
252
+
253
+ self.proj_in = PatchEmbed(
254
+ height=max_height,
255
+ width=max_width,
256
+ patch_size=patch_size,
257
+ embed_dim=self.inner_dim,
258
+ bias=True,
259
+ )
260
+
261
+ self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
262
+ self.gradient_checkpointing = False
263
+
264
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
265
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
266
+ """
267
+ Sets the attention processor to use [feed forward
268
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
269
+
270
+ Parameters:
271
+ chunk_size (`int`, *optional*):
272
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
273
+ over each tensor of dim=`dim`.
274
+ dim (`int`, *optional*, defaults to `0`):
275
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
276
+ or dim=1 (sequence length).
277
+ """
278
+ if dim not in [0, 1]:
279
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
280
+
281
+ # By default chunk size is 1
282
+ chunk_size = chunk_size or 1
283
+
284
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
285
+ if hasattr(module, "set_chunk_feed_forward"):
286
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
287
+
288
+ for child in module.children():
289
+ fn_recursive_feed_forward(child, chunk_size, dim)
290
+
291
+ for module in self.children():
292
+ fn_recursive_feed_forward(module, chunk_size, dim)
293
+
294
+ def _set_gradient_checkpointing(self, module, value=False):
295
+ if hasattr(module, "gradient_checkpointing"):
296
+ module.gradient_checkpointing = value
297
+
298
+ def forward_lyric_encoder(
299
+ self,
300
+ lyric_token_idx: Optional[torch.LongTensor] = None,
301
+ lyric_mask: Optional[torch.LongTensor] = None,
302
+ ):
303
+ # N x T x D
304
+ lyric_embs = self.lyric_embs(lyric_token_idx)
305
+ prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
306
+ prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
307
+ return prompt_prenet_out
308
+
309
+ def encode(
310
+ self,
311
+ encoder_text_hidden_states: Optional[torch.Tensor] = None,
312
+ text_attention_mask: Optional[torch.LongTensor] = None,
313
+ speaker_embeds: Optional[torch.FloatTensor] = None,
314
+ lyric_token_idx: Optional[torch.LongTensor] = None,
315
+ lyric_mask: Optional[torch.LongTensor] = None,
316
+ ):
317
+
318
+ bs = encoder_text_hidden_states.shape[0]
319
+ device = encoder_text_hidden_states.device
320
+
321
+ # speaker embedding
322
+ encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
323
+ speaker_mask = torch.ones(bs, 1, device=device)
324
+
325
+ # genre embedding
326
+ encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
327
+
328
+ # lyric
329
+ encoder_lyric_hidden_states = self.forward_lyric_encoder(
330
+ lyric_token_idx=lyric_token_idx,
331
+ lyric_mask=lyric_mask,
332
+ )
333
+
334
+ encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
335
+ encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
336
+ return encoder_hidden_states, encoder_hidden_mask
337
+
338
+ def decode(
339
+ self,
340
+ hidden_states: torch.Tensor,
341
+ attention_mask: torch.Tensor,
342
+ encoder_hidden_states: torch.Tensor,
343
+ encoder_hidden_mask: torch.Tensor,
344
+ timestep: Optional[torch.Tensor],
345
+ ssl_hidden_states: Optional[List[torch.Tensor]] = None,
346
+ output_length: int = 0,
347
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
348
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
349
+ return_dict: bool = True,
350
+ ):
351
+
352
+ embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
353
+ temb = self.t_block(embedded_timestep)
354
+
355
+ hidden_states = self.proj_in(hidden_states)
356
+
357
+ # controlnet logic
358
+ if block_controlnet_hidden_states is not None:
359
+ control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
360
+ hidden_states = hidden_states + control_condi * controlnet_scale
361
+
362
+ inner_hidden_states = []
363
+
364
+ rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
365
+ encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
366
+
367
+ for index_block, block in enumerate(self.transformer_blocks):
368
+
369
+ if self.training and self.gradient_checkpointing:
370
+
371
+ def create_custom_forward(module, return_dict=None):
372
+ def custom_forward(*inputs):
373
+ if return_dict is not None:
374
+ return module(*inputs, return_dict=return_dict)
375
+ else:
376
+ return module(*inputs)
377
+
378
+ return custom_forward
379
+
380
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
381
+ hidden_states = torch.utils.checkpoint.checkpoint(
382
+ create_custom_forward(block),
383
+ hidden_states=hidden_states,
384
+ attention_mask=attention_mask,
385
+ encoder_hidden_states=encoder_hidden_states,
386
+ encoder_attention_mask=encoder_hidden_mask,
387
+ rotary_freqs_cis=rotary_freqs_cis,
388
+ rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
389
+ temb=temb,
390
+ **ckpt_kwargs,
391
+ )
392
+
393
+ else:
394
+ hidden_states = block(
395
+ hidden_states=hidden_states,
396
+ attention_mask=attention_mask,
397
+ encoder_hidden_states=encoder_hidden_states,
398
+ encoder_attention_mask=encoder_hidden_mask,
399
+ rotary_freqs_cis=rotary_freqs_cis,
400
+ rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
401
+ temb=temb,
402
+ )
403
+
404
+ for ssl_encoder_depth in self.ssl_encoder_depths:
405
+ if index_block == ssl_encoder_depth:
406
+ inner_hidden_states.append(hidden_states)
407
+
408
+ proj_losses = []
409
+ if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
410
+
411
+ for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
412
+ if ssl_hidden_state is None:
413
+ continue
414
+ # 1. N x T x D1 -> N x D x D2
415
+ est_ssl_hidden_state = projector(inner_hidden_state)
416
+ # 3. projection loss
417
+ bs = inner_hidden_state.shape[0]
418
+ proj_loss = 0.0
419
+ for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
420
+ # 2. interpolate
421
+ z_tilde = F.interpolate(z_tilde.unsqueeze(0).transpose(1, 2), size=len(z), mode='linear', align_corners=False).transpose(1, 2).squeeze(0)
422
+
423
+ z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
424
+ z = torch.nn.functional.normalize(z, dim=-1)
425
+ # T x d -> T x 1 -> 1
426
+ target = torch.ones(z.shape[0], device=z.device)
427
+ proj_loss += self.cosine_loss(z, z_tilde, target)
428
+ proj_losses.append((ssl_name, proj_loss / bs))
429
+
430
+ output = self.final_layer(hidden_states, embedded_timestep, output_length)
431
+ if not return_dict:
432
+ return (output, proj_losses)
433
+
434
+ return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
435
+
436
+ # @torch.compile
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ attention_mask: torch.Tensor,
441
+ encoder_text_hidden_states: Optional[torch.Tensor] = None,
442
+ text_attention_mask: Optional[torch.LongTensor] = None,
443
+ speaker_embeds: Optional[torch.FloatTensor] = None,
444
+ lyric_token_idx: Optional[torch.LongTensor] = None,
445
+ lyric_mask: Optional[torch.LongTensor] = None,
446
+ timestep: Optional[torch.Tensor] = None,
447
+ ssl_hidden_states: Optional[List[torch.Tensor]] = None,
448
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
449
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
450
+ return_dict: bool = True,
451
+ ):
452
+ encoder_hidden_states, encoder_hidden_mask = self.encode(
453
+ encoder_text_hidden_states=encoder_text_hidden_states,
454
+ text_attention_mask=text_attention_mask,
455
+ speaker_embeds=speaker_embeds,
456
+ lyric_token_idx=lyric_token_idx,
457
+ lyric_mask=lyric_mask,
458
+ )
459
+
460
+ output_length = hidden_states.shape[-1]
461
+
462
+ output = self.decode(
463
+ hidden_states=hidden_states,
464
+ attention_mask=attention_mask,
465
+ encoder_hidden_states=encoder_hidden_states,
466
+ encoder_hidden_mask=encoder_hidden_mask,
467
+ timestep=timestep,
468
+ ssl_hidden_states=ssl_hidden_states,
469
+ output_length=output_length,
470
+ block_controlnet_hidden_states=block_controlnet_hidden_states,
471
+ controlnet_scale=controlnet_scale,
472
+ return_dict=return_dict,
473
+ )
474
+
475
+ return output
models/attention.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.models.normalization import RMSNorm
22
+
23
+
24
+ try:
25
+ # from .dcformer import DCMHAttention
26
+ from .customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0
27
+ except ImportError:
28
+ # from dcformer import DCMHAttention
29
+ from customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
36
+ """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
37
+ if isinstance(x, (list, tuple)):
38
+ return list(x)
39
+ return [x for _ in range(repeat_time)]
40
+
41
+
42
+ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
43
+ """Return tuple with min_len by repeating element at idx_repeat."""
44
+ # convert to list first
45
+ x = val2list(x)
46
+
47
+ # repeat elements if necessary
48
+ if len(x) > 0:
49
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
50
+
51
+ return tuple(x)
52
+
53
+
54
+ def t2i_modulate(x, shift, scale):
55
+ return x * (1 + scale) + shift
56
+
57
+
58
+ def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
59
+ if isinstance(kernel_size, tuple):
60
+ return tuple([get_same_padding(ks) for ks in kernel_size])
61
+ else:
62
+ assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
63
+ return kernel_size // 2
64
+
65
+ class ConvLayer(nn.Module):
66
+ def __init__(
67
+ self,
68
+ in_dim: int,
69
+ out_dim: int,
70
+ kernel_size=3,
71
+ stride=1,
72
+ dilation=1,
73
+ groups=1,
74
+ padding: Union[int, None] = None,
75
+ use_bias=False,
76
+ norm=None,
77
+ act=None,
78
+ ):
79
+ super().__init__()
80
+ if padding is None:
81
+ padding = get_same_padding(kernel_size)
82
+ padding *= dilation
83
+
84
+ self.in_dim = in_dim
85
+ self.out_dim = out_dim
86
+ self.kernel_size = kernel_size
87
+ self.stride = stride
88
+ self.dilation = dilation
89
+ self.groups = groups
90
+ self.padding = padding
91
+ self.use_bias = use_bias
92
+
93
+ self.conv = nn.Conv1d(
94
+ in_dim,
95
+ out_dim,
96
+ kernel_size=kernel_size,
97
+ stride=stride,
98
+ padding=padding,
99
+ dilation=dilation,
100
+ groups=groups,
101
+ bias=use_bias,
102
+ )
103
+ if norm is not None:
104
+ self.norm = RMSNorm(out_dim, elementwise_affine=False)
105
+ else:
106
+ self.norm = None
107
+ if act is not None:
108
+ self.act = nn.SiLU(inplace=True)
109
+ else:
110
+ self.act = None
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ x = self.conv(x)
114
+ if self.norm:
115
+ x = self.norm(x)
116
+ if self.act:
117
+ x = self.act(x)
118
+ return x
119
+
120
+
121
+ class GLUMBConv(nn.Module):
122
+ def __init__(
123
+ self,
124
+ in_features: int,
125
+ hidden_features: int,
126
+ out_feature=None,
127
+ kernel_size=3,
128
+ stride=1,
129
+ padding: Union[int, None] = None,
130
+ use_bias=False,
131
+ norm=(None, None, None),
132
+ act=("silu", "silu", None),
133
+ dilation=1,
134
+ ):
135
+ out_feature = out_feature or in_features
136
+ super().__init__()
137
+ use_bias = val2tuple(use_bias, 3)
138
+ norm = val2tuple(norm, 3)
139
+ act = val2tuple(act, 3)
140
+
141
+ self.glu_act = nn.SiLU(inplace=False)
142
+ self.inverted_conv = ConvLayer(
143
+ in_features,
144
+ hidden_features * 2,
145
+ 1,
146
+ use_bias=use_bias[0],
147
+ norm=norm[0],
148
+ act=act[0],
149
+ )
150
+ self.depth_conv = ConvLayer(
151
+ hidden_features * 2,
152
+ hidden_features * 2,
153
+ kernel_size,
154
+ stride=stride,
155
+ groups=hidden_features * 2,
156
+ padding=padding,
157
+ use_bias=use_bias[1],
158
+ norm=norm[1],
159
+ act=None,
160
+ dilation=dilation,
161
+ )
162
+ self.point_conv = ConvLayer(
163
+ hidden_features,
164
+ out_feature,
165
+ 1,
166
+ use_bias=use_bias[2],
167
+ norm=norm[2],
168
+ act=act[2],
169
+ )
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ x = x.transpose(1, 2)
173
+ x = self.inverted_conv(x)
174
+ x = self.depth_conv(x)
175
+
176
+ x, gate = torch.chunk(x, 2, dim=1)
177
+ gate = self.glu_act(gate)
178
+ x = x * gate
179
+
180
+ x = self.point_conv(x)
181
+ x = x.transpose(1, 2)
182
+
183
+ return x
184
+
185
+
186
+ class LinearTransformerBlock(nn.Module):
187
+ """
188
+ A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
189
+ """
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ num_attention_heads,
194
+ attention_head_dim,
195
+ use_adaln_single=True,
196
+ cross_attention_dim=None,
197
+ added_kv_proj_dim=None,
198
+ context_pre_only=False,
199
+ mlp_ratio=4.0,
200
+ add_cross_attention=False,
201
+ add_cross_attention_dim=None,
202
+ qk_norm=None,
203
+ ):
204
+ super().__init__()
205
+
206
+ self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
207
+ self.attn = Attention(
208
+ query_dim=dim,
209
+ cross_attention_dim=cross_attention_dim,
210
+ added_kv_proj_dim=added_kv_proj_dim,
211
+ dim_head=attention_head_dim,
212
+ heads=num_attention_heads,
213
+ out_dim=dim,
214
+ bias=True,
215
+ qk_norm=qk_norm,
216
+ processor=CustomLiteLAProcessor2_0(),
217
+ )
218
+
219
+ self.add_cross_attention = add_cross_attention
220
+ self.context_pre_only = context_pre_only
221
+
222
+ if add_cross_attention and add_cross_attention_dim is not None:
223
+ self.cross_attn = Attention(
224
+ query_dim=dim,
225
+ cross_attention_dim=add_cross_attention_dim,
226
+ added_kv_proj_dim=add_cross_attention_dim,
227
+ dim_head=attention_head_dim,
228
+ heads=num_attention_heads,
229
+ out_dim=dim,
230
+ context_pre_only=context_pre_only,
231
+ bias=True,
232
+ qk_norm=qk_norm,
233
+ processor=CustomerAttnProcessor2_0(),
234
+ )
235
+
236
+ self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False)
237
+
238
+ self.ff = GLUMBConv(
239
+ in_features=dim,
240
+ hidden_features=int(dim * mlp_ratio),
241
+ use_bias=(True, True, False),
242
+ norm=(None, None, None),
243
+ act=("silu", "silu", None),
244
+ )
245
+ self.use_adaln_single = use_adaln_single
246
+ if use_adaln_single:
247
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.FloatTensor,
252
+ encoder_hidden_states: torch.FloatTensor = None,
253
+ attention_mask: torch.FloatTensor = None,
254
+ encoder_attention_mask: torch.FloatTensor = None,
255
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
256
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
257
+ temb: torch.FloatTensor = None,
258
+ ):
259
+
260
+ N = hidden_states.shape[0]
261
+
262
+ # step 1: AdaLN single
263
+ if self.use_adaln_single:
264
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
265
+ self.scale_shift_table[None] + temb.reshape(N, 6, -1)
266
+ ).chunk(6, dim=1)
267
+
268
+ norm_hidden_states = self.norm1(hidden_states)
269
+ if self.use_adaln_single:
270
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
271
+
272
+ # step 2: attention
273
+ if not self.add_cross_attention:
274
+ attn_output, encoder_hidden_states = self.attn(
275
+ hidden_states=norm_hidden_states,
276
+ attention_mask=attention_mask,
277
+ encoder_hidden_states=encoder_hidden_states,
278
+ encoder_attention_mask=encoder_attention_mask,
279
+ rotary_freqs_cis=rotary_freqs_cis,
280
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
281
+ )
282
+ else:
283
+ attn_output, _ = self.attn(
284
+ hidden_states=norm_hidden_states,
285
+ attention_mask=attention_mask,
286
+ encoder_hidden_states=None,
287
+ encoder_attention_mask=None,
288
+ rotary_freqs_cis=rotary_freqs_cis,
289
+ rotary_freqs_cis_cross=None,
290
+ )
291
+
292
+ if self.use_adaln_single:
293
+ attn_output = gate_msa * attn_output
294
+ hidden_states = attn_output + hidden_states
295
+
296
+ if self.add_cross_attention:
297
+ attn_output = self.cross_attn(
298
+ hidden_states=hidden_states,
299
+ attention_mask=attention_mask,
300
+ encoder_hidden_states=encoder_hidden_states,
301
+ encoder_attention_mask=encoder_attention_mask,
302
+ rotary_freqs_cis=rotary_freqs_cis,
303
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
304
+ )
305
+ hidden_states = attn_output + hidden_states
306
+
307
+ # step 3: add norm
308
+ norm_hidden_states = self.norm2(hidden_states)
309
+ if self.use_adaln_single:
310
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
311
+
312
+ # step 4: feed forward
313
+ ff_output = self.ff(norm_hidden_states)
314
+ if self.use_adaln_single:
315
+ ff_output = gate_mlp * ff_output
316
+
317
+ hidden_states = hidden_states + ff_output
318
+
319
+ return hidden_states
models/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Transformer2DModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "in_channels": 8,
5
+ "num_layers": 24,
6
+ "inner_dim": 2560,
7
+ "attention_head_dim": 128,
8
+ "num_attention_heads": 20,
9
+ "mlp_ratio": 2.5,
10
+ "out_channels": 8,
11
+ "max_position": 32768,
12
+ "rope_theta": 1000000.0,
13
+ "speaker_embedding_dim": 512,
14
+ "text_embedding_dim": 768,
15
+ "ssl_encoder_depths": [8, 8],
16
+ "ssl_names": ["mert", "m-hubert"],
17
+ "ssl_latent_dims": [1024, 768],
18
+ "patch_size": [16, 1],
19
+ "max_height": 16,
20
+ "max_width": 32768,
21
+ "lyric_encoder_vocab_size": 6693,
22
+ "lyric_hidden_size": 1024
23
+ }
models/customer_attention_processor.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Union, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.models.attention_processor import Attention
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class CustomLiteLAProcessor2_0:
27
+ """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
28
+
29
+ def __init__(self):
30
+ self.kernel_func = nn.ReLU(inplace=False)
31
+ self.eps = 1e-15
32
+ self.pad_val = 1.0
33
+
34
+ def apply_rotary_emb(
35
+ self,
36
+ x: torch.Tensor,
37
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
38
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
41
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
42
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
43
+ tensors contain rotary embeddings and are returned as real tensors.
44
+
45
+ Args:
46
+ x (`torch.Tensor`):
47
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
48
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
52
+ """
53
+ cos, sin = freqs_cis # [S, D]
54
+ cos = cos[None, None]
55
+ sin = sin[None, None]
56
+ cos, sin = cos.to(x.device), sin.to(x.device)
57
+
58
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
59
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
60
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
61
+
62
+ return out
63
+
64
+ def __call__(
65
+ self,
66
+ attn: Attention,
67
+ hidden_states: torch.FloatTensor,
68
+ encoder_hidden_states: torch.FloatTensor = None,
69
+ attention_mask: Optional[torch.FloatTensor] = None,
70
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
71
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
72
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
73
+ *args,
74
+ **kwargs,
75
+ ) -> torch.FloatTensor:
76
+ hidden_states_len = hidden_states.shape[1]
77
+
78
+ input_ndim = hidden_states.ndim
79
+ if input_ndim == 4:
80
+ batch_size, channel, height, width = hidden_states.shape
81
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
82
+ if encoder_hidden_states is not None:
83
+ context_input_ndim = encoder_hidden_states.ndim
84
+ if context_input_ndim == 4:
85
+ batch_size, channel, height, width = encoder_hidden_states.shape
86
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
87
+
88
+ batch_size = hidden_states.shape[0]
89
+
90
+ # `sample` projections.
91
+ dtype = hidden_states.dtype
92
+ query = attn.to_q(hidden_states)
93
+ key = attn.to_k(hidden_states)
94
+ value = attn.to_v(hidden_states)
95
+
96
+ # `context` projections.
97
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
98
+ if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
99
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
100
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
101
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
102
+
103
+ # attention
104
+ if not attn.is_cross_attention:
105
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
106
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
107
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
108
+ else:
109
+ query = hidden_states
110
+ key = encoder_hidden_states
111
+ value = encoder_hidden_states
112
+
113
+ inner_dim = key.shape[-1]
114
+ head_dim = inner_dim // attn.heads
115
+
116
+ query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
117
+ key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
118
+ value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
119
+
120
+ # RoPE需要 [B, H, S, D] 输入
121
+ # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
122
+ query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
123
+
124
+ # Apply query and key normalization if needed
125
+ if attn.norm_q is not None:
126
+ query = attn.norm_q(query)
127
+ if attn.norm_k is not None:
128
+ key = attn.norm_k(key)
129
+
130
+ # Apply RoPE if needed
131
+ if rotary_freqs_cis is not None:
132
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
133
+ if not attn.is_cross_attention:
134
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
135
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
136
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
137
+
138
+ # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
139
+ query = query.permute(0, 1, 3, 2) # [B, H, D, S]
140
+
141
+ if attention_mask is not None:
142
+ # attention_mask: [B, S] -> [B, 1, S, 1]
143
+ attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
144
+ query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
145
+ if not attn.is_cross_attention:
146
+ key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
147
+ value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
148
+
149
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
150
+ encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
151
+ # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
152
+ key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
153
+ value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
154
+
155
+ query = self.kernel_func(query)
156
+ key = self.kernel_func(key)
157
+
158
+ query, key, value = query.float(), key.float(), value.float()
159
+
160
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
161
+
162
+ vk = torch.matmul(value, key)
163
+
164
+ hidden_states = torch.matmul(vk, query)
165
+
166
+ if hidden_states.dtype in [torch.float16, torch.bfloat16]:
167
+ hidden_states = hidden_states.float()
168
+
169
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
170
+
171
+ hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
172
+
173
+ hidden_states = hidden_states.to(dtype)
174
+ if encoder_hidden_states is not None:
175
+ encoder_hidden_states = encoder_hidden_states.to(dtype)
176
+
177
+ # Split the attention outputs.
178
+ if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
179
+ hidden_states, encoder_hidden_states = (
180
+ hidden_states[:, : hidden_states_len],
181
+ hidden_states[:, hidden_states_len:],
182
+ )
183
+
184
+ # linear proj
185
+ hidden_states = attn.to_out[0](hidden_states)
186
+ # dropout
187
+ hidden_states = attn.to_out[1](hidden_states)
188
+ if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
189
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
190
+
191
+ if input_ndim == 4:
192
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
193
+ if encoder_hidden_states is not None and context_input_ndim == 4:
194
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
195
+
196
+ if torch.get_autocast_gpu_dtype() == torch.float16:
197
+ hidden_states = hidden_states.clip(-65504, 65504)
198
+ if encoder_hidden_states is not None:
199
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
200
+
201
+ return hidden_states, encoder_hidden_states
202
+
203
+
204
+ class CustomerAttnProcessor2_0:
205
+ r"""
206
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
207
+ """
208
+
209
+ def __init__(self):
210
+ if not hasattr(F, "scaled_dot_product_attention"):
211
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
212
+
213
+ def apply_rotary_emb(
214
+ self,
215
+ x: torch.Tensor,
216
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
217
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
218
+ """
219
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
220
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
221
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
222
+ tensors contain rotary embeddings and are returned as real tensors.
223
+
224
+ Args:
225
+ x (`torch.Tensor`):
226
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
227
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
228
+
229
+ Returns:
230
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
231
+ """
232
+ cos, sin = freqs_cis # [S, D]
233
+ cos = cos[None, None]
234
+ sin = sin[None, None]
235
+ cos, sin = cos.to(x.device), sin.to(x.device)
236
+
237
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
238
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
239
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
240
+
241
+ return out
242
+
243
+ def __call__(
244
+ self,
245
+ attn: Attention,
246
+ hidden_states: torch.FloatTensor,
247
+ encoder_hidden_states: torch.FloatTensor = None,
248
+ attention_mask: Optional[torch.FloatTensor] = None,
249
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
250
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
251
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
252
+ *args,
253
+ **kwargs,
254
+ ) -> torch.Tensor:
255
+
256
+ residual = hidden_states
257
+ input_ndim = hidden_states.ndim
258
+
259
+ if input_ndim == 4:
260
+ batch_size, channel, height, width = hidden_states.shape
261
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
262
+
263
+ batch_size, sequence_length, _ = (
264
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
265
+ )
266
+
267
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
268
+
269
+ if attn.group_norm is not None:
270
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
271
+
272
+ query = attn.to_q(hidden_states)
273
+
274
+ if encoder_hidden_states is None:
275
+ encoder_hidden_states = hidden_states
276
+ elif attn.norm_cross:
277
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
278
+
279
+ key = attn.to_k(encoder_hidden_states)
280
+ value = attn.to_v(encoder_hidden_states)
281
+
282
+ inner_dim = key.shape[-1]
283
+ head_dim = inner_dim // attn.heads
284
+
285
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
286
+
287
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
288
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
289
+
290
+ if attn.norm_q is not None:
291
+ query = attn.norm_q(query)
292
+ if attn.norm_k is not None:
293
+ key = attn.norm_k(key)
294
+
295
+ # Apply RoPE if needed
296
+ if rotary_freqs_cis is not None:
297
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
298
+ if not attn.is_cross_attention:
299
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
300
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
301
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
302
+
303
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
304
+ # attention_mask: N x S1
305
+ # encoder_attention_mask: N x S2
306
+ # cross attention 整合attention_mask和encoder_attention_mask
307
+ combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
308
+ attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
309
+ attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
310
+
311
+ elif not attn.is_cross_attention and attention_mask is not None:
312
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
313
+ # scaled_dot_product_attention expects attention_mask shape to be
314
+ # (batch, heads, source_length, target_length)
315
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
316
+
317
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
318
+ # TODO: add support for attn.scale when we move to Torch 2.1
319
+ hidden_states = F.scaled_dot_product_attention(
320
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
321
+ )
322
+
323
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
324
+ hidden_states = hidden_states.to(query.dtype)
325
+
326
+ # linear proj
327
+ hidden_states = attn.to_out[0](hidden_states)
328
+ # dropout
329
+ hidden_states = attn.to_out[1](hidden_states)
330
+
331
+ if input_ndim == 4:
332
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
333
+
334
+ if attn.residual_connection:
335
+ hidden_states = hidden_states + residual
336
+
337
+ hidden_states = hidden_states / attn.rescale_output_factor
338
+
339
+ return hidden_states
models/lyrics_utils/lyric_encoder.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+
6
+ class ConvolutionModule(nn.Module):
7
+ """ConvolutionModule in Conformer model."""
8
+
9
+ def __init__(self,
10
+ channels: int,
11
+ kernel_size: int = 15,
12
+ activation: nn.Module = nn.ReLU(),
13
+ norm: str = "batch_norm",
14
+ causal: bool = False,
15
+ bias: bool = True):
16
+ """Construct an ConvolutionModule object.
17
+ Args:
18
+ channels (int): The number of channels of conv layers.
19
+ kernel_size (int): Kernel size of conv layers.
20
+ causal (int): Whether use causal convolution or not
21
+ """
22
+ super().__init__()
23
+
24
+ self.pointwise_conv1 = nn.Conv1d(
25
+ channels,
26
+ 2 * channels,
27
+ kernel_size=1,
28
+ stride=1,
29
+ padding=0,
30
+ bias=bias,
31
+ )
32
+ # self.lorder is used to distinguish if it's a causal convolution,
33
+ # if self.lorder > 0: it's a causal convolution, the input will be
34
+ # padded with self.lorder frames on the left in forward.
35
+ # else: it's a symmetrical convolution
36
+ if causal:
37
+ padding = 0
38
+ self.lorder = kernel_size - 1
39
+ else:
40
+ # kernel_size should be an odd number for none causal convolution
41
+ assert (kernel_size - 1) % 2 == 0
42
+ padding = (kernel_size - 1) // 2
43
+ self.lorder = 0
44
+ self.depthwise_conv = nn.Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ stride=1,
49
+ padding=padding,
50
+ groups=channels,
51
+ bias=bias,
52
+ )
53
+
54
+ assert norm in ['batch_norm', 'layer_norm']
55
+ if norm == "batch_norm":
56
+ self.use_layer_norm = False
57
+ self.norm = nn.BatchNorm1d(channels)
58
+ else:
59
+ self.use_layer_norm = True
60
+ self.norm = nn.LayerNorm(channels)
61
+
62
+ self.pointwise_conv2 = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size=1,
66
+ stride=1,
67
+ padding=0,
68
+ bias=bias,
69
+ )
70
+ self.activation = activation
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
76
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ """Compute convolution module.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (#batch, time, channels).
81
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
82
+ (0, 0, 0) means fake mask.
83
+ cache (torch.Tensor): left context cache, it is only
84
+ used in causal convolution (#batch, channels, cache_t),
85
+ (0, 0, 0) meas fake cache.
86
+ Returns:
87
+ torch.Tensor: Output tensor (#batch, time, channels).
88
+ """
89
+ # exchange the temporal dimension and the feature dimension
90
+ x = x.transpose(1, 2) # (#batch, channels, time)
91
+
92
+ # mask batch padding
93
+ if mask_pad.size(2) > 0: # time > 0
94
+ x.masked_fill_(~mask_pad, 0.0)
95
+
96
+ if self.lorder > 0:
97
+ if cache.size(2) == 0: # cache_t == 0
98
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
99
+ else:
100
+ assert cache.size(0) == x.size(0) # equal batch
101
+ assert cache.size(1) == x.size(1) # equal channel
102
+ x = torch.cat((cache, x), dim=2)
103
+ assert (x.size(2) > self.lorder)
104
+ new_cache = x[:, :, -self.lorder:]
105
+ else:
106
+ # It's better we just return None if no cache is required,
107
+ # However, for JIT export, here we just fake one tensor instead of
108
+ # None.
109
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
110
+
111
+ # GLU mechanism
112
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
113
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
114
+
115
+ # 1D Depthwise Conv
116
+ x = self.depthwise_conv(x)
117
+ if self.use_layer_norm:
118
+ x = x.transpose(1, 2)
119
+ x = self.activation(self.norm(x))
120
+ if self.use_layer_norm:
121
+ x = x.transpose(1, 2)
122
+ x = self.pointwise_conv2(x)
123
+ # mask batch padding
124
+ if mask_pad.size(2) > 0: # time > 0
125
+ x.masked_fill_(~mask_pad, 0.0)
126
+
127
+ return x.transpose(1, 2), new_cache
128
+
129
+ class PositionwiseFeedForward(torch.nn.Module):
130
+ """Positionwise feed forward layer.
131
+
132
+ FeedForward are appied on each position of the sequence.
133
+ The output dim is same with the input dim.
134
+
135
+ Args:
136
+ idim (int): Input dimenstion.
137
+ hidden_units (int): The number of hidden units.
138
+ dropout_rate (float): Dropout rate.
139
+ activation (torch.nn.Module): Activation function
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ idim: int,
145
+ hidden_units: int,
146
+ dropout_rate: float,
147
+ activation: torch.nn.Module = torch.nn.ReLU(),
148
+ ):
149
+ """Construct a PositionwiseFeedForward object."""
150
+ super(PositionwiseFeedForward, self).__init__()
151
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
152
+ self.activation = activation
153
+ self.dropout = torch.nn.Dropout(dropout_rate)
154
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
155
+
156
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
157
+ """Forward function.
158
+
159
+ Args:
160
+ xs: input tensor (B, L, D)
161
+ Returns:
162
+ output tensor, (B, L, D)
163
+ """
164
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
165
+
166
+ class Swish(torch.nn.Module):
167
+ """Construct an Swish object."""
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ """Return Swish activation function."""
171
+ return x * torch.sigmoid(x)
172
+
173
+ class MultiHeadedAttention(nn.Module):
174
+ """Multi-Head Attention layer.
175
+
176
+ Args:
177
+ n_head (int): The number of heads.
178
+ n_feat (int): The number of features.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self,
184
+ n_head: int,
185
+ n_feat: int,
186
+ dropout_rate: float,
187
+ key_bias: bool = True):
188
+ """Construct an MultiHeadedAttention object."""
189
+ super().__init__()
190
+ assert n_feat % n_head == 0
191
+ # We assume d_v always equals d_k
192
+ self.d_k = n_feat // n_head
193
+ self.h = n_head
194
+ self.linear_q = nn.Linear(n_feat, n_feat)
195
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
196
+ self.linear_v = nn.Linear(n_feat, n_feat)
197
+ self.linear_out = nn.Linear(n_feat, n_feat)
198
+ self.dropout = nn.Dropout(p=dropout_rate)
199
+
200
+ def forward_qkv(
201
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
202
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203
+ """Transform query, key and value.
204
+
205
+ Args:
206
+ query (torch.Tensor): Query tensor (#batch, time1, size).
207
+ key (torch.Tensor): Key tensor (#batch, time2, size).
208
+ value (torch.Tensor): Value tensor (#batch, time2, size).
209
+
210
+ Returns:
211
+ torch.Tensor: Transformed query tensor, size
212
+ (#batch, n_head, time1, d_k).
213
+ torch.Tensor: Transformed key tensor, size
214
+ (#batch, n_head, time2, d_k).
215
+ torch.Tensor: Transformed value tensor, size
216
+ (#batch, n_head, time2, d_k).
217
+
218
+ """
219
+ n_batch = query.size(0)
220
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
221
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
222
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
223
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
224
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
225
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
226
+ return q, k, v
227
+
228
+ def forward_attention(
229
+ self,
230
+ value: torch.Tensor,
231
+ scores: torch.Tensor,
232
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
233
+ ) -> torch.Tensor:
234
+ """Compute attention context vector.
235
+
236
+ Args:
237
+ value (torch.Tensor): Transformed value, size
238
+ (#batch, n_head, time2, d_k).
239
+ scores (torch.Tensor): Attention score, size
240
+ (#batch, n_head, time1, time2).
241
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
242
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
243
+
244
+ Returns:
245
+ torch.Tensor: Transformed value (#batch, time1, d_model)
246
+ weighted by the attention score (#batch, time1, time2).
247
+
248
+ """
249
+ n_batch = value.size(0)
250
+
251
+ if mask.size(2) > 0: # time2 > 0
252
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
253
+ # For last chunk, time2 might be larger than scores.size(-1)
254
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
255
+ scores = scores.masked_fill(mask, -float('inf'))
256
+ attn = torch.softmax(scores, dim=-1).masked_fill(
257
+ mask, 0.0) # (batch, head, time1, time2)
258
+
259
+ else:
260
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
261
+
262
+ p_attn = self.dropout(attn)
263
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
264
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
265
+ self.h * self.d_k)
266
+ ) # (batch, time1, d_model)
267
+
268
+ return self.linear_out(x) # (batch, time1, d_model)
269
+
270
+ def forward(
271
+ self,
272
+ query: torch.Tensor,
273
+ key: torch.Tensor,
274
+ value: torch.Tensor,
275
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
276
+ pos_emb: torch.Tensor = torch.empty(0),
277
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
278
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
279
+ """Compute scaled dot product attention.
280
+
281
+ Args:
282
+ query (torch.Tensor): Query tensor (#batch, time1, size).
283
+ key (torch.Tensor): Key tensor (#batch, time2, size).
284
+ value (torch.Tensor): Value tensor (#batch, time2, size).
285
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
286
+ (#batch, time1, time2).
287
+ 1.When applying cross attention between decoder and encoder,
288
+ the batch padding mask for input is in (#batch, 1, T) shape.
289
+ 2.When applying self attention of encoder,
290
+ the mask is in (#batch, T, T) shape.
291
+ 3.When applying self attention of decoder,
292
+ the mask is in (#batch, L, L) shape.
293
+ 4.If the different position in decoder see different block
294
+ of the encoder, such as Mocha, the passed in mask could be
295
+ in (#batch, L, T) shape. But there is no such case in current
296
+ CosyVoice.
297
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
298
+ where `cache_t == chunk_size * num_decoding_left_chunks`
299
+ and `head * d_k == size`
300
+
301
+
302
+ Returns:
303
+ torch.Tensor: Output tensor (#batch, time1, d_model).
304
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
305
+ where `cache_t == chunk_size * num_decoding_left_chunks`
306
+ and `head * d_k == size`
307
+
308
+ """
309
+ q, k, v = self.forward_qkv(query, key, value)
310
+ if cache.size(0) > 0:
311
+ key_cache, value_cache = torch.split(cache,
312
+ cache.size(-1) // 2,
313
+ dim=-1)
314
+ k = torch.cat([key_cache, k], dim=2)
315
+ v = torch.cat([value_cache, v], dim=2)
316
+ new_cache = torch.cat((k, v), dim=-1)
317
+
318
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
319
+ return self.forward_attention(v, scores, mask), new_cache
320
+
321
+
322
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
323
+ """Multi-Head Attention layer with relative position encoding.
324
+ Paper: https://arxiv.org/abs/1901.02860
325
+ Args:
326
+ n_head (int): The number of heads.
327
+ n_feat (int): The number of features.
328
+ dropout_rate (float): Dropout rate.
329
+ """
330
+
331
+ def __init__(self,
332
+ n_head: int,
333
+ n_feat: int,
334
+ dropout_rate: float,
335
+ key_bias: bool = True):
336
+ """Construct an RelPositionMultiHeadedAttention object."""
337
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
338
+ # linear transformation for positional encoding
339
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
340
+ # these two learnable bias are used in matrix c and matrix d
341
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
342
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
343
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
344
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
345
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
346
+
347
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
348
+ """Compute relative positional encoding.
349
+
350
+ Args:
351
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
352
+ time1 means the length of query vector.
353
+
354
+ Returns:
355
+ torch.Tensor: Output tensor.
356
+
357
+ """
358
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
359
+ device=x.device,
360
+ dtype=x.dtype)
361
+ x_padded = torch.cat([zero_pad, x], dim=-1)
362
+
363
+ x_padded = x_padded.view(x.size()[0],
364
+ x.size()[1],
365
+ x.size(3) + 1, x.size(2))
366
+ x = x_padded[:, :, 1:].view_as(x)[
367
+ :, :, :, : x.size(-1) // 2 + 1
368
+ ] # only keep the positions from 0 to time2
369
+ return x
370
+
371
+ def forward(
372
+ self,
373
+ query: torch.Tensor,
374
+ key: torch.Tensor,
375
+ value: torch.Tensor,
376
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
377
+ pos_emb: torch.Tensor = torch.empty(0),
378
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
379
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
380
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
381
+ Args:
382
+ query (torch.Tensor): Query tensor (#batch, time1, size).
383
+ key (torch.Tensor): Key tensor (#batch, time2, size).
384
+ value (torch.Tensor): Value tensor (#batch, time2, size).
385
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
386
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
387
+ pos_emb (torch.Tensor): Positional embedding tensor
388
+ (#batch, time2, size).
389
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
390
+ where `cache_t == chunk_size * num_decoding_left_chunks`
391
+ and `head * d_k == size`
392
+ Returns:
393
+ torch.Tensor: Output tensor (#batch, time1, d_model).
394
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
395
+ where `cache_t == chunk_size * num_decoding_left_chunks`
396
+ and `head * d_k == size`
397
+ """
398
+ q, k, v = self.forward_qkv(query, key, value)
399
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
400
+
401
+ if cache.size(0) > 0:
402
+ key_cache, value_cache = torch.split(cache,
403
+ cache.size(-1) // 2,
404
+ dim=-1)
405
+ k = torch.cat([key_cache, k], dim=2)
406
+ v = torch.cat([value_cache, v], dim=2)
407
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
408
+ # non-trivial to calculate `next_cache_start` here.
409
+ new_cache = torch.cat((k, v), dim=-1)
410
+
411
+ n_batch_pos = pos_emb.size(0)
412
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
413
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
414
+
415
+ # (batch, head, time1, d_k)
416
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
417
+ # (batch, head, time1, d_k)
418
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
419
+
420
+ # compute attention score
421
+ # first compute matrix a and matrix c
422
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
423
+ # (batch, head, time1, time2)
424
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
425
+
426
+ # compute matrix b and matrix d
427
+ # (batch, head, time1, time2)
428
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
429
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
430
+ if matrix_ac.shape != matrix_bd.shape:
431
+ matrix_bd = self.rel_shift(matrix_bd)
432
+
433
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
434
+ self.d_k) # (batch, head, time1, time2)
435
+
436
+ return self.forward_attention(v, scores, mask), new_cache
437
+
438
+
439
+
440
+ def subsequent_mask(
441
+ size: int,
442
+ device: torch.device = torch.device("cpu"),
443
+ ) -> torch.Tensor:
444
+ """Create mask for subsequent steps (size, size).
445
+
446
+ This mask is used only in decoder which works in an auto-regressive mode.
447
+ This means the current step could only do attention with its left steps.
448
+
449
+ In encoder, fully attention is used when streaming is not necessary and
450
+ the sequence is not long. In this case, no attention mask is needed.
451
+
452
+ When streaming is need, chunk-based attention is used in encoder. See
453
+ subsequent_chunk_mask for the chunk-based attention mask.
454
+
455
+ Args:
456
+ size (int): size of mask
457
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
458
+ dtype (torch.device): result dtype
459
+
460
+ Returns:
461
+ torch.Tensor: mask
462
+
463
+ Examples:
464
+ >>> subsequent_mask(3)
465
+ [[1, 0, 0],
466
+ [1, 1, 0],
467
+ [1, 1, 1]]
468
+ """
469
+ arange = torch.arange(size, device=device)
470
+ mask = arange.expand(size, size)
471
+ arange = arange.unsqueeze(-1)
472
+ mask = mask <= arange
473
+ return mask
474
+
475
+
476
+ def subsequent_chunk_mask(
477
+ size: int,
478
+ chunk_size: int,
479
+ num_left_chunks: int = -1,
480
+ device: torch.device = torch.device("cpu"),
481
+ ) -> torch.Tensor:
482
+ """Create mask for subsequent steps (size, size) with chunk size,
483
+ this is for streaming encoder
484
+
485
+ Args:
486
+ size (int): size of mask
487
+ chunk_size (int): size of chunk
488
+ num_left_chunks (int): number of left chunks
489
+ <0: use full chunk
490
+ >=0: use num_left_chunks
491
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
492
+
493
+ Returns:
494
+ torch.Tensor: mask
495
+
496
+ Examples:
497
+ >>> subsequent_chunk_mask(4, 2)
498
+ [[1, 1, 0, 0],
499
+ [1, 1, 0, 0],
500
+ [1, 1, 1, 1],
501
+ [1, 1, 1, 1]]
502
+ """
503
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
504
+ for i in range(size):
505
+ if num_left_chunks < 0:
506
+ start = 0
507
+ else:
508
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
509
+ ending = min((i // chunk_size + 1) * chunk_size, size)
510
+ ret[i, start:ending] = True
511
+ return ret
512
+
513
+ def add_optional_chunk_mask(xs: torch.Tensor,
514
+ masks: torch.Tensor,
515
+ use_dynamic_chunk: bool,
516
+ use_dynamic_left_chunk: bool,
517
+ decoding_chunk_size: int,
518
+ static_chunk_size: int,
519
+ num_decoding_left_chunks: int,
520
+ enable_full_context: bool = True):
521
+ """ Apply optional mask for encoder.
522
+
523
+ Args:
524
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
525
+ mask (torch.Tensor): mask for xs, (B, 1, L)
526
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
527
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
528
+ training.
529
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
530
+ 0: default for training, use random dynamic chunk.
531
+ <0: for decoding, use full chunk.
532
+ >0: for decoding, use fixed chunk size as set.
533
+ static_chunk_size (int): chunk size for static chunk training/decoding
534
+ if it's greater than 0, if use_dynamic_chunk is true,
535
+ this parameter will be ignored
536
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
537
+ the chunk size is decoding_chunk_size.
538
+ >=0: use num_decoding_left_chunks
539
+ <0: use all left chunks
540
+ enable_full_context (bool):
541
+ True: chunk size is either [1, 25] or full context(max_len)
542
+ False: chunk size ~ U[1, 25]
543
+
544
+ Returns:
545
+ torch.Tensor: chunk mask of the input xs.
546
+ """
547
+ # Whether to use chunk mask or not
548
+ if use_dynamic_chunk:
549
+ max_len = xs.size(1)
550
+ if decoding_chunk_size < 0:
551
+ chunk_size = max_len
552
+ num_left_chunks = -1
553
+ elif decoding_chunk_size > 0:
554
+ chunk_size = decoding_chunk_size
555
+ num_left_chunks = num_decoding_left_chunks
556
+ else:
557
+ # chunk size is either [1, 25] or full context(max_len).
558
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
559
+ # delay, the maximum frame is 100 / 4 = 25.
560
+ chunk_size = torch.randint(1, max_len, (1, )).item()
561
+ num_left_chunks = -1
562
+ if chunk_size > max_len // 2 and enable_full_context:
563
+ chunk_size = max_len
564
+ else:
565
+ chunk_size = chunk_size % 25 + 1
566
+ if use_dynamic_left_chunk:
567
+ max_left_chunks = (max_len - 1) // chunk_size
568
+ num_left_chunks = torch.randint(0, max_left_chunks,
569
+ (1, )).item()
570
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
571
+ num_left_chunks,
572
+ xs.device) # (L, L)
573
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
574
+ chunk_masks = masks & chunk_masks # (B, L, L)
575
+ elif static_chunk_size > 0:
576
+ num_left_chunks = num_decoding_left_chunks
577
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
578
+ num_left_chunks,
579
+ xs.device) # (L, L)
580
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
581
+ chunk_masks = masks & chunk_masks # (B, L, L)
582
+ else:
583
+ chunk_masks = masks
584
+ return chunk_masks
585
+
586
+
587
+ class ConformerEncoderLayer(nn.Module):
588
+ """Encoder layer module.
589
+ Args:
590
+ size (int): Input dimension.
591
+ self_attn (torch.nn.Module): Self-attention module instance.
592
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
593
+ instance can be used as the argument.
594
+ feed_forward (torch.nn.Module): Feed-forward module instance.
595
+ `PositionwiseFeedForward` instance can be used as the argument.
596
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
597
+ instance.
598
+ `PositionwiseFeedForward` instance can be used as the argument.
599
+ conv_module (torch.nn.Module): Convolution module instance.
600
+ `ConvlutionModule` instance can be used as the argument.
601
+ dropout_rate (float): Dropout rate.
602
+ normalize_before (bool):
603
+ True: use layer_norm before each sub-block.
604
+ False: use layer_norm after each sub-block.
605
+ """
606
+
607
+ def __init__(
608
+ self,
609
+ size: int,
610
+ self_attn: torch.nn.Module,
611
+ feed_forward: Optional[nn.Module] = None,
612
+ feed_forward_macaron: Optional[nn.Module] = None,
613
+ conv_module: Optional[nn.Module] = None,
614
+ dropout_rate: float = 0.1,
615
+ normalize_before: bool = True,
616
+ ):
617
+ """Construct an EncoderLayer object."""
618
+ super().__init__()
619
+ self.self_attn = self_attn
620
+ self.feed_forward = feed_forward
621
+ self.feed_forward_macaron = feed_forward_macaron
622
+ self.conv_module = conv_module
623
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
624
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
625
+ if feed_forward_macaron is not None:
626
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
627
+ self.ff_scale = 0.5
628
+ else:
629
+ self.ff_scale = 1.0
630
+ if self.conv_module is not None:
631
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
632
+ self.norm_final = nn.LayerNorm(
633
+ size, eps=1e-5) # for the final output of the block
634
+ self.dropout = nn.Dropout(dropout_rate)
635
+ self.size = size
636
+ self.normalize_before = normalize_before
637
+
638
+ def forward(
639
+ self,
640
+ x: torch.Tensor,
641
+ mask: torch.Tensor,
642
+ pos_emb: torch.Tensor,
643
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
644
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
645
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
646
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
647
+ """Compute encoded features.
648
+
649
+ Args:
650
+ x (torch.Tensor): (#batch, time, size)
651
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
652
+ (0, 0, 0) means fake mask.
653
+ pos_emb (torch.Tensor): positional encoding, must not be None
654
+ for ConformerEncoderLayer.
655
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
656
+ (#batch, 1,time), (0, 0, 0) means fake mask.
657
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
658
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
659
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
660
+ (#batch=1, size, cache_t2)
661
+ Returns:
662
+ torch.Tensor: Output tensor (#batch, time, size).
663
+ torch.Tensor: Mask tensor (#batch, time, time).
664
+ torch.Tensor: att_cache tensor,
665
+ (#batch=1, head, cache_t1 + time, d_k * 2).
666
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
667
+ """
668
+
669
+ # whether to use macaron style
670
+ if self.feed_forward_macaron is not None:
671
+ residual = x
672
+ if self.normalize_before:
673
+ x = self.norm_ff_macaron(x)
674
+ x = residual + self.ff_scale * self.dropout(
675
+ self.feed_forward_macaron(x))
676
+ if not self.normalize_before:
677
+ x = self.norm_ff_macaron(x)
678
+
679
+ # multi-headed self-attention module
680
+ residual = x
681
+ if self.normalize_before:
682
+ x = self.norm_mha(x)
683
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
684
+ att_cache)
685
+ x = residual + self.dropout(x_att)
686
+ if not self.normalize_before:
687
+ x = self.norm_mha(x)
688
+
689
+ # convolution module
690
+ # Fake new cnn cache here, and then change it in conv_module
691
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
692
+ if self.conv_module is not None:
693
+ residual = x
694
+ if self.normalize_before:
695
+ x = self.norm_conv(x)
696
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
697
+ x = residual + self.dropout(x)
698
+
699
+ if not self.normalize_before:
700
+ x = self.norm_conv(x)
701
+
702
+ # feed forward module
703
+ residual = x
704
+ if self.normalize_before:
705
+ x = self.norm_ff(x)
706
+
707
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
708
+ if not self.normalize_before:
709
+ x = self.norm_ff(x)
710
+
711
+ if self.conv_module is not None:
712
+ x = self.norm_final(x)
713
+
714
+ return x, mask, new_att_cache, new_cnn_cache
715
+
716
+
717
+
718
+ class EspnetRelPositionalEncoding(torch.nn.Module):
719
+ """Relative positional encoding module (new implementation).
720
+
721
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
722
+
723
+ See : Appendix B in https://arxiv.org/abs/1901.02860
724
+
725
+ Args:
726
+ d_model (int): Embedding dimension.
727
+ dropout_rate (float): Dropout rate.
728
+ max_len (int): Maximum input length.
729
+
730
+ """
731
+
732
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
733
+ """Construct an PositionalEncoding object."""
734
+ super(EspnetRelPositionalEncoding, self).__init__()
735
+ self.d_model = d_model
736
+ self.xscale = math.sqrt(self.d_model)
737
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
738
+ self.pe = None
739
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
740
+
741
+ def extend_pe(self, x: torch.Tensor):
742
+ """Reset the positional encodings."""
743
+ if self.pe is not None:
744
+ # self.pe contains both positive and negative parts
745
+ # the length of self.pe is 2 * input_len - 1
746
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
747
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
748
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
749
+ return
750
+ # Suppose `i` means to the position of query vecotr and `j` means the
751
+ # position of key vector. We use position relative positions when keys
752
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
753
+ pe_positive = torch.zeros(x.size(1), self.d_model)
754
+ pe_negative = torch.zeros(x.size(1), self.d_model)
755
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
756
+ div_term = torch.exp(
757
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
758
+ * -(math.log(10000.0) / self.d_model)
759
+ )
760
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
761
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
762
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
763
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
764
+
765
+ # Reserve the order of positive indices and concat both positive and
766
+ # negative indices. This is used to support the shifting trick
767
+ # as in https://arxiv.org/abs/1901.02860
768
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
769
+ pe_negative = pe_negative[1:].unsqueeze(0)
770
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
771
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
772
+
773
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
774
+ -> Tuple[torch.Tensor, torch.Tensor]:
775
+ """Add positional encoding.
776
+
777
+ Args:
778
+ x (torch.Tensor): Input tensor (batch, time, `*`).
779
+
780
+ Returns:
781
+ torch.Tensor: Encoded tensor (batch, time, `*`).
782
+
783
+ """
784
+ self.extend_pe(x)
785
+ x = x * self.xscale
786
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
787
+ return self.dropout(x), self.dropout(pos_emb)
788
+
789
+ def position_encoding(self,
790
+ offset: Union[int, torch.Tensor],
791
+ size: int) -> torch.Tensor:
792
+ """ For getting encoding in a streaming fashion
793
+
794
+ Attention!!!!!
795
+ we apply dropout only once at the whole utterance level in a none
796
+ streaming way, but will call this function several times with
797
+ increasing input size in a streaming scenario, so the dropout will
798
+ be applied several times.
799
+
800
+ Args:
801
+ offset (int or torch.tensor): start offset
802
+ size (int): required size of position encoding
803
+
804
+ Returns:
805
+ torch.Tensor: Corresponding encoding
806
+ """
807
+ pos_emb = self.pe[
808
+ :,
809
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
810
+ ]
811
+ return pos_emb
812
+
813
+
814
+
815
+ class LinearEmbed(torch.nn.Module):
816
+ """Linear transform the input without subsampling
817
+
818
+ Args:
819
+ idim (int): Input dimension.
820
+ odim (int): Output dimension.
821
+ dropout_rate (float): Dropout rate.
822
+
823
+ """
824
+
825
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
826
+ pos_enc_class: torch.nn.Module):
827
+ """Construct an linear object."""
828
+ super().__init__()
829
+ self.out = torch.nn.Sequential(
830
+ torch.nn.Linear(idim, odim),
831
+ torch.nn.LayerNorm(odim, eps=1e-5),
832
+ torch.nn.Dropout(dropout_rate),
833
+ )
834
+ self.pos_enc = pos_enc_class #rel_pos_espnet
835
+
836
+ def position_encoding(self, offset: Union[int, torch.Tensor],
837
+ size: int) -> torch.Tensor:
838
+ return self.pos_enc.position_encoding(offset, size)
839
+
840
+ def forward(
841
+ self,
842
+ x: torch.Tensor,
843
+ offset: Union[int, torch.Tensor] = 0
844
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
845
+ """Input x.
846
+
847
+ Args:
848
+ x (torch.Tensor): Input tensor (#batch, time, idim).
849
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
850
+
851
+ Returns:
852
+ torch.Tensor: linear input tensor (#batch, time', odim),
853
+ where time' = time .
854
+ torch.Tensor: linear input mask (#batch, 1, time'),
855
+ where time' = time .
856
+
857
+ """
858
+ x = self.out(x)
859
+ x, pos_emb = self.pos_enc(x, offset)
860
+ return x, pos_emb
861
+
862
+
863
+ ATTENTION_CLASSES = {
864
+ "selfattn": MultiHeadedAttention,
865
+ "rel_selfattn": RelPositionMultiHeadedAttention,
866
+ }
867
+
868
+ ACTIVATION_CLASSES = {
869
+ "hardtanh": torch.nn.Hardtanh,
870
+ "tanh": torch.nn.Tanh,
871
+ "relu": torch.nn.ReLU,
872
+ "selu": torch.nn.SELU,
873
+ "swish": getattr(torch.nn, "SiLU", Swish),
874
+ "gelu": torch.nn.GELU,
875
+ }
876
+
877
+
878
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
879
+ """Make mask tensor containing indices of padded part.
880
+
881
+ See description of make_non_pad_mask.
882
+
883
+ Args:
884
+ lengths (torch.Tensor): Batch of lengths (B,).
885
+ Returns:
886
+ torch.Tensor: Mask tensor containing indices of padded part.
887
+
888
+ Examples:
889
+ >>> lengths = [5, 3, 2]
890
+ >>> make_pad_mask(lengths)
891
+ masks = [[0, 0, 0, 0 ,0],
892
+ [0, 0, 0, 1, 1],
893
+ [0, 0, 1, 1, 1]]
894
+ """
895
+ batch_size = lengths.size(0)
896
+ max_len = max_len if max_len > 0 else lengths.max().item()
897
+ seq_range = torch.arange(0,
898
+ max_len,
899
+ dtype=torch.int64,
900
+ device=lengths.device)
901
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
902
+ seq_length_expand = lengths.unsqueeze(-1)
903
+ mask = seq_range_expand >= seq_length_expand
904
+ return mask
905
+
906
+ #https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
907
+ class ConformerEncoder(torch.nn.Module):
908
+ """Conformer encoder module."""
909
+
910
+ def __init__(
911
+ self,
912
+ input_size: int,
913
+ output_size: int = 1024,
914
+ attention_heads: int = 16,
915
+ linear_units: int = 4096,
916
+ num_blocks: int = 6,
917
+ dropout_rate: float = 0.1,
918
+ positional_dropout_rate: float = 0.1,
919
+ attention_dropout_rate: float = 0.0,
920
+ input_layer: str = 'linear',
921
+ pos_enc_layer_type: str = 'rel_pos_espnet',
922
+ normalize_before: bool = True,
923
+ static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
924
+ use_dynamic_chunk: bool = False,
925
+ use_dynamic_left_chunk: bool = False,
926
+ positionwise_conv_kernel_size: int = 1,
927
+ macaron_style: bool =False,
928
+ selfattention_layer_type: str = "rel_selfattn",
929
+ activation_type: str = "swish",
930
+ use_cnn_module: bool = False,
931
+ cnn_module_kernel: int = 15,
932
+ causal: bool = False,
933
+ cnn_module_norm: str = "batch_norm",
934
+ key_bias: bool = True,
935
+ gradient_checkpointing: bool = False,
936
+ ):
937
+ """Construct ConformerEncoder
938
+
939
+ Args:
940
+ input_size to use_dynamic_chunk, see in BaseEncoder
941
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
942
+ conv1d layer.
943
+ macaron_style (bool): Whether to use macaron style for
944
+ positionwise layer.
945
+ selfattention_layer_type (str): Encoder attention layer type,
946
+ the parameter has no effect now, it's just for configure
947
+ compatibility. #'rel_selfattn'
948
+ activation_type (str): Encoder activation function type.
949
+ use_cnn_module (bool): Whether to use convolution module.
950
+ cnn_module_kernel (int): Kernel size of convolution module.
951
+ causal (bool): whether to use causal convolution or not.
952
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
953
+ """
954
+ super().__init__()
955
+ self.output_size = output_size
956
+ self.embed = LinearEmbed(input_size, output_size, dropout_rate,
957
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate))
958
+ self.normalize_before = normalize_before
959
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
960
+ self.gradient_checkpointing = gradient_checkpointing
961
+ self.use_dynamic_chunk = use_dynamic_chunk
962
+
963
+ self.static_chunk_size = static_chunk_size
964
+ self.use_dynamic_chunk = use_dynamic_chunk
965
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
966
+ activation = ACTIVATION_CLASSES[activation_type]()
967
+
968
+ # self-attention module definition
969
+ encoder_selfattn_layer_args = (
970
+ attention_heads,
971
+ output_size,
972
+ attention_dropout_rate,
973
+ key_bias,
974
+ )
975
+ # feed-forward module definition
976
+ positionwise_layer_args = (
977
+ output_size,
978
+ linear_units,
979
+ dropout_rate,
980
+ activation,
981
+ )
982
+ # convolution module definition
983
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
984
+ cnn_module_norm, causal)
985
+
986
+ self.encoders = torch.nn.ModuleList([
987
+ ConformerEncoderLayer(
988
+ output_size,
989
+ RelPositionMultiHeadedAttention(
990
+ *encoder_selfattn_layer_args),
991
+ PositionwiseFeedForward(*positionwise_layer_args),
992
+ PositionwiseFeedForward(
993
+ *positionwise_layer_args) if macaron_style else None,
994
+ ConvolutionModule(
995
+ *convolution_layer_args) if use_cnn_module else None,
996
+ dropout_rate,
997
+ normalize_before,
998
+ ) for _ in range(num_blocks)
999
+ ])
1000
+
1001
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
1002
+ pos_emb: torch.Tensor,
1003
+ mask_pad: torch.Tensor) -> torch.Tensor:
1004
+ for layer in self.encoders:
1005
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
1006
+ return xs
1007
+
1008
+ @torch.jit.unused
1009
+ def forward_layers_checkpointed(self, xs: torch.Tensor,
1010
+ chunk_masks: torch.Tensor,
1011
+ pos_emb: torch.Tensor,
1012
+ mask_pad: torch.Tensor) -> torch.Tensor:
1013
+ for layer in self.encoders:
1014
+ xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
1015
+ chunk_masks, pos_emb,
1016
+ mask_pad)
1017
+ return xs
1018
+
1019
+ def forward(
1020
+ self,
1021
+ xs: torch.Tensor,
1022
+ pad_mask: torch.Tensor,
1023
+ decoding_chunk_size: int = 0,
1024
+ num_decoding_left_chunks: int = -1,
1025
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1026
+ """Embed positions in tensor.
1027
+
1028
+ Args:
1029
+ xs: padded input tensor (B, T, D)
1030
+ xs_lens: input length (B)
1031
+ decoding_chunk_size: decoding chunk size for dynamic chunk
1032
+ 0: default for training, use random dynamic chunk.
1033
+ <0: for decoding, use full chunk.
1034
+ >0: for decoding, use fixed chunk size as set.
1035
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
1036
+ the chunk size is decoding_chunk_size.
1037
+ >=0: use num_decoding_left_chunks
1038
+ <0: use all left chunks
1039
+ Returns:
1040
+ encoder output tensor xs, and subsampled masks
1041
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
1042
+ masks: torch.Tensor batch padding mask after subsample
1043
+ (B, 1, T' ~= T/subsample_rate)
1044
+ NOTE(xcsong):
1045
+ We pass the `__call__` method of the modules instead of `forward` to the
1046
+ checkpointing API because `__call__` attaches all the hooks of the module.
1047
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
1048
+ """
1049
+ T = xs.size(1)
1050
+ masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
1051
+ xs, pos_emb = self.embed(xs)
1052
+ mask_pad = masks # (B, 1, T/subsample_rate)
1053
+ chunk_masks = add_optional_chunk_mask(xs, masks,
1054
+ self.use_dynamic_chunk,
1055
+ self.use_dynamic_left_chunk,
1056
+ decoding_chunk_size,
1057
+ self.static_chunk_size,
1058
+ num_decoding_left_chunks)
1059
+ if self.gradient_checkpointing and self.training:
1060
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
1061
+ mask_pad)
1062
+ else:
1063
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
1064
+ if self.normalize_before:
1065
+ xs = self.after_norm(xs)
1066
+ # Here we assume the mask is not changed in encoder layers, so just
1067
+ # return the masks before encoder layers, and the masks will be used
1068
+ # for cross attention with decoder later
1069
+ return xs, masks
1070
+
models/lyrics_utils/lyric_normalizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from opencc import OpenCC
3
+
4
+
5
+ t2s_converter = OpenCC('t2s')
6
+ s2t_converter = OpenCC('s2t')
7
+
8
+
9
+ EMOJI_PATTERN = re.compile(
10
+ "["
11
+ "\U0001F600-\U0001F64F" # Emoticons
12
+ "]+", flags=re.UNICODE
13
+ )
14
+
15
+ # 创建一个翻译表,用于替换和移除字符
16
+ TRANSLATION_TABLE = str.maketrans({
17
+ '-': ' ', # 将 '-' 替换为空格
18
+ ',': None,
19
+ '.': None,
20
+ ',': None,
21
+ '。': None,
22
+ '!': None,
23
+ '!': None,
24
+ '?': None,
25
+ '?': None,
26
+ '…': None,
27
+ ';': None,
28
+ ';': None,
29
+ ':': None,
30
+ ':': None,
31
+ '\u3000': ' ', # 将全角空格替换为空格
32
+ })
33
+
34
+ # 替换括号中的内容,包括中括号和小括号
35
+ BACKSLASH_PATTERN = re.compile(r'\(.*?\)|\[.*?\]')
36
+
37
+ SPACE_PATTERN = re.compile('(?<!^)\s+(?!$)')
38
+
39
+
40
+ def normalize_text(text, language, strip=True):
41
+ """
42
+ 对文本进行标准化处理,去除标点符号,转为小写(如果适用)
43
+ """
44
+ # Step 1: 替换 '-' 为 ' ' 并移除标点符号
45
+ text = text.translate(TRANSLATION_TABLE)
46
+
47
+ # Step 2: 移除表情符号
48
+ text = EMOJI_PATTERN.sub('', text)
49
+
50
+ # Step 3: 连续空白字符替换为单个空格,首位除外
51
+ text = SPACE_PATTERN.sub(' ', text)
52
+
53
+ # Step 4: 去除首尾空白字符(如果需要)
54
+ if strip:
55
+ text = text.strip()
56
+
57
+ # Step 5: 转为小写
58
+ text = text.lower()
59
+
60
+ # Step 6: 多语言转换
61
+ if language == "zh":
62
+ text = t2s_converter.convert(text)
63
+ if language == "yue":
64
+ text = s2t_converter.convert(text)
65
+ # 其他语言根据需要添加
66
+ return text
models/lyrics_utils/lyric_tokenizer.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import textwrap
4
+ from functools import cached_property
5
+
6
+ import pypinyin
7
+ import torch
8
+ from hangul_romanize import Transliter
9
+ from hangul_romanize.rule import academic
10
+ from num2words import num2words
11
+ from spacy.lang.ar import Arabic
12
+ from spacy.lang.en import English
13
+ from spacy.lang.es import Spanish
14
+ from spacy.lang.ja import Japanese
15
+ from spacy.lang.zh import Chinese
16
+ from tokenizers import Tokenizer
17
+
18
+ from .zh_num2words import TextNorm as zh_num2words
19
+ from typing import Dict, List, Optional, Set, Union
20
+
21
+
22
+ #copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/tokenizer.py
23
+ def get_spacy_lang(lang):
24
+ if lang == "zh":
25
+ return Chinese()
26
+ elif lang == "ja":
27
+ return Japanese()
28
+ elif lang == "ar":
29
+ return Arabic()
30
+ elif lang == "es":
31
+ return Spanish()
32
+ else:
33
+ # For most languages, Enlish does the job
34
+ return English()
35
+
36
+
37
+ def split_sentence(text, lang, text_split_length=250):
38
+ """Preprocess the input text"""
39
+ text_splits = []
40
+ if text_split_length is not None and len(text) >= text_split_length:
41
+ text_splits.append("")
42
+ nlp = get_spacy_lang(lang)
43
+ nlp.add_pipe("sentencizer")
44
+ doc = nlp(text)
45
+ for sentence in doc.sents:
46
+ if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
47
+ # if the last sentence + the current sentence is less than the text_split_length
48
+ # then add the current sentence to the last sentence
49
+ text_splits[-1] += " " + str(sentence)
50
+ text_splits[-1] = text_splits[-1].lstrip()
51
+ elif len(str(sentence)) > text_split_length:
52
+ # if the current sentence is greater than the text_split_length
53
+ for line in textwrap.wrap(
54
+ str(sentence),
55
+ width=text_split_length,
56
+ drop_whitespace=True,
57
+ break_on_hyphens=False,
58
+ tabsize=1,
59
+ ):
60
+ text_splits.append(str(line))
61
+ else:
62
+ text_splits.append(str(sentence))
63
+
64
+ if len(text_splits) > 1:
65
+ if text_splits[0] == "":
66
+ del text_splits[0]
67
+ else:
68
+ text_splits = [text.lstrip()]
69
+
70
+ return text_splits
71
+
72
+
73
+ _whitespace_re = re.compile(r"\s+")
74
+
75
+ # List of (regular expression, replacement) pairs for abbreviations:
76
+ _abbreviations = {
77
+ "en": [
78
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
79
+ for x in [
80
+ ("mrs", "misess"),
81
+ ("mr", "mister"),
82
+ ("dr", "doctor"),
83
+ ("st", "saint"),
84
+ ("co", "company"),
85
+ ("jr", "junior"),
86
+ ("maj", "major"),
87
+ ("gen", "general"),
88
+ ("drs", "doctors"),
89
+ ("rev", "reverend"),
90
+ ("lt", "lieutenant"),
91
+ ("hon", "honorable"),
92
+ ("sgt", "sergeant"),
93
+ ("capt", "captain"),
94
+ ("esq", "esquire"),
95
+ ("ltd", "limited"),
96
+ ("col", "colonel"),
97
+ ("ft", "fort"),
98
+ ]
99
+ ],
100
+ "es": [
101
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
102
+ for x in [
103
+ ("sra", "señora"),
104
+ ("sr", "señor"),
105
+ ("dr", "doctor"),
106
+ ("dra", "doctora"),
107
+ ("st", "santo"),
108
+ ("co", "compañía"),
109
+ ("jr", "junior"),
110
+ ("ltd", "limitada"),
111
+ ]
112
+ ],
113
+ "fr": [
114
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
115
+ for x in [
116
+ ("mme", "madame"),
117
+ ("mr", "monsieur"),
118
+ ("dr", "docteur"),
119
+ ("st", "saint"),
120
+ ("co", "compagnie"),
121
+ ("jr", "junior"),
122
+ ("ltd", "limitée"),
123
+ ]
124
+ ],
125
+ "de": [
126
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
127
+ for x in [
128
+ ("fr", "frau"),
129
+ ("dr", "doktor"),
130
+ ("st", "sankt"),
131
+ ("co", "firma"),
132
+ ("jr", "junior"),
133
+ ]
134
+ ],
135
+ "pt": [
136
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
137
+ for x in [
138
+ ("sra", "senhora"),
139
+ ("sr", "senhor"),
140
+ ("dr", "doutor"),
141
+ ("dra", "doutora"),
142
+ ("st", "santo"),
143
+ ("co", "companhia"),
144
+ ("jr", "júnior"),
145
+ ("ltd", "limitada"),
146
+ ]
147
+ ],
148
+ "it": [
149
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
150
+ for x in [
151
+ # ("sig.ra", "signora"),
152
+ ("sig", "signore"),
153
+ ("dr", "dottore"),
154
+ ("st", "santo"),
155
+ ("co", "compagnia"),
156
+ ("jr", "junior"),
157
+ ("ltd", "limitata"),
158
+ ]
159
+ ],
160
+ "pl": [
161
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
162
+ for x in [
163
+ ("p", "pani"),
164
+ ("m", "pan"),
165
+ ("dr", "doktor"),
166
+ ("sw", "święty"),
167
+ ("jr", "junior"),
168
+ ]
169
+ ],
170
+ "ar": [
171
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
172
+ for x in [
173
+ # There are not many common abbreviations in Arabic as in English.
174
+ ]
175
+ ],
176
+ "zh": [
177
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
178
+ for x in [
179
+ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
180
+ ]
181
+ ],
182
+ "cs": [
183
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
184
+ for x in [
185
+ ("dr", "doktor"), # doctor
186
+ ("ing", "inženýr"), # engineer
187
+ ("p", "pan"), # Could also map to pani for woman but no easy way to do it
188
+ # Other abbreviations would be specialized and not as common.
189
+ ]
190
+ ],
191
+ "ru": [
192
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
193
+ for x in [
194
+ ("г-жа", "госпожа"), # Mrs.
195
+ ("г-н", "господин"), # Mr.
196
+ ("д-р", "доктор"), # doctor
197
+ # Other abbreviations are less common or specialized.
198
+ ]
199
+ ],
200
+ "nl": [
201
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
202
+ for x in [
203
+ ("dhr", "de heer"), # Mr.
204
+ ("mevr", "mevrouw"), # Mrs.
205
+ ("dr", "dokter"), # doctor
206
+ ("jhr", "jonkheer"), # young lord or nobleman
207
+ # Dutch uses more abbreviations, but these are the most common ones.
208
+ ]
209
+ ],
210
+ "tr": [
211
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
212
+ for x in [
213
+ ("b", "bay"), # Mr.
214
+ ("byk", "büyük"), # büyük
215
+ ("dr", "doktor"), # doctor
216
+ # Add other Turkish abbreviations here if needed.
217
+ ]
218
+ ],
219
+ "hu": [
220
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
221
+ for x in [
222
+ ("dr", "doktor"), # doctor
223
+ ("b", "bácsi"), # Mr.
224
+ ("nőv", "nővér"), # nurse
225
+ # Add other Hungarian abbreviations here if needed.
226
+ ]
227
+ ],
228
+ "ko": [
229
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
230
+ for x in [
231
+ # Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
232
+ ]
233
+ ],
234
+ }
235
+
236
+
237
+ def expand_abbreviations_multilingual(text, lang="en"):
238
+ for regex, replacement in _abbreviations[lang]:
239
+ text = re.sub(regex, replacement, text)
240
+ return text
241
+
242
+
243
+ _symbols_multilingual = {
244
+ "en": [
245
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
246
+ for x in [
247
+ ("&", " and "),
248
+ ("@", " at "),
249
+ ("%", " percent "),
250
+ ("#", " hash "),
251
+ ("$", " dollar "),
252
+ ("£", " pound "),
253
+ ("°", " degree "),
254
+ ]
255
+ ],
256
+ "es": [
257
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
258
+ for x in [
259
+ ("&", " y "),
260
+ ("@", " arroba "),
261
+ ("%", " por ciento "),
262
+ ("#", " numeral "),
263
+ ("$", " dolar "),
264
+ ("£", " libra "),
265
+ ("°", " grados "),
266
+ ]
267
+ ],
268
+ "fr": [
269
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
270
+ for x in [
271
+ ("&", " et "),
272
+ ("@", " arobase "),
273
+ ("%", " pour cent "),
274
+ ("#", " dièse "),
275
+ ("$", " dollar "),
276
+ ("£", " livre "),
277
+ ("°", " degrés "),
278
+ ]
279
+ ],
280
+ "de": [
281
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
282
+ for x in [
283
+ ("&", " und "),
284
+ ("@", " at "),
285
+ ("%", " prozent "),
286
+ ("#", " raute "),
287
+ ("$", " dollar "),
288
+ ("£", " pfund "),
289
+ ("°", " grad "),
290
+ ]
291
+ ],
292
+ "pt": [
293
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
294
+ for x in [
295
+ ("&", " e "),
296
+ ("@", " arroba "),
297
+ ("%", " por cento "),
298
+ ("#", " cardinal "),
299
+ ("$", " dólar "),
300
+ ("£", " libra "),
301
+ ("°", " graus "),
302
+ ]
303
+ ],
304
+ "it": [
305
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
306
+ for x in [
307
+ ("&", " e "),
308
+ ("@", " chiocciola "),
309
+ ("%", " per cento "),
310
+ ("#", " cancelletto "),
311
+ ("$", " dollaro "),
312
+ ("£", " sterlina "),
313
+ ("°", " gradi "),
314
+ ]
315
+ ],
316
+ "pl": [
317
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
318
+ for x in [
319
+ ("&", " i "),
320
+ ("@", " małpa "),
321
+ ("%", " procent "),
322
+ ("#", " krzyżyk "),
323
+ ("$", " dolar "),
324
+ ("£", " funt "),
325
+ ("°", " stopnie "),
326
+ ]
327
+ ],
328
+ "ar": [
329
+ # Arabic
330
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
331
+ for x in [
332
+ ("&", " و "),
333
+ ("@", " على "),
334
+ ("%", " في المئة "),
335
+ ("#", " رقم "),
336
+ ("$", " دولار "),
337
+ ("£", " جنيه "),
338
+ ("°", " درجة "),
339
+ ]
340
+ ],
341
+ "zh": [
342
+ # Chinese
343
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
344
+ for x in [
345
+ ("&", " 和 "),
346
+ ("@", " 在 "),
347
+ ("%", " 百分之 "),
348
+ ("#", " 号 "),
349
+ ("$", " 美元 "),
350
+ ("£", " 英镑 "),
351
+ ("°", " 度 "),
352
+ ]
353
+ ],
354
+ "cs": [
355
+ # Czech
356
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
357
+ for x in [
358
+ ("&", " a "),
359
+ ("@", " na "),
360
+ ("%", " procento "),
361
+ ("#", " křížek "),
362
+ ("$", " dolar "),
363
+ ("£", " libra "),
364
+ ("°", " stupně "),
365
+ ]
366
+ ],
367
+ "ru": [
368
+ # Russian
369
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
370
+ for x in [
371
+ ("&", " и "),
372
+ ("@", " собака "),
373
+ ("%", " процентов "),
374
+ ("#", " номер "),
375
+ ("$", " доллар "),
376
+ ("£", " фунт "),
377
+ ("°", " градус "),
378
+ ]
379
+ ],
380
+ "nl": [
381
+ # Dutch
382
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
383
+ for x in [
384
+ ("&", " en "),
385
+ ("@", " bij "),
386
+ ("%", " procent "),
387
+ ("#", " hekje "),
388
+ ("$", " dollar "),
389
+ ("£", " pond "),
390
+ ("°", " graden "),
391
+ ]
392
+ ],
393
+ "tr": [
394
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
395
+ for x in [
396
+ ("&", " ve "),
397
+ ("@", " at "),
398
+ ("%", " yüzde "),
399
+ ("#", " diyez "),
400
+ ("$", " dolar "),
401
+ ("£", " sterlin "),
402
+ ("°", " derece "),
403
+ ]
404
+ ],
405
+ "hu": [
406
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
407
+ for x in [
408
+ ("&", " és "),
409
+ ("@", " kukac "),
410
+ ("%", " százalék "),
411
+ ("#", " kettőskereszt "),
412
+ ("$", " dollár "),
413
+ ("£", " font "),
414
+ ("°", " fok "),
415
+ ]
416
+ ],
417
+ "ko": [
418
+ # Korean
419
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
420
+ for x in [
421
+ ("&", " 그리고 "),
422
+ ("@", " 에 "),
423
+ ("%", " 퍼센트 "),
424
+ ("#", " 번호 "),
425
+ ("$", " 달러 "),
426
+ ("£", " 파운드 "),
427
+ ("°", " 도 "),
428
+ ]
429
+ ],
430
+ }
431
+
432
+
433
+ def expand_symbols_multilingual(text, lang="en"):
434
+ for regex, replacement in _symbols_multilingual[lang]:
435
+ text = re.sub(regex, replacement, text)
436
+ text = text.replace(" ", " ") # Ensure there are no double spaces
437
+ return text.strip()
438
+
439
+
440
+ _ordinal_re = {
441
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
442
+ "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
443
+ "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
444
+ "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
445
+ "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
446
+ "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
447
+ "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
448
+ "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
449
+ "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
450
+ "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
451
+ "nl": re.compile(r"([0-9]+)(de|ste|e)"),
452
+ "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
453
+ "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
454
+ "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
455
+ }
456
+ _number_re = re.compile(r"[0-9]+")
457
+ _currency_re = {
458
+ "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
459
+ "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
460
+ "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
461
+ }
462
+
463
+ _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
464
+ _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
465
+ _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
466
+
467
+
468
+ def _remove_commas(m):
469
+ text = m.group(0)
470
+ if "," in text:
471
+ text = text.replace(",", "")
472
+ return text
473
+
474
+
475
+ def _remove_dots(m):
476
+ text = m.group(0)
477
+ if "." in text:
478
+ text = text.replace(".", "")
479
+ return text
480
+
481
+
482
+ def _expand_decimal_point(m, lang="en"):
483
+ amount = m.group(1).replace(",", ".")
484
+ return num2words(float(amount), lang=lang if lang != "cs" else "cz")
485
+
486
+
487
+ def _expand_currency(m, lang="en", currency="USD"):
488
+ amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
489
+ full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
490
+
491
+ and_equivalents = {
492
+ "en": ", ",
493
+ "es": " con ",
494
+ "fr": " et ",
495
+ "de": " und ",
496
+ "pt": " e ",
497
+ "it": " e ",
498
+ "pl": ", ",
499
+ "cs": ", ",
500
+ "ru": ", ",
501
+ "nl": ", ",
502
+ "ar": ", ",
503
+ "tr": ", ",
504
+ "hu": ", ",
505
+ "ko": ", ",
506
+ }
507
+
508
+ if amount.is_integer():
509
+ last_and = full_amount.rfind(and_equivalents[lang])
510
+ if last_and != -1:
511
+ full_amount = full_amount[:last_and]
512
+
513
+ return full_amount
514
+
515
+
516
+ def _expand_ordinal(m, lang="en"):
517
+ return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
518
+
519
+
520
+ def _expand_number(m, lang="en"):
521
+ return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
522
+
523
+
524
+ def expand_numbers_multilingual(text, lang="en"):
525
+ if lang == "zh":
526
+ text = zh_num2words()(text)
527
+ else:
528
+ if lang in ["en", "ru"]:
529
+ text = re.sub(_comma_number_re, _remove_commas, text)
530
+ else:
531
+ text = re.sub(_dot_number_re, _remove_dots, text)
532
+ try:
533
+ text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
534
+ text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
535
+ text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
536
+ except:
537
+ pass
538
+ if lang != "tr":
539
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
540
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
541
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
542
+ return text
543
+
544
+
545
+ def lowercase(text):
546
+ return text.lower()
547
+
548
+
549
+ def collapse_whitespace(text):
550
+ return re.sub(_whitespace_re, " ", text)
551
+
552
+
553
+ def multilingual_cleaners(text, lang):
554
+ text = text.replace('"', "")
555
+ if lang == "tr":
556
+ text = text.replace("İ", "i")
557
+ text = text.replace("Ö", "ö")
558
+ text = text.replace("Ü", "ü")
559
+ text = lowercase(text)
560
+ try:
561
+ text = expand_numbers_multilingual(text, lang)
562
+ except:
563
+ pass
564
+ try:
565
+ text = expand_abbreviations_multilingual(text, lang)
566
+ except:
567
+ pass
568
+ try:
569
+ text = expand_symbols_multilingual(text, lang=lang)
570
+ except:
571
+ pass
572
+ text = collapse_whitespace(text)
573
+ return text
574
+
575
+
576
+ def basic_cleaners(text):
577
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
578
+ text = lowercase(text)
579
+ text = collapse_whitespace(text)
580
+ return text
581
+
582
+
583
+ def chinese_transliterate(text):
584
+ return "".join(
585
+ [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
586
+ )
587
+
588
+
589
+ def japanese_cleaners(text, katsu):
590
+ text = katsu.romaji(text)
591
+ text = lowercase(text)
592
+ return text
593
+
594
+
595
+ def korean_transliterate(text):
596
+ r = Transliter(academic)
597
+ return r.translit(text)
598
+
599
+
600
+ DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vocab.json")
601
+
602
+
603
+ class VoiceBpeTokenizer:
604
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
605
+ self.tokenizer = None
606
+ if vocab_file is not None:
607
+ self.tokenizer = Tokenizer.from_file(vocab_file)
608
+ self.char_limits = {
609
+ "en": 10000,
610
+ "de": 253,
611
+ "fr": 273,
612
+ "es": 239,
613
+ "it": 213,
614
+ "pt": 203,
615
+ "pl": 224,
616
+ "zh": 82,
617
+ "ar": 166,
618
+ "cs": 186,
619
+ "ru": 182,
620
+ "nl": 251,
621
+ "tr": 226,
622
+ "ja": 71,
623
+ "hu": 224,
624
+ "ko": 95,
625
+ }
626
+
627
+ @cached_property
628
+ def katsu(self):
629
+ import cutlet
630
+
631
+ return cutlet.Cutlet()
632
+
633
+ def check_input_length(self, txt, lang):
634
+ lang = lang.split("-")[0] # remove the region
635
+ limit = self.char_limits.get(lang, 250)
636
+ # if len(txt) > limit:
637
+ # print(
638
+ # f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
639
+ # )
640
+
641
+ def preprocess_text(self, txt, lang):
642
+ if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
643
+ txt = multilingual_cleaners(txt, lang)
644
+ if lang == "zh":
645
+ txt = chinese_transliterate(txt)
646
+ if lang == "ko":
647
+ txt = korean_transliterate(txt)
648
+ elif lang == "ja":
649
+ txt = japanese_cleaners(txt, self.katsu)
650
+ elif lang == "hi":
651
+ # @manmay will implement this
652
+ txt = basic_cleaners(txt)
653
+ else:
654
+ raise NotImplementedError(f"Language '{lang}' is not supported.")
655
+ return txt
656
+
657
+ def encode(self, txt, lang):
658
+ lang = lang.split("-")[0] # remove the region
659
+ self.check_input_length(txt, lang)
660
+ txt = self.preprocess_text(txt, lang)
661
+ lang = "zh-cn" if lang == "zh" else lang
662
+ txt = f"[{lang}]{txt}"
663
+ txt = txt.replace(" ", "[SPACE]")
664
+ return self.tokenizer.encode(txt).ids
665
+
666
+ def decode(self, seq, skip_special_tokens=False):
667
+ if isinstance(seq, torch.Tensor):
668
+ seq = seq.cpu().numpy()
669
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
670
+ txt = txt.replace("[SPACE]", " ")
671
+ txt = txt.replace("[STOP]", "")
672
+ # txt = txt.replace("[UNK]", "")
673
+ return txt
674
+
675
+
676
+ #copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936
677
+ def batch_decode(
678
+ self,
679
+ sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
680
+ skip_special_tokens: bool = False,
681
+ ) -> List[str]:
682
+ """
683
+ Convert a list of lists of token ids into a list of strings by calling decode.
684
+
685
+ Args:
686
+ sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
687
+ List of tokenized input ids. Can be obtained using the `__call__` method.
688
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
689
+ Whether or not to remove special tokens in the decoding.
690
+ kwargs (additional keyword arguments, *optional*):
691
+ Will be passed to the underlying model specific decode method.
692
+
693
+ Returns:
694
+ `List[str]`: The list of decoded sentences.
695
+ """
696
+ return [
697
+ self.decode(seq)
698
+ for seq in sequences
699
+ ]
700
+
701
+ #https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/trainer/dataset.py#L202
702
+ # def pad(self):
703
+
704
+ def __len__(self):
705
+ return self.tokenizer.get_vocab_size()
706
+
707
+ def get_number_tokens(self):
708
+ return max(self.tokenizer.get_vocab().values()) + 1
709
+
710
+
711
+ def test_expand_numbers_multilingual():
712
+ test_cases = [
713
+ # English
714
+ ("In 12.5 seconds.", "In twelve point five seconds.", "en"),
715
+ ("There were 50 soldiers.", "There were fifty soldiers.", "en"),
716
+ ("This is a 1st test", "This is a first test", "en"),
717
+ ("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
718
+ ("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
719
+ ("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
720
+ ("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
721
+ # French
722
+ ("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
723
+ ("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
724
+ ("Ceci est un 1er test", "Ceci est un premier test", "fr"),
725
+ ("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
726
+ ("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
727
+ ("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
728
+ ("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
729
+ # German
730
+ ("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
731
+ ("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
732
+ ("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
733
+ ("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
734
+ ("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
735
+ ("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
736
+ # Spanish
737
+ ("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
738
+ ("Había 50 soldados.", "Había cincuenta soldados.", "es"),
739
+ ("Este es un 1er test", "Este es un primero test", "es"),
740
+ ("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
741
+ ("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
742
+ ("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
743
+ # Italian
744
+ ("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
745
+ ("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
746
+ ("Questo è un 1° test", "Questo è un primo test", "it"),
747
+ ("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
748
+ ("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
749
+ ("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
750
+ # Portuguese
751
+ ("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
752
+ ("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
753
+ ("Este é um 1º teste", "Este é um primeiro teste", "pt"),
754
+ ("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
755
+ ("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
756
+ (
757
+ "Isso custará 20,15€ senhor.",
758
+ "Isso custará vinte euros e quinze cêntimos senhor.",
759
+ "pt",
760
+ ), # "cêntimos" should be "centavos" num2words issue
761
+ # Polish
762
+ ("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
763
+ ("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
764
+ ("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
765
+ ("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
766
+ # Arabic
767
+ ("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
768
+ ("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
769
+ # ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
770
+ # ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
771
+ # Czech
772
+ ("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
773
+ ("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
774
+ ("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
775
+ ("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
776
+ # Russian
777
+ ("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
778
+ ("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
779
+ ("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
780
+ ("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
781
+ # Dutch
782
+ ("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
783
+ ("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
784
+ ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
785
+ ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
786
+ # Chinese (Simplified)
787
+ ("在12.5秒内", "在十二点五秒内", "zh"),
788
+ ("有50名士兵", "有五十名士兵", "zh"),
789
+ # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
790
+ # ("那将是20€先生", '那将是二十欧元先生', 'zh'),
791
+ # Turkish
792
+ # ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
793
+ ("50 asker vardı.", "elli asker vardı.", "tr"),
794
+ ("Bu 1. test", "Bu birinci test", "tr"),
795
+ # ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
796
+ # Hungarian
797
+ ("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
798
+ ("50 katona volt.", "ötven katona volt.", "hu"),
799
+ ("Ez az 1. teszt", "Ez az első teszt", "hu"),
800
+ # Korean
801
+ ("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
802
+ ("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
803
+ ("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
804
+ ]
805
+ for a, b, lang in test_cases:
806
+ out = expand_numbers_multilingual(a, lang=lang)
807
+ assert out == b, f"'{out}' vs '{b}'"
808
+
809
+
810
+ def test_abbreviations_multilingual():
811
+ test_cases = [
812
+ # English
813
+ ("Hello Mr. Smith.", "Hello mister Smith.", "en"),
814
+ ("Dr. Jones is here.", "doctor Jones is here.", "en"),
815
+ # Spanish
816
+ ("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
817
+ ("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
818
+ # French
819
+ ("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
820
+ ("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
821
+ # German
822
+ ("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
823
+ # Portuguese
824
+ ("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
825
+ ("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
826
+ # Italian
827
+ ("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
828
+ # ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
829
+ # Polish
830
+ ("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
831
+ ("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
832
+ # Czech
833
+ ("P. Novák", "pan Novák", "cs"),
834
+ ("Dr. Vojtěch", "doktor Vojtěch", "cs"),
835
+ # Dutch
836
+ ("Dhr. Jansen", "de heer Jansen", "nl"),
837
+ ("Mevr. de Vries", "mevrouw de Vries", "nl"),
838
+ # Russian
839
+ ("Здравствуйте Г-�� Иванов.", "Здравствуйте господин Иванов.", "ru"),
840
+ ("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
841
+ # Turkish
842
+ ("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
843
+ ("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
844
+ # Hungarian
845
+ ("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
846
+ ]
847
+
848
+ for a, b, lang in test_cases:
849
+ out = expand_abbreviations_multilingual(a, lang=lang)
850
+ assert out == b, f"'{out}' vs '{b}'"
851
+
852
+
853
+ def test_symbols_multilingual():
854
+ test_cases = [
855
+ ("I have 14% battery", "I have 14 percent battery", "en"),
856
+ ("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
857
+ ("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
858
+ ("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
859
+ ("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
860
+ ("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
861
+ ("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
862
+ ("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
863
+ ("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
864
+ ("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
865
+ ("Я буду @ дома", "Я буду собака дома", "ru"),
866
+ ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
867
+ ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
868
+ ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
869
+ ("我的电量为 14%", "我的电量为 14 百分之", "zh"),
870
+ ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
871
+ ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
872
+ ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
873
+ ]
874
+
875
+ for a, b, lang in test_cases:
876
+ out = expand_symbols_multilingual(a, lang=lang)
877
+ assert out == b, f"'{out}' vs '{b}'"
878
+
879
+
880
+ if __name__ == "__main__":
881
+ test_expand_numbers_multilingual()
882
+ test_abbreviations_multilingual()
883
+ test_symbols_multilingual()
models/lyrics_utils/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
models/lyrics_utils/zh_num2words.py ADDED
@@ -0,0 +1,1209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors:
2
+ # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
3
+ # 2019.9 - 2022 Jiayu DU
4
+ #copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/zh_num2words.py
5
+ import argparse
6
+ import csv
7
+ import os
8
+ import re
9
+ import string
10
+ import sys
11
+
12
+ # fmt: off
13
+
14
+ # ================================================================================ #
15
+ # basic constant
16
+ # ================================================================================ #
17
+ CHINESE_DIGIS = "零一二三四五六七八九"
18
+ BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
19
+ BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
20
+ SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
21
+ SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
22
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
23
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
24
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
25
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
26
+
27
+ ZERO_ALT = "〇"
28
+ ONE_ALT = "幺"
29
+ TWO_ALTS = ["两", "兩"]
30
+
31
+ POSITIVE = ["正", "正"]
32
+ NEGATIVE = ["负", "負"]
33
+ POINT = ["点", "點"]
34
+ # PLUS = [u'加', u'加']
35
+ # SIL = [u'杠', u'槓']
36
+
37
+ FILLER_CHARS = ["呃", "啊"]
38
+
39
+ ER_WHITELIST = (
40
+ "(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
41
+ "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
42
+ "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
43
+ "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
44
+ )
45
+ ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
46
+
47
+ # 中文数字系统类型
48
+ NUMBERING_TYPES = ["low", "mid", "high"]
49
+
50
+ CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
51
+ CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
52
+ COM_QUANTIFIERS = (
53
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
54
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
55
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
56
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
57
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
58
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
59
+ )
60
+
61
+
62
+ # Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
63
+ CN_PUNCS_STOP = "!?。。"
64
+ CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-"
65
+ CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
66
+
67
+ PUNCS = CN_PUNCS + string.punctuation
68
+ PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma
69
+
70
+
71
+ # https://zh.wikipedia.org/wiki/全行和半行
72
+ QJ2BJ = {
73
+ " ": " ",
74
+ "!": "!",
75
+ """: '"',
76
+ "#": "#",
77
+ "$": "$",
78
+ "%": "%",
79
+ "&": "&",
80
+ "'": "'",
81
+ "(": "(",
82
+ ")": ")",
83
+ "*": "*",
84
+ "+": "+",
85
+ ",": ",",
86
+ "-": "-",
87
+ ".": ".",
88
+ "/": "/",
89
+ "0": "0",
90
+ "1": "1",
91
+ "2": "2",
92
+ "3": "3",
93
+ "4": "4",
94
+ "5": "5",
95
+ "6": "6",
96
+ "7": "7",
97
+ "8": "8",
98
+ "9": "9",
99
+ ":": ":",
100
+ ";": ";",
101
+ "<": "<",
102
+ "=": "=",
103
+ ">": ">",
104
+ "?": "?",
105
+ "@": "@",
106
+ "A": "A",
107
+ "B": "B",
108
+ "C": "C",
109
+ "D": "D",
110
+ "E": "E",
111
+ "F": "F",
112
+ "G": "G",
113
+ "H": "H",
114
+ "I": "I",
115
+ "J": "J",
116
+ "K": "K",
117
+ "L": "L",
118
+ "M": "M",
119
+ "N": "N",
120
+ "O": "O",
121
+ "P": "P",
122
+ "Q": "Q",
123
+ "R": "R",
124
+ "S": "S",
125
+ "T": "T",
126
+ "U": "U",
127
+ "V": "V",
128
+ "W": "W",
129
+ "X": "X",
130
+ "Y": "Y",
131
+ "Z": "Z",
132
+ "[": "[",
133
+ "\": "\\",
134
+ "]": "]",
135
+ "^": "^",
136
+ "_": "_",
137
+ "`": "`",
138
+ "a": "a",
139
+ "b": "b",
140
+ "c": "c",
141
+ "d": "d",
142
+ "e": "e",
143
+ "f": "f",
144
+ "g": "g",
145
+ "h": "h",
146
+ "i": "i",
147
+ "j": "j",
148
+ "k": "k",
149
+ "l": "l",
150
+ "m": "m",
151
+ "n": "n",
152
+ "o": "o",
153
+ "p": "p",
154
+ "q": "q",
155
+ "r": "r",
156
+ "s": "s",
157
+ "t": "t",
158
+ "u": "u",
159
+ "v": "v",
160
+ "w": "w",
161
+ "x": "x",
162
+ "y": "y",
163
+ "z": "z",
164
+ "{": "{",
165
+ "|": "|",
166
+ "}": "}",
167
+ "~": "~",
168
+ }
169
+ QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "")
170
+
171
+
172
+ # 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources:
173
+ # https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
174
+ CN_CHARS_COMMON = (
175
+ "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举"
176
+ "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互"
177
+ "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从"
178
+ "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优"
179
+ "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚"
180
+ "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣"
181
+ "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯"
182
+ "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌"
183
+ "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚"
184
+ "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六"
185
+ "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况"
186
+ "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈"
187
+ "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐"
188
+ "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼"
189
+ "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹"
190
+ "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵"
191
+ "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔"
192
+ "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊"
193
+ "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆"
194
+ "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎"
195
+ "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌"
196
+ "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛"
197
+ "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴"
198
+ "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌"
199
+ "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡"
200
+ "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓"
201
+ "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢"
202
+ "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥"
203
+ "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩"
204
+ "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基"
205
+ "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填"
206
+ "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复"
207
+ "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖"
208
+ "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮"
209
+ "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈"
210
+ "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺���"
211
+ "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱"
212
+ "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽"
213
+ "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾"
214
+ "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝"
215
+ "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山"
216
+ "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃"
217
+ "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧"
218
+ "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉"
219
+ "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡"
220
+ "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐"
221
+ "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷"
222
+ "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖"
223
+ "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循"
224
+ "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀"
225
+ "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓"
226
+ "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟"
227
+ "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭"
228
+ "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥"
229
+ "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我"
230
+ "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔"
231
+ "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡"
232
+ "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥"
233
+ "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫"
234
+ "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎"
235
+ "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭"
236
+ "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘"
237
+ "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢"
238
+ "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整"
239
+ "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗"
240
+ "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝"
241
+ "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡"
242
+ "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜"
243
+ "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽"
244
+ "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅"
245
+ "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔"
246
+ "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩"
247
+ "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯"
248
+ "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘"
249
+ "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂"
250
+ "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰���"
251
+ "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞"
252
+ "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正"
253
+ "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒"
254
+ "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮"
255
+ "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽"
256
+ "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽"
257
+ "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼"
258
+ "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿"
259
+ "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸"
260
+ "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅"
261
+ "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟"
262
+ "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆"
263
+ "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕"
264
+ "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶"
265
+ "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧"
266
+ "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵"
267
+ "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈"
268
+ "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯"
269
+ "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵"
270
+ "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍"
271
+ "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷"
272
+ "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎"
273
+ "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎"
274
+ "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊"
275
+ "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊"
276
+ "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙"
277
+ "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱"
278
+ "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯"
279
+ "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐"
280
+ "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒"
281
+ "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩"
282
+ "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙"
283
+ "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾"
284
+ "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦"
285
+ "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知"
286
+ "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮"
287
+ "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍"
288
+ "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷"
289
+ "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭"
290
+ "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租���"
291
+ "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗"
292
+ "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立"
293
+ "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯"
294
+ "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅"
295
+ "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾"
296
+ "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮"
297
+ "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢"
298
+ "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁"
299
+ "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩"
300
+ "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕"
301
+ "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐"
302
+ "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲"
303
+ "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者"
304
+ "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋"
305
+ "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸"
306
+ "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳"
307
+ "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒"
308
+ "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻"
309
+ "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般"
310
+ "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊"
311
+ "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈"
312
+ "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆"
313
+ "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐"
314
+ "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛"
315
+ "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥"
316
+ "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著"
317
+ "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺"
318
+ "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷"
319
+ "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸"
320
+ "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱"
321
+ "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆"
322
+ "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗"
323
+ "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃"
324
+ "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡"
325
+ "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒"
326
+ "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂"
327
+ "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉"
328
+ "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认"
329
+ "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词"
330
+ "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵���"
331
+ "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡"
332
+ "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹"
333
+ "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼"
334
+ "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤"
335
+ "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑"
336
+ "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮"
337
+ "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇"
338
+ "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较"
339
+ "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄"
340
+ "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆"
341
+ "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒"
342
+ "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱"
343
+ "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴"
344
+ "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝"
345
+ "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭"
346
+ "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒"
347
+ "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼"
348
+ "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨"
349
+ "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐"
350
+ "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹"
351
+ "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣"
352
+ "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼"
353
+ "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶"
354
+ "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈"
355
+ "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳"
356
+ "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰"
357
+ "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵"
358
+ "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额"
359
+ "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰"
360
+ "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥"
361
+ "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑"
362
+ "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高"
363
+ "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾"
364
+ "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨"
365
+ "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓"
366
+ "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶"
367
+ "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟"
368
+ "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄"
369
+ "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷"
370
+ "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦���"
371
+ "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡"
372
+ "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽"
373
+ "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯"
374
+ "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟"
375
+ "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟"
376
+ "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟"
377
+ "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓"
378
+ )
379
+ CN_CHARS_EXT = "吶诶屌囧飚屄"
380
+
381
+ CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
382
+ IN_CH_CHARS = {c: True for c in CN_CHARS}
383
+
384
+ EN_CHARS = string.ascii_letters + string.digits
385
+ IN_EN_CHARS = {c: True for c in EN_CHARS}
386
+
387
+ VALID_CHARS = CN_CHARS + EN_CHARS + " "
388
+ IN_VALID_CHARS = {c: True for c in VALID_CHARS}
389
+
390
+
391
+ # ================================================================================ #
392
+ # basic class
393
+ # ================================================================================ #
394
+ class ChineseChar(object):
395
+ """
396
+ 中文字符
397
+ 每个字符对应简体和繁体,
398
+ e.g. 简体 = '负', 繁体 = '負'
399
+ 转换时可转换为简体或繁体
400
+ """
401
+
402
+ def __init__(self, simplified, traditional):
403
+ self.simplified = simplified
404
+ self.traditional = traditional
405
+ # self.__repr__ = self.__str__
406
+
407
+ def __str__(self):
408
+ return self.simplified or self.traditional or None
409
+
410
+ def __repr__(self):
411
+ return self.__str__()
412
+
413
+
414
+ class ChineseNumberUnit(ChineseChar):
415
+ """
416
+ 中文数字/数位字符
417
+ 每个字符除繁简体外还有一个额外的大写字符
418
+ e.g. '陆' 和 '陸'
419
+ """
420
+
421
+ def __init__(self, power, simplified, traditional, big_s, big_t):
422
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
423
+ self.power = power
424
+ self.big_s = big_s
425
+ self.big_t = big_t
426
+
427
+ def __str__(self):
428
+ return "10^{}".format(self.power)
429
+
430
+ @classmethod
431
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
432
+ if small_unit:
433
+ return ChineseNumberUnit(
434
+ power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]
435
+ )
436
+ elif numbering_type == NUMBERING_TYPES[0]:
437
+ return ChineseNumberUnit(
438
+ power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
439
+ )
440
+ elif numbering_type == NUMBERING_TYPES[1]:
441
+ return ChineseNumberUnit(
442
+ power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
443
+ )
444
+ elif numbering_type == NUMBERING_TYPES[2]:
445
+ return ChineseNumberUnit(
446
+ power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
447
+ )
448
+ else:
449
+ raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type))
450
+
451
+
452
+ class ChineseNumberDigit(ChineseChar):
453
+ """
454
+ 中文数字字符
455
+ """
456
+
457
+ def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
458
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
459
+ self.value = value
460
+ self.big_s = big_s
461
+ self.big_t = big_t
462
+ self.alt_s = alt_s
463
+ self.alt_t = alt_t
464
+
465
+ def __str__(self):
466
+ return str(self.value)
467
+
468
+ @classmethod
469
+ def create(cls, i, v):
470
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
471
+
472
+
473
+ class ChineseMath(ChineseChar):
474
+ """
475
+ 中文数位字符
476
+ """
477
+
478
+ def __init__(self, simplified, traditional, symbol, expression=None):
479
+ super(ChineseMath, self).__init__(simplified, traditional)
480
+ self.symbol = symbol
481
+ self.expression = expression
482
+ self.big_s = simplified
483
+ self.big_t = traditional
484
+
485
+
486
+ CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
487
+
488
+
489
+ class NumberSystem(object):
490
+ """
491
+ 中文数字系统
492
+ """
493
+
494
+ pass
495
+
496
+
497
+ class MathSymbol(object):
498
+ """
499
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
500
+ positive = ['正', '正']
501
+ negative = ['负', '負']
502
+ point = ['点', '點']
503
+ """
504
+
505
+ def __init__(self, positive, negative, point):
506
+ self.positive = positive
507
+ self.negative = negative
508
+ self.point = point
509
+
510
+ def __iter__(self):
511
+ for v in self.__dict__.values():
512
+ yield v
513
+
514
+
515
+ # class OtherSymbol(object):
516
+ # """
517
+ # 其他符号
518
+ # """
519
+ #
520
+ # def __init__(self, sil):
521
+ # self.sil = sil
522
+ #
523
+ # def __iter__(self):
524
+ # for v in self.__dict__.values():
525
+ # yield v
526
+
527
+
528
+ # ================================================================================ #
529
+ # basic utils
530
+ # ================================================================================ #
531
+ def create_system(numbering_type=NUMBERING_TYPES[1]):
532
+ """
533
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
534
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
535
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
536
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
537
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
538
+ 返回对应的数字系统
539
+ """
540
+
541
+ # chinese number units of '亿' and larger
542
+ all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
543
+ larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
544
+ # chinese number units of '十, 百, 千, 万'
545
+ all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
546
+ smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
547
+ # digis
548
+ chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
549
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
550
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
551
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
552
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
553
+
554
+ # symbols
555
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
556
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
557
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
558
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
559
+ system = NumberSystem()
560
+ system.units = smaller_units + larger_units
561
+ system.digits = digits
562
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
563
+ # system.symbols = OtherSymbol(sil_cn)
564
+ return system
565
+
566
+
567
+ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
568
+ def get_symbol(char, system):
569
+ for u in system.units:
570
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
571
+ return u
572
+ for d in system.digits:
573
+ if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
574
+ return d
575
+ for m in system.math:
576
+ if char in [m.traditional, m.simplified]:
577
+ return m
578
+
579
+ def string2symbols(chinese_string, system):
580
+ int_string, dec_string = chinese_string, ""
581
+ for p in [system.math.point.simplified, system.math.point.traditional]:
582
+ if p in chinese_string:
583
+ int_string, dec_string = chinese_string.split(p)
584
+ break
585
+ return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string]
586
+
587
+ def correct_symbols(integer_symbols, system):
588
+ """
589
+ 一百八 to 一百八十
590
+ 一亿一千三百万 to 一亿 一千万 三百万
591
+ """
592
+
593
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
594
+ if integer_symbols[0].power == 1:
595
+ integer_symbols = [system.digits[1]] + integer_symbols
596
+
597
+ if len(integer_symbols) > 1:
598
+ if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
599
+ integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
600
+
601
+ result = []
602
+ unit_count = 0
603
+ for s in integer_symbols:
604
+ if isinstance(s, CND):
605
+ result.append(s)
606
+ unit_count = 0
607
+ elif isinstance(s, CNU):
608
+ current_unit = CNU(s.power, None, None, None, None)
609
+ unit_count += 1
610
+
611
+ if unit_count == 1:
612
+ result.append(current_unit)
613
+ elif unit_count > 1:
614
+ for i in range(len(result)):
615
+ if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
616
+ result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None)
617
+ return result
618
+
619
+ def compute_value(integer_symbols):
620
+ """
621
+ Compute the value.
622
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
623
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
624
+ """
625
+ value = [0]
626
+ last_power = 0
627
+ for s in integer_symbols:
628
+ if isinstance(s, CND):
629
+ value[-1] = s.value
630
+ elif isinstance(s, CNU):
631
+ value[-1] *= pow(10, s.power)
632
+ if s.power > last_power:
633
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
634
+ last_power = s.power
635
+ value.append(0)
636
+ return sum(value)
637
+
638
+ system = create_system(numbering_type)
639
+ int_part, dec_part = string2symbols(chinese_string, system)
640
+ int_part = correct_symbols(int_part, system)
641
+ int_str = str(compute_value(int_part))
642
+ dec_str = "".join([str(d.value) for d in dec_part])
643
+ if dec_part:
644
+ return "{0}.{1}".format(int_str, dec_str)
645
+ else:
646
+ return int_str
647
+
648
+
649
+ def num2chn(
650
+ number_string,
651
+ numbering_type=NUMBERING_TYPES[1],
652
+ big=False,
653
+ traditional=False,
654
+ alt_zero=False,
655
+ alt_one=False,
656
+ alt_two=True,
657
+ use_zeros=True,
658
+ use_units=True,
659
+ ):
660
+ def get_value(value_string, use_zeros=True):
661
+ striped_string = value_string.lstrip("0")
662
+
663
+ # record nothing if all zeros
664
+ if not striped_string:
665
+ return []
666
+
667
+ # record one digits
668
+ elif len(striped_string) == 1:
669
+ if use_zeros and len(value_string) != len(striped_string):
670
+ return [system.digits[0], system.digits[int(striped_string)]]
671
+ else:
672
+ return [system.digits[int(striped_string)]]
673
+
674
+ # recursively record multiple digits
675
+ else:
676
+ result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
677
+ result_string = value_string[: -result_unit.power]
678
+ return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :])
679
+
680
+ system = create_system(numbering_type)
681
+
682
+ int_dec = number_string.split(".")
683
+ if len(int_dec) == 1:
684
+ int_string = int_dec[0]
685
+ dec_string = ""
686
+ elif len(int_dec) == 2:
687
+ int_string = int_dec[0]
688
+ dec_string = int_dec[1]
689
+ else:
690
+ raise ValueError("invalid input num string with more than one dot: {}".format(number_string))
691
+
692
+ if use_units and len(int_string) > 1:
693
+ result_symbols = get_value(int_string)
694
+ else:
695
+ result_symbols = [system.digits[int(c)] for c in int_string]
696
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
697
+ if dec_string:
698
+ result_symbols += [system.math.point] + dec_symbols
699
+
700
+ if alt_two:
701
+ liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t)
702
+ for i, v in enumerate(result_symbols):
703
+ if isinstance(v, CND) and v.value == 2:
704
+ next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
705
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
706
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
707
+ if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
708
+ result_symbols[i] = liang
709
+
710
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
711
+ if big:
712
+ attr_name = "big_"
713
+ if traditional:
714
+ attr_name += "t"
715
+ else:
716
+ attr_name += "s"
717
+ else:
718
+ if traditional:
719
+ attr_name = "traditional"
720
+ else:
721
+ attr_name = "simplified"
722
+
723
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
724
+
725
+ # if not use_zeros:
726
+ # result = result.strip(getattr(system.digits[0], attr_name))
727
+
728
+ if alt_zero:
729
+ result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
730
+
731
+ if alt_one:
732
+ result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
733
+
734
+ for i, p in enumerate(POINT):
735
+ if result.startswith(p):
736
+ return CHINESE_DIGIS[0] + result
737
+
738
+ # ^10, 11, .., 19
739
+ if (
740
+ len(result) >= 2
741
+ and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]]
742
+ and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
743
+ ):
744
+ result = result[1:]
745
+
746
+ return result
747
+
748
+
749
+ # ================================================================================ #
750
+ # different types of rewriters
751
+ # ================================================================================ #
752
+ class Cardinal:
753
+ """
754
+ CARDINAL类
755
+ """
756
+
757
+ def __init__(self, cardinal=None, chntext=None):
758
+ self.cardinal = cardinal
759
+ self.chntext = chntext
760
+
761
+ def chntext2cardinal(self):
762
+ return chn2num(self.chntext)
763
+
764
+ def cardinal2chntext(self):
765
+ return num2chn(self.cardinal)
766
+
767
+
768
+ class Digit:
769
+ """
770
+ DIGIT类
771
+ """
772
+
773
+ def __init__(self, digit=None, chntext=None):
774
+ self.digit = digit
775
+ self.chntext = chntext
776
+
777
+ # def chntext2digit(self):
778
+ # return chn2num(self.chntext)
779
+
780
+ def digit2chntext(self):
781
+ return num2chn(self.digit, alt_two=False, use_units=False)
782
+
783
+
784
+ class TelePhone:
785
+ """
786
+ TELEPHONE类
787
+ """
788
+
789
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
790
+ self.telephone = telephone
791
+ self.raw_chntext = raw_chntext
792
+ self.chntext = chntext
793
+
794
+ # def chntext2telephone(self):
795
+ # sil_parts = self.raw_chntext.split('<SIL>')
796
+ # self.telephone = '-'.join([
797
+ # str(chn2num(p)) for p in sil_parts
798
+ # ])
799
+ # return self.telephone
800
+
801
+ def telephone2chntext(self, fixed=False):
802
+ if fixed:
803
+ sil_parts = self.telephone.split("-")
804
+ self.raw_chntext = "<SIL>".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts])
805
+ self.chntext = self.raw_chntext.replace("<SIL>", "")
806
+ else:
807
+ sp_parts = self.telephone.strip("+").split()
808
+ self.raw_chntext = "<SP>".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts])
809
+ self.chntext = self.raw_chntext.replace("<SP>", "")
810
+ return self.chntext
811
+
812
+
813
+ class Fraction:
814
+ """
815
+ FRACTION类
816
+ """
817
+
818
+ def __init__(self, fraction=None, chntext=None):
819
+ self.fraction = fraction
820
+ self.chntext = chntext
821
+
822
+ def chntext2fraction(self):
823
+ denominator, numerator = self.chntext.split("分之")
824
+ return chn2num(numerator) + "/" + chn2num(denominator)
825
+
826
+ def fraction2chntext(self):
827
+ numerator, denominator = self.fraction.split("/")
828
+ return num2chn(denominator) + "分之" + num2chn(numerator)
829
+
830
+
831
+ class Date:
832
+ """
833
+ DATE类
834
+ """
835
+
836
+ def __init__(self, date=None, chntext=None):
837
+ self.date = date
838
+ self.chntext = chntext
839
+
840
+ # def chntext2date(self):
841
+ # chntext = self.chntext
842
+ # try:
843
+ # year, other = chntext.strip().split('年', maxsplit=1)
844
+ # year = Digit(chntext=year).digit2chntext() + '年'
845
+ # except ValueError:
846
+ # other = chntext
847
+ # year = ''
848
+ # if other:
849
+ # try:
850
+ # month, day = other.strip().split('月', maxsplit=1)
851
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
852
+ # except ValueError:
853
+ # day = chntext
854
+ # month = ''
855
+ # if day:
856
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
857
+ # else:
858
+ # month = ''
859
+ # day = ''
860
+ # date = year + month + day
861
+ # self.date = date
862
+ # return self.date
863
+
864
+ def date2chntext(self):
865
+ date = self.date
866
+ try:
867
+ year, other = date.strip().split("年", 1)
868
+ year = Digit(digit=year).digit2chntext() + "年"
869
+ except ValueError:
870
+ other = date
871
+ year = ""
872
+ if other:
873
+ try:
874
+ month, day = other.strip().split("月", 1)
875
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
876
+ except ValueError:
877
+ day = date
878
+ month = ""
879
+ if day:
880
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
881
+ else:
882
+ month = ""
883
+ day = ""
884
+ chntext = year + month + day
885
+ self.chntext = chntext
886
+ return self.chntext
887
+
888
+
889
+ class Money:
890
+ """
891
+ MONEY类
892
+ """
893
+
894
+ def __init__(self, money=None, chntext=None):
895
+ self.money = money
896
+ self.chntext = chntext
897
+
898
+ # def chntext2money(self):
899
+ # return self.money
900
+
901
+ def money2chntext(self):
902
+ money = self.money
903
+ pattern = re.compile(r"(\d+(\.\d+)?)")
904
+ matchers = pattern.findall(money)
905
+ if matchers:
906
+ for matcher in matchers:
907
+ money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
908
+ self.chntext = money
909
+ return self.chntext
910
+
911
+
912
+ class Percentage:
913
+ """
914
+ PERCENTAGE类
915
+ """
916
+
917
+ def __init__(self, percentage=None, chntext=None):
918
+ self.percentage = percentage
919
+ self.chntext = chntext
920
+
921
+ def chntext2percentage(self):
922
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
923
+
924
+ def percentage2chntext(self):
925
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
926
+
927
+
928
+ def normalize_nsw(raw_text):
929
+ text = "^" + raw_text + "$"
930
+
931
+ # 规范化日期
932
+ pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
933
+ matchers = pattern.findall(text)
934
+ if matchers:
935
+ # print('date')
936
+ for matcher in matchers:
937
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
938
+
939
+ # 规范化金钱
940
+ pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
941
+ matchers = pattern.findall(text)
942
+ if matchers:
943
+ # print('money')
944
+ for matcher in matchers:
945
+ text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
946
+
947
+ # 规范化固话/手机号码
948
+ # 手机
949
+ # http://www.jihaoba.com/news/show/13680
950
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
951
+ # 联通:130、131、132、156、155、186、185、176
952
+ # 电信:133、153、189、180、181、177
953
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
954
+ matchers = pattern.findall(text)
955
+ if matchers:
956
+ # print('telephone')
957
+ for matcher in matchers:
958
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
959
+ # 固话
960
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
961
+ matchers = pattern.findall(text)
962
+ if matchers:
963
+ # print('fixed telephone')
964
+ for matcher in matchers:
965
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
966
+
967
+ # 规范化分数
968
+ pattern = re.compile(r"(\d+/\d+)")
969
+ matchers = pattern.findall(text)
970
+ if matchers:
971
+ # print('fraction')
972
+ for matcher in matchers:
973
+ text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
974
+
975
+ # 规范化百分数
976
+ text = text.replace("%", "%")
977
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
978
+ matchers = pattern.findall(text)
979
+ if matchers:
980
+ # print('percentage')
981
+ for matcher in matchers:
982
+ text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
983
+
984
+ # 规范化纯数+量词
985
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
986
+ matchers = pattern.findall(text)
987
+ if matchers:
988
+ # print('cardinal+quantifier')
989
+ for matcher in matchers:
990
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
991
+
992
+ # 规范化数字编号
993
+ pattern = re.compile(r"(\d{4,32})")
994
+ matchers = pattern.findall(text)
995
+ if matchers:
996
+ # print('digit')
997
+ for matcher in matchers:
998
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
999
+
1000
+ # 规范化纯数
1001
+ pattern = re.compile(r"(\d+(\.\d+)?)")
1002
+ matchers = pattern.findall(text)
1003
+ if matchers:
1004
+ # print('cardinal')
1005
+ for matcher in matchers:
1006
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
1007
+
1008
+ # restore P2P, O2O, B2C, B2B etc
1009
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
1010
+ matchers = pattern.findall(text)
1011
+ if matchers:
1012
+ # print('particular')
1013
+ for matcher in matchers:
1014
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
1015
+
1016
+ return text.lstrip("^").rstrip("$")
1017
+
1018
+
1019
+ def remove_erhua(text):
1020
+ """
1021
+ 去除儿化音词中的儿:
1022
+ 他女儿在那边儿 -> 他女儿在那边
1023
+ """
1024
+
1025
+ new_str = ""
1026
+ while re.search("儿", text):
1027
+ a = re.search("儿", text).span()
1028
+ remove_er_flag = 0
1029
+
1030
+ if ER_WHITELIST_PATTERN.search(text):
1031
+ b = ER_WHITELIST_PATTERN.search(text).span()
1032
+ if b[0] <= a[0]:
1033
+ remove_er_flag = 1
1034
+
1035
+ if remove_er_flag == 0:
1036
+ new_str = new_str + text[0 : a[0]]
1037
+ text = text[a[1] :]
1038
+ else:
1039
+ new_str = new_str + text[0 : b[1]]
1040
+ text = text[b[1] :]
1041
+
1042
+ text = new_str + text
1043
+ return text
1044
+
1045
+
1046
+ def remove_space(text):
1047
+ tokens = text.split()
1048
+ new = []
1049
+ for k, t in enumerate(tokens):
1050
+ if k != 0:
1051
+ if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]):
1052
+ new.append(" ")
1053
+ new.append(t)
1054
+ return "".join(new)
1055
+
1056
+
1057
+ class TextNorm:
1058
+ def __init__(
1059
+ self,
1060
+ to_banjiao: bool = False,
1061
+ to_upper: bool = False,
1062
+ to_lower: bool = False,
1063
+ remove_fillers: bool = False,
1064
+ remove_erhua: bool = False,
1065
+ check_chars: bool = False,
1066
+ remove_space: bool = False,
1067
+ cc_mode: str = "",
1068
+ ):
1069
+ self.to_banjiao = to_banjiao
1070
+ self.to_upper = to_upper
1071
+ self.to_lower = to_lower
1072
+ self.remove_fillers = remove_fillers
1073
+ self.remove_erhua = remove_erhua
1074
+ self.check_chars = check_chars
1075
+ self.remove_space = remove_space
1076
+
1077
+ self.cc = None
1078
+ if cc_mode:
1079
+ from opencc import OpenCC # Open Chinese Convert: pip install opencc
1080
+
1081
+ self.cc = OpenCC(cc_mode)
1082
+
1083
+ def __call__(self, text):
1084
+ if self.cc:
1085
+ text = self.cc.convert(text)
1086
+
1087
+ if self.to_banjiao:
1088
+ text = text.translate(QJ2BJ_TRANSFORM)
1089
+
1090
+ if self.to_upper:
1091
+ text = text.upper()
1092
+
1093
+ if self.to_lower:
1094
+ text = text.lower()
1095
+
1096
+ if self.remove_fillers:
1097
+ for c in FILLER_CHARS:
1098
+ text = text.replace(c, "")
1099
+
1100
+ if self.remove_erhua:
1101
+ text = remove_erhua(text)
1102
+
1103
+ text = normalize_nsw(text)
1104
+
1105
+ text = text.translate(PUNCS_TRANSFORM)
1106
+
1107
+ if self.check_chars:
1108
+ for c in text:
1109
+ if not IN_VALID_CHARS.get(c):
1110
+ print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
1111
+ return ""
1112
+
1113
+ if self.remove_space:
1114
+ text = remove_space(text)
1115
+
1116
+ return text
1117
+
1118
+
1119
+ if __name__ == "__main__":
1120
+ p = argparse.ArgumentParser()
1121
+
1122
+ # normalizer options
1123
+ p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao")
1124
+ p.add_argument("--to_upper", action="store_true", help="convert to upper case")
1125
+ p.add_argument("--to_lower", action="store_true", help="convert to lower case")
1126
+ p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "呃, 啊"')
1127
+ p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"')
1128
+ p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars")
1129
+ p.add_argument("--remove_space", action="store_true", help="remove whitespace")
1130
+ p.add_argument(
1131
+ "--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified"
1132
+ )
1133
+
1134
+ # I/O options
1135
+ p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines")
1136
+ p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead")
1137
+ p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format")
1138
+ p.add_argument("ifile", help="input filename, assume utf-8 encoding")
1139
+ p.add_argument("ofile", help="output filename")
1140
+
1141
+ args = p.parse_args()
1142
+
1143
+ if args.has_key:
1144
+ args.format = "ark"
1145
+
1146
+ normalizer = TextNorm(
1147
+ to_banjiao=args.to_banjiao,
1148
+ to_upper=args.to_upper,
1149
+ to_lower=args.to_lower,
1150
+ remove_fillers=args.remove_fillers,
1151
+ remove_erhua=args.remove_erhua,
1152
+ check_chars=args.check_chars,
1153
+ remove_space=args.remove_space,
1154
+ cc_mode=args.cc_mode,
1155
+ )
1156
+
1157
+ normalizer = TextNorm(
1158
+ to_banjiao=args.to_banjiao,
1159
+ to_upper=args.to_upper,
1160
+ to_lower=args.to_lower,
1161
+ remove_fillers=args.remove_fillers,
1162
+ remove_erhua=args.remove_erhua,
1163
+ check_chars=args.check_chars,
1164
+ remove_space=args.remove_space,
1165
+ cc_mode=args.cc_mode,
1166
+ )
1167
+
1168
+ ndone = 0
1169
+ with open(args.ifile, "r", encoding="utf-8") as istream, open(args.ofile, "w+", encoding="utf-8") as ostream:
1170
+ if args.format == "tsv":
1171
+ reader = csv.DictReader(istream, delimiter="\t")
1172
+ assert "TEXT" in reader.fieldnames
1173
+ print("\t".join(reader.fieldnames), file=ostream)
1174
+
1175
+ for item in reader:
1176
+ text = item["TEXT"]
1177
+
1178
+ if text:
1179
+ text = normalizer(text)
1180
+
1181
+ if text:
1182
+ item["TEXT"] = text
1183
+ print("\t".join([item[f] for f in reader.fieldnames]), file=ostream)
1184
+
1185
+ ndone += 1
1186
+ if ndone % args.log_interval == 0:
1187
+ print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
1188
+ else:
1189
+ for l in istream:
1190
+ key, text = "", ""
1191
+ if args.format == "ark": # KALDI archive, line format: "key text"
1192
+ cols = l.strip().split(maxsplit=1)
1193
+ key, text = cols[0], cols[1] if len(cols) == 2 else ""
1194
+ else:
1195
+ text = l.strip()
1196
+
1197
+ if text:
1198
+ text = normalizer(text)
1199
+
1200
+ if text:
1201
+ if args.format == "ark":
1202
+ print(key + "\t" + text, file=ostream)
1203
+ else:
1204
+ print(text, file=ostream)
1205
+
1206
+ ndone += 1
1207
+ if ndone % args.log_interval == 0:
1208
+ print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
1209
+ print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True)
music_dcae/__init__.py ADDED
File without changes
music_dcae/music_dcae_pipeline.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import AutoencoderDC
4
+ import torchaudio
5
+ import torchvision.transforms as transforms
6
+ import torchaudio
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+
12
+ try:
13
+ from .music_vocoder import ADaMoSHiFiGANV1
14
+ except ImportError:
15
+ from music_vocoder import ADaMoSHiFiGANV1
16
+
17
+
18
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
20
+ VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
21
+
22
+
23
+ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
24
+ @register_to_config
25
+ def __init__(self, source_sample_rate=None, dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH):
26
+ super(MusicDCAE, self).__init__()
27
+
28
+ self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
29
+ self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
30
+
31
+ if source_sample_rate is None:
32
+ source_sample_rate = 48000
33
+
34
+ self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
35
+
36
+ self.transform = transforms.Compose([
37
+ transforms.Normalize(0.5, 0.5),
38
+ ])
39
+ self.min_mel_value = -11.0
40
+ self.max_mel_value = 3.0
41
+ self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
42
+ self.mel_chunk_size = 1024
43
+ self.time_dimention_multiple = 8
44
+ self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
45
+ self.scale_factor = 0.1786
46
+ self.shift_factor = -1.9091
47
+
48
+ def load_audio(self, audio_path):
49
+ audio, sr = torchaudio.load(audio_path)
50
+ return audio, sr
51
+
52
+ def forward_mel(self, audios):
53
+ mels = []
54
+ for i in range(len(audios)):
55
+ image = self.vocoder.mel_transform(audios[i])
56
+ mels.append(image)
57
+ mels = torch.stack(mels)
58
+ return mels
59
+
60
+ @torch.no_grad()
61
+ def encode(self, audios, audio_lengths=None, sr=None):
62
+ if audio_lengths is None:
63
+ audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
64
+ audio_lengths = audio_lengths.to(audios.device)
65
+
66
+ # audios: N x 2 x T, 48kHz
67
+ device = audios.device
68
+ dtype = audios.dtype
69
+
70
+ if sr is None:
71
+ sr = 48000
72
+ resampler = self.resampler
73
+ else:
74
+ resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
75
+
76
+ audio = resampler(audios)
77
+
78
+ max_audio_len = audio.shape[-1]
79
+ if max_audio_len % (8 * 512) != 0:
80
+ audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
81
+
82
+ mels = self.forward_mel(audio)
83
+ mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
84
+ mels = self.transform(mels)
85
+ latents = []
86
+ for mel in mels:
87
+ latent = self.dcae.encoder(mel.unsqueeze(0))
88
+ latents.append(latent)
89
+ latents = torch.cat(latents, dim=0)
90
+ latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
91
+ latents = (latents - self.shift_factor) * self.scale_factor
92
+ return latents, latent_lengths
93
+
94
+ @torch.no_grad()
95
+ def decode(self, latents, audio_lengths=None, sr=None):
96
+ latents = latents / self.scale_factor + self.shift_factor
97
+
98
+ mels = []
99
+
100
+ for latent in latents:
101
+ mel = self.dcae.decoder(latent.unsqueeze(0))
102
+ mels.append(mel)
103
+ mels = torch.cat(mels, dim=0)
104
+
105
+ mels = mels * 0.5 + 0.5
106
+ mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
107
+ bsz, channels, num_mel, mel_width = mels.shape
108
+ pred_wavs = []
109
+ for i in range(bsz):
110
+ mel = mels[i]
111
+ wav = self.vocoder.decode(mel).squeeze(1)
112
+ pred_wavs.append(wav)
113
+
114
+ pred_wavs = torch.stack(pred_wavs)
115
+
116
+ if sr is not None:
117
+ resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
118
+ pred_wavs = [resampler(wav) for wav in pred_wavs]
119
+ else:
120
+ sr = 44100
121
+ if audio_lengths is not None:
122
+ pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
123
+ return sr, pred_wavs
124
+
125
+ def forward(self, audios, audio_lengths=None, sr=None):
126
+ latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
127
+ sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
128
+ return sr, pred_wavs, latents, latent_lengths
129
+
130
+
131
+ if __name__ == "__main__":
132
+
133
+ audio, sr = torchaudio.load("test.wav")
134
+ audio_lengths = torch.tensor([audio.shape[1]])
135
+ audios = audio.unsqueeze(0)
136
+
137
+ # test encode only
138
+ model = MusicDCAE()
139
+ # latents, latent_lengths = model.encode(audios, audio_lengths)
140
+ # print("latents shape: ", latents.shape)
141
+ # print("latent_lengths: ", latent_lengths)
142
+
143
+ # test encode and decode
144
+ sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
145
+ print("reconstructed wavs: ", pred_wavs[0].shape)
146
+ print("latents shape: ", latents.shape)
147
+ print("latent_lengths: ", latent_lengths)
148
+ print("sr: ", sr)
149
+ torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
150
+ print("test_reconstructed.flac")
music_dcae/music_log_mel.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from torchaudio.transforms import MelScale
5
+
6
+
7
+ class LinearSpectrogram(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft=2048,
11
+ win_length=2048,
12
+ hop_length=512,
13
+ center=False,
14
+ mode="pow2_sqrt",
15
+ ):
16
+ super().__init__()
17
+
18
+ self.n_fft = n_fft
19
+ self.win_length = win_length
20
+ self.hop_length = hop_length
21
+ self.center = center
22
+ self.mode = mode
23
+
24
+ self.register_buffer("window", torch.hann_window(win_length))
25
+
26
+ def forward(self, y: Tensor) -> Tensor:
27
+ if y.ndim == 3:
28
+ y = y.squeeze(1)
29
+
30
+ y = torch.nn.functional.pad(
31
+ y.unsqueeze(1),
32
+ (
33
+ (self.win_length - self.hop_length) // 2,
34
+ (self.win_length - self.hop_length + 1) // 2,
35
+ ),
36
+ mode="reflect",
37
+ ).squeeze(1)
38
+ dtype = y.dtype
39
+ spec = torch.stft(
40
+ y.float(),
41
+ self.n_fft,
42
+ hop_length=self.hop_length,
43
+ win_length=self.win_length,
44
+ window=self.window,
45
+ center=self.center,
46
+ pad_mode="reflect",
47
+ normalized=False,
48
+ onesided=True,
49
+ return_complex=True,
50
+ )
51
+ spec = torch.view_as_real(spec)
52
+
53
+ if self.mode == "pow2_sqrt":
54
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
55
+ spec = spec.to(dtype)
56
+ return spec
57
+
58
+
59
+ class LogMelSpectrogram(nn.Module):
60
+ def __init__(
61
+ self,
62
+ sample_rate=44100,
63
+ n_fft=2048,
64
+ win_length=2048,
65
+ hop_length=512,
66
+ n_mels=128,
67
+ center=False,
68
+ f_min=0.0,
69
+ f_max=None,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.sample_rate = sample_rate
74
+ self.n_fft = n_fft
75
+ self.win_length = win_length
76
+ self.hop_length = hop_length
77
+ self.center = center
78
+ self.n_mels = n_mels
79
+ self.f_min = f_min
80
+ self.f_max = f_max or sample_rate // 2
81
+
82
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
83
+ self.mel_scale = MelScale(
84
+ self.n_mels,
85
+ self.sample_rate,
86
+ self.f_min,
87
+ self.f_max,
88
+ self.n_fft // 2 + 1,
89
+ "slaney",
90
+ "slaney",
91
+ )
92
+
93
+ def compress(self, x: Tensor) -> Tensor:
94
+ return torch.log(torch.clamp(x, min=1e-5))
95
+
96
+ def decompress(self, x: Tensor) -> Tensor:
97
+ return torch.exp(x)
98
+
99
+ def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
100
+ linear = self.spectrogram(x)
101
+ x = self.mel_scale(linear)
102
+ x = self.compress(x)
103
+ # print(x.shape)
104
+ if return_linear:
105
+ return x, self.compress(linear)
106
+
107
+ return x
music_dcae/music_vocoder.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ from torch import nn
4
+
5
+ from functools import partial
6
+ from math import prod
7
+ from typing import Callable, Tuple, List
8
+
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from torch.nn import Conv1d
12
+ from torch.nn.utils import weight_norm
13
+ from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.loaders import FromOriginalModelMixin
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+
18
+
19
+ try:
20
+ from music_log_mel import LogMelSpectrogram
21
+ except ImportError:
22
+ from .music_log_mel import LogMelSpectrogram
23
+
24
+
25
+ def drop_path(
26
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
27
+ ):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
29
+
30
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
31
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
32
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
33
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
34
+ 'survival rate' as the argument.
35
+
36
+ """ # noqa: E501
37
+
38
+ if drop_prob == 0.0 or not training:
39
+ return x
40
+ keep_prob = 1 - drop_prob
41
+ shape = (x.shape[0],) + (1,) * (
42
+ x.ndim - 1
43
+ ) # work with diff dim tensors, not just 2D ConvNets
44
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
45
+ if keep_prob > 0.0 and scale_by_keep:
46
+ random_tensor.div_(keep_prob)
47
+ return x * random_tensor
48
+
49
+
50
+ class DropPath(nn.Module):
51
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
52
+
53
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
54
+ super(DropPath, self).__init__()
55
+ self.drop_prob = drop_prob
56
+ self.scale_by_keep = scale_by_keep
57
+
58
+ def forward(self, x):
59
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
60
+
61
+ def extra_repr(self):
62
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
63
+
64
+
65
+ class LayerNorm(nn.Module):
66
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
67
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
68
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
69
+ with shape (batch_size, channels, height, width).
70
+ """ # noqa: E501
71
+
72
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
73
+ super().__init__()
74
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
75
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
76
+ self.eps = eps
77
+ self.data_format = data_format
78
+ if self.data_format not in ["channels_last", "channels_first"]:
79
+ raise NotImplementedError
80
+ self.normalized_shape = (normalized_shape,)
81
+
82
+ def forward(self, x):
83
+ if self.data_format == "channels_last":
84
+ return F.layer_norm(
85
+ x, self.normalized_shape, self.weight, self.bias, self.eps
86
+ )
87
+ elif self.data_format == "channels_first":
88
+ u = x.mean(1, keepdim=True)
89
+ s = (x - u).pow(2).mean(1, keepdim=True)
90
+ x = (x - u) / torch.sqrt(s + self.eps)
91
+ x = self.weight[:, None] * x + self.bias[:, None]
92
+ return x
93
+
94
+
95
+ class ConvNeXtBlock(nn.Module):
96
+ r"""ConvNeXt Block. There are two equivalent implementations:
97
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
98
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
99
+ We use (2) as we find it slightly faster in PyTorch
100
+
101
+ Args:
102
+ dim (int): Number of input channels.
103
+ drop_path (float): Stochastic depth rate. Default: 0.0
104
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
105
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
106
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
107
+ dilation (int): Dilation for depthwise conv. Default: 1.
108
+ """ # noqa: E501
109
+
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ drop_path: float = 0.0,
114
+ layer_scale_init_value: float = 1e-6,
115
+ mlp_ratio: float = 4.0,
116
+ kernel_size: int = 7,
117
+ dilation: int = 1,
118
+ ):
119
+ super().__init__()
120
+
121
+ self.dwconv = nn.Conv1d(
122
+ dim,
123
+ dim,
124
+ kernel_size=kernel_size,
125
+ padding=int(dilation * (kernel_size - 1) / 2),
126
+ groups=dim,
127
+ ) # depthwise conv
128
+ self.norm = LayerNorm(dim, eps=1e-6)
129
+ self.pwconv1 = nn.Linear(
130
+ dim, int(mlp_ratio * dim)
131
+ ) # pointwise/1x1 convs, implemented with linear layers
132
+ self.act = nn.GELU()
133
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
134
+ self.gamma = (
135
+ nn.Parameter(layer_scale_init_value *
136
+ torch.ones((dim)), requires_grad=True)
137
+ if layer_scale_init_value > 0
138
+ else None
139
+ )
140
+ self.drop_path = DropPath(
141
+ drop_path) if drop_path > 0.0 else nn.Identity()
142
+
143
+ def forward(self, x, apply_residual: bool = True):
144
+ input = x
145
+
146
+ x = self.dwconv(x)
147
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
148
+ x = self.norm(x)
149
+ x = self.pwconv1(x)
150
+ x = self.act(x)
151
+ x = self.pwconv2(x)
152
+
153
+ if self.gamma is not None:
154
+ x = self.gamma * x
155
+
156
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
157
+ x = self.drop_path(x)
158
+
159
+ if apply_residual:
160
+ x = input + x
161
+
162
+ return x
163
+
164
+
165
+ class ParallelConvNeXtBlock(nn.Module):
166
+ def __init__(self, kernel_sizes: List[int], *args, **kwargs):
167
+ super().__init__()
168
+ self.blocks = nn.ModuleList(
169
+ [
170
+ ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
171
+ for kernel_size in kernel_sizes
172
+ ]
173
+ )
174
+
175
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ return torch.stack(
177
+ [block(x, apply_residual=False) for block in self.blocks] + [x],
178
+ dim=1,
179
+ ).sum(dim=1)
180
+
181
+
182
+ class ConvNeXtEncoder(nn.Module):
183
+ def __init__(
184
+ self,
185
+ input_channels=3,
186
+ depths=[3, 3, 9, 3],
187
+ dims=[96, 192, 384, 768],
188
+ drop_path_rate=0.0,
189
+ layer_scale_init_value=1e-6,
190
+ kernel_sizes: Tuple[int] = (7,),
191
+ ):
192
+ super().__init__()
193
+ assert len(depths) == len(dims)
194
+
195
+ self.channel_layers = nn.ModuleList()
196
+ stem = nn.Sequential(
197
+ nn.Conv1d(
198
+ input_channels,
199
+ dims[0],
200
+ kernel_size=7,
201
+ padding=3,
202
+ padding_mode="replicate",
203
+ ),
204
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
205
+ )
206
+ self.channel_layers.append(stem)
207
+
208
+ for i in range(len(depths) - 1):
209
+ mid_layer = nn.Sequential(
210
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
211
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
212
+ )
213
+ self.channel_layers.append(mid_layer)
214
+
215
+ block_fn = (
216
+ partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
217
+ if len(kernel_sizes) == 1
218
+ else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
219
+ )
220
+
221
+ self.stages = nn.ModuleList()
222
+ drop_path_rates = [
223
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
224
+ ]
225
+
226
+ cur = 0
227
+ for i in range(len(depths)):
228
+ stage = nn.Sequential(
229
+ *[
230
+ block_fn(
231
+ dim=dims[i],
232
+ drop_path=drop_path_rates[cur + j],
233
+ layer_scale_init_value=layer_scale_init_value,
234
+ )
235
+ for j in range(depths[i])
236
+ ]
237
+ )
238
+ self.stages.append(stage)
239
+ cur += depths[i]
240
+
241
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
242
+ self.apply(self._init_weights)
243
+
244
+ def _init_weights(self, m):
245
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
246
+ nn.init.trunc_normal_(m.weight, std=0.02)
247
+ nn.init.constant_(m.bias, 0)
248
+
249
+ def forward(
250
+ self,
251
+ x: torch.Tensor,
252
+ ) -> torch.Tensor:
253
+ for channel_layer, stage in zip(self.channel_layers, self.stages):
254
+ x = channel_layer(x)
255
+ x = stage(x)
256
+
257
+ return self.norm(x)
258
+
259
+
260
+ def init_weights(m, mean=0.0, std=0.01):
261
+ classname = m.__class__.__name__
262
+ if classname.find("Conv") != -1:
263
+ m.weight.data.normal_(mean, std)
264
+
265
+
266
+ def get_padding(kernel_size, dilation=1):
267
+ return (kernel_size * dilation - dilation) // 2
268
+
269
+
270
+ class ResBlock1(torch.nn.Module):
271
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
272
+ super().__init__()
273
+
274
+ self.convs1 = nn.ModuleList(
275
+ [
276
+ weight_norm(
277
+ Conv1d(
278
+ channels,
279
+ channels,
280
+ kernel_size,
281
+ 1,
282
+ dilation=dilation[0],
283
+ padding=get_padding(kernel_size, dilation[0]),
284
+ )
285
+ ),
286
+ weight_norm(
287
+ Conv1d(
288
+ channels,
289
+ channels,
290
+ kernel_size,
291
+ 1,
292
+ dilation=dilation[1],
293
+ padding=get_padding(kernel_size, dilation[1]),
294
+ )
295
+ ),
296
+ weight_norm(
297
+ Conv1d(
298
+ channels,
299
+ channels,
300
+ kernel_size,
301
+ 1,
302
+ dilation=dilation[2],
303
+ padding=get_padding(kernel_size, dilation[2]),
304
+ )
305
+ ),
306
+ ]
307
+ )
308
+ self.convs1.apply(init_weights)
309
+
310
+ self.convs2 = nn.ModuleList(
311
+ [
312
+ weight_norm(
313
+ Conv1d(
314
+ channels,
315
+ channels,
316
+ kernel_size,
317
+ 1,
318
+ dilation=1,
319
+ padding=get_padding(kernel_size, 1),
320
+ )
321
+ ),
322
+ weight_norm(
323
+ Conv1d(
324
+ channels,
325
+ channels,
326
+ kernel_size,
327
+ 1,
328
+ dilation=1,
329
+ padding=get_padding(kernel_size, 1),
330
+ )
331
+ ),
332
+ weight_norm(
333
+ Conv1d(
334
+ channels,
335
+ channels,
336
+ kernel_size,
337
+ 1,
338
+ dilation=1,
339
+ padding=get_padding(kernel_size, 1),
340
+ )
341
+ ),
342
+ ]
343
+ )
344
+ self.convs2.apply(init_weights)
345
+
346
+ def forward(self, x):
347
+ for c1, c2 in zip(self.convs1, self.convs2):
348
+ xt = F.silu(x)
349
+ xt = c1(xt)
350
+ xt = F.silu(xt)
351
+ xt = c2(xt)
352
+ x = xt + x
353
+ return x
354
+
355
+ def remove_weight_norm(self):
356
+ for conv in self.convs1:
357
+ remove_weight_norm(conv)
358
+ for conv in self.convs2:
359
+ remove_weight_norm(conv)
360
+
361
+
362
+ class HiFiGANGenerator(nn.Module):
363
+ def __init__(
364
+ self,
365
+ *,
366
+ hop_length: int = 512,
367
+ upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
368
+ upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
369
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
370
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
371
+ (1, 3, 5), (1, 3, 5), (1, 3, 5)),
372
+ num_mels: int = 128,
373
+ upsample_initial_channel: int = 512,
374
+ use_template: bool = True,
375
+ pre_conv_kernel_size: int = 7,
376
+ post_conv_kernel_size: int = 7,
377
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
378
+ ):
379
+ super().__init__()
380
+
381
+ assert (
382
+ prod(upsample_rates) == hop_length
383
+ ), f"hop_length must be {prod(upsample_rates)}"
384
+
385
+ self.conv_pre = weight_norm(
386
+ nn.Conv1d(
387
+ num_mels,
388
+ upsample_initial_channel,
389
+ pre_conv_kernel_size,
390
+ 1,
391
+ padding=get_padding(pre_conv_kernel_size),
392
+ )
393
+ )
394
+
395
+ self.num_upsamples = len(upsample_rates)
396
+ self.num_kernels = len(resblock_kernel_sizes)
397
+
398
+ self.noise_convs = nn.ModuleList()
399
+ self.use_template = use_template
400
+ self.ups = nn.ModuleList()
401
+
402
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
403
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
404
+ self.ups.append(
405
+ weight_norm(
406
+ nn.ConvTranspose1d(
407
+ upsample_initial_channel // (2**i),
408
+ upsample_initial_channel // (2 ** (i + 1)),
409
+ k,
410
+ u,
411
+ padding=(k - u) // 2,
412
+ )
413
+ )
414
+ )
415
+
416
+ if not use_template:
417
+ continue
418
+
419
+ if i + 1 < len(upsample_rates):
420
+ stride_f0 = np.prod(upsample_rates[i + 1:])
421
+ self.noise_convs.append(
422
+ Conv1d(
423
+ 1,
424
+ c_cur,
425
+ kernel_size=stride_f0 * 2,
426
+ stride=stride_f0,
427
+ padding=stride_f0 // 2,
428
+ )
429
+ )
430
+ else:
431
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
432
+
433
+ self.resblocks = nn.ModuleList()
434
+ for i in range(len(self.ups)):
435
+ ch = upsample_initial_channel // (2 ** (i + 1))
436
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
437
+ self.resblocks.append(ResBlock1(ch, k, d))
438
+
439
+ self.activation_post = post_activation()
440
+ self.conv_post = weight_norm(
441
+ nn.Conv1d(
442
+ ch,
443
+ 1,
444
+ post_conv_kernel_size,
445
+ 1,
446
+ padding=get_padding(post_conv_kernel_size),
447
+ )
448
+ )
449
+ self.ups.apply(init_weights)
450
+ self.conv_post.apply(init_weights)
451
+
452
+ def forward(self, x, template=None):
453
+ x = self.conv_pre(x)
454
+
455
+ for i in range(self.num_upsamples):
456
+ x = F.silu(x, inplace=True)
457
+ x = self.ups[i](x)
458
+
459
+ if self.use_template:
460
+ x = x + self.noise_convs[i](template)
461
+
462
+ xs = None
463
+
464
+ for j in range(self.num_kernels):
465
+ if xs is None:
466
+ xs = self.resblocks[i * self.num_kernels + j](x)
467
+ else:
468
+ xs += self.resblocks[i * self.num_kernels + j](x)
469
+
470
+ x = xs / self.num_kernels
471
+
472
+ x = self.activation_post(x)
473
+ x = self.conv_post(x)
474
+ x = torch.tanh(x)
475
+
476
+ return x
477
+
478
+ def remove_weight_norm(self):
479
+ for up in self.ups:
480
+ remove_weight_norm(up)
481
+ for block in self.resblocks:
482
+ block.remove_weight_norm()
483
+ remove_weight_norm(self.conv_pre)
484
+ remove_weight_norm(self.conv_post)
485
+
486
+
487
+ class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
488
+
489
+ @register_to_config
490
+ def __init__(
491
+ self,
492
+ input_channels: int = 128,
493
+ depths: List[int] = [3, 3, 9, 3],
494
+ dims: List[int] = [128, 256, 384, 512],
495
+ drop_path_rate: float = 0.0,
496
+ kernel_sizes: Tuple[int] = (7,),
497
+ upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
498
+ upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
499
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
500
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
501
+ (1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
502
+ num_mels: int = 512,
503
+ upsample_initial_channel: int = 1024,
504
+ use_template: bool = False,
505
+ pre_conv_kernel_size: int = 13,
506
+ post_conv_kernel_size: int = 13,
507
+ sampling_rate: int = 44100,
508
+ n_fft: int = 2048,
509
+ win_length: int = 2048,
510
+ hop_length: int = 512,
511
+ f_min: int = 40,
512
+ f_max: int = 16000,
513
+ n_mels: int = 128,
514
+ ):
515
+ super().__init__()
516
+
517
+ self.backbone = ConvNeXtEncoder(
518
+ input_channels=input_channels,
519
+ depths=depths,
520
+ dims=dims,
521
+ drop_path_rate=drop_path_rate,
522
+ kernel_sizes=kernel_sizes,
523
+ )
524
+
525
+ self.head = HiFiGANGenerator(
526
+ hop_length=hop_length,
527
+ upsample_rates=upsample_rates,
528
+ upsample_kernel_sizes=upsample_kernel_sizes,
529
+ resblock_kernel_sizes=resblock_kernel_sizes,
530
+ resblock_dilation_sizes=resblock_dilation_sizes,
531
+ num_mels=num_mels,
532
+ upsample_initial_channel=upsample_initial_channel,
533
+ use_template=use_template,
534
+ pre_conv_kernel_size=pre_conv_kernel_size,
535
+ post_conv_kernel_size=post_conv_kernel_size,
536
+ )
537
+ self.sampling_rate = sampling_rate
538
+ self.mel_transform = LogMelSpectrogram(
539
+ sample_rate=sampling_rate,
540
+ n_fft=n_fft,
541
+ win_length=win_length,
542
+ hop_length=hop_length,
543
+ f_min=f_min,
544
+ f_max=f_max,
545
+ n_mels=n_mels,
546
+ )
547
+ self.eval()
548
+
549
+ @torch.no_grad()
550
+ def decode(self, mel):
551
+ y = self.backbone(mel)
552
+ y = self.head(y)
553
+ return y
554
+
555
+ @torch.no_grad()
556
+ def encode(self, x):
557
+ return self.mel_transform(x)
558
+
559
+ def forward(self, mel):
560
+ y = self.backbone(mel)
561
+ y = self.head(y)
562
+ return y
563
+
564
+
565
+ if __name__ == "__main__":
566
+ import soundfile as sf
567
+
568
+ x = "test_audio.flac"
569
+ model = ADaMoSHiFiGANV1.from_pretrained("./checkpoints/music_vocoder", local_files_only=True)
570
+
571
+ wav, sr = librosa.load(x, sr=44100, mono=True)
572
+ wav = torch.from_numpy(wav).float()[None]
573
+ mel = model.encode(wav)
574
+
575
+ wav = model.decode(mel)[0].mT
576
+ sf.write("test_audio_vocoder_rec.flac", wav.cpu().numpy(), 44100)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pipeline_ace_step.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ import os
4
+ import re
5
+ import glob
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+ import json
12
+ import math
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
17
+ from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
18
+ from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ from transformers import UMT5EncoderModel, AutoTokenizer
21
+
22
+ from language_segmentation import LangSegment
23
+ from music_dcae.music_dcae_pipeline import MusicDCAE
24
+ from models.ace_step_transformer import ACEStepTransformer2DModel
25
+ from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
26
+ from apg_guidance import apg_forward, MomentumBuffer, cfg_forward, cfg_zero_star, cfg_double_condition_forward
27
+ import torchaudio
28
+
29
+
30
+ SUPPORT_LANGUAGES = {
31
+ "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
32
+ "pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
33
+ "nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
34
+ "ko": 6152, "hi": 6680
35
+ }
36
+
37
+ structure_pattern = re.compile(r"\[.*?\]")
38
+
39
+
40
+ def ensure_directory_exists(directory):
41
+ directory = str(directory)
42
+ if not os.path.exists(directory):
43
+ os.makedirs(directory)
44
+
45
+
46
+ REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
47
+
48
+
49
+ class ACEStepPipeline:
50
+
51
+ def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, **kwargs):
52
+ if checkpoint_dir is None:
53
+ if persistent_storage_path is None:
54
+ checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
55
+ else:
56
+ checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
57
+ ensure_directory_exists(checkpoint_dir)
58
+
59
+ self.checkpoint_dir = checkpoint_dir
60
+ device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu")
61
+ self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
62
+ self.device = device
63
+ self.loaded = False
64
+
65
+ def load_checkpoint(self, checkpoint_dir=None):
66
+ device = self.device
67
+
68
+ dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
69
+ vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
70
+ ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
71
+ text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
72
+
73
+ files_exist = (
74
+ os.path.exists(os.path.join(dcae_model_path, "config.json")) and
75
+ os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
76
+ os.path.exists(os.path.join(vocoder_model_path, "config.json")) and
77
+ os.path.exists(os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")) and
78
+ os.path.exists(os.path.join(ace_step_model_path, "config.json")) and
79
+ os.path.exists(os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")) and
80
+ os.path.exists(os.path.join(text_encoder_model_path, "config.json")) and
81
+ os.path.exists(os.path.join(text_encoder_model_path, "model.safetensors")) and
82
+ os.path.exists(os.path.join(text_encoder_model_path, "special_tokens_map.json")) and
83
+ os.path.exists(os.path.join(text_encoder_model_path, "tokenizer_config.json")) and
84
+ os.path.exists(os.path.join(text_encoder_model_path, "tokenizer.json"))
85
+ )
86
+
87
+ if not files_exist:
88
+ logger.info(f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub")
89
+
90
+ # download music dcae model
91
+ os.makedirs(dcae_model_path, exist_ok=True)
92
+ hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
93
+ filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
94
+ hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
95
+ filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
96
+
97
+ # download vocoder model
98
+ os.makedirs(vocoder_model_path, exist_ok=True)
99
+ hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
100
+ filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
101
+ hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
102
+ filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
103
+
104
+ # download ace_step transformer model
105
+ os.makedirs(ace_step_model_path, exist_ok=True)
106
+ hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
107
+ filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
108
+ hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
109
+ filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
110
+
111
+ # download text encoder model
112
+ os.makedirs(text_encoder_model_path, exist_ok=True)
113
+ hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
114
+ filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
115
+ hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
116
+ filename="model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
117
+ hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
118
+ filename="special_tokens_map.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
119
+ hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
120
+ filename="tokenizer_config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
121
+ hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
122
+ filename="tokenizer.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
123
+
124
+ logger.info("Models downloaded")
125
+
126
+ dcae_checkpoint_path = dcae_model_path
127
+ vocoder_checkpoint_path = vocoder_model_path
128
+ ace_step_checkpoint_path = ace_step_model_path
129
+ text_encoder_checkpoint_path = text_encoder_model_path
130
+
131
+ self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
132
+ self.music_dcae.to(device).eval().to(self.dtype)
133
+
134
+ self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
135
+ self.ace_step_transformer.to(device).eval().to(self.dtype)
136
+
137
+ lang_segment = LangSegment()
138
+
139
+ lang_segment.setfilters([
140
+ 'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
141
+ 'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
142
+ 'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
143
+ 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
144
+ 'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
145
+ 'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
146
+ ])
147
+ self.lang_segment = lang_segment
148
+ self.lyric_tokenizer = VoiceBpeTokenizer()
149
+ text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path).eval()
150
+ text_encoder_model = text_encoder_model.to(device).to(self.dtype)
151
+ text_encoder_model.requires_grad_(False)
152
+ self.text_encoder_model = text_encoder_model
153
+ self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
154
+ self.loaded = True
155
+
156
+ # compile
157
+ # self.music_dcae = torch.compile(self.music_dcae)
158
+ # self.ace_step_transformer = torch.compile(self.ace_step_transformer)
159
+ # self.text_encoder_model = torch.compile(self.text_encoder_model)
160
+
161
+ def get_text_embeddings(self, texts, device, text_max_length=256):
162
+ inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
163
+ inputs = {key: value.to(device) for key, value in inputs.items()}
164
+ if self.text_encoder_model.device != device:
165
+ self.text_encoder_model.to(device)
166
+ with torch.no_grad():
167
+ outputs = self.text_encoder_model(**inputs)
168
+ last_hidden_states = outputs.last_hidden_state
169
+ attention_mask = inputs["attention_mask"]
170
+ return last_hidden_states, attention_mask
171
+
172
+ def get_text_embeddings_null(self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10):
173
+ inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
174
+ inputs = {key: value.to(device) for key, value in inputs.items()}
175
+ if self.text_encoder_model.device != device:
176
+ self.text_encoder_model.to(device)
177
+
178
+ def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
179
+ handlers = []
180
+
181
+ def hook(module, input, output):
182
+ output[:] *= tau
183
+ return output
184
+
185
+ for i in range(l_min, l_max):
186
+ handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
187
+ handlers.append(handler)
188
+
189
+ with torch.no_grad():
190
+ outputs = self.text_encoder_model(**inputs)
191
+ last_hidden_states = outputs.last_hidden_state
192
+
193
+ for hook in handlers:
194
+ hook.remove()
195
+
196
+ return last_hidden_states
197
+
198
+ last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
199
+ return last_hidden_states
200
+
201
+ def set_seeds(self, batch_size, manual_seeds=None):
202
+ seeds = None
203
+ if manual_seeds is not None:
204
+ if isinstance(manual_seeds, str):
205
+ if "," in manual_seeds:
206
+ seeds = list(map(int, manual_seeds.split(",")))
207
+ elif manual_seeds.isdigit():
208
+ seeds = int(manual_seeds)
209
+
210
+ random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
211
+ actual_seeds = []
212
+ for i in range(batch_size):
213
+ seed = None
214
+ if seeds is None:
215
+ seed = torch.randint(0, 2**32, (1,)).item()
216
+ if isinstance(seeds, int):
217
+ seed = seeds
218
+ if isinstance(seeds, list):
219
+ seed = seeds[i]
220
+ random_generators[i].manual_seed(seed)
221
+ actual_seeds.append(seed)
222
+ return random_generators, actual_seeds
223
+
224
+ def get_lang(self, text):
225
+ language = "en"
226
+ try:
227
+ _ = self.lang_segment.getTexts(text)
228
+ langCounts = self.lang_segment.getCounts()
229
+ language = langCounts[0][0]
230
+ if len(langCounts) > 1 and language == "en":
231
+ language = langCounts[1][0]
232
+ except Exception as err:
233
+ language = "en"
234
+ return language
235
+
236
+ def tokenize_lyrics(self, lyrics, debug=False):
237
+ lines = lyrics.split("\n")
238
+ lyric_token_idx = [261]
239
+ for line in lines:
240
+ line = line.strip()
241
+ if not line:
242
+ lyric_token_idx += [2]
243
+ continue
244
+
245
+ lang = self.get_lang(line)
246
+
247
+ if lang not in SUPPORT_LANGUAGES:
248
+ lang = "en"
249
+ if "zh" in lang:
250
+ lang = "zh"
251
+ if "spa" in lang:
252
+ lang = "es"
253
+
254
+ try:
255
+ if structure_pattern.match(line):
256
+ token_idx = self.lyric_tokenizer.encode(line, "en")
257
+ else:
258
+ token_idx = self.lyric_tokenizer.encode(line, lang)
259
+ if debug:
260
+ toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
261
+ logger.info(f"debbug {line} --> {lang} --> {toks}")
262
+ lyric_token_idx = lyric_token_idx + token_idx + [2]
263
+ except Exception as e:
264
+ print("tokenize error", e, "for line", line, "major_language", lang)
265
+ return lyric_token_idx
266
+
267
+ @torch.no_grad()
268
+ def text2music_diffusion_process(
269
+ self,
270
+ duration,
271
+ encoder_text_hidden_states,
272
+ text_attention_mask,
273
+ speaker_embds,
274
+ lyric_token_ids,
275
+ lyric_mask,
276
+ random_generators=None,
277
+ infer_steps=60,
278
+ guidance_scale=15.0,
279
+ omega_scale=10.0,
280
+ scheduler_type="euler",
281
+ cfg_type="apg",
282
+ zero_steps=1,
283
+ use_zero_init=True,
284
+ guidance_interval=0.5,
285
+ guidance_interval_decay=1.0,
286
+ min_guidance_scale=3.0,
287
+ oss_steps=[],
288
+ encoder_text_hidden_states_null=None,
289
+ use_erg_lyric=False,
290
+ use_erg_diffusion=False,
291
+ retake_random_generators=None,
292
+ retake_variance=0.5,
293
+ add_retake_noise=False,
294
+ guidance_scale_text=0.0,
295
+ guidance_scale_lyric=0.0,
296
+ ):
297
+
298
+ logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
299
+ do_classifier_free_guidance = True
300
+ if guidance_scale == 0.0 or guidance_scale == 1.0:
301
+ do_classifier_free_guidance = False
302
+
303
+ do_double_condition_guidance = False
304
+ if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
305
+ do_double_condition_guidance = True
306
+ logger.info("do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(do_double_condition_guidance, guidance_scale_text, guidance_scale_lyric))
307
+
308
+ device = encoder_text_hidden_states.device
309
+ dtype = encoder_text_hidden_states.dtype
310
+ bsz = encoder_text_hidden_states.shape[0]
311
+
312
+ if scheduler_type == "euler":
313
+ scheduler = FlowMatchEulerDiscreteScheduler(
314
+ num_train_timesteps=1000,
315
+ shift=3.0,
316
+ )
317
+ elif scheduler_type == "heun":
318
+ scheduler = FlowMatchHeunDiscreteScheduler(
319
+ num_train_timesteps=1000,
320
+ shift=3.0,
321
+ )
322
+ frame_length = int(duration * 44100 / 512 / 8)
323
+
324
+ if len(oss_steps) > 0:
325
+ infer_steps = max(oss_steps)
326
+ scheduler.set_timesteps
327
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
328
+ new_timesteps = torch.zeros(len(oss_steps), dtype=dtype, device=device)
329
+ for idx in range(len(oss_steps)):
330
+ new_timesteps[idx] = timesteps[oss_steps[idx]-1]
331
+ num_inference_steps = len(oss_steps)
332
+ sigmas = (new_timesteps / 1000).float().cpu().numpy()
333
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=num_inference_steps, device=device, sigmas=sigmas)
334
+ logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
335
+ else:
336
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
337
+
338
+ target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
339
+ if add_retake_noise:
340
+ retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
341
+ retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
342
+ # to make sure mean = 0, std = 1
343
+ target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
344
+
345
+ attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
346
+
347
+ # guidance interval逻辑
348
+ start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
349
+ end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
350
+ logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
351
+
352
+ momentum_buffer = MomentumBuffer()
353
+
354
+ def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
355
+ handlers = []
356
+
357
+ def hook(module, input, output):
358
+ output[:] *= tau
359
+ return output
360
+
361
+ for i in range(l_min, l_max):
362
+ handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
363
+ handlers.append(handler)
364
+
365
+ encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
366
+
367
+ for hook in handlers:
368
+ hook.remove()
369
+
370
+ return encoder_hidden_states
371
+
372
+ # P(speaker, text, lyric)
373
+ encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
374
+ encoder_text_hidden_states,
375
+ text_attention_mask,
376
+ speaker_embds,
377
+ lyric_token_ids,
378
+ lyric_mask,
379
+ )
380
+
381
+ if use_erg_lyric:
382
+ # P(null_speaker, text_weaker, lyric_weaker)
383
+ encoder_hidden_states_null = forward_encoder_with_temperature(
384
+ self,
385
+ inputs={
386
+ "encoder_text_hidden_states": encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states),
387
+ "text_attention_mask": text_attention_mask,
388
+ "speaker_embeds": torch.zeros_like(speaker_embds),
389
+ "lyric_token_idx": lyric_token_ids,
390
+ "lyric_mask": lyric_mask,
391
+ }
392
+ )
393
+ else:
394
+ # P(null_speaker, null_text, null_lyric)
395
+ encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
396
+ torch.zeros_like(encoder_text_hidden_states),
397
+ text_attention_mask,
398
+ torch.zeros_like(speaker_embds),
399
+ torch.zeros_like(lyric_token_ids),
400
+ lyric_mask,
401
+ )
402
+
403
+ encoder_hidden_states_no_lyric = None
404
+ if do_double_condition_guidance:
405
+ # P(null_speaker, text, lyric_weaker)
406
+ if use_erg_lyric:
407
+ encoder_hidden_states_no_lyric = forward_encoder_with_temperature(
408
+ self,
409
+ inputs={
410
+ "encoder_text_hidden_states": encoder_text_hidden_states,
411
+ "text_attention_mask": text_attention_mask,
412
+ "speaker_embeds": torch.zeros_like(speaker_embds),
413
+ "lyric_token_idx": lyric_token_ids,
414
+ "lyric_mask": lyric_mask,
415
+ }
416
+ )
417
+ # P(null_speaker, text, no_lyric)
418
+ else:
419
+ encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
420
+ encoder_text_hidden_states,
421
+ text_attention_mask,
422
+ torch.zeros_like(speaker_embds),
423
+ torch.zeros_like(lyric_token_ids),
424
+ lyric_mask,
425
+ )
426
+
427
+ def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
428
+ handlers = []
429
+
430
+ def hook(module, input, output):
431
+ output[:] *= tau
432
+ return output
433
+
434
+ for i in range(l_min, l_max):
435
+ handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
436
+ handlers.append(handler)
437
+ handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
438
+ handlers.append(handler)
439
+
440
+ sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
441
+
442
+ for hook in handlers:
443
+ hook.remove()
444
+
445
+ return sample
446
+
447
+
448
+ for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
449
+ # expand the latents if we are doing classifier free guidance
450
+ latents = target_latents
451
+
452
+ is_in_guidance_interval = start_idx <= i < end_idx
453
+ if is_in_guidance_interval and do_classifier_free_guidance:
454
+ # compute current guidance scale
455
+ if guidance_interval_decay > 0:
456
+ # Linearly interpolate to calculate the current guidance scale
457
+ progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1]
458
+ current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
459
+ else:
460
+ current_guidance_scale = guidance_scale
461
+
462
+ latent_model_input = latents
463
+ timestep = t.expand(latent_model_input.shape[0])
464
+ output_length = latent_model_input.shape[-1]
465
+ # P(x|speaker, text, lyric)
466
+ noise_pred_with_cond = self.ace_step_transformer.decode(
467
+ hidden_states=latent_model_input,
468
+ attention_mask=attention_mask,
469
+ encoder_hidden_states=encoder_hidden_states,
470
+ encoder_hidden_mask=encoder_hidden_mask,
471
+ output_length=output_length,
472
+ timestep=timestep,
473
+ ).sample
474
+
475
+ noise_pred_with_only_text_cond = None
476
+ if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
477
+ noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
478
+ hidden_states=latent_model_input,
479
+ attention_mask=attention_mask,
480
+ encoder_hidden_states=encoder_hidden_states_no_lyric,
481
+ encoder_hidden_mask=encoder_hidden_mask,
482
+ output_length=output_length,
483
+ timestep=timestep,
484
+ ).sample
485
+
486
+ if use_erg_diffusion:
487
+ noise_pred_uncond = forward_diffusion_with_temperature(
488
+ self,
489
+ hidden_states=latent_model_input,
490
+ timestep=timestep,
491
+ inputs={
492
+ "encoder_hidden_states": encoder_hidden_states_null,
493
+ "encoder_hidden_mask": encoder_hidden_mask,
494
+ "output_length": output_length,
495
+ "attention_mask": attention_mask,
496
+ },
497
+ )
498
+ else:
499
+ noise_pred_uncond = self.ace_step_transformer.decode(
500
+ hidden_states=latent_model_input,
501
+ attention_mask=attention_mask,
502
+ encoder_hidden_states=encoder_hidden_states_null,
503
+ encoder_hidden_mask=encoder_hidden_mask,
504
+ output_length=output_length,
505
+ timestep=timestep,
506
+ ).sample
507
+
508
+ if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
509
+ noise_pred = cfg_double_condition_forward(
510
+ cond_output=noise_pred_with_cond,
511
+ uncond_output=noise_pred_uncond,
512
+ only_text_cond_output=noise_pred_with_only_text_cond,
513
+ guidance_scale_text=guidance_scale_text,
514
+ guidance_scale_lyric=guidance_scale_lyric,
515
+ )
516
+
517
+ elif cfg_type == "apg":
518
+ noise_pred = apg_forward(
519
+ pred_cond=noise_pred_with_cond,
520
+ pred_uncond=noise_pred_uncond,
521
+ guidance_scale=current_guidance_scale,
522
+ momentum_buffer=momentum_buffer,
523
+ )
524
+ elif cfg_type == "cfg":
525
+ noise_pred = cfg_forward(
526
+ cond_output=noise_pred_with_cond,
527
+ uncond_output=noise_pred_uncond,
528
+ cfg_strength=current_guidance_scale,
529
+ )
530
+ elif cfg_type == "cfg_star":
531
+ noise_pred = cfg_zero_star(
532
+ noise_pred_with_cond=noise_pred_with_cond,
533
+ noise_pred_uncond=noise_pred_uncond,
534
+ guidance_scale=current_guidance_scale,
535
+ i=i,
536
+ zero_steps=zero_steps,
537
+ use_zero_init=use_zero_init
538
+ )
539
+ else:
540
+ latent_model_input = latents
541
+ timestep = t.expand(latent_model_input.shape[0])
542
+ noise_pred = self.ace_step_transformer.decode(
543
+ hidden_states=latent_model_input,
544
+ attention_mask=attention_mask,
545
+ encoder_hidden_states=encoder_hidden_states,
546
+ encoder_hidden_mask=encoder_hidden_mask,
547
+ output_length=latent_model_input.shape[-1],
548
+ timestep=timestep,
549
+ ).sample
550
+
551
+ target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
552
+
553
+ return target_latents
554
+
555
+ def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
556
+ output_audio_paths = []
557
+ bs = latents.shape[0]
558
+ audio_lengths = [target_wav_duration_second * sample_rate] * bs
559
+ pred_latents = latents
560
+ with torch.no_grad():
561
+ _, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
562
+ pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
563
+ for i in tqdm(range(bs)):
564
+ output_audio_path = self.save_wav_file(pred_wavs[i], i, sample_rate=sample_rate)
565
+ output_audio_paths.append(output_audio_path)
566
+ return output_audio_paths
567
+
568
+ def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="flac"):
569
+ if save_path is None:
570
+ logger.warning("save_path is None, using default path ./outputs/")
571
+ base_path = f"./outputs/"
572
+ ensure_directory_exists(base_path)
573
+ else:
574
+ base_path = save_path
575
+ ensure_directory_exists(base_path)
576
+
577
+ output_path_flac = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
578
+ target_wav = target_wav.float()
579
+ torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format=format)
580
+ return output_path_flac
581
+
582
+ def __call__(
583
+ self,
584
+ audio_duration: float = 60.0,
585
+ prompt: str = None,
586
+ lyrics: str = None,
587
+ infer_step: int = 60,
588
+ guidance_scale: float = 15.0,
589
+ scheduler_type: str = "euler",
590
+ cfg_type: str = "apg",
591
+ omega_scale: int = 10.0,
592
+ manual_seeds: list = None,
593
+ guidance_interval: float = 0.5,
594
+ guidance_interval_decay: float = 0.,
595
+ min_guidance_scale: float = 3.0,
596
+ use_erg_tag: bool = True,
597
+ use_erg_lyric: bool = True,
598
+ use_erg_diffusion: bool = True,
599
+ oss_steps: str = None,
600
+ guidance_scale_text: float = 0.0,
601
+ guidance_scale_lyric: float = 0.0,
602
+ retake_seeds: list = None,
603
+ retake_variance: float = 0.5,
604
+ task: str = "text2music",
605
+ save_path: str = None,
606
+ format: str = "flac",
607
+ batch_size: int = 1,
608
+ ):
609
+
610
+ start_time = time.time()
611
+
612
+ if not self.loaded:
613
+ logger.warning("Checkpoint not loaded, loading checkpoint...")
614
+ self.load_checkpoint(self.checkpoint_dir)
615
+ load_model_cost = time.time() - start_time
616
+ logger.info(f"Model loaded in {load_model_cost:.2f} seconds.")
617
+
618
+ start_time = time.time()
619
+
620
+ random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
621
+ retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds)
622
+
623
+ if isinstance(oss_steps, str) and len(oss_steps) > 0:
624
+ oss_steps = list(map(int, oss_steps.split(",")))
625
+ else:
626
+ oss_steps = []
627
+
628
+ texts = [prompt]
629
+ encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
630
+ encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
631
+ text_attention_mask = text_attention_mask.repeat(batch_size, 1)
632
+
633
+ encoder_text_hidden_states_null = None
634
+ if use_erg_tag:
635
+ encoder_text_hidden_states_null = self.get_text_embeddings_null(texts, self.device)
636
+ encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1)
637
+
638
+ # not support for released checkpoint
639
+ speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
640
+
641
+ # 6 lyric
642
+ lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
643
+ lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
644
+ if len(lyrics) > 0:
645
+ lyric_token_idx = self.tokenize_lyrics(lyrics, debug=True)
646
+ lyric_mask = [1] * len(lyric_token_idx)
647
+ lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
648
+ lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
649
+
650
+ if audio_duration <= 0:
651
+ audio_duration = random.uniform(30.0, 240.0)
652
+ logger.info(f"random audio duration: {audio_duration}")
653
+
654
+ end_time = time.time()
655
+ preprocess_time_cost = end_time - start_time
656
+ start_time = end_time
657
+
658
+ target_latents = self.text2music_diffusion_process(
659
+ duration=audio_duration,
660
+ encoder_text_hidden_states=encoder_text_hidden_states,
661
+ text_attention_mask=text_attention_mask,
662
+ speaker_embds=speaker_embeds,
663
+ lyric_token_ids=lyric_token_idx,
664
+ lyric_mask=lyric_mask,
665
+ guidance_scale=guidance_scale,
666
+ omega_scale=omega_scale,
667
+ infer_steps=infer_step,
668
+ random_generators=random_generators,
669
+ scheduler_type=scheduler_type,
670
+ cfg_type=cfg_type,
671
+ guidance_interval=guidance_interval,
672
+ guidance_interval_decay=guidance_interval_decay,
673
+ min_guidance_scale=min_guidance_scale,
674
+ oss_steps=oss_steps,
675
+ encoder_text_hidden_states_null=encoder_text_hidden_states_null,
676
+ use_erg_lyric=use_erg_lyric,
677
+ use_erg_diffusion=use_erg_diffusion,
678
+ retake_random_generators=retake_random_generators,
679
+ retake_variance=retake_variance,
680
+ add_retake_noise=task == "retake",
681
+ guidance_scale_text=guidance_scale_text,
682
+ guidance_scale_lyric=guidance_scale_lyric,
683
+ )
684
+
685
+ end_time = time.time()
686
+ diffusion_time_cost = end_time - start_time
687
+ start_time = end_time
688
+
689
+ output_paths = self.latents2audio(
690
+ latents=target_latents,
691
+ target_wav_duration_second=audio_duration,
692
+ save_path=save_path,
693
+ format=format,
694
+ )
695
+
696
+ end_time = time.time()
697
+ latent2audio_time_cost = end_time - start_time
698
+ timecosts = {
699
+ "preprocess": preprocess_time_cost,
700
+ "diffusion": diffusion_time_cost,
701
+ "latent2audio": latent2audio_time_cost,
702
+ }
703
+
704
+ input_params_json = {
705
+ "task": task,
706
+ "prompt": prompt,
707
+ "lyrics": lyrics,
708
+ "audio_duration": audio_duration,
709
+ "infer_step": infer_step,
710
+ "guidance_scale": guidance_scale,
711
+ "scheduler_type": scheduler_type,
712
+ "cfg_type": cfg_type,
713
+ "omega_scale": omega_scale,
714
+ "guidance_interval": guidance_interval,
715
+ "guidance_interval_decay": guidance_interval_decay,
716
+ "min_guidance_scale": min_guidance_scale,
717
+ "use_erg_tag": use_erg_tag,
718
+ "use_erg_lyric": use_erg_lyric,
719
+ "use_erg_diffusion": use_erg_diffusion,
720
+ "oss_steps": oss_steps,
721
+ "timecosts": timecosts,
722
+ "actual_seeds": actual_seeds,
723
+ "retake_seeds": actual_retake_seeds,
724
+ "retake_variance": retake_variance,
725
+ "guidance_scale_text": guidance_scale_text,
726
+ "guidance_scale_lyric": guidance_scale_lyric,
727
+ }
728
+ # save input_params_json
729
+ for output_audio_path in output_paths:
730
+ input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json")
731
+ input_params_json["audio_path"] = output_audio_path
732
+ with open(input_params_json_save_path, "w", encoding="utf-8") as f:
733
+ json.dump(input_params_json, f, indent=4, ensure_ascii=False)
734
+
735
+ return output_paths + [input_params_json]
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==3.4.1
2
+ diffusers==0.32.2
3
+ gradio==5.23.3
4
+ librosa==0.11.0
5
+ loguru==0.7.3
6
+ matplotlib==3.10.1
7
+ numpy
8
+ pypinyin==0.53.0
9
+ pytorch_lightning==2.5.1
10
+ soundfile==0.13.1
11
+ torch
12
+ torchaudio
13
+ torchvision
14
+ tqdm==4.67.1
15
+ transformers==4.50.0
16
+ py3langid==0.3.0
17
+ hangul-romanize==0.1.0
18
+ num2words==0.5.14
19
+ spacy==3.8.4
20
+ accelerate==1.6.0
21
+ cutlet
22
+ fugashi[unidic-lite]
schedulers/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ use_dynamic_shifting=False,
70
+ base_shift: Optional[float] = 0.5,
71
+ max_shift: Optional[float] = 1.15,
72
+ base_image_seq_len: Optional[int] = 256,
73
+ max_image_seq_len: Optional[int] = 4096,
74
+ ):
75
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
76
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
77
+
78
+ sigmas = timesteps / num_train_timesteps
79
+ if not use_dynamic_shifting:
80
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
81
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
82
+
83
+ self.timesteps = sigmas * num_train_timesteps
84
+
85
+ self._step_index = None
86
+ self._begin_index = None
87
+
88
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
+ self.sigma_min = self.sigmas[-1].item()
90
+ self.sigma_max = self.sigmas[0].item()
91
+
92
+ @property
93
+ def step_index(self):
94
+ """
95
+ The index counter for current timestep. It will increase 1 after each scheduler step.
96
+ """
97
+ return self._step_index
98
+
99
+ @property
100
+ def begin_index(self):
101
+ """
102
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
103
+ """
104
+ return self._begin_index
105
+
106
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
107
+ def set_begin_index(self, begin_index: int = 0):
108
+ """
109
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
110
+
111
+ Args:
112
+ begin_index (`int`):
113
+ The begin index for the scheduler.
114
+ """
115
+ self._begin_index = begin_index
116
+
117
+ def scale_noise(
118
+ self,
119
+ sample: torch.FloatTensor,
120
+ timestep: Union[float, torch.FloatTensor],
121
+ noise: Optional[torch.FloatTensor] = None,
122
+ ) -> torch.FloatTensor:
123
+ """
124
+ Forward process in flow-matching
125
+
126
+ Args:
127
+ sample (`torch.FloatTensor`):
128
+ The input sample.
129
+ timestep (`int`, *optional*):
130
+ The current timestep in the diffusion chain.
131
+
132
+ Returns:
133
+ `torch.FloatTensor`:
134
+ A scaled input sample.
135
+ """
136
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
137
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
138
+
139
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
140
+ # mps does not support float64
141
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
142
+ timestep = timestep.to(sample.device, dtype=torch.float32)
143
+ else:
144
+ schedule_timesteps = self.timesteps.to(sample.device)
145
+ timestep = timestep.to(sample.device)
146
+
147
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
148
+ if self.begin_index is None:
149
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
150
+ elif self.step_index is not None:
151
+ # add_noise is called after first denoising step (for inpainting)
152
+ step_indices = [self.step_index] * timestep.shape[0]
153
+ else:
154
+ # add noise is called before first denoising step to create initial latent(img2img)
155
+ step_indices = [self.begin_index] * timestep.shape[0]
156
+
157
+ sigma = sigmas[step_indices].flatten()
158
+ while len(sigma.shape) < len(sample.shape):
159
+ sigma = sigma.unsqueeze(-1)
160
+
161
+ sample = sigma * noise + (1.0 - sigma) * sample
162
+
163
+ return sample
164
+
165
+ def _sigma_to_t(self, sigma):
166
+ return sigma * self.config.num_train_timesteps
167
+
168
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
169
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
170
+
171
+ def set_timesteps(
172
+ self,
173
+ num_inference_steps: int = None,
174
+ device: Union[str, torch.device] = None,
175
+ sigmas: Optional[List[float]] = None,
176
+ mu: Optional[float] = None,
177
+ ):
178
+ """
179
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
180
+
181
+ Args:
182
+ num_inference_steps (`int`):
183
+ The number of diffusion steps used when generating samples with a pre-trained model.
184
+ device (`str` or `torch.device`, *optional*):
185
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
186
+ """
187
+
188
+ if self.config.use_dynamic_shifting and mu is None:
189
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
190
+
191
+ if sigmas is None:
192
+ self.num_inference_steps = num_inference_steps
193
+ timesteps = np.linspace(
194
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
195
+ )
196
+
197
+ sigmas = timesteps / self.config.num_train_timesteps
198
+
199
+ if self.config.use_dynamic_shifting:
200
+ sigmas = self.time_shift(mu, 1.0, sigmas)
201
+ else:
202
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
203
+
204
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
205
+ timesteps = sigmas * self.config.num_train_timesteps
206
+
207
+ self.timesteps = timesteps.to(device=device)
208
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
209
+
210
+ self._step_index = None
211
+ self._begin_index = None
212
+
213
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
214
+ if schedule_timesteps is None:
215
+ schedule_timesteps = self.timesteps
216
+
217
+ indices = (schedule_timesteps == timestep).nonzero()
218
+
219
+ # The sigma index that is taken for the **very** first `step`
220
+ # is always the second index (or the last index if there is only 1)
221
+ # This way we can ensure we don't accidentally skip a sigma in
222
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
223
+ pos = 1 if len(indices) > 1 else 0
224
+
225
+ return indices[pos].item()
226
+
227
+ def _init_step_index(self, timestep):
228
+ if self.begin_index is None:
229
+ if isinstance(timestep, torch.Tensor):
230
+ timestep = timestep.to(self.timesteps.device)
231
+ self._step_index = self.index_for_timestep(timestep)
232
+ else:
233
+ self._step_index = self._begin_index
234
+
235
+ def step(
236
+ self,
237
+ model_output: torch.FloatTensor,
238
+ timestep: Union[float, torch.FloatTensor],
239
+ sample: torch.FloatTensor,
240
+ s_churn: float = 0.0,
241
+ s_tmin: float = 0.0,
242
+ s_tmax: float = float("inf"),
243
+ s_noise: float = 1.0,
244
+ generator: Optional[torch.Generator] = None,
245
+ return_dict: bool = True,
246
+ omega: Union[float, np.array] = 0.0
247
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
248
+ """
249
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
250
+ process from the learned model outputs (most often the predicted noise).
251
+
252
+ Args:
253
+ model_output (`torch.FloatTensor`):
254
+ The direct output from learned diffusion model.
255
+ timestep (`float`):
256
+ The current discrete timestep in the diffusion chain.
257
+ sample (`torch.FloatTensor`):
258
+ A current instance of a sample created by the diffusion process.
259
+ s_churn (`float`):
260
+ s_tmin (`float`):
261
+ s_tmax (`float`):
262
+ s_noise (`float`, defaults to 1.0):
263
+ Scaling factor for noise added to the sample.
264
+ generator (`torch.Generator`, *optional*):
265
+ A random number generator.
266
+ return_dict (`bool`):
267
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
268
+ tuple.
269
+
270
+ Returns:
271
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
272
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
273
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
274
+ """
275
+
276
+ def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
277
+ # L = Lower bound
278
+ # U = Upper bound
279
+ # x_0 = Midpoint (x corresponding to y = 1.0)
280
+ # k = Steepness, can adjust based on preference
281
+
282
+ if isinstance(x, torch.Tensor):
283
+ device_ = x.device
284
+ x = x.to(torch.float).cpu().numpy()
285
+
286
+ new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
287
+
288
+ if isinstance(new_x, np.ndarray):
289
+ new_x = torch.from_numpy(new_x).to(device_)
290
+ return new_x
291
+
292
+ self.omega_bef_rescale = omega
293
+ omega = logistic_function(omega, k=0.1)
294
+ self.omega_aft_rescale = omega
295
+
296
+ if (
297
+ isinstance(timestep, int)
298
+ or isinstance(timestep, torch.IntTensor)
299
+ or isinstance(timestep, torch.LongTensor)
300
+ ):
301
+ raise ValueError(
302
+ (
303
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
304
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
305
+ " one of the `scheduler.timesteps` as a timestep."
306
+ ),
307
+ )
308
+
309
+ if self.step_index is None:
310
+ self._init_step_index(timestep)
311
+
312
+ # Upcast to avoid precision issues when computing prev_sample
313
+ sample = sample.to(torch.float32)
314
+
315
+ sigma = self.sigmas[self.step_index]
316
+ sigma_next = self.sigmas[self.step_index + 1]
317
+
318
+ ## --
319
+ ## mean shift 1
320
+ dx = (sigma_next - sigma) * model_output
321
+ m = dx.mean()
322
+ # print(dx.shape) # torch.Size([1, 16, 128, 128])
323
+ # print(f'm: {m}') # m: -0.0014209747314453125
324
+ # raise NotImplementedError
325
+ dx_ = (dx - m) * omega + m
326
+ prev_sample = sample + dx_
327
+
328
+ # ## --
329
+ # ## mean shift 2
330
+ # m = model_output.mean()
331
+ # model_output_ = (model_output - m) * omega + m
332
+ # prev_sample = sample + (sigma_next - sigma) * model_output_
333
+
334
+ # ## --
335
+ # ## original
336
+ # prev_sample = sample + (sigma_next - sigma) * model_output * omega
337
+
338
+ # ## --
339
+ # ## spatial mean 1
340
+ # dx = (sigma_next - sigma) * model_output
341
+ # m = dx.mean(dim=(0, 1), keepdim=True)
342
+ # # print(dx.shape) # torch.Size([1, 16, 128, 128])
343
+ # # print(m.shape) # torch.Size([1, 1, 128, 128])
344
+ # # raise NotImplementedError
345
+ # dx_ = (dx - m) * omega + m
346
+ # prev_sample = sample + dx_
347
+
348
+ # ## --
349
+ # ## spatial mean 2
350
+ # m = model_output.mean(dim=(0, 1), keepdim=True)
351
+ # model_output_ = (model_output - m) * omega + m
352
+ # prev_sample = sample + (sigma_next - sigma) * model_output_
353
+
354
+ # ## --
355
+ # ## channel mean 1
356
+ # m = model_output.mean(dim=(2, 3), keepdim=True)
357
+ # # print(m.shape) # torch.Size([1, 16, 1, 1])
358
+ # model_output_ = (model_output - m) * omega + m
359
+ # prev_sample = sample + (sigma_next - sigma) * model_output_
360
+
361
+ # ## --
362
+ # ## channel mean 2
363
+ # dx = (sigma_next - sigma) * model_output
364
+ # m = dx.mean(dim=(2, 3), keepdim=True)
365
+ # # print(m.shape) # torch.Size([1, 16, 1, 1])
366
+ # dx_ = (dx - m) * omega + m
367
+ # prev_sample = sample + dx_
368
+
369
+ # ## --
370
+ # ## keep sample mean
371
+ # m_tgt = sample.mean()
372
+ # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
373
+ # m_src = prev_sample_.mean()
374
+ # prev_sample = prev_sample_ - m_src + m_tgt
375
+
376
+ # ## --
377
+ # ## test
378
+ # # print(sample.mean())
379
+ # prev_sample = sample + (sigma_next - sigma) * model_output * omega
380
+ # # raise NotImplementedError
381
+
382
+ # Cast sample back to model compatible dtype
383
+ prev_sample = prev_sample.to(model_output.dtype)
384
+
385
+ # upon completion increase step index by one
386
+ self._step_index += 1
387
+
388
+ if not return_dict:
389
+ return (prev_sample,)
390
+
391
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
392
+
393
+ def __len__(self):
394
+ return self.config.num_train_timesteps
schedulers/scheduling_flow_match_heun_discrete.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.utils import BaseOutput, logging
23
+ from diffusers.utils.torch_utils import randn_tensor
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Heun scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 2
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ ):
70
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
71
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
72
+
73
+ sigmas = timesteps / num_train_timesteps
74
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
75
+
76
+ self.timesteps = sigmas * num_train_timesteps
77
+
78
+ self._step_index = None
79
+ self._begin_index = None
80
+
81
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
82
+ self.sigma_min = self.sigmas[-1].item()
83
+ self.sigma_max = self.sigmas[0].item()
84
+
85
+ @property
86
+ def step_index(self):
87
+ """
88
+ The index counter for current timestep. It will increase 1 after each scheduler step.
89
+ """
90
+ return self._step_index
91
+
92
+ @property
93
+ def begin_index(self):
94
+ """
95
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
96
+ """
97
+ return self._begin_index
98
+
99
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
100
+ def set_begin_index(self, begin_index: int = 0):
101
+ """
102
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
103
+
104
+ Args:
105
+ begin_index (`int`):
106
+ The begin index for the scheduler.
107
+ """
108
+ self._begin_index = begin_index
109
+
110
+ def scale_noise(
111
+ self,
112
+ sample: torch.FloatTensor,
113
+ timestep: Union[float, torch.FloatTensor],
114
+ noise: Optional[torch.FloatTensor] = None,
115
+ ) -> torch.FloatTensor:
116
+ """
117
+ Forward process in flow-matching
118
+
119
+ Args:
120
+ sample (`torch.FloatTensor`):
121
+ The input sample.
122
+ timestep (`int`, *optional*):
123
+ The current timestep in the diffusion chain.
124
+
125
+ Returns:
126
+ `torch.FloatTensor`:
127
+ A scaled input sample.
128
+ """
129
+ if self.step_index is None:
130
+ self._init_step_index(timestep)
131
+
132
+ sigma = self.sigmas[self.step_index]
133
+ sample = sigma * noise + (1.0 - sigma) * sample
134
+
135
+ return sample
136
+
137
+ def _sigma_to_t(self, sigma):
138
+ return sigma * self.config.num_train_timesteps
139
+
140
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
141
+ """
142
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
143
+
144
+ Args:
145
+ num_inference_steps (`int`):
146
+ The number of diffusion steps used when generating samples with a pre-trained model.
147
+ device (`str` or `torch.device`, *optional*):
148
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
149
+ """
150
+ self.num_inference_steps = num_inference_steps
151
+
152
+ timesteps = np.linspace(
153
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
154
+ )
155
+
156
+ sigmas = timesteps / self.config.num_train_timesteps
157
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
158
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
159
+
160
+ timesteps = sigmas * self.config.num_train_timesteps
161
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
162
+ self.timesteps = timesteps.to(device=device)
163
+
164
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
165
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
166
+
167
+ # empty dt and derivative
168
+ self.prev_derivative = None
169
+ self.dt = None
170
+
171
+ self._step_index = None
172
+ self._begin_index = None
173
+
174
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
175
+ if schedule_timesteps is None:
176
+ schedule_timesteps = self.timesteps
177
+
178
+ indices = (schedule_timesteps == timestep).nonzero()
179
+
180
+ # The sigma index that is taken for the **very** first `step`
181
+ # is always the second index (or the last index if there is only 1)
182
+ # This way we can ensure we don't accidentally skip a sigma in
183
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
184
+ pos = 1 if len(indices) > 1 else 0
185
+
186
+ return indices[pos].item()
187
+
188
+ def _init_step_index(self, timestep):
189
+ if self.begin_index is None:
190
+ if isinstance(timestep, torch.Tensor):
191
+ timestep = timestep.to(self.timesteps.device)
192
+ self._step_index = self.index_for_timestep(timestep)
193
+ else:
194
+ self._step_index = self._begin_index
195
+
196
+ @property
197
+ def state_in_first_order(self):
198
+ return self.dt is None
199
+
200
+ def step(
201
+ self,
202
+ model_output: torch.FloatTensor,
203
+ timestep: Union[float, torch.FloatTensor],
204
+ sample: torch.FloatTensor,
205
+ s_churn: float = 0.0,
206
+ s_tmin: float = 0.0,
207
+ s_tmax: float = float("inf"),
208
+ s_noise: float = 1.0,
209
+ generator: Optional[torch.Generator] = None,
210
+ return_dict: bool = True,
211
+ omega: Union[float, np.array] = 0.0
212
+ ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
213
+ """
214
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
215
+ process from the learned model outputs (most often the predicted noise).
216
+
217
+ Args:
218
+ model_output (`torch.FloatTensor`):
219
+ The direct output from learned diffusion model.
220
+ timestep (`float`):
221
+ The current discrete timestep in the diffusion chain.
222
+ sample (`torch.FloatTensor`):
223
+ A current instance of a sample created by the diffusion process.
224
+ s_churn (`float`):
225
+ s_tmin (`float`):
226
+ s_tmax (`float`):
227
+ s_noise (`float`, defaults to 1.0):
228
+ Scaling factor for noise added to the sample.
229
+ generator (`torch.Generator`, *optional*):
230
+ A random number generator.
231
+ return_dict (`bool`):
232
+ Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
233
+ tuple.
234
+
235
+ Returns:
236
+ [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
237
+ If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
238
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
239
+ """
240
+
241
+ def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
242
+ # L = Lower bound
243
+ # U = Upper bound
244
+ # x_0 = Midpoint (x corresponding to y = 1.0)
245
+ # k = Steepness, can adjust based on preference
246
+
247
+ if isinstance(x, torch.Tensor):
248
+ device_ = x.device
249
+ x = x.to(torch.float).cpu().numpy()
250
+
251
+ new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
252
+
253
+ if isinstance(new_x, np.ndarray):
254
+ new_x = torch.from_numpy(new_x).to(device_)
255
+ return new_x
256
+
257
+ self.omega_bef_rescale = omega
258
+ omega = logistic_function(omega, k=0.1)
259
+ self.omega_aft_rescale = omega
260
+
261
+ if (
262
+ isinstance(timestep, int)
263
+ or isinstance(timestep, torch.IntTensor)
264
+ or isinstance(timestep, torch.LongTensor)
265
+ ):
266
+ raise ValueError(
267
+ (
268
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
269
+ " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
270
+ " one of the `scheduler.timesteps` as a timestep."
271
+ ),
272
+ )
273
+
274
+ if self.step_index is None:
275
+ self._init_step_index(timestep)
276
+
277
+ # Upcast to avoid precision issues when computing prev_sample
278
+ sample = sample.to(torch.float32)
279
+
280
+ if self.state_in_first_order:
281
+ sigma = self.sigmas[self.step_index]
282
+ sigma_next = self.sigmas[self.step_index + 1]
283
+ else:
284
+ # 2nd order / Heun's method
285
+ sigma = self.sigmas[self.step_index - 1]
286
+ sigma_next = self.sigmas[self.step_index]
287
+
288
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
289
+
290
+ sigma_hat = sigma * (gamma + 1)
291
+
292
+ if gamma > 0:
293
+ noise = randn_tensor(
294
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
295
+ )
296
+ eps = noise * s_noise
297
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
298
+
299
+ if self.state_in_first_order:
300
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
301
+ denoised = sample - model_output * sigma
302
+ # 2. convert to an ODE derivative for 1st order
303
+ derivative = (sample - denoised) / sigma_hat
304
+ # 3. Delta timestep
305
+ dt = sigma_next - sigma_hat
306
+
307
+ # store for 2nd order step
308
+ self.prev_derivative = derivative
309
+ self.dt = dt
310
+ self.sample = sample
311
+ else:
312
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
313
+ denoised = sample - model_output * sigma_next
314
+ # 2. 2nd order / Heun's method
315
+ derivative = (sample - denoised) / sigma_next
316
+ derivative = 0.5 * (self.prev_derivative + derivative)
317
+
318
+ # 3. take prev timestep & sample
319
+ dt = self.dt
320
+ sample = self.sample
321
+
322
+ # free dt and derivative
323
+ # Note, this puts the scheduler in "first order mode"
324
+ self.prev_derivative = None
325
+ self.dt = None
326
+ self.sample = None
327
+
328
+ # original sample way
329
+ # prev_sample = sample + derivative * dt
330
+
331
+ dx = derivative * dt
332
+ m = dx.mean()
333
+ dx_ = (dx - m) * omega + m
334
+ prev_sample = sample + dx_
335
+
336
+ # Cast sample back to model compatible dtype
337
+ prev_sample = prev_sample.to(model_output.dtype)
338
+
339
+ # upon completion increase step index by one
340
+ self._step_index += 1
341
+
342
+ if not return_dict:
343
+ return (prev_sample,)
344
+
345
+ return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
346
+
347
+ def __len__(self):
348
+ return self.config.num_train_timesteps
ui/components.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ TAG_PLACEHOLDER = "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic"
5
+ LYRIC_PLACEHOLDER = """[verse]
6
+ Neon lights they flicker bright
7
+ City hums in dead of night
8
+ Rhythms pulse through concrete veins
9
+ Lost in echoes of refrains
10
+
11
+ [verse]
12
+ Bassline groovin' in my chest
13
+ Heartbeats match the city's zest
14
+ Electric whispers fill the air
15
+ Synthesized dreams everywhere
16
+
17
+ [chorus]
18
+ Turn it up and let it flow
19
+ Feel the fire let it grow
20
+ In this rhythm we belong
21
+ Hear the night sing out our song
22
+
23
+ [verse]
24
+ Guitar strings they start to weep
25
+ Wake the soul from silent sleep
26
+ Every note a story told
27
+ In this night we’re bold and gold
28
+
29
+ [bridge]
30
+ Voices blend in harmony
31
+ Lost in pure cacophony
32
+ Timeless echoes timeless cries
33
+ Soulful shouts beneath the skies
34
+
35
+ [verse]
36
+ Keyboard dances on the keys
37
+ Melodies on evening breeze
38
+ Catch the tune and hold it tight
39
+ In this moment we take flight
40
+ """
41
+
42
+
43
+ def create_output_ui(task_name="Text2Music"):
44
+ # For many consumer-grade GPU devices, only one batch can be run
45
+ output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1")
46
+ # output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
47
+ with gr.Accordion(f"{task_name} Parameters", open=False):
48
+ input_params_json = gr.JSON(label=f"{task_name} Parameters")
49
+ # outputs = [output_audio1, output_audio2]
50
+ outputs = [output_audio1]
51
+ return outputs, input_params_json
52
+
53
+
54
+ def dump_func(*args):
55
+ print(args)
56
+ return []
57
+
58
+
59
+ def create_text2music_ui(
60
+ gr,
61
+ text2music_process_func,
62
+ sample_data_func=None,
63
+ ):
64
+ with gr.Row():
65
+ with gr.Column():
66
+
67
+ with gr.Row(equal_height=True):
68
+ audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=180, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
69
+ sample_bnt = gr.Button("Sample", variant="primary", scale=1)
70
+
71
+ prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.")
72
+ lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
73
+
74
+ with gr.Accordion("Basic Settings", open=True):
75
+ infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
76
+ guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
77
+ guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=5.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
78
+ guidance_scale_lyric = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.5, label="Guidance Scale Lyric", interactive=True)
79
+
80
+ manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None, info="Seed for the generation")
81
+
82
+ with gr.Accordion("Advanced Settings", open=False):
83
+ scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type", info="Scheduler type for the generation. euler is recommended. heun will take more time.")
84
+ cfg_type = gr.Radio(["cfg", "apg", "cfg_star"], value="apg", label="CFG Type", elem_id="cfg_type", info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.")
85
+ use_erg_tag = gr.Checkbox(label="use ERG for tag", value=True, info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.")
86
+ use_erg_lyric = gr.Checkbox(label="use ERG for lyric", value=True, info="The same but apply to lyric encoder's attention.")
87
+ use_erg_diffusion = gr.Checkbox(label="use ERG for diffusion", value=True, info="The same but apply to diffusion model's attention.")
88
+
89
+ omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True, info="Granularity scale for the generation. Higher values can reduce artifacts")
90
+
91
+ guidance_interval = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Guidance Interval", interactive=True, info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)")
92
+ guidance_interval_decay = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Guidance Interval Decay", interactive=True, info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.")
93
+ min_guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=3.0, label="Min Guidance Scale", interactive=True, info="Min guidance scale for guidance interval decay's end scale")
94
+ oss_steps = gr.Textbox(label="OSS Steps", placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200", value=None, info="Optimal Steps for the generation. But not test well")
95
+
96
+ text2music_bnt = gr.Button(variant="primary")
97
+
98
+ with gr.Column():
99
+ outputs, input_params_json = create_output_ui()
100
+ with gr.Tab("retake"):
101
+ retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance", info="Variance for the retake. 0.0 means no variance. 1.0 means full variance.")
102
+ retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None, info="Seed for the retake.")
103
+ retake_bnt = gr.Button(variant="primary")
104
+ retake_outputs, retake_input_params_json = create_output_ui("Retake")
105
+
106
+ def retake_process_func(json_data, retake_variance, retake_seeds):
107
+ return text2music_process_func(
108
+ json_data["audio_duration"],
109
+ json_data["prompt"],
110
+ json_data["lyrics"],
111
+ json_data["infer_step"],
112
+ json_data["guidance_scale"],
113
+ json_data["scheduler_type"],
114
+ json_data["cfg_type"],
115
+ json_data["omega_scale"],
116
+ ", ".join(map(str, json_data["actual_seeds"])),
117
+ json_data["guidance_interval"],
118
+ json_data["guidance_interval_decay"],
119
+ json_data["min_guidance_scale"],
120
+ json_data["use_erg_tag"],
121
+ json_data["use_erg_lyric"],
122
+ json_data["use_erg_diffusion"],
123
+ ", ".join(map(str, json_data["oss_steps"])),
124
+ json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
125
+ json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
126
+ retake_seeds=retake_seeds,
127
+ retake_variance=retake_variance,
128
+ task="retake",
129
+ )
130
+
131
+ retake_bnt.click(
132
+ fn=retake_process_func,
133
+ inputs=[
134
+ input_params_json,
135
+ retake_variance,
136
+ retake_seeds,
137
+ ],
138
+ outputs=retake_outputs + [retake_input_params_json],
139
+ )
140
+ with gr.Tab("repainting"):
141
+ pass
142
+ with gr.Tab("edit"):
143
+ pass
144
+
145
+ def sample_data():
146
+ json_data = sample_data_func()
147
+ return (
148
+ json_data["audio_duration"],
149
+ json_data["prompt"],
150
+ json_data["lyrics"],
151
+ json_data["infer_step"],
152
+ json_data["guidance_scale"],
153
+ json_data["scheduler_type"],
154
+ json_data["cfg_type"],
155
+ json_data["omega_scale"],
156
+ ", ".join(map(str, json_data["actual_seeds"])),
157
+ json_data["guidance_interval"],
158
+ json_data["guidance_interval_decay"],
159
+ json_data["min_guidance_scale"],
160
+ json_data["use_erg_tag"],
161
+ json_data["use_erg_lyric"],
162
+ json_data["use_erg_diffusion"],
163
+ ", ".join(map(str, json_data["oss_steps"])),
164
+ json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
165
+ json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
166
+ )
167
+
168
+ sample_bnt.click(
169
+ sample_data,
170
+ outputs=[
171
+ audio_duration,
172
+ prompt,
173
+ lyrics,
174
+ infer_step,
175
+ guidance_scale,
176
+ scheduler_type,
177
+ cfg_type,
178
+ omega_scale,
179
+ manual_seeds,
180
+ guidance_interval,
181
+ guidance_interval_decay,
182
+ min_guidance_scale,
183
+ use_erg_tag,
184
+ use_erg_lyric,
185
+ use_erg_diffusion,
186
+ oss_steps,
187
+ guidance_scale_text,
188
+ guidance_scale_lyric,
189
+ ],
190
+ )
191
+
192
+ text2music_bnt.click(
193
+ fn=text2music_process_func,
194
+ inputs=[
195
+ audio_duration,
196
+ prompt,
197
+ lyrics,
198
+ infer_step,
199
+ guidance_scale,
200
+ scheduler_type,
201
+ cfg_type,
202
+ omega_scale,
203
+ manual_seeds,
204
+ guidance_interval,
205
+ guidance_interval_decay,
206
+ min_guidance_scale,
207
+ use_erg_tag,
208
+ use_erg_lyric,
209
+ use_erg_diffusion,
210
+ oss_steps,
211
+ guidance_scale_text,
212
+ guidance_scale_lyric,
213
+ ], outputs=outputs + [input_params_json]
214
+ )
215
+
216
+
217
+ def create_main_demo_ui(
218
+ text2music_process_func=dump_func,
219
+ sample_data_func=dump_func,
220
+ ):
221
+ with gr.Blocks(
222
+ title="FusicModel 1.0 DEMO",
223
+ ) as demo:
224
+ gr.Markdown(
225
+ """
226
+ <h1 style="text-align: center;">FusicModel 1.0 DEMO</h1>
227
+ """
228
+ )
229
+
230
+ with gr.Tab("text2music"):
231
+ create_text2music_ui(
232
+ gr=gr,
233
+ text2music_process_func=text2music_process_func,
234
+ sample_data_func=sample_data_func,
235
+ )
236
+ return demo
237
+
238
+
239
+ if __name__ == "__main__":
240
+ demo = create_main_demo_ui()
241
+ demo.launch(
242
+ server_name="0.0.0.0",
243
+ server_port=7860,
244
+ )