Spaces:
Running
Running
[feat] v1 commit
Browse files- .gitignore +164 -0
- apg_guidance.py +90 -0
- app.py +39 -0
- data_sampler.py +23 -0
- examples/input_params/output_20250426071706_0_input_params.json +25 -0
- examples/input_params/output_20250426071812_0_input_params.json +25 -0
- examples/input_params/output_20250426072346_0_input_params.json +25 -0
- examples/input_params/output_20250426072508_0_input_params.json +25 -0
- examples/input_params/output_20250426073829_0_input_params.json +25 -0
- examples/input_params/output_20250426074037_0_input_params.json +25 -0
- examples/input_params/output_20250426074214_0_input_params.json +25 -0
- examples/input_params/output_20250426074413_0_input_params.json +25 -0
- examples/input_params/output_20250426075107_0_input_params.json +25 -0
- examples/input_params/output_20250426075537_0_input_params.json +25 -0
- examples/input_params/output_20250426075843_0_input_params.json +25 -0
- examples/input_params/output_20250426080234_0_input_params.json +25 -0
- examples/input_params/output_20250426080407_0_input_params.json +25 -0
- examples/input_params/output_20250426080601_0_input_params.json +25 -0
- examples/input_params/output_20250426081134_0_input_params.json +25 -0
- examples/input_params/output_20250426091716_0_input_params.json +25 -0
- examples/input_params/output_20250426092025_0_input_params.json +25 -0
- examples/input_params/output_20250426093007_0_input_params.json +25 -0
- examples/input_params/output_20250426093146_0_input_params.json +25 -0
- language_segmentation/LangSegment.py +866 -0
- language_segmentation/__init__.py +9 -0
- language_segmentation/utils/__init__.py +0 -0
- language_segmentation/utils/num.py +327 -0
- models/ace_step_transformer.py +475 -0
- models/attention.py +319 -0
- models/config.json +23 -0
- models/customer_attention_processor.py +339 -0
- models/lyrics_utils/lyric_encoder.py +1070 -0
- models/lyrics_utils/lyric_normalizer.py +66 -0
- models/lyrics_utils/lyric_tokenizer.py +883 -0
- models/lyrics_utils/vocab.json +0 -0
- models/lyrics_utils/zh_num2words.py +1209 -0
- music_dcae/__init__.py +0 -0
- music_dcae/music_dcae_pipeline.py +150 -0
- music_dcae/music_log_mel.py +107 -0
- music_dcae/music_vocoder.py +576 -0
- packages.txt +1 -0
- pipeline_ace_step.py +735 -0
- requirements.txt +22 -0
- schedulers/scheduling_flow_match_euler_discrete.py +394 -0
- schedulers/scheduling_flow_match_heun_discrete.py +348 -0
- ui/components.py +244 -0
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
checkpoints/
|
apg_guidance.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class MomentumBuffer:
|
5 |
+
def __init__(self, momentum: float = -0.75):
|
6 |
+
self.momentum = momentum
|
7 |
+
self.running_average = 0
|
8 |
+
|
9 |
+
def update(self, update_value: torch.Tensor):
|
10 |
+
new_average = self.momentum * self.running_average
|
11 |
+
self.running_average = update_value + new_average
|
12 |
+
|
13 |
+
|
14 |
+
def project(
|
15 |
+
v0: torch.Tensor, # [B, C, H, W]
|
16 |
+
v1: torch.Tensor, # [B, C, H, W]
|
17 |
+
dims=[-1, -2],
|
18 |
+
):
|
19 |
+
dtype = v0.dtype
|
20 |
+
v0, v1 = v0.double(), v1.double()
|
21 |
+
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
22 |
+
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
23 |
+
v0_orthogonal = v0 - v0_parallel
|
24 |
+
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
25 |
+
|
26 |
+
|
27 |
+
def apg_forward(
|
28 |
+
pred_cond: torch.Tensor, # [B, C, H, W]
|
29 |
+
pred_uncond: torch.Tensor, # [B, C, H, W]
|
30 |
+
guidance_scale: float,
|
31 |
+
momentum_buffer: MomentumBuffer = None,
|
32 |
+
eta: float = 0.0,
|
33 |
+
norm_threshold: float = 2.5,
|
34 |
+
dims=[-1, -2],
|
35 |
+
):
|
36 |
+
diff = pred_cond - pred_uncond
|
37 |
+
if momentum_buffer is not None:
|
38 |
+
momentum_buffer.update(diff)
|
39 |
+
diff = momentum_buffer.running_average
|
40 |
+
|
41 |
+
if norm_threshold > 0:
|
42 |
+
ones = torch.ones_like(diff)
|
43 |
+
diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
|
44 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
45 |
+
diff = diff * scale_factor
|
46 |
+
|
47 |
+
diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
|
48 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
49 |
+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
50 |
+
return pred_guided
|
51 |
+
|
52 |
+
|
53 |
+
def cfg_forward(cond_output, uncond_output, cfg_strength):
|
54 |
+
return uncond_output + cfg_strength * (cond_output - uncond_output)
|
55 |
+
|
56 |
+
def cfg_double_condition_forward(
|
57 |
+
cond_output,
|
58 |
+
uncond_output,
|
59 |
+
only_text_cond_output,
|
60 |
+
guidance_scale_text,
|
61 |
+
guidance_scale_lyric,
|
62 |
+
):
|
63 |
+
return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output
|
64 |
+
|
65 |
+
|
66 |
+
def optimized_scale(positive_flat, negative_flat):
|
67 |
+
|
68 |
+
# Calculate dot production
|
69 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
70 |
+
|
71 |
+
# Squared norm of uncondition
|
72 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
73 |
+
|
74 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
75 |
+
st_star = dot_product / squared_norm
|
76 |
+
|
77 |
+
return st_star
|
78 |
+
|
79 |
+
|
80 |
+
def cfg_zero_star(noise_pred_with_cond, noise_pred_uncond, guidance_scale, i, zero_steps=1, use_zero_init=True):
|
81 |
+
bsz = noise_pred_with_cond.shape[0]
|
82 |
+
positive_flat = noise_pred_with_cond.view(bsz, -1)
|
83 |
+
negative_flat = noise_pred_uncond.view(bsz, -1)
|
84 |
+
alpha = optimized_scale(positive_flat, negative_flat)
|
85 |
+
alpha = alpha.view(bsz, 1, 1, 1)
|
86 |
+
if (i <= zero_steps) and use_zero_init:
|
87 |
+
noise_pred = noise_pred_with_cond * 0.
|
88 |
+
else:
|
89 |
+
noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
|
90 |
+
return noise_pred
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from ui.components import create_main_demo_ui
|
3 |
+
from pipeline_ace_step import ACEStepPipeline
|
4 |
+
from data_sampler import DataSampler
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--checkpoint_path", type=str, default=None)
|
10 |
+
parser.add_argument("--port", type=int, default=7860)
|
11 |
+
parser.add_argument("--device_id", type=int, default=0)
|
12 |
+
parser.add_argument("--share", action='store_true', default=False)
|
13 |
+
parser.add_argument("--bf16", action='store_true', default=True)
|
14 |
+
|
15 |
+
args = parser.parse_args()
|
16 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
17 |
+
|
18 |
+
|
19 |
+
persistent_storage_path = "/data"
|
20 |
+
|
21 |
+
|
22 |
+
def main(args):
|
23 |
+
|
24 |
+
model_demo = ACEStepPipeline(
|
25 |
+
checkpoint_dir=args.checkpoint_path,
|
26 |
+
dtype="bfloat16" if args.bf16 else "float32",
|
27 |
+
persistent_storage_path=persistent_storage_path
|
28 |
+
)
|
29 |
+
data_sampler = DataSampler()
|
30 |
+
|
31 |
+
demo = create_main_demo_ui(
|
32 |
+
text2music_process_func=model_demo.__call__,
|
33 |
+
sample_data_func=data_sampler.sample,
|
34 |
+
)
|
35 |
+
demo.launch()
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
main(args)
|
data_sampler.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
DEFAULT_ROOT_DIR = "examples/input_params"
|
7 |
+
|
8 |
+
|
9 |
+
class DataSampler:
|
10 |
+
def __init__(self, root_dir=DEFAULT_ROOT_DIR):
|
11 |
+
self.root_dir = root_dir
|
12 |
+
|
13 |
+
# glob
|
14 |
+
self.input_params_files = list(Path(self.root_dir).glob("*.json"))
|
15 |
+
|
16 |
+
def load_json(self, file_path):
|
17 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
18 |
+
return json.load(f)
|
19 |
+
|
20 |
+
def sample(self):
|
21 |
+
json_path = random.choice(self.input_params_files)
|
22 |
+
json_data = self.load_json(json_path)
|
23 |
+
return json_data
|
examples/input_params/output_20250426071706_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "pop, rap, electronic, blues, hip-house, rhythm and blues",
|
3 |
+
"lyrics": "[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意",
|
4 |
+
"audio_duration": 170.63997916666668,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 3.191075086593628,
|
19 |
+
"diffusion": 17.459356784820557,
|
20 |
+
"latent2audio": 1.7095518112182617
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
3299954530
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426071812_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "country rock, folk rock, southern rock, bluegrass, country pop",
|
3 |
+
"lyrics": "[verse]\nWoke up to the sunrise glow\nTook my heart and hit the road\nWheels hummin' the only tune I know\nStraight to where the wildflowers grow\n\n[verse]\nGot that old map all wrinkled and torn\nDestination unknown but I'm reborn\nWith a smile that the wind has worn\nChasin' dreams that can't be sworn\n\n[chorus]\nRidin' on a highway to sunshine\nGot my shades and my radio on fine\nLeave the shadows in the rearview rhyme\nHeart's racing as we chase the time\n\n[verse]\nMet a girl with a heart of gold\nTold stories that never get old\nHer laugh like a tale that's been told\nA melody so bold yet uncontrolled\n\n[bridge]\nClouds roll by like silent ghosts\nAs we drive along the coast\nWe toast to the days we love the most\nFreedom's song is what we post\n\n[chorus]\nRidin' on a highway to sunshine\nGot my shades and my radio on fine\nLeave the shadows in the rearview rhyme\nHeart's racing as we chase the time",
|
4 |
+
"audio_duration": 224.23997916666667,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 4.262240648269653,
|
19 |
+
"diffusion": 15.380569219589233,
|
20 |
+
"latent2audio": 2.3227272033691406
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
401640
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426072346_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "hip-house, funk",
|
3 |
+
"lyrics": "[verse]\n哎呀跳起来,脚尖踩节拍 (oo-yeah!)\n灯光闪烁像星星盛开 (uh-huh!)\n人人都醒来,把烦恼踹开 (get it!)\n热血沸腾,汗水自己安排\n\n[chorus]\n嘿,你还等啥?快抓住节拍 (come on!)\n光芒指引,让心都不存在 (whoa!)\n点燃热火,我们一起飙high (let’s go!)\n跳入午夜的狂欢时代\n\n[bridge]\n咚咚鼓声啊,让你的灵魂起飞 (woo!)\n手心拍一拍,能量翻倍 (ah-hah!)\n键盘响起来,如宇宙的交汇 (oh yeah!)\n就是这感觉,兄弟姐妹都陶醉\n\n[verse]\n灵魂从不睡,只想继续燃烧 (woo!)\n节奏像热浪,席卷这街道 (ow!)\n大伙儿涌上楼台,满面微笑 (yeah!)\n这一刻属于我们,无可替代\n\n[chorus]\n嘿,你还等啥?快抓住节拍 (come on!)\n光芒指引,让心都不存在 (whoa!)\n点燃热火,我们一起飙high (let’s go!)\n跳入午夜的狂欢时代\n\n[verse]\n世界多精彩,握紧把它打开 (alright!)\n每一步都像星球在摇摆 (uh-huh!)\n无边无际的律动像大海 (oo-yeah!)\n跟着光芒之舞,一起澎湃",
|
4 |
+
"audio_duration": 204.19997916666668,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.05196118354797363,
|
19 |
+
"diffusion": 15.530808210372925,
|
20 |
+
"latent2audio": 2.5604095458984375
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
401640
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426072508_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic",
|
3 |
+
"lyrics": "[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin' in my chest\nHeartbeats match the city's zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song\n\n[verse]\nGuitar strings they start to weep\nWake the soul from silent sleep\nEvery note a story told\nIn this night we’re bold and gold\n\n[bridge]\nVoices blend in harmony\nLost in pure cacophony\nTimeless echoes timeless cries\nSoulful shouts beneath the skies\n\n[verse]\nKeyboard dances on the keys\nMelodies on evening breeze\nCatch the tune and hold it tight\nIn this moment we take flight",
|
4 |
+
"audio_duration": 178.87997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.02882218360900879,
|
19 |
+
"diffusion": 16.91233205795288,
|
20 |
+
"latent2audio": 1.7794082164764404
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
401640
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426073829_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "electronic rap",
|
3 |
+
"lyrics": "[verse]\nWaves on the bass, pulsing in the speakers,\nTurn the dial up, we chasing six-figure features,\nGrinding on the beats, codes in the creases,\nDigital hustler, midnight in sneakers.\n\n[chorus]\nElectro vibes, hearts beat with the hum,\nUrban legends ride, we ain't ever numb,\nCircuits sparking live, tapping on the drum,\nLiving on the edge, never succumb.\n\n[verse]\nSynthesizers blaze, city lights a glow,\nRhythm in the haze, moving with the flow,\nSwagger on stage, energy to blow,\nFrom the blocks to the booth, you already know.\n\n[bridge]\nNight's electric, streets full of dreams,\nBass hits collective, bursting at seams,\nHustle perspective, all in the schemes,\nRise and reflective, ain't no in-betweens.\n\n[verse]\nVibin' with the crew, sync in the wire,\nGot the dance moves, fire in the attire,\nRhythm and blues, soul's our supplier,\nRun the digital zoo, higher and higher.\n\n[chorus]\nElectro vibes, hearts beat with the hum,\nUrban legends ride, we ain't ever numb,\nCircuits sparking live, tapping on the drum,\nLiving on the edge, never succumb.",
|
4 |
+
"audio_duration": 221.42547916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.024875164031982422,
|
19 |
+
"diffusion": 20.566852569580078,
|
20 |
+
"latent2audio": 2.2281734943389893
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
401640
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426074037_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "electronic, house, electro house, synthesizer, drums, bass, percussion, fast, energetic, uplifting, exciting",
|
3 |
+
"lyrics": "[verse]\n霓虹灯下我们追逐\n人群跃动像潮水满布\n热浪袭来吹散孤独\n跳进节奏不如停下脚步\n\n[pre-chorus]\n脚尖触电快点感受\n迎着风声释放自由\n心跳节拍配合节奏\n一切烦恼请靠边游\n\n[chorus]\n夏夜狂奔没有尽头\n星光闪烁舞池不朽\n尽情挥洒所有节奏\n无边热情把你包裹哦\n\n[verse]\n天空翻滚黑云入夜\n每颗星星像音乐律贴\n耳边回响那低音线\n环绕耳际如梦境般甜\n\n[pre-chorus]\n脚尖触电快点感受\n迎着风声释放自由\n心跳节拍配合节奏\n一切烦恼请靠边游\n\n[chorus]\n夏夜狂奔没有尽头\n星光闪烁舞池不朽\n尽情挥洒所有节奏\n无边热情把你包裹哦",
|
4 |
+
"audio_duration": 221.47997916666668,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.028400182723999023,
|
19 |
+
"diffusion": 13.195815324783325,
|
20 |
+
"latent2audio": 2.1679723262786865
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
3440445703
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426074214_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "synth-pop, electronic, pop, synthesizer, drums, bass, piano, 128 BPM, energetic, uplifting, modern",
|
3 |
+
"lyrics": "[verse]\nWoke up in a city that's always alive\nNeon lights they shimmer they thrive\nElectric pulses beat they drive\nMy heart races just to survive\n\n[chorus]\nOh electric dreams they keep me high\nThrough the wires I soar and fly\nMidnight rhythms in the sky\nElectric dreams together we’ll defy\n\n[verse]\nLost in the labyrinth of screens\nVirtual love or so it seems\nIn the night the city gleams\nDigital faces haunted by memes\n\n[chorus]\nOh electric dreams they keep me high\nThrough the wires I soar and fly\nMidnight rhythms in the sky\nElectric dreams together we’ll defy\n\n[bridge]\nSilent whispers in my ear\nPixelated love serene and clear\nThrough the chaos find you near\nIn electric dreams no fear\n\n[verse]\nBound by circuits intertwined\nLove like ours is hard to find\nIn this world we’re truly blind\nBut electric dreams free the mind",
|
4 |
+
"audio_duration": 221.27997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.025463581085205078,
|
19 |
+
"diffusion": 15.243804454803467,
|
20 |
+
"latent2audio": 2.170398473739624
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
3400270027
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426074413_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "Cuban music, salsa, son, Afro-Cuban, traditional Cuban",
|
3 |
+
"lyrics": "[verse]\nSun dips low the night ignites\nBassline hums with gleaming lights\nElectric guitar singing tales so fine\nIn the rhythm we all intertwine\n\n[verse]\nDrums beat steady calling out\nPercussion guides no room for doubt\nElectric pulse through every vein\nDance away every ounce of pain\n\n[chorus]\nFeel the rhythm feel the flow\nLet the music take control\nBassline deep electric hum\nIn this night we're never numb\n\n[bridge]\nStars above they start to glow\nEchoes of the night's soft glow\nElectric strings weave through the air\nIn this moment none compare\n\n[verse]\nHeartbeats sync with every tone\nLost in music never alone\nElectric tales of love and peace\nIn this groove we find release\n\n[chorus]\nFeel the rhythm feel the flow\nLet the music take control\nBassline deep electric hum\nIn this night we're never numb",
|
4 |
+
"audio_duration": 208.27997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.026132583618164062,
|
19 |
+
"diffusion": 15.139378070831299,
|
20 |
+
"latent2audio": 2.2071540355682373
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
3358899399
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426075107_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "pop, piano, rap, dark, atmospheric",
|
3 |
+
"lyrics": "[verse]\n月光爬上窗 染白冷的床\n心跳的方向 带我入迷惘\n黑夜吞噬光 命运的纸张\n爱是血色霜 邪恶又芬芳\n\n[chorus]\n你是猎人的欲望 我是迷途的小羊\n深陷你眼眸的荒 唐突献出心脏\n我在夜里回荡 是谁给我希望\n黑暗风中飘荡 假装不再受伤\n\n[verse]\n心锁在门外 谁会解开关怀\n温柔的手拍 藏着冷酷杀害\n思绪如尘埃 撞击爱的霹雳\n灵魂的独白 为你沾满血迹\n\n[bridge]\n你是噩梦的歌唱 是灵魂的捆绑\n绝望中带着光 悬崖边的渴望\n心跳被你鼓掌 恶魔也痴痴想\n渐渐没了抵抗 古老诡计流淌\n\n[chorus]\n你是猎人的欲望 我是迷途的小羊\n深陷你眼眸的荒 唐突献出心脏\n我在夜里回荡 是谁给我希望\n黑暗风中飘荡 假装不再受伤\n\n[outro]\n爱如月黑无光 渗进梦的战场\n逃入无声的场 放手或心嚷嚷\n隐秘的极端 爱是极致风浪\n灵魂彻底交偿 你是终极虚妄",
|
4 |
+
"audio_duration": 146.91997916666668,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.03876018524169922,
|
19 |
+
"diffusion": 15.962624549865723,
|
20 |
+
"latent2audio": 1.4594337940216064
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
2065110378
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426075537_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "surf music",
|
3 |
+
"lyrics": "[verse]\nSunshine on the boulevard the beach is calling loud\nWaves are dancing golden sand under a cotton cloud\nElectric heartbeat pounding fast the tide is on our side\nCatch a wave and feel alive we’ll take it for a ride\n\n[verse]\nPalm trees swaying left to right they know where we belong\nFeel the rhythm of the night it keeps us moving strong\nSea spray kisses salty air we’re flying with the breeze\nChampagne states of mind we ride we do just as we please\n\n[chorus]\nWe’re riding waves of life together hand in hand\nWith every beat we chase the beat it’s our own wonderland\nFeel the music take you higher as the shorelines blur\nThis is our world our endless summer as we live and learn\n\n[bridge]\nMoonlight paints the ocean blue reflections in our eyes\nStars align to light our path we’re surfing through the skies\nEvery moment like a song we sing it loud and clear\nEvery day’s a new adventure with you always near\n\n[verse]\nNeon lights and city sounds they blend with ocean views\nWe’re unstoppable tonight no way that we can lose\nDreams are written in the sand they sparkle in the sun\nTogether we’re a masterpiece our story’s just begun\n\n[chorus]\nWe’re riding waves of life together hand in hand\nWith every beat we chase the beat it’s our own wonderland\nFeel the music take you higher as the shorelines blur\nThis is our world our endless summer as we live and learn",
|
4 |
+
"audio_duration": 236.55997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.033666133880615234,
|
19 |
+
"diffusion": 16.291455507278442,
|
20 |
+
"latent2audio": 2.3726775646209717
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
508630535
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426075843_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "alternative rock, pop, rock",
|
3 |
+
"lyrics": "[verse]\nBright lights flashing in the city sky\nRunning fast and we don't know why\nElectric nights got our hearts on fire\nChasing dreams we'll never tire\n\n[verse]\nGrit in our eyes wind in our hair\nBreaking rules we don't even care\nShouting loud above the crowd\nLiving life like we're unbowed\n\n[chorus]\nRunning wild in the night so free\nFeel the beat pumping endlessly\nHearts collide in the midnight air\nWe belong we don't have a care\n\n[verse]\nPiercing through like a lightning strike\nEvery moment feels like a hike\nDaring bold never backing down\nKings and queens without a crown\n\n[chorus]\nRunning wild in the night so free\nFeel the beat pumping endlessly\nHearts collide in the midnight air\nWe belong we don't have a care\n\n[bridge]\nClose your eyes let your spirit soar\nWe are the ones who wanted more\nBreaking chains of the mundane\nIn this world we'll make our claim",
|
4 |
+
"audio_duration": 202.19997916666668,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.02512216567993164,
|
19 |
+
"diffusion": 18.860822677612305,
|
20 |
+
"latent2audio": 2.0361969470977783
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
1255121549
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426080234_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "rock, hip - hop, orchestral, bass, drums, electric guitar, piano, synthesizer, violin, viola, cello, fast, energetic, motivational, inspirational, empowering",
|
3 |
+
"lyrics": "### **[Intro – Spoken]** \n*\"The streets whisper, their echoes never fade. \nEvery step I take leaves a mark—this ain't just a game.\"* \n\n### **[Hook/Chorus]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Verse 1]** \nCold nights, empty pockets, dreams laced with fight, \nEvery loss made me sharper, cut deep like a knife. \nThey said I wouldn’t make it, now they watch in despair, \nFrom the curb to the throne, took the pain, made it rare. \nEvery siren’s a melody, every alley holds a tale, \nRose from the shadows, left my name on the trail. \nStreetlights flicker like warnings in the haze, \nBut I move like a phantom, unfazed by the blaze. \n\n### **[Hook/Chorus]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Verse 2]** \nBarbed wire fences couldn't lock in my mind, \nEvery cage they designed, I left broken behind. \nThey want control, but I’m destined to roam, \nWhere the lost find their voice, where the heart sets the tone. \nSteel and concrete, where the lessons run deep, \nEvery crack in the pavement tells a story of heat. \nBut I rise, undefeated, like a king with no throne, \nWriting scripts in the struggle, my legacy’s stone. \n\n### **[Bridge]** \nFeel the rhythm of the underground roar, \nEvery wound tells a story of the battles before. \nBlood, sweat, and echoes fill the cold midnight, \nBut we move with the fire—unshaken, upright. \n\n### **[Verse 3]** \nNo regrets, no retreat, this game has no pause, \nEvery step that I take is a win for the lost. \nI took lessons from hustlers, wisdom from pain, \nNow the echoes of struggle carve power in my name. \nThey built walls, but I walk through the cracks, \nTurned dirt into gold, never looked back. \nThrough the struggle we rise, through the fire we claim, \nThis is more than just music—it's life in the frame. \n\n### **[Hook/Chorus – Reprise]** \nBorn in the chaos, I weather the storm, \nRising from ashes where warriors are born. \nChains couldn't hold me, the system’s a maze, \nI rewrite the rules, set the city ablaze! \n\n### **[Outro – Spoken]** \n*\"The scars, the struggle, the grind—it’s all part of the rhythm. \nWe never break, we never fold. We rise.\"*",
|
4 |
+
"audio_duration": 153.95997916666667,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.04368758201599121,
|
19 |
+
"diffusion": 17.16369390487671,
|
20 |
+
"latent2audio": 1.5405471324920654
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
2659225017
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426080407_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "tango finlandés, campanas, disco, dark pop, electro, guitarra clásica, corridos tumba",
|
3 |
+
"lyrics": "[inst]",
|
4 |
+
"audio_duration": 162.79997916666667,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.011058568954467773,
|
19 |
+
"diffusion": 9.924944400787354,
|
20 |
+
"latent2audio": 1.6034839153289795
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
780297686
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426080601_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "Nightclubs, dance parties, workout playlists, radio broadcasts",
|
3 |
+
"lyrics": "Burning in motion, set me alight!\nEvery heartbeat turns into a fight!\nCaged in rhythm, chained in time!\nLove’s a battle— You're Mine! You're Mine!",
|
4 |
+
"audio_duration": 221.83997916666667,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.012485980987548828,
|
19 |
+
"diffusion": 14.345409154891968,
|
20 |
+
"latent2audio": 2.174558639526367
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
1318394052
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426081134_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "melancholic, world, sad, medieval, soulful",
|
3 |
+
"lyrics": "[Verse]\nIn a world so grand he roams the skies alone\nHis heart a heavy stone a tale untold\nWhispers of his past echo through the night\nA lonely dragon searching for the light\n\n[Verse 2]\nOnce a mighty force now he drifts in pain\nHis scales once shimmered now they're dark with shame\nCast out by his kin in shadows he does hide\nA haunting sorrow burns deep inside\n\n[Chorus]\nRoaming endless fields with no friend in sight\nHis roar a mournful cry beneath the moon's pale light\nTears fall like stars as he flies on his way\nA lonely dragon yearning for the break of day\n\n[Bridge]\nThe world turns cold the nights grow long\nIn his heart he carries an ancient song\nOf battles fought and love long gone\nA legend now but his soul is torn\n\n[Verse 3]\nHoping for a day he'll find a kindred soul\nTo share his pain and make him whole\nTill then he drifts a shadow in the sky\nA lonely dragon with tears in his eye\n\n[Chorus]\nRoaming endless fields with no friend in sight\nHis roar a mournful cry beneath the moon's pale light\nTears fall like stars as he flies on his way\nA lonely dragon yearning for the break of day",
|
4 |
+
"audio_duration": 239.99997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.029100656509399414,
|
19 |
+
"diffusion": 22.503791570663452,
|
20 |
+
"latent2audio": 2.3603708744049072
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
2166832218
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426091716_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "anime, cute female vocals, kawaii pop, j-pop, childish, piano, guitar, synthesizer, fast, happy, cheerful, lighthearted",
|
3 |
+
"lyrics": "[Chorus]\nねぇ、顔が赤いよ?\nどうしたの? 熱があるの?\nそれとも怒ってるの?\nねぇ、言ってよ!\n\nどうしてそんな目で見るの?\n私、悪いことした?\n何か間違えたの?\nお願い、やめて… 怖いから…\nだから、やめてよ…\n\n[Bridge]\n目を閉じて、くるっと背を向けて、\n何も見なかったフリするから、\n怒らないで… 許してよ…\n\n[Chorus]\nねぇ、顔が赤いよ?\nどうしたの? 熱があるの?\nそれとも怒ってるの?\nねぇ、言ってよ!\n\nどうしてそんな目で見るの?\n私、悪いことした?\n何か間違えたの?\nお願い、やめて… 怖いから…\nだから、やめてよ…\n\n[Bridge 2]\n待って、もし私が悪いなら、\nごめんなさいって言うから、\nアイスクリームあげるから、\nもう怒らないで?\n\nOoooh… 言ってよ!",
|
4 |
+
"audio_duration": 160,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.0282442569732666,
|
19 |
+
"diffusion": 12.104875326156616,
|
20 |
+
"latent2audio": 1.587641954421997
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
4028738662
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426092025_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "dark, death rock, metal, hardcore, electric guitar, powerful, bass, drums, 110 bpm, G major",
|
3 |
+
"lyrics": "[Verse]\nMy lovers betray me\nThe snake in my garden is hissing\nIn the air is the sweetness of roses\nAnd under my skin\nThere's a thorn\n\n[Verse 2]\nI should have known\nThat God sends his angel in shadows\nWith blood in his veins\nI watch the enemy\nGivin' me the hand of my savior\n\n[Chorus]\nAnd I can't love again\nWith the echo of your name in my head\nWith the demons in my bed\nWith the memories\nYour ghost\nI see it\n'Cause it comes to haunt me\nJust to taunt me\nIt comes to haunt me\nJust to taunt me\n\n[Verse 3]\nWith sugar and spice\nIt's hard to ignore the nostalgia\nWith the men on their knees\nAt the gates of my heart\nHow they beg me\n\n[Verse 4]\nThey say\n\"No one will ever love you\nThe way that I do\nNo one will ever touch you\nThe way that I do\"\n\n[Chorus]\nAnd I can't love again\nWith the echo of your name in my head\nWith the demons in my bed\nWith the memories\nYour ghost\nI see it\n'Cause it comes to haunt me\nJust to taunt me\nIt comes to haunt me\nJust to taunt me",
|
4 |
+
"audio_duration": 174.27997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 3.8372838497161865,
|
19 |
+
"diffusion": 13.039669275283813,
|
20 |
+
"latent2audio": 1.7923030853271484
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
4064916393
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426093007_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "aggressive, Heavy Riffs, Blast Beats, Satanic Black Metal",
|
3 |
+
"lyrics": "[verse]\nFloating through the galaxy on a midnight ride\nStars are dancing all around in cosmic tides\nFeel the pulse of space and time beneath our feet\nEvery beat a heartbeat in this endless suite\n\n[chorus]\nGalactic dreams under neon lights\nSailing through the velvet nights\nWe are echoes in a cosmic sea\nIn a universe where we are free\n\n[verse]\nPlanetary whispers in the sky tonight\nEvery constellation's got a secret sight\nDistant worlds and moons we have yet to see\nIn the void of space where we can just be\n\n[bridge]\nAsteroids and comets in a ballet they spin\nLost in the rhythm of where our dreams begin\nClose your eyes and let the synths take flight\nWe're voyagers on an electric night\n\n[verse]\nLet the piano keys unlock the stars above\nEvery chord a memory every note is love\nIn this synth symphony we find our grace\nDrifting forever in this boundless space\n\n[chorus]\nGalactic dreams under neon lights\nSailing through the velvet nights\nWe are echoes in a cosmic sea\nIn a universe where we are free",
|
4 |
+
"audio_duration": 181.99997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.025065898895263672,
|
19 |
+
"diffusion": 17.176705837249756,
|
20 |
+
"latent2audio": 1.8225171566009521
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
1132623236
|
24 |
+
]
|
25 |
+
}
|
examples/input_params/output_20250426093146_0_input_params.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"prompt": "r&b, soul, funk/soul",
|
3 |
+
"lyrics": "[verse]\nDancing through electric fires\nHeart is buzzing like live wires\nIn your arms I find desire\nFeel the beat as we get higher\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why\n\n[verse]\nWhisper secrets that make me blush\nUnder the neon city hush\nYour touch gives me such a rush\nTurn it up we're feeling lush\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why\n\n[bridge]\nThrough the lights and the smoky haze\nI see you in a thousand ways\nLove's a script and we’re the play\nTurn the page stay till we sway\n\n[chorus]\nElectric love in the night sky\nWe’re gonna soar baby you and I\nDrop the bass let the rhythm fly\nFeel the heat and don't ask why",
|
4 |
+
"audio_duration": 195.15997916666666,
|
5 |
+
"infer_step": 60,
|
6 |
+
"guidance_scale": 15,
|
7 |
+
"scheduler_type": "euler",
|
8 |
+
"cfg_type": "apg",
|
9 |
+
"omega_scale": 10,
|
10 |
+
"guidance_interval": 0.5,
|
11 |
+
"guidance_interval_decay": 0,
|
12 |
+
"min_guidance_scale": 3,
|
13 |
+
"use_erg_tag": true,
|
14 |
+
"use_erg_lyric": true,
|
15 |
+
"use_erg_diffusion": true,
|
16 |
+
"oss_steps": [],
|
17 |
+
"timecosts": {
|
18 |
+
"preprocess": 0.025553464889526367,
|
19 |
+
"diffusion": 18.250118494033813,
|
20 |
+
"latent2audio": 1.9400627613067627
|
21 |
+
},
|
22 |
+
"actual_seeds": [
|
23 |
+
2853131993
|
24 |
+
]
|
25 |
+
}
|
language_segmentation/LangSegment.py
ADDED
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file bundles language identification functions.
|
3 |
+
|
4 |
+
Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
|
5 |
+
|
6 |
+
Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
|
7 |
+
Based on research by Marco Lui and Tim Baldwin.
|
8 |
+
|
9 |
+
See LICENSE file for more info.
|
10 |
+
https://github.com/adbar/py3langid
|
11 |
+
|
12 |
+
Projects:
|
13 |
+
https://github.com/juntaosun/LangSegment
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
import sys
|
19 |
+
import numpy as np
|
20 |
+
from collections import Counter
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
# import langid
|
24 |
+
# import py3langid as langid
|
25 |
+
# pip install py3langid==0.2.2
|
26 |
+
|
27 |
+
# 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
|
28 |
+
# langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
|
29 |
+
# For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
|
30 |
+
from py3langid.langid import LanguageIdentifier, MODEL_FILE
|
31 |
+
|
32 |
+
# Digital processing
|
33 |
+
try:from .utils.num import num2str
|
34 |
+
except ImportError:
|
35 |
+
try:from utils.num import num2str
|
36 |
+
except ImportError as e:
|
37 |
+
raise e
|
38 |
+
|
39 |
+
# -----------------------------------
|
40 |
+
# 更新日志:新版本分词更加精准。
|
41 |
+
# Changelog: The new version of the word segmentation is more accurate.
|
42 |
+
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
43 |
+
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
44 |
+
# -----------------------------------
|
45 |
+
|
46 |
+
|
47 |
+
# Word segmentation function:
|
48 |
+
# automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
|
49 |
+
# making it more suitable for TTS processing.
|
50 |
+
# This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
|
51 |
+
# This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
|
52 |
+
|
53 |
+
#===========================================================================================================
|
54 |
+
#分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
|
55 |
+
#このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
|
56 |
+
#===========================================================================================================
|
57 |
+
#(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
|
58 |
+
#(2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
|
59 |
+
#この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
|
60 |
+
#===========================================================================================================
|
61 |
+
|
62 |
+
#===========================================================================================================
|
63 |
+
# 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
|
64 |
+
# 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
|
65 |
+
#===========================================================================================================
|
66 |
+
# (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
|
67 |
+
# (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
|
68 |
+
# 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
|
69 |
+
#===========================================================================================================
|
70 |
+
|
71 |
+
# ===========================================================================================================
|
72 |
+
# 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
|
73 |
+
# 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
|
74 |
+
# ===========================================================================================================
|
75 |
+
# (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
|
76 |
+
# (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
|
77 |
+
# 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
|
78 |
+
# ===========================================================================================================
|
79 |
+
|
80 |
+
|
81 |
+
# 手动分词标签规范:<语言标签>文本内容</语言标签>
|
82 |
+
# 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
|
83 |
+
# Manual word segmentation tag specification: <language tags> text content </language tags>
|
84 |
+
# 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
|
85 |
+
# ===========================================================================================================
|
86 |
+
# For manual word segmentation, labels need to appear in pairs, such as:
|
87 |
+
# 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
|
88 |
+
# 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
|
89 |
+
# Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
|
90 |
+
# ===========================================================================================================
|
91 |
+
|
92 |
+
|
93 |
+
# ===========================================================================================================
|
94 |
+
# 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
|
95 |
+
# 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
|
96 |
+
# 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
|
97 |
+
# Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
|
98 |
+
# ===========================================================================================================
|
99 |
+
# 中文实现:Chinese implementation:
|
100 |
+
# 【SSML】<number>=中文大写数字读法(单字)
|
101 |
+
# 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
|
102 |
+
# 【SSML】<currency>=按金额发音。
|
103 |
+
# 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
104 |
+
# ===========================================================================================================
|
105 |
+
class LangSSML:
|
106 |
+
|
107 |
+
def __init__(self):
|
108 |
+
# 纯数字
|
109 |
+
self._zh_numerals_number = {
|
110 |
+
'0': '零',
|
111 |
+
'1': '一',
|
112 |
+
'2': '二',
|
113 |
+
'3': '三',
|
114 |
+
'4': '四',
|
115 |
+
'5': '五',
|
116 |
+
'6': '六',
|
117 |
+
'7': '七',
|
118 |
+
'8': '八',
|
119 |
+
'9': '九'
|
120 |
+
}
|
121 |
+
|
122 |
+
# 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
|
123 |
+
# Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
|
124 |
+
def _format_chinese_data(self, date_str:str):
|
125 |
+
# 处理日期格式
|
126 |
+
input_date = date_str
|
127 |
+
if date_str is None or date_str.strip() == "":return ""
|
128 |
+
date_str = re.sub(r"[\/\._|年|月]","-",date_str)
|
129 |
+
date_str = re.sub(r"日",r"",date_str)
|
130 |
+
date_arrs = date_str.split(' ')
|
131 |
+
if len(date_arrs) == 1 and ":" in date_arrs[0]:
|
132 |
+
time_str = date_arrs[0]
|
133 |
+
date_arrs = []
|
134 |
+
else:
|
135 |
+
time_str = date_arrs[1] if len(date_arrs) >=2 else ""
|
136 |
+
def nonZero(num,cn,func=None):
|
137 |
+
if func is not None:num=func(num)
|
138 |
+
return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
|
139 |
+
f_number = self.to_chinese_number
|
140 |
+
f_currency = self.to_chinese_currency
|
141 |
+
# year, month, day
|
142 |
+
year_month_day = ""
|
143 |
+
if len(date_arrs) > 0:
|
144 |
+
year, month, day = "","",""
|
145 |
+
parts = date_arrs[0].split('-')
|
146 |
+
if len(parts) == 3: # 格式为 YYYY-MM-DD
|
147 |
+
year, month, day = parts
|
148 |
+
elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
|
149 |
+
if len(parts[0]) == 4: # 年-月
|
150 |
+
year, month = parts
|
151 |
+
else:month, day = parts # 月-日
|
152 |
+
elif len(parts[0]) > 0: # 仅有月-日或年
|
153 |
+
if len(parts[0]) == 4:
|
154 |
+
year = parts[0]
|
155 |
+
else:day = parts[0]
|
156 |
+
year,month,day = nonZero(year,"年",f_number),nonZero(month,"月",f_currency),nonZero(day,"日",f_currency)
|
157 |
+
year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
|
158 |
+
# hours, minutes, seconds
|
159 |
+
time_str = re.sub(r"[\/\.\-:_]",":",time_str)
|
160 |
+
time_arrs = time_str.split(":")
|
161 |
+
hours, minutes, seconds = "","",""
|
162 |
+
if len(time_arrs) == 3: # H/M/S
|
163 |
+
hours, minutes, seconds = time_arrs
|
164 |
+
elif len(time_arrs) == 2:# H/M
|
165 |
+
hours, minutes = time_arrs
|
166 |
+
elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}点' # H
|
167 |
+
if len(time_arrs) > 1:
|
168 |
+
hours, minutes, seconds = nonZero(hours,"点",f_currency),nonZero(minutes,"分",f_currency),nonZero(seconds,"秒",f_currency)
|
169 |
+
hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
|
170 |
+
output_date = f"{year_month_day}{hours_minutes_seconds}"
|
171 |
+
return output_date
|
172 |
+
|
173 |
+
# 【SSML】number=中文大写数字读法(单字)
|
174 |
+
# Chinese Numbers(single word)
|
175 |
+
def to_chinese_number(self, num:str):
|
176 |
+
pattern = r'(\d+)'
|
177 |
+
zh_numerals = self._zh_numerals_number
|
178 |
+
arrs = re.split(pattern, num)
|
179 |
+
output = ""
|
180 |
+
for item in arrs:
|
181 |
+
if re.match(pattern,item):
|
182 |
+
output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
|
183 |
+
else:output += item
|
184 |
+
output = output.replace(".","点")
|
185 |
+
return output
|
186 |
+
|
187 |
+
# 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
|
188 |
+
# Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
|
189 |
+
def to_chinese_telephone(self, num:str):
|
190 |
+
output = self.to_chinese_number(num.replace("+86","")) # zh +86
|
191 |
+
output = output.replace("一","幺")
|
192 |
+
return output
|
193 |
+
|
194 |
+
# 【SSML】currency=按金额发音。
|
195 |
+
# Digital processing from GPT_SoVITS num.py (thanks)
|
196 |
+
def to_chinese_currency(self, num:str):
|
197 |
+
pattern = r'(\d+)'
|
198 |
+
arrs = re.split(pattern, num)
|
199 |
+
output = ""
|
200 |
+
for item in arrs:
|
201 |
+
if re.match(pattern,item):
|
202 |
+
output += num2str(item)
|
203 |
+
else:output += item
|
204 |
+
output = output.replace(".","点")
|
205 |
+
return output
|
206 |
+
|
207 |
+
# 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
208 |
+
def to_chinese_date(self, num:str):
|
209 |
+
chinese_date = self._format_chinese_data(num)
|
210 |
+
return chinese_date
|
211 |
+
|
212 |
+
|
213 |
+
class LangSegment:
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
|
217 |
+
self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
|
218 |
+
|
219 |
+
self._text_cache = None
|
220 |
+
self._text_lasts = None
|
221 |
+
self._text_langs = None
|
222 |
+
self._lang_count = None
|
223 |
+
self._lang_eos = None
|
224 |
+
|
225 |
+
# 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
|
226 |
+
# Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
|
227 |
+
# <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
|
228 |
+
self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
|
229 |
+
|
230 |
+
# 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
|
231 |
+
# 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
|
232 |
+
# 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
|
233 |
+
# The language filter group function allows you to specify reserved languages.
|
234 |
+
# Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
|
235 |
+
# 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
|
236 |
+
|
237 |
+
# 系统默认过滤器。System default filter。(ISO 639-1 codes given)
|
238 |
+
# ----------------------------------------------------------------------------------------------------------------------------------
|
239 |
+
# "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
|
240 |
+
# "th"泰语=Thai
|
241 |
+
# ----------------------------------------------------------------------------------------------------------------------------------
|
242 |
+
self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
|
243 |
+
|
244 |
+
# 用户可自定义过滤器。User-defined filters
|
245 |
+
self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本
|
246 |
+
|
247 |
+
# 合并文本
|
248 |
+
self.isLangMerge = True
|
249 |
+
|
250 |
+
# 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
|
251 |
+
# 请使用API启用:self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
|
252 |
+
|
253 |
+
# 预览版功能,自动启用或禁用,无需设置
|
254 |
+
# Preview feature, automatically enabled or disabled, no settings required
|
255 |
+
self.EnablePreview = False
|
256 |
+
|
257 |
+
# 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
|
258 |
+
# In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
|
259 |
+
# 示例:您可以任意指定多种组合,进行过滤
|
260 |
+
# Example: You can specify any combination to filter
|
261 |
+
|
262 |
+
# 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
|
263 |
+
# 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
|
264 |
+
# 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
|
265 |
+
# Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
|
266 |
+
# Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
|
267 |
+
self.LangPriorityThreshold = 0.89
|
268 |
+
|
269 |
+
# Langfilters = ["zh"] # 按中文识别
|
270 |
+
# Langfilters = ["en"] # 按英文识别
|
271 |
+
# Langfilters = ["ja"] # 按日文识别
|
272 |
+
# Langfilters = ["ko"] # 按韩文识别
|
273 |
+
# Langfilters = ["zh_ja"] # 中日混合识别
|
274 |
+
# Langfilters = ["zh_en"] # 中英混合识别
|
275 |
+
# Langfilters = ["ja_en"] # 日英混合识别
|
276 |
+
# Langfilters = ["zh_ko"] # 中韩混合识别
|
277 |
+
# Langfilters = ["ja_ko"] # 日韩混合识别
|
278 |
+
# Langfilters = ["en_ko"] # 英韩混合识别
|
279 |
+
# Langfilters = ["zh_ja_en"] # 中日英混合识别
|
280 |
+
# Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
|
281 |
+
|
282 |
+
# 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
|
283 |
+
# より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
|
284 |
+
|
285 |
+
# 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
|
286 |
+
# 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
|
287 |
+
self.keepPinyin = False
|
288 |
+
|
289 |
+
# DEFINITION
|
290 |
+
self.PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
|
291 |
+
|
292 |
+
self.LangSSML = LangSSML()
|
293 |
+
|
294 |
+
def _clears(self):
|
295 |
+
self._text_cache = None
|
296 |
+
self._text_lasts = None
|
297 |
+
self._text_langs = None
|
298 |
+
self._text_waits = None
|
299 |
+
self._lang_count = None
|
300 |
+
self._lang_eos = None
|
301 |
+
|
302 |
+
def _is_english_word(self, word):
|
303 |
+
return bool(re.match(r'^[a-zA-Z]+$', word))
|
304 |
+
|
305 |
+
def _is_chinese(self, word):
|
306 |
+
for char in word:
|
307 |
+
if '\u4e00' <= char <= '\u9fff':
|
308 |
+
return True
|
309 |
+
return False
|
310 |
+
|
311 |
+
def _is_japanese_kana(self, word):
|
312 |
+
pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
|
313 |
+
matches = pattern.findall(word)
|
314 |
+
return len(matches) > 0
|
315 |
+
|
316 |
+
def _insert_english_uppercase(self, word):
|
317 |
+
modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
|
318 |
+
modified_text = modified_text.strip('-')
|
319 |
+
return modified_text + " "
|
320 |
+
|
321 |
+
def _split_camel_case(self, word):
|
322 |
+
return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
|
323 |
+
|
324 |
+
def _statistics(self, language, text):
|
325 |
+
# Language word statistics:
|
326 |
+
# Chinese characters usually occupy double bytes
|
327 |
+
if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
|
328 |
+
self._lang_count = defaultdict(int)
|
329 |
+
lang_count = self._lang_count
|
330 |
+
if not "|" in language:
|
331 |
+
lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
|
332 |
+
self._lang_count = lang_count
|
333 |
+
|
334 |
+
def _clear_text_number(self, text):
|
335 |
+
if text == "\n":return text,False # Keep Line Breaks
|
336 |
+
clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
|
337 |
+
is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
|
338 |
+
return clear_text,is_number
|
339 |
+
|
340 |
+
def _saveData(self, words,language:str,text:str,score:float,symbol=None):
|
341 |
+
# Pre-detection
|
342 |
+
clear_text , is_number = self._clear_text_number(text)
|
343 |
+
# Merge the same language and save the results
|
344 |
+
preData = words[-1] if len(words) > 0 else None
|
345 |
+
if symbol is not None:pass
|
346 |
+
elif preData is not None and preData["symbol"] is None:
|
347 |
+
if len(clear_text) == 0:language = preData["lang"]
|
348 |
+
elif is_number == True:language = preData["lang"]
|
349 |
+
_ , pre_is_number = self._clear_text_number(preData["text"])
|
350 |
+
if (preData["lang"] == language):
|
351 |
+
self._statistics(preData["lang"],text)
|
352 |
+
text = preData["text"] + text
|
353 |
+
preData["text"] = text
|
354 |
+
return preData
|
355 |
+
elif pre_is_number == True:
|
356 |
+
text = f'{preData["text"]}{text}'
|
357 |
+
words.pop()
|
358 |
+
elif is_number == True:
|
359 |
+
priority_language = self._get_filters_string()[:2]
|
360 |
+
if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
|
361 |
+
data = {"lang":language,"text": text,"score":score,"symbol":symbol}
|
362 |
+
filters = self.Langfilters
|
363 |
+
if filters is None or len(filters) == 0 or "?" in language or \
|
364 |
+
language in filters or language in filters[0] or \
|
365 |
+
filters[0] == "*" or filters[0] in "alls-mixs-autos":
|
366 |
+
words.append(data)
|
367 |
+
self._statistics(data["lang"],data["text"])
|
368 |
+
return data
|
369 |
+
|
370 |
+
def _addwords(self, words,language,text,score,symbol=None):
|
371 |
+
if text == "\n":pass # Keep Line Breaks
|
372 |
+
elif text is None or len(text.strip()) == 0:return True
|
373 |
+
if language is None:language = ""
|
374 |
+
language = language.lower()
|
375 |
+
if language == 'en':text = self._insert_english_uppercase(text)
|
376 |
+
# text = re.sub(r'[(())]', ',' , text) # Keep it.
|
377 |
+
text_waits = self._text_waits
|
378 |
+
ispre_waits = len(text_waits)>0
|
379 |
+
preResult = text_waits.pop() if ispre_waits else None
|
380 |
+
if preResult is None:preResult = words[-1] if len(words) > 0 else None
|
381 |
+
if preResult and ("|" in preResult["lang"]):
|
382 |
+
pre_lang = preResult["lang"]
|
383 |
+
if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
|
384 |
+
else:preResult["lang"]=pre_lang.split("|")[0]
|
385 |
+
if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
|
386 |
+
pre_lang = preResult["lang"] if preResult else None
|
387 |
+
if ("|" in language) and (pre_lang and not pre_lang in language and not "…" in language):language = language.split("|")[0]
|
388 |
+
if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
|
389 |
+
else:self._saveData(words,language,text,score,symbol)
|
390 |
+
return False
|
391 |
+
|
392 |
+
def _get_prev_data(self, words):
|
393 |
+
data = words[-1] if words and len(words) > 0 else None
|
394 |
+
if data:return (data["lang"] , data["text"])
|
395 |
+
return (None,"")
|
396 |
+
|
397 |
+
def _match_ending(self, input , index):
|
398 |
+
if input is None or len(input) == 0:return False,None
|
399 |
+
input = re.sub(r'\s+', '', input)
|
400 |
+
if len(input) == 0 or abs(index) > len(input):return False,None
|
401 |
+
ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
|
402 |
+
return ending_pattern.match(input[index]),input[index]
|
403 |
+
|
404 |
+
def _cleans_text(self, cleans_text):
|
405 |
+
cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
|
406 |
+
cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
|
407 |
+
return cleans_text.strip()
|
408 |
+
|
409 |
+
def _mean_processing(self, text:str):
|
410 |
+
if text is None or (text.strip()) == "":return None , 0.0
|
411 |
+
arrs = self._split_camel_case(text).split(" ")
|
412 |
+
langs = []
|
413 |
+
for t in arrs:
|
414 |
+
if len(t.strip()) <= 3:continue
|
415 |
+
language, score = self.langid.classify(t)
|
416 |
+
langs.append({"lang":language})
|
417 |
+
if len(langs) == 0:return None , 0.0
|
418 |
+
return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
|
419 |
+
|
420 |
+
def _lang_classify(self, cleans_text):
|
421 |
+
language, score = self.langid.classify(cleans_text)
|
422 |
+
# fix: Huggingface is np.float32
|
423 |
+
if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
|
424 |
+
score = score.item()
|
425 |
+
score = round(score , 3)
|
426 |
+
return language, score
|
427 |
+
|
428 |
+
def _get_filters_string(self):
|
429 |
+
filters = self.Langfilters
|
430 |
+
return "-".join(filters).lower().strip() if filters is not None else ""
|
431 |
+
|
432 |
+
def _parse_language(self, words , segment):
|
433 |
+
LANG_JA = "ja"
|
434 |
+
LANG_ZH = "zh"
|
435 |
+
LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
|
436 |
+
LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
|
437 |
+
language = LANG_ZH
|
438 |
+
regex_pattern = re.compile(r'([^\w\s]+)')
|
439 |
+
lines = regex_pattern.split(segment)
|
440 |
+
lines_max = len(lines)
|
441 |
+
LANG_EOS =self._lang_eos
|
442 |
+
for index, text in enumerate(lines):
|
443 |
+
if len(text) == 0:continue
|
444 |
+
EOS = index >= (lines_max - 1)
|
445 |
+
nextId = index + 1
|
446 |
+
nextText = lines[nextId] if not EOS else ""
|
447 |
+
nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
|
448 |
+
textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
|
449 |
+
if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
|
450 |
+
lines[nextId] = f'{text}{nextText}'
|
451 |
+
continue
|
452 |
+
number_tags = re.compile(r'(⑥\d{6,}⑥)')
|
453 |
+
cleans_text = re.sub(number_tags, '' ,text)
|
454 |
+
cleans_text = re.sub(r'\d+', '' ,cleans_text)
|
455 |
+
cleans_text = self._cleans_text(cleans_text)
|
456 |
+
# fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
|
457 |
+
if not EOS and len(cleans_text) <= 2:
|
458 |
+
lines[nextId] = f'{text}{nextText}'
|
459 |
+
continue
|
460 |
+
language,score = self._lang_classify(cleans_text)
|
461 |
+
prev_language , prev_text = self._get_prev_data(words)
|
462 |
+
if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
|
463 |
+
if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
|
464 |
+
filters_string = self._get_filters_string()
|
465 |
+
if score < self.LangPriorityThreshold and len(filters_string) > 0:
|
466 |
+
index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
|
467 |
+
if index_ja != -1 and index_ja < index_zh:language = LANG_JA
|
468 |
+
elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
|
469 |
+
if self._is_japanese_kana(cleans_text):language = LANG_JA
|
470 |
+
elif len(cleans_text) > 2 and score > 0.90:pass
|
471 |
+
elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
|
472 |
+
else:
|
473 |
+
LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
|
474 |
+
match_end,match_char = self._match_ending(text, -1)
|
475 |
+
referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
|
476 |
+
if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
|
477 |
+
else:language = f"{LANG_UNKNOWN}|…"
|
478 |
+
text,*_ = re.subn(number_tags , self._restore_number , text )
|
479 |
+
self._addwords(words,language,text,score)
|
480 |
+
|
481 |
+
# ----------------------------------------------------------
|
482 |
+
# 【SSML】中文数字处理:Chinese Number Processing (SSML support)
|
483 |
+
# 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
|
484 |
+
# The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
|
485 |
+
# 中文电话号码:<telephone>1234567</telephone>
|
486 |
+
# 中文数字号码:<number>1234567</number>
|
487 |
+
def _process_symbol_SSML(self, words,data):
|
488 |
+
tag , match = data
|
489 |
+
language = SSML = match[1]
|
490 |
+
text = match[2]
|
491 |
+
score = 1.0
|
492 |
+
if SSML == "telephone":
|
493 |
+
# 中文-电话号码
|
494 |
+
language = "zh"
|
495 |
+
text = self.LangSSML.to_chinese_telephone(text)
|
496 |
+
elif SSML == "number":
|
497 |
+
# 中文-数字读法
|
498 |
+
language = "zh"
|
499 |
+
text = self.LangSSML.to_chinese_number(text)
|
500 |
+
elif SSML == "currency":
|
501 |
+
# 中文-按金额发音
|
502 |
+
language = "zh"
|
503 |
+
text = self.LangSSML.to_chinese_currency(text)
|
504 |
+
elif SSML == "date":
|
505 |
+
# 中文-按金额发音
|
506 |
+
language = "zh"
|
507 |
+
text = self.LangSSML.to_chinese_date(text)
|
508 |
+
self._addwords(words,language,text,score,SSML)
|
509 |
+
|
510 |
+
# ----------------------------------------------------------
|
511 |
+
def _restore_number(self, matche):
|
512 |
+
value = matche.group(0)
|
513 |
+
text_cache = self._text_cache
|
514 |
+
if value in text_cache:
|
515 |
+
process , data = text_cache[value]
|
516 |
+
tag , match = data
|
517 |
+
value = match
|
518 |
+
return value
|
519 |
+
|
520 |
+
def _pattern_symbols(self, item , text):
|
521 |
+
if text is None:return text
|
522 |
+
tag , pattern , process = item
|
523 |
+
matches = pattern.findall(text)
|
524 |
+
if len(matches) == 1 and "".join(matches[0]) == text:
|
525 |
+
return text
|
526 |
+
for i , match in enumerate(matches):
|
527 |
+
key = f"⑥{tag}{i:06d}⑥"
|
528 |
+
text = re.sub(pattern , key , text , count=1)
|
529 |
+
self._text_cache[key] = (process , (tag , match))
|
530 |
+
return text
|
531 |
+
|
532 |
+
def _process_symbol(self, words,data):
|
533 |
+
tag , match = data
|
534 |
+
language = match[1]
|
535 |
+
text = match[2]
|
536 |
+
score = 1.0
|
537 |
+
filters = self._get_filters_string()
|
538 |
+
if language not in filters:
|
539 |
+
self._process_symbol_SSML(words,data)
|
540 |
+
else:
|
541 |
+
self._addwords(words,language,text,score,True)
|
542 |
+
|
543 |
+
def _process_english(self, words,data):
|
544 |
+
tag , match = data
|
545 |
+
text = match[0]
|
546 |
+
filters = self._get_filters_string()
|
547 |
+
priority_language = filters[:2]
|
548 |
+
# Preview feature, other language segmentation processing
|
549 |
+
enablePreview = self.EnablePreview
|
550 |
+
if enablePreview == True:
|
551 |
+
# Experimental: Other language support
|
552 |
+
regex_pattern = re.compile(r'(.*?[。.??!!]+[\n]{,1})')
|
553 |
+
lines = regex_pattern.split(text)
|
554 |
+
for index , text in enumerate(lines):
|
555 |
+
if len(text.strip()) == 0:continue
|
556 |
+
cleans_text = self._cleans_text(text)
|
557 |
+
language,score = self._lang_classify(cleans_text)
|
558 |
+
if language not in filters:
|
559 |
+
language,score = self._mean_processing(cleans_text)
|
560 |
+
if language is None or score <= 0.0:continue
|
561 |
+
elif language in filters:pass # pass
|
562 |
+
elif score >= 0.95:continue # High score, but not in the filter, excluded.
|
563 |
+
elif score <= 0.15 and filters[:2] == "fr":language = priority_language
|
564 |
+
else:language = "en"
|
565 |
+
self._addwords(words,language,text,score)
|
566 |
+
else:
|
567 |
+
# Default is English
|
568 |
+
language, score = "en", 1.0
|
569 |
+
self._addwords(words,language,text,score)
|
570 |
+
|
571 |
+
def _process_Russian(self, words,data):
|
572 |
+
tag , match = data
|
573 |
+
text = match[0]
|
574 |
+
language = "ru"
|
575 |
+
score = 1.0
|
576 |
+
self._addwords(words,language,text,score)
|
577 |
+
|
578 |
+
def _process_Thai(self, words,data):
|
579 |
+
tag , match = data
|
580 |
+
text = match[0]
|
581 |
+
language = "th"
|
582 |
+
score = 1.0
|
583 |
+
self._addwords(words,language,text,score)
|
584 |
+
|
585 |
+
def _process_korean(self, words,data):
|
586 |
+
tag , match = data
|
587 |
+
text = match[0]
|
588 |
+
language = "ko"
|
589 |
+
score = 1.0
|
590 |
+
self._addwords(words,language,text,score)
|
591 |
+
|
592 |
+
def _process_quotes(self, words,data):
|
593 |
+
tag , match = data
|
594 |
+
text = "".join(match)
|
595 |
+
childs = self.PARSE_TAG.findall(text)
|
596 |
+
if len(childs) > 0:
|
597 |
+
self._process_tags(words , text , False)
|
598 |
+
else:
|
599 |
+
cleans_text = self._cleans_text(match[1])
|
600 |
+
if len(cleans_text) <= 5:
|
601 |
+
self._parse_language(words,text)
|
602 |
+
else:
|
603 |
+
language,score = self._lang_classify(cleans_text)
|
604 |
+
self._addwords(words,language,text,score)
|
605 |
+
|
606 |
+
def _process_pinyin(self, words,data):
|
607 |
+
tag , match = data
|
608 |
+
text = match
|
609 |
+
language = "zh"
|
610 |
+
score = 1.0
|
611 |
+
self._addwords(words,language,text,score)
|
612 |
+
|
613 |
+
def _process_number(self, words,data): # "$0" process only
|
614 |
+
"""
|
615 |
+
Numbers alone cannot accurately identify language.
|
616 |
+
Because numbers are universal in all languages.
|
617 |
+
So it won't be executed here, just for testing.
|
618 |
+
"""
|
619 |
+
tag , match = data
|
620 |
+
language = words[0]["lang"] if len(words) > 0 else "zh"
|
621 |
+
text = match
|
622 |
+
score = 0.0
|
623 |
+
self._addwords(words,language,text,score)
|
624 |
+
|
625 |
+
def _process_tags(self, words , text , root_tag):
|
626 |
+
text_cache = self._text_cache
|
627 |
+
segments = re.split(self.PARSE_TAG, text)
|
628 |
+
segments_len = len(segments) - 1
|
629 |
+
for index , text in enumerate(segments):
|
630 |
+
if root_tag:self._lang_eos = index >= segments_len
|
631 |
+
if self.PARSE_TAG.match(text):
|
632 |
+
process , data = text_cache[text]
|
633 |
+
if process:process(words , data)
|
634 |
+
else:
|
635 |
+
self._parse_language(words , text)
|
636 |
+
return words
|
637 |
+
|
638 |
+
def _merge_results(self, words):
|
639 |
+
new_word = []
|
640 |
+
for index , cur_data in enumerate(words):
|
641 |
+
if "symbol" in cur_data:del cur_data["symbol"]
|
642 |
+
if index == 0:new_word.append(cur_data)
|
643 |
+
else:
|
644 |
+
pre_data = new_word[-1]
|
645 |
+
if cur_data["lang"] == pre_data["lang"]:
|
646 |
+
pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
|
647 |
+
else:new_word.append(cur_data)
|
648 |
+
return new_word
|
649 |
+
|
650 |
+
def _parse_symbols(self, text):
|
651 |
+
TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
|
652 |
+
TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
|
653 |
+
TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
|
654 |
+
# Get custom language filter
|
655 |
+
filters = self.Langfilters
|
656 |
+
filters = filters if filters is not None else ""
|
657 |
+
# =======================================================================================================
|
658 |
+
# Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
|
659 |
+
# 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
|
660 |
+
# If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
|
661 |
+
# S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
|
662 |
+
# Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
|
663 |
+
# -------------------------------------------------------------------------------------------------------
|
664 |
+
# Preview feature, other language support
|
665 |
+
enablePreview = self.EnablePreview
|
666 |
+
if "fr" in filters or \
|
667 |
+
"vi" in filters:enablePreview = True
|
668 |
+
self.EnablePreview = enablePreview
|
669 |
+
# 实验性:法语字符支持。Prise en charge des caractères français
|
670 |
+
RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
|
671 |
+
# 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
|
672 |
+
RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
|
673 |
+
# -------------------------------------------------------------------------------------------------------
|
674 |
+
# Basic options:
|
675 |
+
process_list = [
|
676 |
+
( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
|
677 |
+
( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
|
678 |
+
( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
|
679 |
+
( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
|
680 |
+
( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
|
681 |
+
( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
|
682 |
+
( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
|
683 |
+
( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
|
684 |
+
]
|
685 |
+
# Extended options: Default False
|
686 |
+
if self.keepPinyin == True:process_list.insert(1 ,
|
687 |
+
( TAG_S2 , re.compile(r'([\(({](?:\s*\w*\d\w*\s*)+[})\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
|
688 |
+
)
|
689 |
+
# -------------------------------------------------------------------------------------------------------
|
690 |
+
words = []
|
691 |
+
lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
|
692 |
+
for index , text in enumerate(lines):
|
693 |
+
if len(text.strip()) == 0:continue
|
694 |
+
self._lang_eos = False
|
695 |
+
self._text_cache = {}
|
696 |
+
for item in process_list:
|
697 |
+
text = self._pattern_symbols(item , text)
|
698 |
+
cur_word = self._process_tags([] , text , True)
|
699 |
+
if len(cur_word) == 0:continue
|
700 |
+
cur_data = cur_word[0] if len(cur_word) > 0 else None
|
701 |
+
pre_data = words[-1] if len(words) > 0 else None
|
702 |
+
if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] \
|
703 |
+
and cur_data["symbol"] == False and pre_data["symbol"] :
|
704 |
+
cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
|
705 |
+
words.pop()
|
706 |
+
words += cur_word
|
707 |
+
if self.isLangMerge == True:words = self._merge_results(words)
|
708 |
+
lang_count = self._lang_count
|
709 |
+
if lang_count and len(lang_count) > 0:
|
710 |
+
lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
|
711 |
+
lang_count = list(lang_count.items())
|
712 |
+
self._lang_count = lang_count
|
713 |
+
return words
|
714 |
+
|
715 |
+
def setfilters(self, filters):
|
716 |
+
# 当过滤器更改时,清除缓存
|
717 |
+
# 필터가 변경되면 캐시를 지웁니다.
|
718 |
+
# フィルタが変更されると、キャッシュがクリアされます
|
719 |
+
# When the filter changes, clear the cache
|
720 |
+
if self.Langfilters != filters:
|
721 |
+
self._clears()
|
722 |
+
self.Langfilters = filters
|
723 |
+
|
724 |
+
def getfilters(self):
|
725 |
+
return self.Langfilters
|
726 |
+
|
727 |
+
def setPriorityThreshold(self, threshold:float):
|
728 |
+
self.LangPriorityThreshold = threshold
|
729 |
+
|
730 |
+
def getPriorityThreshold(self):
|
731 |
+
return self.LangPriorityThreshold
|
732 |
+
|
733 |
+
def getCounts(self):
|
734 |
+
lang_count = self._lang_count
|
735 |
+
if lang_count is not None:return lang_count
|
736 |
+
text_langs = self._text_langs
|
737 |
+
if text_langs is None or len(text_langs) == 0:return [("zh",0)]
|
738 |
+
lang_counts = defaultdict(int)
|
739 |
+
for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
|
740 |
+
lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
|
741 |
+
lang_counts = list(lang_counts.items())
|
742 |
+
self._lang_count = lang_counts
|
743 |
+
return lang_counts
|
744 |
+
|
745 |
+
def getTexts(self, text:str):
|
746 |
+
if text is None or len(text.strip()) == 0:
|
747 |
+
self._clears()
|
748 |
+
return []
|
749 |
+
# lasts
|
750 |
+
text_langs = self._text_langs
|
751 |
+
if self._text_lasts == text and text_langs is not None:return text_langs
|
752 |
+
# parse
|
753 |
+
self._text_waits = []
|
754 |
+
self._lang_count = None
|
755 |
+
self._text_lasts = text
|
756 |
+
text = self._parse_symbols(text)
|
757 |
+
self._text_langs = text
|
758 |
+
return text
|
759 |
+
|
760 |
+
def classify(self, text:str):
|
761 |
+
return self.getTexts(text)
|
762 |
+
|
763 |
+
def printList(langlist):
|
764 |
+
"""
|
765 |
+
功能:打印数组结果
|
766 |
+
기능: 어레이 결과 인쇄
|
767 |
+
機能:配列結果を印刷
|
768 |
+
Function: Print array results
|
769 |
+
"""
|
770 |
+
print("\n===================【打印结果】===================")
|
771 |
+
if langlist is None or len(langlist) == 0:
|
772 |
+
print("无内容结果,No content result")
|
773 |
+
return
|
774 |
+
for line in langlist:
|
775 |
+
print(line)
|
776 |
+
pass
|
777 |
+
|
778 |
+
|
779 |
+
|
780 |
+
def main():
|
781 |
+
|
782 |
+
# -----------------------------------
|
783 |
+
# 更新日志:新版本分词更加精准。
|
784 |
+
# Changelog: The new version of the word segmentation is more accurate.
|
785 |
+
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
786 |
+
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
787 |
+
# -----------------------------------
|
788 |
+
|
789 |
+
# 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
790 |
+
# text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
|
791 |
+
|
792 |
+
# 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
793 |
+
# text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
|
794 |
+
|
795 |
+
# 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
796 |
+
# text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
|
797 |
+
|
798 |
+
|
799 |
+
# 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
|
800 |
+
# text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
|
801 |
+
|
802 |
+
|
803 |
+
# 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
|
804 |
+
langsegment = LangSegment()
|
805 |
+
langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
|
806 |
+
text = """
|
807 |
+
我喜欢在雨天里听音乐。
|
808 |
+
I enjoy listening to music on rainy days.
|
809 |
+
雨の日に音楽を聴くのが好きです。
|
810 |
+
비 오는 날에 음악을 듣는 것을 즐깁니다。
|
811 |
+
J'aime écouter de la musique les jours de pluie.
|
812 |
+
Tôi thích nghe nhạc vào những ngày mưa.
|
813 |
+
Мне нравится слушать музыку в дождливую погоду.
|
814 |
+
ฉันชอบฟังเพลงในวันที่ฝนตก
|
815 |
+
"""
|
816 |
+
|
817 |
+
|
818 |
+
|
819 |
+
# 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
|
820 |
+
langlist = langsegment.getTexts(text)
|
821 |
+
printList(langlist)
|
822 |
+
|
823 |
+
|
824 |
+
# 语种统计:Language statistics:
|
825 |
+
print("\n===================【语种统计】===================")
|
826 |
+
# 获取所有语种数组结果,根据内容字数降序排列
|
827 |
+
# Get the array results in all languages, sorted in descending order according to the number of content words
|
828 |
+
langCounts = langsegment.getCounts()
|
829 |
+
print(langCounts , "\n")
|
830 |
+
|
831 |
+
# 根据结果获取内容的主要语种 (语言,字数含标点)
|
832 |
+
# Get the main language of content based on the results (language, word count including punctuation)
|
833 |
+
lang , count = langCounts[0]
|
834 |
+
print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
|
835 |
+
print("==================================================\n")
|
836 |
+
|
837 |
+
|
838 |
+
# 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
|
839 |
+
# ===================【打印结果】===================
|
840 |
+
# {'lang': 'zh', 'text': '你的名字叫'}
|
841 |
+
# {'lang': 'ja', 'text': '佐々木?'}
|
842 |
+
# {'lang': 'zh', 'text': '吗?韩语中的'}
|
843 |
+
# {'lang': 'ko', 'text': '안녕 오빠'}
|
844 |
+
# {'lang': 'zh', 'text': '读什么呢?'}
|
845 |
+
# {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
|
846 |
+
# {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
|
847 |
+
# {'lang': 'en', 'text': 'i Phone '}
|
848 |
+
# {'lang': 'zh', 'text': '15系列机型和三款'}
|
849 |
+
# {'lang': 'en', 'text': 'Apple Watch '}
|
850 |
+
# {'lang': 'zh', 'text': '等一系列新品,这次的'}
|
851 |
+
# {'lang': 'en', 'text': 'i Pad Air '}
|
852 |
+
# {'lang': 'zh', 'text': '采用了'}
|
853 |
+
# {'lang': 'en', 'text': 'L C D '}
|
854 |
+
# {'lang': 'zh', 'text': '屏幕'}
|
855 |
+
# ===================【语种统计】===================
|
856 |
+
|
857 |
+
# ===================【语种统计】===================
|
858 |
+
# [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
|
859 |
+
|
860 |
+
# 输入内容的主要语言为 = zh ,字数 = 51
|
861 |
+
# ==================================================
|
862 |
+
# The main language of the input content is = zh, word count = 51
|
863 |
+
|
864 |
+
|
865 |
+
if __name__ == "__main__":
|
866 |
+
main()
|
language_segmentation/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .LangSegment import LangSegment
|
2 |
+
|
3 |
+
|
4 |
+
# release
|
5 |
+
__version__ = '0.3.5'
|
6 |
+
|
7 |
+
|
8 |
+
# develop
|
9 |
+
__develop__ = 'dev-0.0.1'
|
language_segmentation/utils/__init__.py
ADDED
File without changes
|
language_segmentation/utils/num.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Digital processing from GPT_SoVITS num.py (thanks)
|
15 |
+
"""
|
16 |
+
Rules to verbalize numbers into Chinese characters.
|
17 |
+
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
18 |
+
"""
|
19 |
+
|
20 |
+
import re
|
21 |
+
from collections import OrderedDict
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
25 |
+
UNITS = OrderedDict({
|
26 |
+
1: '十',
|
27 |
+
2: '百',
|
28 |
+
3: '千',
|
29 |
+
4: '万',
|
30 |
+
8: '亿',
|
31 |
+
})
|
32 |
+
|
33 |
+
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
34 |
+
|
35 |
+
# 分数表达式
|
36 |
+
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
37 |
+
|
38 |
+
|
39 |
+
def replace_frac(match) -> str:
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
match (re.Match)
|
43 |
+
Returns:
|
44 |
+
str
|
45 |
+
"""
|
46 |
+
sign = match.group(1)
|
47 |
+
nominator = match.group(2)
|
48 |
+
denominator = match.group(3)
|
49 |
+
sign: str = "负" if sign else ""
|
50 |
+
nominator: str = num2str(nominator)
|
51 |
+
denominator: str = num2str(denominator)
|
52 |
+
result = f"{sign}{denominator}分之{nominator}"
|
53 |
+
return result
|
54 |
+
|
55 |
+
|
56 |
+
# 百分数表达式
|
57 |
+
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
58 |
+
|
59 |
+
|
60 |
+
def replace_percentage(match) -> str:
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
match (re.Match)
|
64 |
+
Returns:
|
65 |
+
str
|
66 |
+
"""
|
67 |
+
sign = match.group(1)
|
68 |
+
percent = match.group(2)
|
69 |
+
sign: str = "负" if sign else ""
|
70 |
+
percent: str = num2str(percent)
|
71 |
+
result = f"{sign}百分之{percent}"
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
# 整数表达式
|
76 |
+
# 带负号的整数 -10
|
77 |
+
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
78 |
+
|
79 |
+
|
80 |
+
def replace_negative_num(match) -> str:
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
match (re.Match)
|
84 |
+
Returns:
|
85 |
+
str
|
86 |
+
"""
|
87 |
+
sign = match.group(1)
|
88 |
+
number = match.group(2)
|
89 |
+
sign: str = "负" if sign else ""
|
90 |
+
number: str = num2str(number)
|
91 |
+
result = f"{sign}{number}"
|
92 |
+
return result
|
93 |
+
|
94 |
+
|
95 |
+
# 编号-无符号整形
|
96 |
+
# 00078
|
97 |
+
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
98 |
+
|
99 |
+
|
100 |
+
def replace_default_num(match):
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
match (re.Match)
|
104 |
+
Returns:
|
105 |
+
str
|
106 |
+
"""
|
107 |
+
number = match.group(0)
|
108 |
+
return verbalize_digit(number, alt_one=True)
|
109 |
+
|
110 |
+
|
111 |
+
# 加减乘除
|
112 |
+
# RE_ASMD = re.compile(
|
113 |
+
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
114 |
+
RE_ASMD = re.compile(
|
115 |
+
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
116 |
+
|
117 |
+
asmd_map = {
|
118 |
+
'+': '加',
|
119 |
+
'-': '减',
|
120 |
+
'×': '乘',
|
121 |
+
'÷': '除',
|
122 |
+
'=': '等于'
|
123 |
+
}
|
124 |
+
|
125 |
+
def replace_asmd(match) -> str:
|
126 |
+
"""
|
127 |
+
Args:
|
128 |
+
match (re.Match)
|
129 |
+
Returns:
|
130 |
+
str
|
131 |
+
"""
|
132 |
+
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
133 |
+
return result
|
134 |
+
|
135 |
+
|
136 |
+
# 次方专项
|
137 |
+
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
138 |
+
|
139 |
+
power_map = {
|
140 |
+
'⁰': '0',
|
141 |
+
'¹': '1',
|
142 |
+
'²': '2',
|
143 |
+
'³': '3',
|
144 |
+
'⁴': '4',
|
145 |
+
'⁵': '5',
|
146 |
+
'⁶': '6',
|
147 |
+
'⁷': '7',
|
148 |
+
'⁸': '8',
|
149 |
+
'⁹': '9',
|
150 |
+
'ˣ': 'x',
|
151 |
+
'ʸ': 'y',
|
152 |
+
'ⁿ': 'n'
|
153 |
+
}
|
154 |
+
|
155 |
+
def replace_power(match) -> str:
|
156 |
+
"""
|
157 |
+
Args:
|
158 |
+
match (re.Match)
|
159 |
+
Returns:
|
160 |
+
str
|
161 |
+
"""
|
162 |
+
power_num = ""
|
163 |
+
for m in match.group(0):
|
164 |
+
power_num += power_map[m]
|
165 |
+
result = "的" + power_num + "次方"
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
# 数字表达式
|
170 |
+
# 纯小数
|
171 |
+
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
172 |
+
# 正整数 + 量词
|
173 |
+
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
174 |
+
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
175 |
+
|
176 |
+
|
177 |
+
def replace_positive_quantifier(match) -> str:
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
match (re.Match)
|
181 |
+
Returns:
|
182 |
+
str
|
183 |
+
"""
|
184 |
+
number = match.group(1)
|
185 |
+
match_2 = match.group(2)
|
186 |
+
if match_2 == "+":
|
187 |
+
match_2 = "多"
|
188 |
+
match_2: str = match_2 if match_2 else ""
|
189 |
+
quantifiers: str = match.group(3)
|
190 |
+
number: str = num2str(number)
|
191 |
+
result = f"{number}{match_2}{quantifiers}"
|
192 |
+
return result
|
193 |
+
|
194 |
+
|
195 |
+
def replace_number(match) -> str:
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
match (re.Match)
|
199 |
+
Returns:
|
200 |
+
str
|
201 |
+
"""
|
202 |
+
sign = match.group(1)
|
203 |
+
number = match.group(2)
|
204 |
+
pure_decimal = match.group(5)
|
205 |
+
if pure_decimal:
|
206 |
+
result = num2str(pure_decimal)
|
207 |
+
else:
|
208 |
+
sign: str = "负" if sign else ""
|
209 |
+
number: str = num2str(number)
|
210 |
+
result = f"{sign}{number}"
|
211 |
+
return result
|
212 |
+
|
213 |
+
|
214 |
+
# 范围表达式
|
215 |
+
# match.group(1) and match.group(8) are copy from RE_NUMBER
|
216 |
+
|
217 |
+
RE_RANGE = re.compile(
|
218 |
+
r"""
|
219 |
+
(?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
|
220 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
|
221 |
+
[-~] # 匹配范围分隔符
|
222 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
223 |
+
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
224 |
+
""", re.VERBOSE)
|
225 |
+
|
226 |
+
|
227 |
+
def replace_range(match) -> str:
|
228 |
+
"""
|
229 |
+
Args:
|
230 |
+
match (re.Match)
|
231 |
+
Returns:
|
232 |
+
str
|
233 |
+
"""
|
234 |
+
first, second = match.group(1), match.group(6)
|
235 |
+
first = RE_NUMBER.sub(replace_number, first)
|
236 |
+
second = RE_NUMBER.sub(replace_number, second)
|
237 |
+
result = f"{first}到{second}"
|
238 |
+
return result
|
239 |
+
|
240 |
+
|
241 |
+
# ~至表达式
|
242 |
+
RE_TO_RANGE = re.compile(
|
243 |
+
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
244 |
+
|
245 |
+
def replace_to_range(match) -> str:
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
match (re.Match)
|
249 |
+
Returns:
|
250 |
+
str
|
251 |
+
"""
|
252 |
+
result = match.group(0).replace('~', '至')
|
253 |
+
return result
|
254 |
+
|
255 |
+
|
256 |
+
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
257 |
+
stripped = value_string.lstrip('0')
|
258 |
+
if len(stripped) == 0:
|
259 |
+
return []
|
260 |
+
elif len(stripped) == 1:
|
261 |
+
if use_zero and len(stripped) < len(value_string):
|
262 |
+
return [DIGITS['0'], DIGITS[stripped]]
|
263 |
+
else:
|
264 |
+
return [DIGITS[stripped]]
|
265 |
+
else:
|
266 |
+
largest_unit = next(
|
267 |
+
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
268 |
+
first_part = value_string[:-largest_unit]
|
269 |
+
second_part = value_string[-largest_unit:]
|
270 |
+
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
271 |
+
second_part)
|
272 |
+
|
273 |
+
|
274 |
+
def verbalize_cardinal(value_string: str) -> str:
|
275 |
+
if not value_string:
|
276 |
+
return ''
|
277 |
+
|
278 |
+
# 000 -> '零' , 0 -> '零'
|
279 |
+
value_string = value_string.lstrip('0')
|
280 |
+
if len(value_string) == 0:
|
281 |
+
return DIGITS['0']
|
282 |
+
|
283 |
+
result_symbols = _get_value(value_string)
|
284 |
+
# verbalized number starting with '一十*' is abbreviated as `十*`
|
285 |
+
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
286 |
+
'1'] and result_symbols[1] == UNITS[1]:
|
287 |
+
result_symbols = result_symbols[1:]
|
288 |
+
return ''.join(result_symbols)
|
289 |
+
|
290 |
+
|
291 |
+
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
292 |
+
result_symbols = [DIGITS[digit] for digit in value_string]
|
293 |
+
result = ''.join(result_symbols)
|
294 |
+
if alt_one:
|
295 |
+
result = result.replace("一", "幺")
|
296 |
+
return result
|
297 |
+
|
298 |
+
|
299 |
+
def num2str(value_string: str) -> str:
|
300 |
+
integer_decimal = value_string.split('.')
|
301 |
+
if len(integer_decimal) == 1:
|
302 |
+
integer = integer_decimal[0]
|
303 |
+
decimal = ''
|
304 |
+
elif len(integer_decimal) == 2:
|
305 |
+
integer, decimal = integer_decimal
|
306 |
+
else:
|
307 |
+
raise ValueError(
|
308 |
+
f"The value string: '${value_string}' has more than one point in it."
|
309 |
+
)
|
310 |
+
|
311 |
+
result = verbalize_cardinal(integer)
|
312 |
+
|
313 |
+
decimal = decimal.rstrip('0')
|
314 |
+
if decimal:
|
315 |
+
# '.22' is verbalized as '零点二二'
|
316 |
+
# '3.20' is verbalized as '三点二
|
317 |
+
result = result if result else "零"
|
318 |
+
result += '点' + verbalize_digit(decimal)
|
319 |
+
return result
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
|
324 |
+
text = ""
|
325 |
+
text = num2str(text)
|
326 |
+
print(text)
|
327 |
+
pass
|
models/ace_step_transformer.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional, Tuple, List, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
25 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
26 |
+
|
27 |
+
|
28 |
+
from .attention import LinearTransformerBlock, t2i_modulate
|
29 |
+
from .lyrics_utils.lyric_encoder import ConformerEncoder as LyricEncoder
|
30 |
+
|
31 |
+
|
32 |
+
def cross_norm(hidden_states, controlnet_input):
|
33 |
+
# input N x T x c
|
34 |
+
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
35 |
+
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
36 |
+
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
37 |
+
return controlnet_input
|
38 |
+
|
39 |
+
|
40 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
41 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
42 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.dim = dim
|
46 |
+
self.max_position_embeddings = max_position_embeddings
|
47 |
+
self.base = base
|
48 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
49 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
50 |
+
|
51 |
+
# Build here to make `torch.jit.trace` work.
|
52 |
+
self._set_cos_sin_cache(
|
53 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
54 |
+
)
|
55 |
+
|
56 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
57 |
+
self.max_seq_len_cached = seq_len
|
58 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
59 |
+
|
60 |
+
freqs = torch.outer(t, self.inv_freq)
|
61 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
62 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
63 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
64 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
65 |
+
|
66 |
+
def forward(self, x, seq_len=None):
|
67 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
68 |
+
if seq_len > self.max_seq_len_cached:
|
69 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
70 |
+
|
71 |
+
return (
|
72 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
73 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class T2IFinalLayer(nn.Module):
|
78 |
+
"""
|
79 |
+
The final layer of Sana.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256):
|
83 |
+
super().__init__()
|
84 |
+
self.norm_final = nn.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
85 |
+
self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
|
86 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
|
87 |
+
self.out_channels = out_channels
|
88 |
+
self.patch_size = patch_size
|
89 |
+
|
90 |
+
def unpatchfy(
|
91 |
+
self,
|
92 |
+
hidden_states: torch.Tensor,
|
93 |
+
width: int,
|
94 |
+
):
|
95 |
+
# 4 unpatchify
|
96 |
+
new_height, new_width = 1, hidden_states.size(1)
|
97 |
+
hidden_states = hidden_states.reshape(
|
98 |
+
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
99 |
+
).contiguous()
|
100 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
101 |
+
output = hidden_states.reshape(
|
102 |
+
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
103 |
+
).contiguous()
|
104 |
+
if width > new_width:
|
105 |
+
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
106 |
+
elif width < new_width:
|
107 |
+
output = output[:, :, :, :width]
|
108 |
+
return output
|
109 |
+
|
110 |
+
def forward(self, x, t, output_length):
|
111 |
+
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
112 |
+
x = t2i_modulate(self.norm_final(x), shift, scale)
|
113 |
+
x = self.linear(x)
|
114 |
+
# unpatchify
|
115 |
+
output = self.unpatchfy(x, output_length)
|
116 |
+
return output
|
117 |
+
|
118 |
+
|
119 |
+
class PatchEmbed(nn.Module):
|
120 |
+
"""2D Image to Patch Embedding"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
height=16,
|
125 |
+
width=4096,
|
126 |
+
patch_size=(16, 1),
|
127 |
+
in_channels=8,
|
128 |
+
embed_dim=1152,
|
129 |
+
bias=True,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
patch_size_h, patch_size_w = patch_size
|
133 |
+
self.early_conv_layers = nn.Sequential(
|
134 |
+
nn.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias),
|
135 |
+
torch.nn.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True),
|
136 |
+
nn.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
|
137 |
+
)
|
138 |
+
self.patch_size = patch_size
|
139 |
+
self.height, self.width = height // patch_size_h, width // patch_size_w
|
140 |
+
self.base_size = self.width
|
141 |
+
|
142 |
+
def forward(self, latent):
|
143 |
+
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
144 |
+
latent = self.early_conv_layers(latent)
|
145 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
146 |
+
return latent
|
147 |
+
|
148 |
+
|
149 |
+
@dataclass
|
150 |
+
class Transformer2DModelOutput(BaseOutput):
|
151 |
+
|
152 |
+
sample: torch.FloatTensor
|
153 |
+
proj_losses: Optional[Tuple[Tuple[str, torch.Tensor]]] = None
|
154 |
+
|
155 |
+
|
156 |
+
class ACEStepTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
157 |
+
_supports_gradient_checkpointing = True
|
158 |
+
|
159 |
+
@register_to_config
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
in_channels: Optional[int] = 8,
|
163 |
+
num_layers: int = 28,
|
164 |
+
inner_dim: int = 1536,
|
165 |
+
attention_head_dim: int = 64,
|
166 |
+
num_attention_heads: int = 24,
|
167 |
+
mlp_ratio: float = 4.0,
|
168 |
+
out_channels: int = 8,
|
169 |
+
max_position: int = 32768,
|
170 |
+
rope_theta: float = 1000000.0,
|
171 |
+
speaker_embedding_dim: int = 512,
|
172 |
+
text_embedding_dim: int = 768,
|
173 |
+
ssl_encoder_depths: List[int] = [9, 9],
|
174 |
+
ssl_names: List[str] = ["mert", "m-hubert"],
|
175 |
+
ssl_latent_dims: List[int] = [1024, 768],
|
176 |
+
lyric_encoder_vocab_size: int = 6681,
|
177 |
+
lyric_hidden_size: int = 1024,
|
178 |
+
patch_size: List[int] = [16, 1],
|
179 |
+
max_height: int = 16,
|
180 |
+
max_width: int = 4096,
|
181 |
+
**kwargs,
|
182 |
+
):
|
183 |
+
super().__init__()
|
184 |
+
|
185 |
+
self.num_attention_heads = num_attention_heads
|
186 |
+
self.attention_head_dim = attention_head_dim
|
187 |
+
inner_dim = num_attention_heads * attention_head_dim
|
188 |
+
self.inner_dim = inner_dim
|
189 |
+
self.out_channels = out_channels
|
190 |
+
self.max_position = max_position
|
191 |
+
self.patch_size = patch_size
|
192 |
+
|
193 |
+
self.rope_theta = rope_theta
|
194 |
+
|
195 |
+
self.rotary_emb = Qwen2RotaryEmbedding(
|
196 |
+
dim=self.attention_head_dim,
|
197 |
+
max_position_embeddings=self.max_position,
|
198 |
+
base=self.rope_theta,
|
199 |
+
)
|
200 |
+
|
201 |
+
# 2. Define input layers
|
202 |
+
self.in_channels = in_channels
|
203 |
+
|
204 |
+
# 3. Define transformers blocks
|
205 |
+
self.transformer_blocks = nn.ModuleList(
|
206 |
+
[
|
207 |
+
LinearTransformerBlock(
|
208 |
+
dim=self.inner_dim,
|
209 |
+
num_attention_heads=self.num_attention_heads,
|
210 |
+
attention_head_dim=attention_head_dim,
|
211 |
+
mlp_ratio=mlp_ratio,
|
212 |
+
add_cross_attention=True,
|
213 |
+
add_cross_attention_dim=self.inner_dim,
|
214 |
+
)
|
215 |
+
for i in range(self.config.num_layers)
|
216 |
+
]
|
217 |
+
)
|
218 |
+
self.num_layers = num_layers
|
219 |
+
|
220 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
221 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
|
222 |
+
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(self.inner_dim, 6 * self.inner_dim, bias=True))
|
223 |
+
|
224 |
+
# speaker
|
225 |
+
self.speaker_embedder = nn.Linear(speaker_embedding_dim, self.inner_dim)
|
226 |
+
|
227 |
+
# genre
|
228 |
+
self.genre_embedder = nn.Linear(text_embedding_dim, self.inner_dim)
|
229 |
+
|
230 |
+
# lyric
|
231 |
+
self.lyric_embs = nn.Embedding(lyric_encoder_vocab_size, lyric_hidden_size)
|
232 |
+
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0)
|
233 |
+
self.lyric_proj = nn.Linear(lyric_hidden_size, self.inner_dim)
|
234 |
+
|
235 |
+
projector_dim = 2 * self.inner_dim
|
236 |
+
|
237 |
+
self.projectors = nn.ModuleList([
|
238 |
+
nn.Sequential(
|
239 |
+
nn.Linear(self.inner_dim, projector_dim),
|
240 |
+
nn.SiLU(),
|
241 |
+
nn.Linear(projector_dim, projector_dim),
|
242 |
+
nn.SiLU(),
|
243 |
+
nn.Linear(projector_dim, ssl_dim),
|
244 |
+
) for ssl_dim in ssl_latent_dims
|
245 |
+
])
|
246 |
+
|
247 |
+
self.ssl_latent_dims = ssl_latent_dims
|
248 |
+
self.ssl_encoder_depths = ssl_encoder_depths
|
249 |
+
|
250 |
+
self.cosine_loss = torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
|
251 |
+
self.ssl_names = ssl_names
|
252 |
+
|
253 |
+
self.proj_in = PatchEmbed(
|
254 |
+
height=max_height,
|
255 |
+
width=max_width,
|
256 |
+
patch_size=patch_size,
|
257 |
+
embed_dim=self.inner_dim,
|
258 |
+
bias=True,
|
259 |
+
)
|
260 |
+
|
261 |
+
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels)
|
262 |
+
self.gradient_checkpointing = False
|
263 |
+
|
264 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
265 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
266 |
+
"""
|
267 |
+
Sets the attention processor to use [feed forward
|
268 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
269 |
+
|
270 |
+
Parameters:
|
271 |
+
chunk_size (`int`, *optional*):
|
272 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
273 |
+
over each tensor of dim=`dim`.
|
274 |
+
dim (`int`, *optional*, defaults to `0`):
|
275 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
276 |
+
or dim=1 (sequence length).
|
277 |
+
"""
|
278 |
+
if dim not in [0, 1]:
|
279 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
280 |
+
|
281 |
+
# By default chunk size is 1
|
282 |
+
chunk_size = chunk_size or 1
|
283 |
+
|
284 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
285 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
286 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
287 |
+
|
288 |
+
for child in module.children():
|
289 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
290 |
+
|
291 |
+
for module in self.children():
|
292 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
293 |
+
|
294 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
295 |
+
if hasattr(module, "gradient_checkpointing"):
|
296 |
+
module.gradient_checkpointing = value
|
297 |
+
|
298 |
+
def forward_lyric_encoder(
|
299 |
+
self,
|
300 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
301 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
302 |
+
):
|
303 |
+
# N x T x D
|
304 |
+
lyric_embs = self.lyric_embs(lyric_token_idx)
|
305 |
+
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
306 |
+
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
307 |
+
return prompt_prenet_out
|
308 |
+
|
309 |
+
def encode(
|
310 |
+
self,
|
311 |
+
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
312 |
+
text_attention_mask: Optional[torch.LongTensor] = None,
|
313 |
+
speaker_embeds: Optional[torch.FloatTensor] = None,
|
314 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
315 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
316 |
+
):
|
317 |
+
|
318 |
+
bs = encoder_text_hidden_states.shape[0]
|
319 |
+
device = encoder_text_hidden_states.device
|
320 |
+
|
321 |
+
# speaker embedding
|
322 |
+
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
323 |
+
speaker_mask = torch.ones(bs, 1, device=device)
|
324 |
+
|
325 |
+
# genre embedding
|
326 |
+
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
327 |
+
|
328 |
+
# lyric
|
329 |
+
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
330 |
+
lyric_token_idx=lyric_token_idx,
|
331 |
+
lyric_mask=lyric_mask,
|
332 |
+
)
|
333 |
+
|
334 |
+
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
335 |
+
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
336 |
+
return encoder_hidden_states, encoder_hidden_mask
|
337 |
+
|
338 |
+
def decode(
|
339 |
+
self,
|
340 |
+
hidden_states: torch.Tensor,
|
341 |
+
attention_mask: torch.Tensor,
|
342 |
+
encoder_hidden_states: torch.Tensor,
|
343 |
+
encoder_hidden_mask: torch.Tensor,
|
344 |
+
timestep: Optional[torch.Tensor],
|
345 |
+
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
|
346 |
+
output_length: int = 0,
|
347 |
+
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
348 |
+
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
349 |
+
return_dict: bool = True,
|
350 |
+
):
|
351 |
+
|
352 |
+
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
353 |
+
temb = self.t_block(embedded_timestep)
|
354 |
+
|
355 |
+
hidden_states = self.proj_in(hidden_states)
|
356 |
+
|
357 |
+
# controlnet logic
|
358 |
+
if block_controlnet_hidden_states is not None:
|
359 |
+
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
360 |
+
hidden_states = hidden_states + control_condi * controlnet_scale
|
361 |
+
|
362 |
+
inner_hidden_states = []
|
363 |
+
|
364 |
+
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
365 |
+
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
366 |
+
|
367 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
368 |
+
|
369 |
+
if self.training and self.gradient_checkpointing:
|
370 |
+
|
371 |
+
def create_custom_forward(module, return_dict=None):
|
372 |
+
def custom_forward(*inputs):
|
373 |
+
if return_dict is not None:
|
374 |
+
return module(*inputs, return_dict=return_dict)
|
375 |
+
else:
|
376 |
+
return module(*inputs)
|
377 |
+
|
378 |
+
return custom_forward
|
379 |
+
|
380 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
381 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
382 |
+
create_custom_forward(block),
|
383 |
+
hidden_states=hidden_states,
|
384 |
+
attention_mask=attention_mask,
|
385 |
+
encoder_hidden_states=encoder_hidden_states,
|
386 |
+
encoder_attention_mask=encoder_hidden_mask,
|
387 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
388 |
+
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
389 |
+
temb=temb,
|
390 |
+
**ckpt_kwargs,
|
391 |
+
)
|
392 |
+
|
393 |
+
else:
|
394 |
+
hidden_states = block(
|
395 |
+
hidden_states=hidden_states,
|
396 |
+
attention_mask=attention_mask,
|
397 |
+
encoder_hidden_states=encoder_hidden_states,
|
398 |
+
encoder_attention_mask=encoder_hidden_mask,
|
399 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
400 |
+
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
401 |
+
temb=temb,
|
402 |
+
)
|
403 |
+
|
404 |
+
for ssl_encoder_depth in self.ssl_encoder_depths:
|
405 |
+
if index_block == ssl_encoder_depth:
|
406 |
+
inner_hidden_states.append(hidden_states)
|
407 |
+
|
408 |
+
proj_losses = []
|
409 |
+
if len(inner_hidden_states) > 0 and ssl_hidden_states is not None and len(ssl_hidden_states) > 0:
|
410 |
+
|
411 |
+
for inner_hidden_state, projector, ssl_hidden_state, ssl_name in zip(inner_hidden_states, self.projectors, ssl_hidden_states, self.ssl_names):
|
412 |
+
if ssl_hidden_state is None:
|
413 |
+
continue
|
414 |
+
# 1. N x T x D1 -> N x D x D2
|
415 |
+
est_ssl_hidden_state = projector(inner_hidden_state)
|
416 |
+
# 3. projection loss
|
417 |
+
bs = inner_hidden_state.shape[0]
|
418 |
+
proj_loss = 0.0
|
419 |
+
for i, (z, z_tilde) in enumerate(zip(ssl_hidden_state, est_ssl_hidden_state)):
|
420 |
+
# 2. interpolate
|
421 |
+
z_tilde = F.interpolate(z_tilde.unsqueeze(0).transpose(1, 2), size=len(z), mode='linear', align_corners=False).transpose(1, 2).squeeze(0)
|
422 |
+
|
423 |
+
z_tilde = torch.nn.functional.normalize(z_tilde, dim=-1)
|
424 |
+
z = torch.nn.functional.normalize(z, dim=-1)
|
425 |
+
# T x d -> T x 1 -> 1
|
426 |
+
target = torch.ones(z.shape[0], device=z.device)
|
427 |
+
proj_loss += self.cosine_loss(z, z_tilde, target)
|
428 |
+
proj_losses.append((ssl_name, proj_loss / bs))
|
429 |
+
|
430 |
+
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
431 |
+
if not return_dict:
|
432 |
+
return (output, proj_losses)
|
433 |
+
|
434 |
+
return Transformer2DModelOutput(sample=output, proj_losses=proj_losses)
|
435 |
+
|
436 |
+
# @torch.compile
|
437 |
+
def forward(
|
438 |
+
self,
|
439 |
+
hidden_states: torch.Tensor,
|
440 |
+
attention_mask: torch.Tensor,
|
441 |
+
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
442 |
+
text_attention_mask: Optional[torch.LongTensor] = None,
|
443 |
+
speaker_embeds: Optional[torch.FloatTensor] = None,
|
444 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
445 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
446 |
+
timestep: Optional[torch.Tensor] = None,
|
447 |
+
ssl_hidden_states: Optional[List[torch.Tensor]] = None,
|
448 |
+
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
449 |
+
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
450 |
+
return_dict: bool = True,
|
451 |
+
):
|
452 |
+
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
453 |
+
encoder_text_hidden_states=encoder_text_hidden_states,
|
454 |
+
text_attention_mask=text_attention_mask,
|
455 |
+
speaker_embeds=speaker_embeds,
|
456 |
+
lyric_token_idx=lyric_token_idx,
|
457 |
+
lyric_mask=lyric_mask,
|
458 |
+
)
|
459 |
+
|
460 |
+
output_length = hidden_states.shape[-1]
|
461 |
+
|
462 |
+
output = self.decode(
|
463 |
+
hidden_states=hidden_states,
|
464 |
+
attention_mask=attention_mask,
|
465 |
+
encoder_hidden_states=encoder_hidden_states,
|
466 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
467 |
+
timestep=timestep,
|
468 |
+
ssl_hidden_states=ssl_hidden_states,
|
469 |
+
output_length=output_length,
|
470 |
+
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
471 |
+
controlnet_scale=controlnet_scale,
|
472 |
+
return_dict=return_dict,
|
473 |
+
)
|
474 |
+
|
475 |
+
return output
|
models/attention.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import logging
|
21 |
+
from diffusers.models.normalization import RMSNorm
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
# from .dcformer import DCMHAttention
|
26 |
+
from .customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0
|
27 |
+
except ImportError:
|
28 |
+
# from dcformer import DCMHAttention
|
29 |
+
from customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
36 |
+
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
37 |
+
if isinstance(x, (list, tuple)):
|
38 |
+
return list(x)
|
39 |
+
return [x for _ in range(repeat_time)]
|
40 |
+
|
41 |
+
|
42 |
+
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
43 |
+
"""Return tuple with min_len by repeating element at idx_repeat."""
|
44 |
+
# convert to list first
|
45 |
+
x = val2list(x)
|
46 |
+
|
47 |
+
# repeat elements if necessary
|
48 |
+
if len(x) > 0:
|
49 |
+
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
50 |
+
|
51 |
+
return tuple(x)
|
52 |
+
|
53 |
+
|
54 |
+
def t2i_modulate(x, shift, scale):
|
55 |
+
return x * (1 + scale) + shift
|
56 |
+
|
57 |
+
|
58 |
+
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
59 |
+
if isinstance(kernel_size, tuple):
|
60 |
+
return tuple([get_same_padding(ks) for ks in kernel_size])
|
61 |
+
else:
|
62 |
+
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
63 |
+
return kernel_size // 2
|
64 |
+
|
65 |
+
class ConvLayer(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
in_dim: int,
|
69 |
+
out_dim: int,
|
70 |
+
kernel_size=3,
|
71 |
+
stride=1,
|
72 |
+
dilation=1,
|
73 |
+
groups=1,
|
74 |
+
padding: Union[int, None] = None,
|
75 |
+
use_bias=False,
|
76 |
+
norm=None,
|
77 |
+
act=None,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
if padding is None:
|
81 |
+
padding = get_same_padding(kernel_size)
|
82 |
+
padding *= dilation
|
83 |
+
|
84 |
+
self.in_dim = in_dim
|
85 |
+
self.out_dim = out_dim
|
86 |
+
self.kernel_size = kernel_size
|
87 |
+
self.stride = stride
|
88 |
+
self.dilation = dilation
|
89 |
+
self.groups = groups
|
90 |
+
self.padding = padding
|
91 |
+
self.use_bias = use_bias
|
92 |
+
|
93 |
+
self.conv = nn.Conv1d(
|
94 |
+
in_dim,
|
95 |
+
out_dim,
|
96 |
+
kernel_size=kernel_size,
|
97 |
+
stride=stride,
|
98 |
+
padding=padding,
|
99 |
+
dilation=dilation,
|
100 |
+
groups=groups,
|
101 |
+
bias=use_bias,
|
102 |
+
)
|
103 |
+
if norm is not None:
|
104 |
+
self.norm = RMSNorm(out_dim, elementwise_affine=False)
|
105 |
+
else:
|
106 |
+
self.norm = None
|
107 |
+
if act is not None:
|
108 |
+
self.act = nn.SiLU(inplace=True)
|
109 |
+
else:
|
110 |
+
self.act = None
|
111 |
+
|
112 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
113 |
+
x = self.conv(x)
|
114 |
+
if self.norm:
|
115 |
+
x = self.norm(x)
|
116 |
+
if self.act:
|
117 |
+
x = self.act(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class GLUMBConv(nn.Module):
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
in_features: int,
|
125 |
+
hidden_features: int,
|
126 |
+
out_feature=None,
|
127 |
+
kernel_size=3,
|
128 |
+
stride=1,
|
129 |
+
padding: Union[int, None] = None,
|
130 |
+
use_bias=False,
|
131 |
+
norm=(None, None, None),
|
132 |
+
act=("silu", "silu", None),
|
133 |
+
dilation=1,
|
134 |
+
):
|
135 |
+
out_feature = out_feature or in_features
|
136 |
+
super().__init__()
|
137 |
+
use_bias = val2tuple(use_bias, 3)
|
138 |
+
norm = val2tuple(norm, 3)
|
139 |
+
act = val2tuple(act, 3)
|
140 |
+
|
141 |
+
self.glu_act = nn.SiLU(inplace=False)
|
142 |
+
self.inverted_conv = ConvLayer(
|
143 |
+
in_features,
|
144 |
+
hidden_features * 2,
|
145 |
+
1,
|
146 |
+
use_bias=use_bias[0],
|
147 |
+
norm=norm[0],
|
148 |
+
act=act[0],
|
149 |
+
)
|
150 |
+
self.depth_conv = ConvLayer(
|
151 |
+
hidden_features * 2,
|
152 |
+
hidden_features * 2,
|
153 |
+
kernel_size,
|
154 |
+
stride=stride,
|
155 |
+
groups=hidden_features * 2,
|
156 |
+
padding=padding,
|
157 |
+
use_bias=use_bias[1],
|
158 |
+
norm=norm[1],
|
159 |
+
act=None,
|
160 |
+
dilation=dilation,
|
161 |
+
)
|
162 |
+
self.point_conv = ConvLayer(
|
163 |
+
hidden_features,
|
164 |
+
out_feature,
|
165 |
+
1,
|
166 |
+
use_bias=use_bias[2],
|
167 |
+
norm=norm[2],
|
168 |
+
act=act[2],
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
172 |
+
x = x.transpose(1, 2)
|
173 |
+
x = self.inverted_conv(x)
|
174 |
+
x = self.depth_conv(x)
|
175 |
+
|
176 |
+
x, gate = torch.chunk(x, 2, dim=1)
|
177 |
+
gate = self.glu_act(gate)
|
178 |
+
x = x * gate
|
179 |
+
|
180 |
+
x = self.point_conv(x)
|
181 |
+
x = x.transpose(1, 2)
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class LinearTransformerBlock(nn.Module):
|
187 |
+
"""
|
188 |
+
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
189 |
+
"""
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dim,
|
193 |
+
num_attention_heads,
|
194 |
+
attention_head_dim,
|
195 |
+
use_adaln_single=True,
|
196 |
+
cross_attention_dim=None,
|
197 |
+
added_kv_proj_dim=None,
|
198 |
+
context_pre_only=False,
|
199 |
+
mlp_ratio=4.0,
|
200 |
+
add_cross_attention=False,
|
201 |
+
add_cross_attention_dim=None,
|
202 |
+
qk_norm=None,
|
203 |
+
):
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
207 |
+
self.attn = Attention(
|
208 |
+
query_dim=dim,
|
209 |
+
cross_attention_dim=cross_attention_dim,
|
210 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
211 |
+
dim_head=attention_head_dim,
|
212 |
+
heads=num_attention_heads,
|
213 |
+
out_dim=dim,
|
214 |
+
bias=True,
|
215 |
+
qk_norm=qk_norm,
|
216 |
+
processor=CustomLiteLAProcessor2_0(),
|
217 |
+
)
|
218 |
+
|
219 |
+
self.add_cross_attention = add_cross_attention
|
220 |
+
self.context_pre_only = context_pre_only
|
221 |
+
|
222 |
+
if add_cross_attention and add_cross_attention_dim is not None:
|
223 |
+
self.cross_attn = Attention(
|
224 |
+
query_dim=dim,
|
225 |
+
cross_attention_dim=add_cross_attention_dim,
|
226 |
+
added_kv_proj_dim=add_cross_attention_dim,
|
227 |
+
dim_head=attention_head_dim,
|
228 |
+
heads=num_attention_heads,
|
229 |
+
out_dim=dim,
|
230 |
+
context_pre_only=context_pre_only,
|
231 |
+
bias=True,
|
232 |
+
qk_norm=qk_norm,
|
233 |
+
processor=CustomerAttnProcessor2_0(),
|
234 |
+
)
|
235 |
+
|
236 |
+
self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False)
|
237 |
+
|
238 |
+
self.ff = GLUMBConv(
|
239 |
+
in_features=dim,
|
240 |
+
hidden_features=int(dim * mlp_ratio),
|
241 |
+
use_bias=(True, True, False),
|
242 |
+
norm=(None, None, None),
|
243 |
+
act=("silu", "silu", None),
|
244 |
+
)
|
245 |
+
self.use_adaln_single = use_adaln_single
|
246 |
+
if use_adaln_single:
|
247 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
hidden_states: torch.FloatTensor,
|
252 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
253 |
+
attention_mask: torch.FloatTensor = None,
|
254 |
+
encoder_attention_mask: torch.FloatTensor = None,
|
255 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
256 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
257 |
+
temb: torch.FloatTensor = None,
|
258 |
+
):
|
259 |
+
|
260 |
+
N = hidden_states.shape[0]
|
261 |
+
|
262 |
+
# step 1: AdaLN single
|
263 |
+
if self.use_adaln_single:
|
264 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
265 |
+
self.scale_shift_table[None] + temb.reshape(N, 6, -1)
|
266 |
+
).chunk(6, dim=1)
|
267 |
+
|
268 |
+
norm_hidden_states = self.norm1(hidden_states)
|
269 |
+
if self.use_adaln_single:
|
270 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
271 |
+
|
272 |
+
# step 2: attention
|
273 |
+
if not self.add_cross_attention:
|
274 |
+
attn_output, encoder_hidden_states = self.attn(
|
275 |
+
hidden_states=norm_hidden_states,
|
276 |
+
attention_mask=attention_mask,
|
277 |
+
encoder_hidden_states=encoder_hidden_states,
|
278 |
+
encoder_attention_mask=encoder_attention_mask,
|
279 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
280 |
+
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
attn_output, _ = self.attn(
|
284 |
+
hidden_states=norm_hidden_states,
|
285 |
+
attention_mask=attention_mask,
|
286 |
+
encoder_hidden_states=None,
|
287 |
+
encoder_attention_mask=None,
|
288 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
289 |
+
rotary_freqs_cis_cross=None,
|
290 |
+
)
|
291 |
+
|
292 |
+
if self.use_adaln_single:
|
293 |
+
attn_output = gate_msa * attn_output
|
294 |
+
hidden_states = attn_output + hidden_states
|
295 |
+
|
296 |
+
if self.add_cross_attention:
|
297 |
+
attn_output = self.cross_attn(
|
298 |
+
hidden_states=hidden_states,
|
299 |
+
attention_mask=attention_mask,
|
300 |
+
encoder_hidden_states=encoder_hidden_states,
|
301 |
+
encoder_attention_mask=encoder_attention_mask,
|
302 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
303 |
+
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
304 |
+
)
|
305 |
+
hidden_states = attn_output + hidden_states
|
306 |
+
|
307 |
+
# step 3: add norm
|
308 |
+
norm_hidden_states = self.norm2(hidden_states)
|
309 |
+
if self.use_adaln_single:
|
310 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
311 |
+
|
312 |
+
# step 4: feed forward
|
313 |
+
ff_output = self.ff(norm_hidden_states)
|
314 |
+
if self.use_adaln_single:
|
315 |
+
ff_output = gate_mlp * ff_output
|
316 |
+
|
317 |
+
hidden_states = hidden_states + ff_output
|
318 |
+
|
319 |
+
return hidden_states
|
models/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "Transformer2DModel",
|
3 |
+
"_diffusers_version": "0.27.2",
|
4 |
+
"in_channels": 8,
|
5 |
+
"num_layers": 24,
|
6 |
+
"inner_dim": 2560,
|
7 |
+
"attention_head_dim": 128,
|
8 |
+
"num_attention_heads": 20,
|
9 |
+
"mlp_ratio": 2.5,
|
10 |
+
"out_channels": 8,
|
11 |
+
"max_position": 32768,
|
12 |
+
"rope_theta": 1000000.0,
|
13 |
+
"speaker_embedding_dim": 512,
|
14 |
+
"text_embedding_dim": 768,
|
15 |
+
"ssl_encoder_depths": [8, 8],
|
16 |
+
"ssl_names": ["mert", "m-hubert"],
|
17 |
+
"ssl_latent_dims": [1024, 768],
|
18 |
+
"patch_size": [16, 1],
|
19 |
+
"max_height": 16,
|
20 |
+
"max_width": 32768,
|
21 |
+
"lyric_encoder_vocab_size": 6693,
|
22 |
+
"lyric_hidden_size": 1024
|
23 |
+
}
|
models/customer_attention_processor.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Union, Tuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import logging
|
21 |
+
from diffusers.models.attention_processor import Attention
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
|
26 |
+
class CustomLiteLAProcessor2_0:
|
27 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
self.kernel_func = nn.ReLU(inplace=False)
|
31 |
+
self.eps = 1e-15
|
32 |
+
self.pad_val = 1.0
|
33 |
+
|
34 |
+
def apply_rotary_emb(
|
35 |
+
self,
|
36 |
+
x: torch.Tensor,
|
37 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
38 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
39 |
+
"""
|
40 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
41 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
42 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
43 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x (`torch.Tensor`):
|
47 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
48 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
52 |
+
"""
|
53 |
+
cos, sin = freqs_cis # [S, D]
|
54 |
+
cos = cos[None, None]
|
55 |
+
sin = sin[None, None]
|
56 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
57 |
+
|
58 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
59 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
60 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
61 |
+
|
62 |
+
return out
|
63 |
+
|
64 |
+
def __call__(
|
65 |
+
self,
|
66 |
+
attn: Attention,
|
67 |
+
hidden_states: torch.FloatTensor,
|
68 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
69 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
70 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
71 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
72 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
73 |
+
*args,
|
74 |
+
**kwargs,
|
75 |
+
) -> torch.FloatTensor:
|
76 |
+
hidden_states_len = hidden_states.shape[1]
|
77 |
+
|
78 |
+
input_ndim = hidden_states.ndim
|
79 |
+
if input_ndim == 4:
|
80 |
+
batch_size, channel, height, width = hidden_states.shape
|
81 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
82 |
+
if encoder_hidden_states is not None:
|
83 |
+
context_input_ndim = encoder_hidden_states.ndim
|
84 |
+
if context_input_ndim == 4:
|
85 |
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
86 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
87 |
+
|
88 |
+
batch_size = hidden_states.shape[0]
|
89 |
+
|
90 |
+
# `sample` projections.
|
91 |
+
dtype = hidden_states.dtype
|
92 |
+
query = attn.to_q(hidden_states)
|
93 |
+
key = attn.to_k(hidden_states)
|
94 |
+
value = attn.to_v(hidden_states)
|
95 |
+
|
96 |
+
# `context` projections.
|
97 |
+
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
98 |
+
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
99 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
100 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
101 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
102 |
+
|
103 |
+
# attention
|
104 |
+
if not attn.is_cross_attention:
|
105 |
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
106 |
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
107 |
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
108 |
+
else:
|
109 |
+
query = hidden_states
|
110 |
+
key = encoder_hidden_states
|
111 |
+
value = encoder_hidden_states
|
112 |
+
|
113 |
+
inner_dim = key.shape[-1]
|
114 |
+
head_dim = inner_dim // attn.heads
|
115 |
+
|
116 |
+
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
117 |
+
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
118 |
+
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
119 |
+
|
120 |
+
# RoPE需要 [B, H, S, D] 输入
|
121 |
+
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
122 |
+
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
123 |
+
|
124 |
+
# Apply query and key normalization if needed
|
125 |
+
if attn.norm_q is not None:
|
126 |
+
query = attn.norm_q(query)
|
127 |
+
if attn.norm_k is not None:
|
128 |
+
key = attn.norm_k(key)
|
129 |
+
|
130 |
+
# Apply RoPE if needed
|
131 |
+
if rotary_freqs_cis is not None:
|
132 |
+
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
133 |
+
if not attn.is_cross_attention:
|
134 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
135 |
+
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
136 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
137 |
+
|
138 |
+
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
139 |
+
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
140 |
+
|
141 |
+
if attention_mask is not None:
|
142 |
+
# attention_mask: [B, S] -> [B, 1, S, 1]
|
143 |
+
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
144 |
+
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
145 |
+
if not attn.is_cross_attention:
|
146 |
+
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
147 |
+
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
148 |
+
|
149 |
+
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
150 |
+
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
151 |
+
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
152 |
+
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
153 |
+
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
154 |
+
|
155 |
+
query = self.kernel_func(query)
|
156 |
+
key = self.kernel_func(key)
|
157 |
+
|
158 |
+
query, key, value = query.float(), key.float(), value.float()
|
159 |
+
|
160 |
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
161 |
+
|
162 |
+
vk = torch.matmul(value, key)
|
163 |
+
|
164 |
+
hidden_states = torch.matmul(vk, query)
|
165 |
+
|
166 |
+
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
167 |
+
hidden_states = hidden_states.float()
|
168 |
+
|
169 |
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
170 |
+
|
171 |
+
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
172 |
+
|
173 |
+
hidden_states = hidden_states.to(dtype)
|
174 |
+
if encoder_hidden_states is not None:
|
175 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
176 |
+
|
177 |
+
# Split the attention outputs.
|
178 |
+
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
179 |
+
hidden_states, encoder_hidden_states = (
|
180 |
+
hidden_states[:, : hidden_states_len],
|
181 |
+
hidden_states[:, hidden_states_len:],
|
182 |
+
)
|
183 |
+
|
184 |
+
# linear proj
|
185 |
+
hidden_states = attn.to_out[0](hidden_states)
|
186 |
+
# dropout
|
187 |
+
hidden_states = attn.to_out[1](hidden_states)
|
188 |
+
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
189 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
190 |
+
|
191 |
+
if input_ndim == 4:
|
192 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
193 |
+
if encoder_hidden_states is not None and context_input_ndim == 4:
|
194 |
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
195 |
+
|
196 |
+
if torch.get_autocast_gpu_dtype() == torch.float16:
|
197 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
198 |
+
if encoder_hidden_states is not None:
|
199 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
200 |
+
|
201 |
+
return hidden_states, encoder_hidden_states
|
202 |
+
|
203 |
+
|
204 |
+
class CustomerAttnProcessor2_0:
|
205 |
+
r"""
|
206 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self):
|
210 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
211 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
212 |
+
|
213 |
+
def apply_rotary_emb(
|
214 |
+
self,
|
215 |
+
x: torch.Tensor,
|
216 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
217 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
218 |
+
"""
|
219 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
220 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
221 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
222 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
x (`torch.Tensor`):
|
226 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
227 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
231 |
+
"""
|
232 |
+
cos, sin = freqs_cis # [S, D]
|
233 |
+
cos = cos[None, None]
|
234 |
+
sin = sin[None, None]
|
235 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
236 |
+
|
237 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
238 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
239 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
240 |
+
|
241 |
+
return out
|
242 |
+
|
243 |
+
def __call__(
|
244 |
+
self,
|
245 |
+
attn: Attention,
|
246 |
+
hidden_states: torch.FloatTensor,
|
247 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
248 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
249 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
250 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
251 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
252 |
+
*args,
|
253 |
+
**kwargs,
|
254 |
+
) -> torch.Tensor:
|
255 |
+
|
256 |
+
residual = hidden_states
|
257 |
+
input_ndim = hidden_states.ndim
|
258 |
+
|
259 |
+
if input_ndim == 4:
|
260 |
+
batch_size, channel, height, width = hidden_states.shape
|
261 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
262 |
+
|
263 |
+
batch_size, sequence_length, _ = (
|
264 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
265 |
+
)
|
266 |
+
|
267 |
+
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
268 |
+
|
269 |
+
if attn.group_norm is not None:
|
270 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
271 |
+
|
272 |
+
query = attn.to_q(hidden_states)
|
273 |
+
|
274 |
+
if encoder_hidden_states is None:
|
275 |
+
encoder_hidden_states = hidden_states
|
276 |
+
elif attn.norm_cross:
|
277 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
278 |
+
|
279 |
+
key = attn.to_k(encoder_hidden_states)
|
280 |
+
value = attn.to_v(encoder_hidden_states)
|
281 |
+
|
282 |
+
inner_dim = key.shape[-1]
|
283 |
+
head_dim = inner_dim // attn.heads
|
284 |
+
|
285 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
286 |
+
|
287 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
288 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
289 |
+
|
290 |
+
if attn.norm_q is not None:
|
291 |
+
query = attn.norm_q(query)
|
292 |
+
if attn.norm_k is not None:
|
293 |
+
key = attn.norm_k(key)
|
294 |
+
|
295 |
+
# Apply RoPE if needed
|
296 |
+
if rotary_freqs_cis is not None:
|
297 |
+
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
298 |
+
if not attn.is_cross_attention:
|
299 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
300 |
+
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
301 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
302 |
+
|
303 |
+
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
304 |
+
# attention_mask: N x S1
|
305 |
+
# encoder_attention_mask: N x S2
|
306 |
+
# cross attention 整合attention_mask和encoder_attention_mask
|
307 |
+
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
308 |
+
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
309 |
+
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
310 |
+
|
311 |
+
elif not attn.is_cross_attention and attention_mask is not None:
|
312 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
313 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
314 |
+
# (batch, heads, source_length, target_length)
|
315 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
316 |
+
|
317 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
318 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
319 |
+
hidden_states = F.scaled_dot_product_attention(
|
320 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
321 |
+
)
|
322 |
+
|
323 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
324 |
+
hidden_states = hidden_states.to(query.dtype)
|
325 |
+
|
326 |
+
# linear proj
|
327 |
+
hidden_states = attn.to_out[0](hidden_states)
|
328 |
+
# dropout
|
329 |
+
hidden_states = attn.to_out[1](hidden_states)
|
330 |
+
|
331 |
+
if input_ndim == 4:
|
332 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
333 |
+
|
334 |
+
if attn.residual_connection:
|
335 |
+
hidden_states = hidden_states + residual
|
336 |
+
|
337 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
338 |
+
|
339 |
+
return hidden_states
|
models/lyrics_utils/lyric_encoder.py
ADDED
@@ -0,0 +1,1070 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
class ConvolutionModule(nn.Module):
|
7 |
+
"""ConvolutionModule in Conformer model."""
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
channels: int,
|
11 |
+
kernel_size: int = 15,
|
12 |
+
activation: nn.Module = nn.ReLU(),
|
13 |
+
norm: str = "batch_norm",
|
14 |
+
causal: bool = False,
|
15 |
+
bias: bool = True):
|
16 |
+
"""Construct an ConvolutionModule object.
|
17 |
+
Args:
|
18 |
+
channels (int): The number of channels of conv layers.
|
19 |
+
kernel_size (int): Kernel size of conv layers.
|
20 |
+
causal (int): Whether use causal convolution or not
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.pointwise_conv1 = nn.Conv1d(
|
25 |
+
channels,
|
26 |
+
2 * channels,
|
27 |
+
kernel_size=1,
|
28 |
+
stride=1,
|
29 |
+
padding=0,
|
30 |
+
bias=bias,
|
31 |
+
)
|
32 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
33 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
34 |
+
# padded with self.lorder frames on the left in forward.
|
35 |
+
# else: it's a symmetrical convolution
|
36 |
+
if causal:
|
37 |
+
padding = 0
|
38 |
+
self.lorder = kernel_size - 1
|
39 |
+
else:
|
40 |
+
# kernel_size should be an odd number for none causal convolution
|
41 |
+
assert (kernel_size - 1) % 2 == 0
|
42 |
+
padding = (kernel_size - 1) // 2
|
43 |
+
self.lorder = 0
|
44 |
+
self.depthwise_conv = nn.Conv1d(
|
45 |
+
channels,
|
46 |
+
channels,
|
47 |
+
kernel_size,
|
48 |
+
stride=1,
|
49 |
+
padding=padding,
|
50 |
+
groups=channels,
|
51 |
+
bias=bias,
|
52 |
+
)
|
53 |
+
|
54 |
+
assert norm in ['batch_norm', 'layer_norm']
|
55 |
+
if norm == "batch_norm":
|
56 |
+
self.use_layer_norm = False
|
57 |
+
self.norm = nn.BatchNorm1d(channels)
|
58 |
+
else:
|
59 |
+
self.use_layer_norm = True
|
60 |
+
self.norm = nn.LayerNorm(channels)
|
61 |
+
|
62 |
+
self.pointwise_conv2 = nn.Conv1d(
|
63 |
+
channels,
|
64 |
+
channels,
|
65 |
+
kernel_size=1,
|
66 |
+
stride=1,
|
67 |
+
padding=0,
|
68 |
+
bias=bias,
|
69 |
+
)
|
70 |
+
self.activation = activation
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self,
|
74 |
+
x: torch.Tensor,
|
75 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
76 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
77 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
78 |
+
"""Compute convolution module.
|
79 |
+
Args:
|
80 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
81 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
82 |
+
(0, 0, 0) means fake mask.
|
83 |
+
cache (torch.Tensor): left context cache, it is only
|
84 |
+
used in causal convolution (#batch, channels, cache_t),
|
85 |
+
(0, 0, 0) meas fake cache.
|
86 |
+
Returns:
|
87 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
88 |
+
"""
|
89 |
+
# exchange the temporal dimension and the feature dimension
|
90 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
91 |
+
|
92 |
+
# mask batch padding
|
93 |
+
if mask_pad.size(2) > 0: # time > 0
|
94 |
+
x.masked_fill_(~mask_pad, 0.0)
|
95 |
+
|
96 |
+
if self.lorder > 0:
|
97 |
+
if cache.size(2) == 0: # cache_t == 0
|
98 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
99 |
+
else:
|
100 |
+
assert cache.size(0) == x.size(0) # equal batch
|
101 |
+
assert cache.size(1) == x.size(1) # equal channel
|
102 |
+
x = torch.cat((cache, x), dim=2)
|
103 |
+
assert (x.size(2) > self.lorder)
|
104 |
+
new_cache = x[:, :, -self.lorder:]
|
105 |
+
else:
|
106 |
+
# It's better we just return None if no cache is required,
|
107 |
+
# However, for JIT export, here we just fake one tensor instead of
|
108 |
+
# None.
|
109 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
110 |
+
|
111 |
+
# GLU mechanism
|
112 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
113 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
114 |
+
|
115 |
+
# 1D Depthwise Conv
|
116 |
+
x = self.depthwise_conv(x)
|
117 |
+
if self.use_layer_norm:
|
118 |
+
x = x.transpose(1, 2)
|
119 |
+
x = self.activation(self.norm(x))
|
120 |
+
if self.use_layer_norm:
|
121 |
+
x = x.transpose(1, 2)
|
122 |
+
x = self.pointwise_conv2(x)
|
123 |
+
# mask batch padding
|
124 |
+
if mask_pad.size(2) > 0: # time > 0
|
125 |
+
x.masked_fill_(~mask_pad, 0.0)
|
126 |
+
|
127 |
+
return x.transpose(1, 2), new_cache
|
128 |
+
|
129 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
130 |
+
"""Positionwise feed forward layer.
|
131 |
+
|
132 |
+
FeedForward are appied on each position of the sequence.
|
133 |
+
The output dim is same with the input dim.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
idim (int): Input dimenstion.
|
137 |
+
hidden_units (int): The number of hidden units.
|
138 |
+
dropout_rate (float): Dropout rate.
|
139 |
+
activation (torch.nn.Module): Activation function
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
idim: int,
|
145 |
+
hidden_units: int,
|
146 |
+
dropout_rate: float,
|
147 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
148 |
+
):
|
149 |
+
"""Construct a PositionwiseFeedForward object."""
|
150 |
+
super(PositionwiseFeedForward, self).__init__()
|
151 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
152 |
+
self.activation = activation
|
153 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
154 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
155 |
+
|
156 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
157 |
+
"""Forward function.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
xs: input tensor (B, L, D)
|
161 |
+
Returns:
|
162 |
+
output tensor, (B, L, D)
|
163 |
+
"""
|
164 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
165 |
+
|
166 |
+
class Swish(torch.nn.Module):
|
167 |
+
"""Construct an Swish object."""
|
168 |
+
|
169 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
170 |
+
"""Return Swish activation function."""
|
171 |
+
return x * torch.sigmoid(x)
|
172 |
+
|
173 |
+
class MultiHeadedAttention(nn.Module):
|
174 |
+
"""Multi-Head Attention layer.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
n_head (int): The number of heads.
|
178 |
+
n_feat (int): The number of features.
|
179 |
+
dropout_rate (float): Dropout rate.
|
180 |
+
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self,
|
184 |
+
n_head: int,
|
185 |
+
n_feat: int,
|
186 |
+
dropout_rate: float,
|
187 |
+
key_bias: bool = True):
|
188 |
+
"""Construct an MultiHeadedAttention object."""
|
189 |
+
super().__init__()
|
190 |
+
assert n_feat % n_head == 0
|
191 |
+
# We assume d_v always equals d_k
|
192 |
+
self.d_k = n_feat // n_head
|
193 |
+
self.h = n_head
|
194 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
195 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
196 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
197 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
198 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
199 |
+
|
200 |
+
def forward_qkv(
|
201 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
202 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
203 |
+
"""Transform query, key and value.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
207 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
208 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: Transformed query tensor, size
|
212 |
+
(#batch, n_head, time1, d_k).
|
213 |
+
torch.Tensor: Transformed key tensor, size
|
214 |
+
(#batch, n_head, time2, d_k).
|
215 |
+
torch.Tensor: Transformed value tensor, size
|
216 |
+
(#batch, n_head, time2, d_k).
|
217 |
+
|
218 |
+
"""
|
219 |
+
n_batch = query.size(0)
|
220 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
221 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
222 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
223 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
224 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
225 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
226 |
+
return q, k, v
|
227 |
+
|
228 |
+
def forward_attention(
|
229 |
+
self,
|
230 |
+
value: torch.Tensor,
|
231 |
+
scores: torch.Tensor,
|
232 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
233 |
+
) -> torch.Tensor:
|
234 |
+
"""Compute attention context vector.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
value (torch.Tensor): Transformed value, size
|
238 |
+
(#batch, n_head, time2, d_k).
|
239 |
+
scores (torch.Tensor): Attention score, size
|
240 |
+
(#batch, n_head, time1, time2).
|
241 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
242 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
246 |
+
weighted by the attention score (#batch, time1, time2).
|
247 |
+
|
248 |
+
"""
|
249 |
+
n_batch = value.size(0)
|
250 |
+
|
251 |
+
if mask.size(2) > 0: # time2 > 0
|
252 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
253 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
254 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
255 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
256 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
257 |
+
mask, 0.0) # (batch, head, time1, time2)
|
258 |
+
|
259 |
+
else:
|
260 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
261 |
+
|
262 |
+
p_attn = self.dropout(attn)
|
263 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
264 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
265 |
+
self.h * self.d_k)
|
266 |
+
) # (batch, time1, d_model)
|
267 |
+
|
268 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
269 |
+
|
270 |
+
def forward(
|
271 |
+
self,
|
272 |
+
query: torch.Tensor,
|
273 |
+
key: torch.Tensor,
|
274 |
+
value: torch.Tensor,
|
275 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
276 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
277 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
278 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
279 |
+
"""Compute scaled dot product attention.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
283 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
284 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
285 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
286 |
+
(#batch, time1, time2).
|
287 |
+
1.When applying cross attention between decoder and encoder,
|
288 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
289 |
+
2.When applying self attention of encoder,
|
290 |
+
the mask is in (#batch, T, T) shape.
|
291 |
+
3.When applying self attention of decoder,
|
292 |
+
the mask is in (#batch, L, L) shape.
|
293 |
+
4.If the different position in decoder see different block
|
294 |
+
of the encoder, such as Mocha, the passed in mask could be
|
295 |
+
in (#batch, L, T) shape. But there is no such case in current
|
296 |
+
CosyVoice.
|
297 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
298 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
299 |
+
and `head * d_k == size`
|
300 |
+
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
304 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
305 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
306 |
+
and `head * d_k == size`
|
307 |
+
|
308 |
+
"""
|
309 |
+
q, k, v = self.forward_qkv(query, key, value)
|
310 |
+
if cache.size(0) > 0:
|
311 |
+
key_cache, value_cache = torch.split(cache,
|
312 |
+
cache.size(-1) // 2,
|
313 |
+
dim=-1)
|
314 |
+
k = torch.cat([key_cache, k], dim=2)
|
315 |
+
v = torch.cat([value_cache, v], dim=2)
|
316 |
+
new_cache = torch.cat((k, v), dim=-1)
|
317 |
+
|
318 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
319 |
+
return self.forward_attention(v, scores, mask), new_cache
|
320 |
+
|
321 |
+
|
322 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
323 |
+
"""Multi-Head Attention layer with relative position encoding.
|
324 |
+
Paper: https://arxiv.org/abs/1901.02860
|
325 |
+
Args:
|
326 |
+
n_head (int): The number of heads.
|
327 |
+
n_feat (int): The number of features.
|
328 |
+
dropout_rate (float): Dropout rate.
|
329 |
+
"""
|
330 |
+
|
331 |
+
def __init__(self,
|
332 |
+
n_head: int,
|
333 |
+
n_feat: int,
|
334 |
+
dropout_rate: float,
|
335 |
+
key_bias: bool = True):
|
336 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
337 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
338 |
+
# linear transformation for positional encoding
|
339 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
340 |
+
# these two learnable bias are used in matrix c and matrix d
|
341 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
342 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
343 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
344 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
345 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
346 |
+
|
347 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
348 |
+
"""Compute relative positional encoding.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
352 |
+
time1 means the length of query vector.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
torch.Tensor: Output tensor.
|
356 |
+
|
357 |
+
"""
|
358 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
359 |
+
device=x.device,
|
360 |
+
dtype=x.dtype)
|
361 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
362 |
+
|
363 |
+
x_padded = x_padded.view(x.size()[0],
|
364 |
+
x.size()[1],
|
365 |
+
x.size(3) + 1, x.size(2))
|
366 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
367 |
+
:, :, :, : x.size(-1) // 2 + 1
|
368 |
+
] # only keep the positions from 0 to time2
|
369 |
+
return x
|
370 |
+
|
371 |
+
def forward(
|
372 |
+
self,
|
373 |
+
query: torch.Tensor,
|
374 |
+
key: torch.Tensor,
|
375 |
+
value: torch.Tensor,
|
376 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
377 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
378 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
379 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
380 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
381 |
+
Args:
|
382 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
383 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
384 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
385 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
386 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
387 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
388 |
+
(#batch, time2, size).
|
389 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
390 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
391 |
+
and `head * d_k == size`
|
392 |
+
Returns:
|
393 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
394 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
395 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
396 |
+
and `head * d_k == size`
|
397 |
+
"""
|
398 |
+
q, k, v = self.forward_qkv(query, key, value)
|
399 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
400 |
+
|
401 |
+
if cache.size(0) > 0:
|
402 |
+
key_cache, value_cache = torch.split(cache,
|
403 |
+
cache.size(-1) // 2,
|
404 |
+
dim=-1)
|
405 |
+
k = torch.cat([key_cache, k], dim=2)
|
406 |
+
v = torch.cat([value_cache, v], dim=2)
|
407 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
408 |
+
# non-trivial to calculate `next_cache_start` here.
|
409 |
+
new_cache = torch.cat((k, v), dim=-1)
|
410 |
+
|
411 |
+
n_batch_pos = pos_emb.size(0)
|
412 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
413 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
414 |
+
|
415 |
+
# (batch, head, time1, d_k)
|
416 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
417 |
+
# (batch, head, time1, d_k)
|
418 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
419 |
+
|
420 |
+
# compute attention score
|
421 |
+
# first compute matrix a and matrix c
|
422 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
423 |
+
# (batch, head, time1, time2)
|
424 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
425 |
+
|
426 |
+
# compute matrix b and matrix d
|
427 |
+
# (batch, head, time1, time2)
|
428 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
429 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
430 |
+
if matrix_ac.shape != matrix_bd.shape:
|
431 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
432 |
+
|
433 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
434 |
+
self.d_k) # (batch, head, time1, time2)
|
435 |
+
|
436 |
+
return self.forward_attention(v, scores, mask), new_cache
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
def subsequent_mask(
|
441 |
+
size: int,
|
442 |
+
device: torch.device = torch.device("cpu"),
|
443 |
+
) -> torch.Tensor:
|
444 |
+
"""Create mask for subsequent steps (size, size).
|
445 |
+
|
446 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
447 |
+
This means the current step could only do attention with its left steps.
|
448 |
+
|
449 |
+
In encoder, fully attention is used when streaming is not necessary and
|
450 |
+
the sequence is not long. In this case, no attention mask is needed.
|
451 |
+
|
452 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
453 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
size (int): size of mask
|
457 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
458 |
+
dtype (torch.device): result dtype
|
459 |
+
|
460 |
+
Returns:
|
461 |
+
torch.Tensor: mask
|
462 |
+
|
463 |
+
Examples:
|
464 |
+
>>> subsequent_mask(3)
|
465 |
+
[[1, 0, 0],
|
466 |
+
[1, 1, 0],
|
467 |
+
[1, 1, 1]]
|
468 |
+
"""
|
469 |
+
arange = torch.arange(size, device=device)
|
470 |
+
mask = arange.expand(size, size)
|
471 |
+
arange = arange.unsqueeze(-1)
|
472 |
+
mask = mask <= arange
|
473 |
+
return mask
|
474 |
+
|
475 |
+
|
476 |
+
def subsequent_chunk_mask(
|
477 |
+
size: int,
|
478 |
+
chunk_size: int,
|
479 |
+
num_left_chunks: int = -1,
|
480 |
+
device: torch.device = torch.device("cpu"),
|
481 |
+
) -> torch.Tensor:
|
482 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
483 |
+
this is for streaming encoder
|
484 |
+
|
485 |
+
Args:
|
486 |
+
size (int): size of mask
|
487 |
+
chunk_size (int): size of chunk
|
488 |
+
num_left_chunks (int): number of left chunks
|
489 |
+
<0: use full chunk
|
490 |
+
>=0: use num_left_chunks
|
491 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
torch.Tensor: mask
|
495 |
+
|
496 |
+
Examples:
|
497 |
+
>>> subsequent_chunk_mask(4, 2)
|
498 |
+
[[1, 1, 0, 0],
|
499 |
+
[1, 1, 0, 0],
|
500 |
+
[1, 1, 1, 1],
|
501 |
+
[1, 1, 1, 1]]
|
502 |
+
"""
|
503 |
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
504 |
+
for i in range(size):
|
505 |
+
if num_left_chunks < 0:
|
506 |
+
start = 0
|
507 |
+
else:
|
508 |
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
509 |
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
510 |
+
ret[i, start:ending] = True
|
511 |
+
return ret
|
512 |
+
|
513 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
514 |
+
masks: torch.Tensor,
|
515 |
+
use_dynamic_chunk: bool,
|
516 |
+
use_dynamic_left_chunk: bool,
|
517 |
+
decoding_chunk_size: int,
|
518 |
+
static_chunk_size: int,
|
519 |
+
num_decoding_left_chunks: int,
|
520 |
+
enable_full_context: bool = True):
|
521 |
+
""" Apply optional mask for encoder.
|
522 |
+
|
523 |
+
Args:
|
524 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
525 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
526 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
527 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
528 |
+
training.
|
529 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
530 |
+
0: default for training, use random dynamic chunk.
|
531 |
+
<0: for decoding, use full chunk.
|
532 |
+
>0: for decoding, use fixed chunk size as set.
|
533 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
534 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
535 |
+
this parameter will be ignored
|
536 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
537 |
+
the chunk size is decoding_chunk_size.
|
538 |
+
>=0: use num_decoding_left_chunks
|
539 |
+
<0: use all left chunks
|
540 |
+
enable_full_context (bool):
|
541 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
542 |
+
False: chunk size ~ U[1, 25]
|
543 |
+
|
544 |
+
Returns:
|
545 |
+
torch.Tensor: chunk mask of the input xs.
|
546 |
+
"""
|
547 |
+
# Whether to use chunk mask or not
|
548 |
+
if use_dynamic_chunk:
|
549 |
+
max_len = xs.size(1)
|
550 |
+
if decoding_chunk_size < 0:
|
551 |
+
chunk_size = max_len
|
552 |
+
num_left_chunks = -1
|
553 |
+
elif decoding_chunk_size > 0:
|
554 |
+
chunk_size = decoding_chunk_size
|
555 |
+
num_left_chunks = num_decoding_left_chunks
|
556 |
+
else:
|
557 |
+
# chunk size is either [1, 25] or full context(max_len).
|
558 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
559 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
560 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
561 |
+
num_left_chunks = -1
|
562 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
563 |
+
chunk_size = max_len
|
564 |
+
else:
|
565 |
+
chunk_size = chunk_size % 25 + 1
|
566 |
+
if use_dynamic_left_chunk:
|
567 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
568 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
569 |
+
(1, )).item()
|
570 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
571 |
+
num_left_chunks,
|
572 |
+
xs.device) # (L, L)
|
573 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
574 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
575 |
+
elif static_chunk_size > 0:
|
576 |
+
num_left_chunks = num_decoding_left_chunks
|
577 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
578 |
+
num_left_chunks,
|
579 |
+
xs.device) # (L, L)
|
580 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
581 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
582 |
+
else:
|
583 |
+
chunk_masks = masks
|
584 |
+
return chunk_masks
|
585 |
+
|
586 |
+
|
587 |
+
class ConformerEncoderLayer(nn.Module):
|
588 |
+
"""Encoder layer module.
|
589 |
+
Args:
|
590 |
+
size (int): Input dimension.
|
591 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
592 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
593 |
+
instance can be used as the argument.
|
594 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
595 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
596 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
597 |
+
instance.
|
598 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
599 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
600 |
+
`ConvlutionModule` instance can be used as the argument.
|
601 |
+
dropout_rate (float): Dropout rate.
|
602 |
+
normalize_before (bool):
|
603 |
+
True: use layer_norm before each sub-block.
|
604 |
+
False: use layer_norm after each sub-block.
|
605 |
+
"""
|
606 |
+
|
607 |
+
def __init__(
|
608 |
+
self,
|
609 |
+
size: int,
|
610 |
+
self_attn: torch.nn.Module,
|
611 |
+
feed_forward: Optional[nn.Module] = None,
|
612 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
613 |
+
conv_module: Optional[nn.Module] = None,
|
614 |
+
dropout_rate: float = 0.1,
|
615 |
+
normalize_before: bool = True,
|
616 |
+
):
|
617 |
+
"""Construct an EncoderLayer object."""
|
618 |
+
super().__init__()
|
619 |
+
self.self_attn = self_attn
|
620 |
+
self.feed_forward = feed_forward
|
621 |
+
self.feed_forward_macaron = feed_forward_macaron
|
622 |
+
self.conv_module = conv_module
|
623 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
624 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
625 |
+
if feed_forward_macaron is not None:
|
626 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
627 |
+
self.ff_scale = 0.5
|
628 |
+
else:
|
629 |
+
self.ff_scale = 1.0
|
630 |
+
if self.conv_module is not None:
|
631 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
632 |
+
self.norm_final = nn.LayerNorm(
|
633 |
+
size, eps=1e-5) # for the final output of the block
|
634 |
+
self.dropout = nn.Dropout(dropout_rate)
|
635 |
+
self.size = size
|
636 |
+
self.normalize_before = normalize_before
|
637 |
+
|
638 |
+
def forward(
|
639 |
+
self,
|
640 |
+
x: torch.Tensor,
|
641 |
+
mask: torch.Tensor,
|
642 |
+
pos_emb: torch.Tensor,
|
643 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
644 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
645 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
646 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
647 |
+
"""Compute encoded features.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
x (torch.Tensor): (#batch, time, size)
|
651 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
652 |
+
(0, 0, 0) means fake mask.
|
653 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
654 |
+
for ConformerEncoderLayer.
|
655 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
656 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
657 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
658 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
659 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
660 |
+
(#batch=1, size, cache_t2)
|
661 |
+
Returns:
|
662 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
663 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
664 |
+
torch.Tensor: att_cache tensor,
|
665 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
666 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
667 |
+
"""
|
668 |
+
|
669 |
+
# whether to use macaron style
|
670 |
+
if self.feed_forward_macaron is not None:
|
671 |
+
residual = x
|
672 |
+
if self.normalize_before:
|
673 |
+
x = self.norm_ff_macaron(x)
|
674 |
+
x = residual + self.ff_scale * self.dropout(
|
675 |
+
self.feed_forward_macaron(x))
|
676 |
+
if not self.normalize_before:
|
677 |
+
x = self.norm_ff_macaron(x)
|
678 |
+
|
679 |
+
# multi-headed self-attention module
|
680 |
+
residual = x
|
681 |
+
if self.normalize_before:
|
682 |
+
x = self.norm_mha(x)
|
683 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
684 |
+
att_cache)
|
685 |
+
x = residual + self.dropout(x_att)
|
686 |
+
if not self.normalize_before:
|
687 |
+
x = self.norm_mha(x)
|
688 |
+
|
689 |
+
# convolution module
|
690 |
+
# Fake new cnn cache here, and then change it in conv_module
|
691 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
692 |
+
if self.conv_module is not None:
|
693 |
+
residual = x
|
694 |
+
if self.normalize_before:
|
695 |
+
x = self.norm_conv(x)
|
696 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
697 |
+
x = residual + self.dropout(x)
|
698 |
+
|
699 |
+
if not self.normalize_before:
|
700 |
+
x = self.norm_conv(x)
|
701 |
+
|
702 |
+
# feed forward module
|
703 |
+
residual = x
|
704 |
+
if self.normalize_before:
|
705 |
+
x = self.norm_ff(x)
|
706 |
+
|
707 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
708 |
+
if not self.normalize_before:
|
709 |
+
x = self.norm_ff(x)
|
710 |
+
|
711 |
+
if self.conv_module is not None:
|
712 |
+
x = self.norm_final(x)
|
713 |
+
|
714 |
+
return x, mask, new_att_cache, new_cnn_cache
|
715 |
+
|
716 |
+
|
717 |
+
|
718 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
719 |
+
"""Relative positional encoding module (new implementation).
|
720 |
+
|
721 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
722 |
+
|
723 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
724 |
+
|
725 |
+
Args:
|
726 |
+
d_model (int): Embedding dimension.
|
727 |
+
dropout_rate (float): Dropout rate.
|
728 |
+
max_len (int): Maximum input length.
|
729 |
+
|
730 |
+
"""
|
731 |
+
|
732 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
733 |
+
"""Construct an PositionalEncoding object."""
|
734 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
735 |
+
self.d_model = d_model
|
736 |
+
self.xscale = math.sqrt(self.d_model)
|
737 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
738 |
+
self.pe = None
|
739 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
740 |
+
|
741 |
+
def extend_pe(self, x: torch.Tensor):
|
742 |
+
"""Reset the positional encodings."""
|
743 |
+
if self.pe is not None:
|
744 |
+
# self.pe contains both positive and negative parts
|
745 |
+
# the length of self.pe is 2 * input_len - 1
|
746 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
747 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
748 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
749 |
+
return
|
750 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
751 |
+
# position of key vector. We use position relative positions when keys
|
752 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
753 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
754 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
755 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
756 |
+
div_term = torch.exp(
|
757 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
758 |
+
* -(math.log(10000.0) / self.d_model)
|
759 |
+
)
|
760 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
761 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
762 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
763 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
764 |
+
|
765 |
+
# Reserve the order of positive indices and concat both positive and
|
766 |
+
# negative indices. This is used to support the shifting trick
|
767 |
+
# as in https://arxiv.org/abs/1901.02860
|
768 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
769 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
770 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
771 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
772 |
+
|
773 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
774 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
775 |
+
"""Add positional encoding.
|
776 |
+
|
777 |
+
Args:
|
778 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
779 |
+
|
780 |
+
Returns:
|
781 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
782 |
+
|
783 |
+
"""
|
784 |
+
self.extend_pe(x)
|
785 |
+
x = x * self.xscale
|
786 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
787 |
+
return self.dropout(x), self.dropout(pos_emb)
|
788 |
+
|
789 |
+
def position_encoding(self,
|
790 |
+
offset: Union[int, torch.Tensor],
|
791 |
+
size: int) -> torch.Tensor:
|
792 |
+
""" For getting encoding in a streaming fashion
|
793 |
+
|
794 |
+
Attention!!!!!
|
795 |
+
we apply dropout only once at the whole utterance level in a none
|
796 |
+
streaming way, but will call this function several times with
|
797 |
+
increasing input size in a streaming scenario, so the dropout will
|
798 |
+
be applied several times.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
offset (int or torch.tensor): start offset
|
802 |
+
size (int): required size of position encoding
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
torch.Tensor: Corresponding encoding
|
806 |
+
"""
|
807 |
+
pos_emb = self.pe[
|
808 |
+
:,
|
809 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
810 |
+
]
|
811 |
+
return pos_emb
|
812 |
+
|
813 |
+
|
814 |
+
|
815 |
+
class LinearEmbed(torch.nn.Module):
|
816 |
+
"""Linear transform the input without subsampling
|
817 |
+
|
818 |
+
Args:
|
819 |
+
idim (int): Input dimension.
|
820 |
+
odim (int): Output dimension.
|
821 |
+
dropout_rate (float): Dropout rate.
|
822 |
+
|
823 |
+
"""
|
824 |
+
|
825 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
826 |
+
pos_enc_class: torch.nn.Module):
|
827 |
+
"""Construct an linear object."""
|
828 |
+
super().__init__()
|
829 |
+
self.out = torch.nn.Sequential(
|
830 |
+
torch.nn.Linear(idim, odim),
|
831 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
832 |
+
torch.nn.Dropout(dropout_rate),
|
833 |
+
)
|
834 |
+
self.pos_enc = pos_enc_class #rel_pos_espnet
|
835 |
+
|
836 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
837 |
+
size: int) -> torch.Tensor:
|
838 |
+
return self.pos_enc.position_encoding(offset, size)
|
839 |
+
|
840 |
+
def forward(
|
841 |
+
self,
|
842 |
+
x: torch.Tensor,
|
843 |
+
offset: Union[int, torch.Tensor] = 0
|
844 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
845 |
+
"""Input x.
|
846 |
+
|
847 |
+
Args:
|
848 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
849 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
850 |
+
|
851 |
+
Returns:
|
852 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
853 |
+
where time' = time .
|
854 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
855 |
+
where time' = time .
|
856 |
+
|
857 |
+
"""
|
858 |
+
x = self.out(x)
|
859 |
+
x, pos_emb = self.pos_enc(x, offset)
|
860 |
+
return x, pos_emb
|
861 |
+
|
862 |
+
|
863 |
+
ATTENTION_CLASSES = {
|
864 |
+
"selfattn": MultiHeadedAttention,
|
865 |
+
"rel_selfattn": RelPositionMultiHeadedAttention,
|
866 |
+
}
|
867 |
+
|
868 |
+
ACTIVATION_CLASSES = {
|
869 |
+
"hardtanh": torch.nn.Hardtanh,
|
870 |
+
"tanh": torch.nn.Tanh,
|
871 |
+
"relu": torch.nn.ReLU,
|
872 |
+
"selu": torch.nn.SELU,
|
873 |
+
"swish": getattr(torch.nn, "SiLU", Swish),
|
874 |
+
"gelu": torch.nn.GELU,
|
875 |
+
}
|
876 |
+
|
877 |
+
|
878 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
879 |
+
"""Make mask tensor containing indices of padded part.
|
880 |
+
|
881 |
+
See description of make_non_pad_mask.
|
882 |
+
|
883 |
+
Args:
|
884 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
885 |
+
Returns:
|
886 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
887 |
+
|
888 |
+
Examples:
|
889 |
+
>>> lengths = [5, 3, 2]
|
890 |
+
>>> make_pad_mask(lengths)
|
891 |
+
masks = [[0, 0, 0, 0 ,0],
|
892 |
+
[0, 0, 0, 1, 1],
|
893 |
+
[0, 0, 1, 1, 1]]
|
894 |
+
"""
|
895 |
+
batch_size = lengths.size(0)
|
896 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
897 |
+
seq_range = torch.arange(0,
|
898 |
+
max_len,
|
899 |
+
dtype=torch.int64,
|
900 |
+
device=lengths.device)
|
901 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
902 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
903 |
+
mask = seq_range_expand >= seq_length_expand
|
904 |
+
return mask
|
905 |
+
|
906 |
+
#https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
|
907 |
+
class ConformerEncoder(torch.nn.Module):
|
908 |
+
"""Conformer encoder module."""
|
909 |
+
|
910 |
+
def __init__(
|
911 |
+
self,
|
912 |
+
input_size: int,
|
913 |
+
output_size: int = 1024,
|
914 |
+
attention_heads: int = 16,
|
915 |
+
linear_units: int = 4096,
|
916 |
+
num_blocks: int = 6,
|
917 |
+
dropout_rate: float = 0.1,
|
918 |
+
positional_dropout_rate: float = 0.1,
|
919 |
+
attention_dropout_rate: float = 0.0,
|
920 |
+
input_layer: str = 'linear',
|
921 |
+
pos_enc_layer_type: str = 'rel_pos_espnet',
|
922 |
+
normalize_before: bool = True,
|
923 |
+
static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
|
924 |
+
use_dynamic_chunk: bool = False,
|
925 |
+
use_dynamic_left_chunk: bool = False,
|
926 |
+
positionwise_conv_kernel_size: int = 1,
|
927 |
+
macaron_style: bool =False,
|
928 |
+
selfattention_layer_type: str = "rel_selfattn",
|
929 |
+
activation_type: str = "swish",
|
930 |
+
use_cnn_module: bool = False,
|
931 |
+
cnn_module_kernel: int = 15,
|
932 |
+
causal: bool = False,
|
933 |
+
cnn_module_norm: str = "batch_norm",
|
934 |
+
key_bias: bool = True,
|
935 |
+
gradient_checkpointing: bool = False,
|
936 |
+
):
|
937 |
+
"""Construct ConformerEncoder
|
938 |
+
|
939 |
+
Args:
|
940 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
941 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
942 |
+
conv1d layer.
|
943 |
+
macaron_style (bool): Whether to use macaron style for
|
944 |
+
positionwise layer.
|
945 |
+
selfattention_layer_type (str): Encoder attention layer type,
|
946 |
+
the parameter has no effect now, it's just for configure
|
947 |
+
compatibility. #'rel_selfattn'
|
948 |
+
activation_type (str): Encoder activation function type.
|
949 |
+
use_cnn_module (bool): Whether to use convolution module.
|
950 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
951 |
+
causal (bool): whether to use causal convolution or not.
|
952 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
953 |
+
"""
|
954 |
+
super().__init__()
|
955 |
+
self.output_size = output_size
|
956 |
+
self.embed = LinearEmbed(input_size, output_size, dropout_rate,
|
957 |
+
EspnetRelPositionalEncoding(output_size, positional_dropout_rate))
|
958 |
+
self.normalize_before = normalize_before
|
959 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
960 |
+
self.gradient_checkpointing = gradient_checkpointing
|
961 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
962 |
+
|
963 |
+
self.static_chunk_size = static_chunk_size
|
964 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
965 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
966 |
+
activation = ACTIVATION_CLASSES[activation_type]()
|
967 |
+
|
968 |
+
# self-attention module definition
|
969 |
+
encoder_selfattn_layer_args = (
|
970 |
+
attention_heads,
|
971 |
+
output_size,
|
972 |
+
attention_dropout_rate,
|
973 |
+
key_bias,
|
974 |
+
)
|
975 |
+
# feed-forward module definition
|
976 |
+
positionwise_layer_args = (
|
977 |
+
output_size,
|
978 |
+
linear_units,
|
979 |
+
dropout_rate,
|
980 |
+
activation,
|
981 |
+
)
|
982 |
+
# convolution module definition
|
983 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
984 |
+
cnn_module_norm, causal)
|
985 |
+
|
986 |
+
self.encoders = torch.nn.ModuleList([
|
987 |
+
ConformerEncoderLayer(
|
988 |
+
output_size,
|
989 |
+
RelPositionMultiHeadedAttention(
|
990 |
+
*encoder_selfattn_layer_args),
|
991 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
992 |
+
PositionwiseFeedForward(
|
993 |
+
*positionwise_layer_args) if macaron_style else None,
|
994 |
+
ConvolutionModule(
|
995 |
+
*convolution_layer_args) if use_cnn_module else None,
|
996 |
+
dropout_rate,
|
997 |
+
normalize_before,
|
998 |
+
) for _ in range(num_blocks)
|
999 |
+
])
|
1000 |
+
|
1001 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
1002 |
+
pos_emb: torch.Tensor,
|
1003 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
1004 |
+
for layer in self.encoders:
|
1005 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
1006 |
+
return xs
|
1007 |
+
|
1008 |
+
@torch.jit.unused
|
1009 |
+
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
1010 |
+
chunk_masks: torch.Tensor,
|
1011 |
+
pos_emb: torch.Tensor,
|
1012 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
1013 |
+
for layer in self.encoders:
|
1014 |
+
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
1015 |
+
chunk_masks, pos_emb,
|
1016 |
+
mask_pad)
|
1017 |
+
return xs
|
1018 |
+
|
1019 |
+
def forward(
|
1020 |
+
self,
|
1021 |
+
xs: torch.Tensor,
|
1022 |
+
pad_mask: torch.Tensor,
|
1023 |
+
decoding_chunk_size: int = 0,
|
1024 |
+
num_decoding_left_chunks: int = -1,
|
1025 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1026 |
+
"""Embed positions in tensor.
|
1027 |
+
|
1028 |
+
Args:
|
1029 |
+
xs: padded input tensor (B, T, D)
|
1030 |
+
xs_lens: input length (B)
|
1031 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
1032 |
+
0: default for training, use random dynamic chunk.
|
1033 |
+
<0: for decoding, use full chunk.
|
1034 |
+
>0: for decoding, use fixed chunk size as set.
|
1035 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
1036 |
+
the chunk size is decoding_chunk_size.
|
1037 |
+
>=0: use num_decoding_left_chunks
|
1038 |
+
<0: use all left chunks
|
1039 |
+
Returns:
|
1040 |
+
encoder output tensor xs, and subsampled masks
|
1041 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
1042 |
+
masks: torch.Tensor batch padding mask after subsample
|
1043 |
+
(B, 1, T' ~= T/subsample_rate)
|
1044 |
+
NOTE(xcsong):
|
1045 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
1046 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
1047 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
1048 |
+
"""
|
1049 |
+
T = xs.size(1)
|
1050 |
+
masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
|
1051 |
+
xs, pos_emb = self.embed(xs)
|
1052 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
1053 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
1054 |
+
self.use_dynamic_chunk,
|
1055 |
+
self.use_dynamic_left_chunk,
|
1056 |
+
decoding_chunk_size,
|
1057 |
+
self.static_chunk_size,
|
1058 |
+
num_decoding_left_chunks)
|
1059 |
+
if self.gradient_checkpointing and self.training:
|
1060 |
+
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
1061 |
+
mask_pad)
|
1062 |
+
else:
|
1063 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
1064 |
+
if self.normalize_before:
|
1065 |
+
xs = self.after_norm(xs)
|
1066 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
1067 |
+
# return the masks before encoder layers, and the masks will be used
|
1068 |
+
# for cross attention with decoder later
|
1069 |
+
return xs, masks
|
1070 |
+
|
models/lyrics_utils/lyric_normalizer.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from opencc import OpenCC
|
3 |
+
|
4 |
+
|
5 |
+
t2s_converter = OpenCC('t2s')
|
6 |
+
s2t_converter = OpenCC('s2t')
|
7 |
+
|
8 |
+
|
9 |
+
EMOJI_PATTERN = re.compile(
|
10 |
+
"["
|
11 |
+
"\U0001F600-\U0001F64F" # Emoticons
|
12 |
+
"]+", flags=re.UNICODE
|
13 |
+
)
|
14 |
+
|
15 |
+
# 创建一个翻译表,用于替换和移除字符
|
16 |
+
TRANSLATION_TABLE = str.maketrans({
|
17 |
+
'-': ' ', # 将 '-' 替换为空格
|
18 |
+
',': None,
|
19 |
+
'.': None,
|
20 |
+
',': None,
|
21 |
+
'。': None,
|
22 |
+
'!': None,
|
23 |
+
'!': None,
|
24 |
+
'?': None,
|
25 |
+
'?': None,
|
26 |
+
'…': None,
|
27 |
+
';': None,
|
28 |
+
';': None,
|
29 |
+
':': None,
|
30 |
+
':': None,
|
31 |
+
'\u3000': ' ', # 将全角空格替换为空格
|
32 |
+
})
|
33 |
+
|
34 |
+
# 替换括号中的内容,包括中括号和小括号
|
35 |
+
BACKSLASH_PATTERN = re.compile(r'\(.*?\)|\[.*?\]')
|
36 |
+
|
37 |
+
SPACE_PATTERN = re.compile('(?<!^)\s+(?!$)')
|
38 |
+
|
39 |
+
|
40 |
+
def normalize_text(text, language, strip=True):
|
41 |
+
"""
|
42 |
+
对文本进行标准化处理,去除标点符号,转为小写(如果适用)
|
43 |
+
"""
|
44 |
+
# Step 1: 替换 '-' 为 ' ' 并移除标点符号
|
45 |
+
text = text.translate(TRANSLATION_TABLE)
|
46 |
+
|
47 |
+
# Step 2: 移除表情符号
|
48 |
+
text = EMOJI_PATTERN.sub('', text)
|
49 |
+
|
50 |
+
# Step 3: 连续空白字符替换为单个空格,首位除外
|
51 |
+
text = SPACE_PATTERN.sub(' ', text)
|
52 |
+
|
53 |
+
# Step 4: 去除首尾空白字符(如果需要)
|
54 |
+
if strip:
|
55 |
+
text = text.strip()
|
56 |
+
|
57 |
+
# Step 5: 转为小写
|
58 |
+
text = text.lower()
|
59 |
+
|
60 |
+
# Step 6: 多语言转换
|
61 |
+
if language == "zh":
|
62 |
+
text = t2s_converter.convert(text)
|
63 |
+
if language == "yue":
|
64 |
+
text = s2t_converter.convert(text)
|
65 |
+
# 其他语言根据需要添加
|
66 |
+
return text
|
models/lyrics_utils/lyric_tokenizer.py
ADDED
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import textwrap
|
4 |
+
from functools import cached_property
|
5 |
+
|
6 |
+
import pypinyin
|
7 |
+
import torch
|
8 |
+
from hangul_romanize import Transliter
|
9 |
+
from hangul_romanize.rule import academic
|
10 |
+
from num2words import num2words
|
11 |
+
from spacy.lang.ar import Arabic
|
12 |
+
from spacy.lang.en import English
|
13 |
+
from spacy.lang.es import Spanish
|
14 |
+
from spacy.lang.ja import Japanese
|
15 |
+
from spacy.lang.zh import Chinese
|
16 |
+
from tokenizers import Tokenizer
|
17 |
+
|
18 |
+
from .zh_num2words import TextNorm as zh_num2words
|
19 |
+
from typing import Dict, List, Optional, Set, Union
|
20 |
+
|
21 |
+
|
22 |
+
#copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/tokenizer.py
|
23 |
+
def get_spacy_lang(lang):
|
24 |
+
if lang == "zh":
|
25 |
+
return Chinese()
|
26 |
+
elif lang == "ja":
|
27 |
+
return Japanese()
|
28 |
+
elif lang == "ar":
|
29 |
+
return Arabic()
|
30 |
+
elif lang == "es":
|
31 |
+
return Spanish()
|
32 |
+
else:
|
33 |
+
# For most languages, Enlish does the job
|
34 |
+
return English()
|
35 |
+
|
36 |
+
|
37 |
+
def split_sentence(text, lang, text_split_length=250):
|
38 |
+
"""Preprocess the input text"""
|
39 |
+
text_splits = []
|
40 |
+
if text_split_length is not None and len(text) >= text_split_length:
|
41 |
+
text_splits.append("")
|
42 |
+
nlp = get_spacy_lang(lang)
|
43 |
+
nlp.add_pipe("sentencizer")
|
44 |
+
doc = nlp(text)
|
45 |
+
for sentence in doc.sents:
|
46 |
+
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
47 |
+
# if the last sentence + the current sentence is less than the text_split_length
|
48 |
+
# then add the current sentence to the last sentence
|
49 |
+
text_splits[-1] += " " + str(sentence)
|
50 |
+
text_splits[-1] = text_splits[-1].lstrip()
|
51 |
+
elif len(str(sentence)) > text_split_length:
|
52 |
+
# if the current sentence is greater than the text_split_length
|
53 |
+
for line in textwrap.wrap(
|
54 |
+
str(sentence),
|
55 |
+
width=text_split_length,
|
56 |
+
drop_whitespace=True,
|
57 |
+
break_on_hyphens=False,
|
58 |
+
tabsize=1,
|
59 |
+
):
|
60 |
+
text_splits.append(str(line))
|
61 |
+
else:
|
62 |
+
text_splits.append(str(sentence))
|
63 |
+
|
64 |
+
if len(text_splits) > 1:
|
65 |
+
if text_splits[0] == "":
|
66 |
+
del text_splits[0]
|
67 |
+
else:
|
68 |
+
text_splits = [text.lstrip()]
|
69 |
+
|
70 |
+
return text_splits
|
71 |
+
|
72 |
+
|
73 |
+
_whitespace_re = re.compile(r"\s+")
|
74 |
+
|
75 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
76 |
+
_abbreviations = {
|
77 |
+
"en": [
|
78 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
79 |
+
for x in [
|
80 |
+
("mrs", "misess"),
|
81 |
+
("mr", "mister"),
|
82 |
+
("dr", "doctor"),
|
83 |
+
("st", "saint"),
|
84 |
+
("co", "company"),
|
85 |
+
("jr", "junior"),
|
86 |
+
("maj", "major"),
|
87 |
+
("gen", "general"),
|
88 |
+
("drs", "doctors"),
|
89 |
+
("rev", "reverend"),
|
90 |
+
("lt", "lieutenant"),
|
91 |
+
("hon", "honorable"),
|
92 |
+
("sgt", "sergeant"),
|
93 |
+
("capt", "captain"),
|
94 |
+
("esq", "esquire"),
|
95 |
+
("ltd", "limited"),
|
96 |
+
("col", "colonel"),
|
97 |
+
("ft", "fort"),
|
98 |
+
]
|
99 |
+
],
|
100 |
+
"es": [
|
101 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
102 |
+
for x in [
|
103 |
+
("sra", "señora"),
|
104 |
+
("sr", "señor"),
|
105 |
+
("dr", "doctor"),
|
106 |
+
("dra", "doctora"),
|
107 |
+
("st", "santo"),
|
108 |
+
("co", "compañía"),
|
109 |
+
("jr", "junior"),
|
110 |
+
("ltd", "limitada"),
|
111 |
+
]
|
112 |
+
],
|
113 |
+
"fr": [
|
114 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
115 |
+
for x in [
|
116 |
+
("mme", "madame"),
|
117 |
+
("mr", "monsieur"),
|
118 |
+
("dr", "docteur"),
|
119 |
+
("st", "saint"),
|
120 |
+
("co", "compagnie"),
|
121 |
+
("jr", "junior"),
|
122 |
+
("ltd", "limitée"),
|
123 |
+
]
|
124 |
+
],
|
125 |
+
"de": [
|
126 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
127 |
+
for x in [
|
128 |
+
("fr", "frau"),
|
129 |
+
("dr", "doktor"),
|
130 |
+
("st", "sankt"),
|
131 |
+
("co", "firma"),
|
132 |
+
("jr", "junior"),
|
133 |
+
]
|
134 |
+
],
|
135 |
+
"pt": [
|
136 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
137 |
+
for x in [
|
138 |
+
("sra", "senhora"),
|
139 |
+
("sr", "senhor"),
|
140 |
+
("dr", "doutor"),
|
141 |
+
("dra", "doutora"),
|
142 |
+
("st", "santo"),
|
143 |
+
("co", "companhia"),
|
144 |
+
("jr", "júnior"),
|
145 |
+
("ltd", "limitada"),
|
146 |
+
]
|
147 |
+
],
|
148 |
+
"it": [
|
149 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
150 |
+
for x in [
|
151 |
+
# ("sig.ra", "signora"),
|
152 |
+
("sig", "signore"),
|
153 |
+
("dr", "dottore"),
|
154 |
+
("st", "santo"),
|
155 |
+
("co", "compagnia"),
|
156 |
+
("jr", "junior"),
|
157 |
+
("ltd", "limitata"),
|
158 |
+
]
|
159 |
+
],
|
160 |
+
"pl": [
|
161 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
162 |
+
for x in [
|
163 |
+
("p", "pani"),
|
164 |
+
("m", "pan"),
|
165 |
+
("dr", "doktor"),
|
166 |
+
("sw", "święty"),
|
167 |
+
("jr", "junior"),
|
168 |
+
]
|
169 |
+
],
|
170 |
+
"ar": [
|
171 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
172 |
+
for x in [
|
173 |
+
# There are not many common abbreviations in Arabic as in English.
|
174 |
+
]
|
175 |
+
],
|
176 |
+
"zh": [
|
177 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
178 |
+
for x in [
|
179 |
+
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
180 |
+
]
|
181 |
+
],
|
182 |
+
"cs": [
|
183 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
184 |
+
for x in [
|
185 |
+
("dr", "doktor"), # doctor
|
186 |
+
("ing", "inženýr"), # engineer
|
187 |
+
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
188 |
+
# Other abbreviations would be specialized and not as common.
|
189 |
+
]
|
190 |
+
],
|
191 |
+
"ru": [
|
192 |
+
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
193 |
+
for x in [
|
194 |
+
("г-жа", "госпожа"), # Mrs.
|
195 |
+
("г-н", "господин"), # Mr.
|
196 |
+
("д-р", "доктор"), # doctor
|
197 |
+
# Other abbreviations are less common or specialized.
|
198 |
+
]
|
199 |
+
],
|
200 |
+
"nl": [
|
201 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
202 |
+
for x in [
|
203 |
+
("dhr", "de heer"), # Mr.
|
204 |
+
("mevr", "mevrouw"), # Mrs.
|
205 |
+
("dr", "dokter"), # doctor
|
206 |
+
("jhr", "jonkheer"), # young lord or nobleman
|
207 |
+
# Dutch uses more abbreviations, but these are the most common ones.
|
208 |
+
]
|
209 |
+
],
|
210 |
+
"tr": [
|
211 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
212 |
+
for x in [
|
213 |
+
("b", "bay"), # Mr.
|
214 |
+
("byk", "büyük"), # büyük
|
215 |
+
("dr", "doktor"), # doctor
|
216 |
+
# Add other Turkish abbreviations here if needed.
|
217 |
+
]
|
218 |
+
],
|
219 |
+
"hu": [
|
220 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
221 |
+
for x in [
|
222 |
+
("dr", "doktor"), # doctor
|
223 |
+
("b", "bácsi"), # Mr.
|
224 |
+
("nőv", "nővér"), # nurse
|
225 |
+
# Add other Hungarian abbreviations here if needed.
|
226 |
+
]
|
227 |
+
],
|
228 |
+
"ko": [
|
229 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
230 |
+
for x in [
|
231 |
+
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
|
232 |
+
]
|
233 |
+
],
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
def expand_abbreviations_multilingual(text, lang="en"):
|
238 |
+
for regex, replacement in _abbreviations[lang]:
|
239 |
+
text = re.sub(regex, replacement, text)
|
240 |
+
return text
|
241 |
+
|
242 |
+
|
243 |
+
_symbols_multilingual = {
|
244 |
+
"en": [
|
245 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
246 |
+
for x in [
|
247 |
+
("&", " and "),
|
248 |
+
("@", " at "),
|
249 |
+
("%", " percent "),
|
250 |
+
("#", " hash "),
|
251 |
+
("$", " dollar "),
|
252 |
+
("£", " pound "),
|
253 |
+
("°", " degree "),
|
254 |
+
]
|
255 |
+
],
|
256 |
+
"es": [
|
257 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
258 |
+
for x in [
|
259 |
+
("&", " y "),
|
260 |
+
("@", " arroba "),
|
261 |
+
("%", " por ciento "),
|
262 |
+
("#", " numeral "),
|
263 |
+
("$", " dolar "),
|
264 |
+
("£", " libra "),
|
265 |
+
("°", " grados "),
|
266 |
+
]
|
267 |
+
],
|
268 |
+
"fr": [
|
269 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
270 |
+
for x in [
|
271 |
+
("&", " et "),
|
272 |
+
("@", " arobase "),
|
273 |
+
("%", " pour cent "),
|
274 |
+
("#", " dièse "),
|
275 |
+
("$", " dollar "),
|
276 |
+
("£", " livre "),
|
277 |
+
("°", " degrés "),
|
278 |
+
]
|
279 |
+
],
|
280 |
+
"de": [
|
281 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
282 |
+
for x in [
|
283 |
+
("&", " und "),
|
284 |
+
("@", " at "),
|
285 |
+
("%", " prozent "),
|
286 |
+
("#", " raute "),
|
287 |
+
("$", " dollar "),
|
288 |
+
("£", " pfund "),
|
289 |
+
("°", " grad "),
|
290 |
+
]
|
291 |
+
],
|
292 |
+
"pt": [
|
293 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
294 |
+
for x in [
|
295 |
+
("&", " e "),
|
296 |
+
("@", " arroba "),
|
297 |
+
("%", " por cento "),
|
298 |
+
("#", " cardinal "),
|
299 |
+
("$", " dólar "),
|
300 |
+
("£", " libra "),
|
301 |
+
("°", " graus "),
|
302 |
+
]
|
303 |
+
],
|
304 |
+
"it": [
|
305 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
306 |
+
for x in [
|
307 |
+
("&", " e "),
|
308 |
+
("@", " chiocciola "),
|
309 |
+
("%", " per cento "),
|
310 |
+
("#", " cancelletto "),
|
311 |
+
("$", " dollaro "),
|
312 |
+
("£", " sterlina "),
|
313 |
+
("°", " gradi "),
|
314 |
+
]
|
315 |
+
],
|
316 |
+
"pl": [
|
317 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
318 |
+
for x in [
|
319 |
+
("&", " i "),
|
320 |
+
("@", " małpa "),
|
321 |
+
("%", " procent "),
|
322 |
+
("#", " krzyżyk "),
|
323 |
+
("$", " dolar "),
|
324 |
+
("£", " funt "),
|
325 |
+
("°", " stopnie "),
|
326 |
+
]
|
327 |
+
],
|
328 |
+
"ar": [
|
329 |
+
# Arabic
|
330 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
331 |
+
for x in [
|
332 |
+
("&", " و "),
|
333 |
+
("@", " على "),
|
334 |
+
("%", " في المئة "),
|
335 |
+
("#", " رقم "),
|
336 |
+
("$", " دولار "),
|
337 |
+
("£", " جنيه "),
|
338 |
+
("°", " درجة "),
|
339 |
+
]
|
340 |
+
],
|
341 |
+
"zh": [
|
342 |
+
# Chinese
|
343 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
344 |
+
for x in [
|
345 |
+
("&", " 和 "),
|
346 |
+
("@", " 在 "),
|
347 |
+
("%", " 百分之 "),
|
348 |
+
("#", " 号 "),
|
349 |
+
("$", " 美元 "),
|
350 |
+
("£", " 英镑 "),
|
351 |
+
("°", " 度 "),
|
352 |
+
]
|
353 |
+
],
|
354 |
+
"cs": [
|
355 |
+
# Czech
|
356 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
357 |
+
for x in [
|
358 |
+
("&", " a "),
|
359 |
+
("@", " na "),
|
360 |
+
("%", " procento "),
|
361 |
+
("#", " křížek "),
|
362 |
+
("$", " dolar "),
|
363 |
+
("£", " libra "),
|
364 |
+
("°", " stupně "),
|
365 |
+
]
|
366 |
+
],
|
367 |
+
"ru": [
|
368 |
+
# Russian
|
369 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
370 |
+
for x in [
|
371 |
+
("&", " и "),
|
372 |
+
("@", " собака "),
|
373 |
+
("%", " процентов "),
|
374 |
+
("#", " номер "),
|
375 |
+
("$", " доллар "),
|
376 |
+
("£", " фунт "),
|
377 |
+
("°", " градус "),
|
378 |
+
]
|
379 |
+
],
|
380 |
+
"nl": [
|
381 |
+
# Dutch
|
382 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
383 |
+
for x in [
|
384 |
+
("&", " en "),
|
385 |
+
("@", " bij "),
|
386 |
+
("%", " procent "),
|
387 |
+
("#", " hekje "),
|
388 |
+
("$", " dollar "),
|
389 |
+
("£", " pond "),
|
390 |
+
("°", " graden "),
|
391 |
+
]
|
392 |
+
],
|
393 |
+
"tr": [
|
394 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
395 |
+
for x in [
|
396 |
+
("&", " ve "),
|
397 |
+
("@", " at "),
|
398 |
+
("%", " yüzde "),
|
399 |
+
("#", " diyez "),
|
400 |
+
("$", " dolar "),
|
401 |
+
("£", " sterlin "),
|
402 |
+
("°", " derece "),
|
403 |
+
]
|
404 |
+
],
|
405 |
+
"hu": [
|
406 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
407 |
+
for x in [
|
408 |
+
("&", " és "),
|
409 |
+
("@", " kukac "),
|
410 |
+
("%", " százalék "),
|
411 |
+
("#", " kettőskereszt "),
|
412 |
+
("$", " dollár "),
|
413 |
+
("£", " font "),
|
414 |
+
("°", " fok "),
|
415 |
+
]
|
416 |
+
],
|
417 |
+
"ko": [
|
418 |
+
# Korean
|
419 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
420 |
+
for x in [
|
421 |
+
("&", " 그리고 "),
|
422 |
+
("@", " 에 "),
|
423 |
+
("%", " 퍼센트 "),
|
424 |
+
("#", " 번호 "),
|
425 |
+
("$", " 달러 "),
|
426 |
+
("£", " 파운드 "),
|
427 |
+
("°", " 도 "),
|
428 |
+
]
|
429 |
+
],
|
430 |
+
}
|
431 |
+
|
432 |
+
|
433 |
+
def expand_symbols_multilingual(text, lang="en"):
|
434 |
+
for regex, replacement in _symbols_multilingual[lang]:
|
435 |
+
text = re.sub(regex, replacement, text)
|
436 |
+
text = text.replace(" ", " ") # Ensure there are no double spaces
|
437 |
+
return text.strip()
|
438 |
+
|
439 |
+
|
440 |
+
_ordinal_re = {
|
441 |
+
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
442 |
+
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
|
443 |
+
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
|
444 |
+
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
|
445 |
+
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
|
446 |
+
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
|
447 |
+
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
|
448 |
+
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
|
449 |
+
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
450 |
+
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
451 |
+
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
452 |
+
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
453 |
+
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
|
454 |
+
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
|
455 |
+
}
|
456 |
+
_number_re = re.compile(r"[0-9]+")
|
457 |
+
_currency_re = {
|
458 |
+
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
459 |
+
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
460 |
+
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
461 |
+
}
|
462 |
+
|
463 |
+
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
464 |
+
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
465 |
+
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
466 |
+
|
467 |
+
|
468 |
+
def _remove_commas(m):
|
469 |
+
text = m.group(0)
|
470 |
+
if "," in text:
|
471 |
+
text = text.replace(",", "")
|
472 |
+
return text
|
473 |
+
|
474 |
+
|
475 |
+
def _remove_dots(m):
|
476 |
+
text = m.group(0)
|
477 |
+
if "." in text:
|
478 |
+
text = text.replace(".", "")
|
479 |
+
return text
|
480 |
+
|
481 |
+
|
482 |
+
def _expand_decimal_point(m, lang="en"):
|
483 |
+
amount = m.group(1).replace(",", ".")
|
484 |
+
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
485 |
+
|
486 |
+
|
487 |
+
def _expand_currency(m, lang="en", currency="USD"):
|
488 |
+
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
489 |
+
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
490 |
+
|
491 |
+
and_equivalents = {
|
492 |
+
"en": ", ",
|
493 |
+
"es": " con ",
|
494 |
+
"fr": " et ",
|
495 |
+
"de": " und ",
|
496 |
+
"pt": " e ",
|
497 |
+
"it": " e ",
|
498 |
+
"pl": ", ",
|
499 |
+
"cs": ", ",
|
500 |
+
"ru": ", ",
|
501 |
+
"nl": ", ",
|
502 |
+
"ar": ", ",
|
503 |
+
"tr": ", ",
|
504 |
+
"hu": ", ",
|
505 |
+
"ko": ", ",
|
506 |
+
}
|
507 |
+
|
508 |
+
if amount.is_integer():
|
509 |
+
last_and = full_amount.rfind(and_equivalents[lang])
|
510 |
+
if last_and != -1:
|
511 |
+
full_amount = full_amount[:last_and]
|
512 |
+
|
513 |
+
return full_amount
|
514 |
+
|
515 |
+
|
516 |
+
def _expand_ordinal(m, lang="en"):
|
517 |
+
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
518 |
+
|
519 |
+
|
520 |
+
def _expand_number(m, lang="en"):
|
521 |
+
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
522 |
+
|
523 |
+
|
524 |
+
def expand_numbers_multilingual(text, lang="en"):
|
525 |
+
if lang == "zh":
|
526 |
+
text = zh_num2words()(text)
|
527 |
+
else:
|
528 |
+
if lang in ["en", "ru"]:
|
529 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
530 |
+
else:
|
531 |
+
text = re.sub(_dot_number_re, _remove_dots, text)
|
532 |
+
try:
|
533 |
+
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
534 |
+
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
535 |
+
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
536 |
+
except:
|
537 |
+
pass
|
538 |
+
if lang != "tr":
|
539 |
+
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
540 |
+
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
541 |
+
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
542 |
+
return text
|
543 |
+
|
544 |
+
|
545 |
+
def lowercase(text):
|
546 |
+
return text.lower()
|
547 |
+
|
548 |
+
|
549 |
+
def collapse_whitespace(text):
|
550 |
+
return re.sub(_whitespace_re, " ", text)
|
551 |
+
|
552 |
+
|
553 |
+
def multilingual_cleaners(text, lang):
|
554 |
+
text = text.replace('"', "")
|
555 |
+
if lang == "tr":
|
556 |
+
text = text.replace("İ", "i")
|
557 |
+
text = text.replace("Ö", "ö")
|
558 |
+
text = text.replace("Ü", "ü")
|
559 |
+
text = lowercase(text)
|
560 |
+
try:
|
561 |
+
text = expand_numbers_multilingual(text, lang)
|
562 |
+
except:
|
563 |
+
pass
|
564 |
+
try:
|
565 |
+
text = expand_abbreviations_multilingual(text, lang)
|
566 |
+
except:
|
567 |
+
pass
|
568 |
+
try:
|
569 |
+
text = expand_symbols_multilingual(text, lang=lang)
|
570 |
+
except:
|
571 |
+
pass
|
572 |
+
text = collapse_whitespace(text)
|
573 |
+
return text
|
574 |
+
|
575 |
+
|
576 |
+
def basic_cleaners(text):
|
577 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
578 |
+
text = lowercase(text)
|
579 |
+
text = collapse_whitespace(text)
|
580 |
+
return text
|
581 |
+
|
582 |
+
|
583 |
+
def chinese_transliterate(text):
|
584 |
+
return "".join(
|
585 |
+
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
586 |
+
)
|
587 |
+
|
588 |
+
|
589 |
+
def japanese_cleaners(text, katsu):
|
590 |
+
text = katsu.romaji(text)
|
591 |
+
text = lowercase(text)
|
592 |
+
return text
|
593 |
+
|
594 |
+
|
595 |
+
def korean_transliterate(text):
|
596 |
+
r = Transliter(academic)
|
597 |
+
return r.translit(text)
|
598 |
+
|
599 |
+
|
600 |
+
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vocab.json")
|
601 |
+
|
602 |
+
|
603 |
+
class VoiceBpeTokenizer:
|
604 |
+
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
605 |
+
self.tokenizer = None
|
606 |
+
if vocab_file is not None:
|
607 |
+
self.tokenizer = Tokenizer.from_file(vocab_file)
|
608 |
+
self.char_limits = {
|
609 |
+
"en": 10000,
|
610 |
+
"de": 253,
|
611 |
+
"fr": 273,
|
612 |
+
"es": 239,
|
613 |
+
"it": 213,
|
614 |
+
"pt": 203,
|
615 |
+
"pl": 224,
|
616 |
+
"zh": 82,
|
617 |
+
"ar": 166,
|
618 |
+
"cs": 186,
|
619 |
+
"ru": 182,
|
620 |
+
"nl": 251,
|
621 |
+
"tr": 226,
|
622 |
+
"ja": 71,
|
623 |
+
"hu": 224,
|
624 |
+
"ko": 95,
|
625 |
+
}
|
626 |
+
|
627 |
+
@cached_property
|
628 |
+
def katsu(self):
|
629 |
+
import cutlet
|
630 |
+
|
631 |
+
return cutlet.Cutlet()
|
632 |
+
|
633 |
+
def check_input_length(self, txt, lang):
|
634 |
+
lang = lang.split("-")[0] # remove the region
|
635 |
+
limit = self.char_limits.get(lang, 250)
|
636 |
+
# if len(txt) > limit:
|
637 |
+
# print(
|
638 |
+
# f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
639 |
+
# )
|
640 |
+
|
641 |
+
def preprocess_text(self, txt, lang):
|
642 |
+
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
643 |
+
txt = multilingual_cleaners(txt, lang)
|
644 |
+
if lang == "zh":
|
645 |
+
txt = chinese_transliterate(txt)
|
646 |
+
if lang == "ko":
|
647 |
+
txt = korean_transliterate(txt)
|
648 |
+
elif lang == "ja":
|
649 |
+
txt = japanese_cleaners(txt, self.katsu)
|
650 |
+
elif lang == "hi":
|
651 |
+
# @manmay will implement this
|
652 |
+
txt = basic_cleaners(txt)
|
653 |
+
else:
|
654 |
+
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
655 |
+
return txt
|
656 |
+
|
657 |
+
def encode(self, txt, lang):
|
658 |
+
lang = lang.split("-")[0] # remove the region
|
659 |
+
self.check_input_length(txt, lang)
|
660 |
+
txt = self.preprocess_text(txt, lang)
|
661 |
+
lang = "zh-cn" if lang == "zh" else lang
|
662 |
+
txt = f"[{lang}]{txt}"
|
663 |
+
txt = txt.replace(" ", "[SPACE]")
|
664 |
+
return self.tokenizer.encode(txt).ids
|
665 |
+
|
666 |
+
def decode(self, seq, skip_special_tokens=False):
|
667 |
+
if isinstance(seq, torch.Tensor):
|
668 |
+
seq = seq.cpu().numpy()
|
669 |
+
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
|
670 |
+
txt = txt.replace("[SPACE]", " ")
|
671 |
+
txt = txt.replace("[STOP]", "")
|
672 |
+
# txt = txt.replace("[UNK]", "")
|
673 |
+
return txt
|
674 |
+
|
675 |
+
|
676 |
+
#copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936
|
677 |
+
def batch_decode(
|
678 |
+
self,
|
679 |
+
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
680 |
+
skip_special_tokens: bool = False,
|
681 |
+
) -> List[str]:
|
682 |
+
"""
|
683 |
+
Convert a list of lists of token ids into a list of strings by calling decode.
|
684 |
+
|
685 |
+
Args:
|
686 |
+
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
687 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
688 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
689 |
+
Whether or not to remove special tokens in the decoding.
|
690 |
+
kwargs (additional keyword arguments, *optional*):
|
691 |
+
Will be passed to the underlying model specific decode method.
|
692 |
+
|
693 |
+
Returns:
|
694 |
+
`List[str]`: The list of decoded sentences.
|
695 |
+
"""
|
696 |
+
return [
|
697 |
+
self.decode(seq)
|
698 |
+
for seq in sequences
|
699 |
+
]
|
700 |
+
|
701 |
+
#https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/trainer/dataset.py#L202
|
702 |
+
# def pad(self):
|
703 |
+
|
704 |
+
def __len__(self):
|
705 |
+
return self.tokenizer.get_vocab_size()
|
706 |
+
|
707 |
+
def get_number_tokens(self):
|
708 |
+
return max(self.tokenizer.get_vocab().values()) + 1
|
709 |
+
|
710 |
+
|
711 |
+
def test_expand_numbers_multilingual():
|
712 |
+
test_cases = [
|
713 |
+
# English
|
714 |
+
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
|
715 |
+
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
|
716 |
+
("This is a 1st test", "This is a first test", "en"),
|
717 |
+
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
|
718 |
+
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
|
719 |
+
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
|
720 |
+
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
|
721 |
+
# French
|
722 |
+
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
|
723 |
+
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
|
724 |
+
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
|
725 |
+
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
|
726 |
+
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
|
727 |
+
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
|
728 |
+
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
|
729 |
+
# German
|
730 |
+
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
|
731 |
+
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
|
732 |
+
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
|
733 |
+
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
|
734 |
+
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
|
735 |
+
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
|
736 |
+
# Spanish
|
737 |
+
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
|
738 |
+
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
|
739 |
+
("Este es un 1er test", "Este es un primero test", "es"),
|
740 |
+
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
|
741 |
+
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
|
742 |
+
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
|
743 |
+
# Italian
|
744 |
+
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
|
745 |
+
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
|
746 |
+
("Questo è un 1° test", "Questo è un primo test", "it"),
|
747 |
+
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
|
748 |
+
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
|
749 |
+
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
|
750 |
+
# Portuguese
|
751 |
+
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
|
752 |
+
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
|
753 |
+
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
|
754 |
+
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
|
755 |
+
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
|
756 |
+
(
|
757 |
+
"Isso custará 20,15€ senhor.",
|
758 |
+
"Isso custará vinte euros e quinze cêntimos senhor.",
|
759 |
+
"pt",
|
760 |
+
), # "cêntimos" should be "centavos" num2words issue
|
761 |
+
# Polish
|
762 |
+
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
|
763 |
+
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
|
764 |
+
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
|
765 |
+
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
|
766 |
+
# Arabic
|
767 |
+
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
|
768 |
+
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
|
769 |
+
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
|
770 |
+
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
|
771 |
+
# Czech
|
772 |
+
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
|
773 |
+
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
|
774 |
+
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
|
775 |
+
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
|
776 |
+
# Russian
|
777 |
+
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
|
778 |
+
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
|
779 |
+
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
|
780 |
+
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
|
781 |
+
# Dutch
|
782 |
+
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
|
783 |
+
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
|
784 |
+
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
785 |
+
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
786 |
+
# Chinese (Simplified)
|
787 |
+
("在12.5秒内", "在十二点五秒内", "zh"),
|
788 |
+
("有50名士兵", "有五十名士兵", "zh"),
|
789 |
+
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
790 |
+
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
791 |
+
# Turkish
|
792 |
+
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
|
793 |
+
("50 asker vardı.", "elli asker vardı.", "tr"),
|
794 |
+
("Bu 1. test", "Bu birinci test", "tr"),
|
795 |
+
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
|
796 |
+
# Hungarian
|
797 |
+
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
|
798 |
+
("50 katona volt.", "ötven katona volt.", "hu"),
|
799 |
+
("Ez az 1. teszt", "Ez az első teszt", "hu"),
|
800 |
+
# Korean
|
801 |
+
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
|
802 |
+
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
|
803 |
+
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
|
804 |
+
]
|
805 |
+
for a, b, lang in test_cases:
|
806 |
+
out = expand_numbers_multilingual(a, lang=lang)
|
807 |
+
assert out == b, f"'{out}' vs '{b}'"
|
808 |
+
|
809 |
+
|
810 |
+
def test_abbreviations_multilingual():
|
811 |
+
test_cases = [
|
812 |
+
# English
|
813 |
+
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
|
814 |
+
("Dr. Jones is here.", "doctor Jones is here.", "en"),
|
815 |
+
# Spanish
|
816 |
+
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
|
817 |
+
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
|
818 |
+
# French
|
819 |
+
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
|
820 |
+
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
|
821 |
+
# German
|
822 |
+
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
|
823 |
+
# Portuguese
|
824 |
+
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
|
825 |
+
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
|
826 |
+
# Italian
|
827 |
+
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
|
828 |
+
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
|
829 |
+
# Polish
|
830 |
+
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
|
831 |
+
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
|
832 |
+
# Czech
|
833 |
+
("P. Novák", "pan Novák", "cs"),
|
834 |
+
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
|
835 |
+
# Dutch
|
836 |
+
("Dhr. Jansen", "de heer Jansen", "nl"),
|
837 |
+
("Mevr. de Vries", "mevrouw de Vries", "nl"),
|
838 |
+
# Russian
|
839 |
+
("Здравствуйте Г-�� Иванов.", "Здравствуйте господин Иванов.", "ru"),
|
840 |
+
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
|
841 |
+
# Turkish
|
842 |
+
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
|
843 |
+
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
|
844 |
+
# Hungarian
|
845 |
+
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
|
846 |
+
]
|
847 |
+
|
848 |
+
for a, b, lang in test_cases:
|
849 |
+
out = expand_abbreviations_multilingual(a, lang=lang)
|
850 |
+
assert out == b, f"'{out}' vs '{b}'"
|
851 |
+
|
852 |
+
|
853 |
+
def test_symbols_multilingual():
|
854 |
+
test_cases = [
|
855 |
+
("I have 14% battery", "I have 14 percent battery", "en"),
|
856 |
+
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
|
857 |
+
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
|
858 |
+
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
|
859 |
+
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
|
860 |
+
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
|
861 |
+
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
|
862 |
+
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
|
863 |
+
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
|
864 |
+
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
|
865 |
+
("Я буду @ дома", "Я буду собака дома", "ru"),
|
866 |
+
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
867 |
+
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
868 |
+
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
869 |
+
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
870 |
+
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
871 |
+
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
872 |
+
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
873 |
+
]
|
874 |
+
|
875 |
+
for a, b, lang in test_cases:
|
876 |
+
out = expand_symbols_multilingual(a, lang=lang)
|
877 |
+
assert out == b, f"'{out}' vs '{b}'"
|
878 |
+
|
879 |
+
|
880 |
+
if __name__ == "__main__":
|
881 |
+
test_expand_numbers_multilingual()
|
882 |
+
test_abbreviations_multilingual()
|
883 |
+
test_symbols_multilingual()
|
models/lyrics_utils/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/lyrics_utils/zh_num2words.py
ADDED
@@ -0,0 +1,1209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors:
|
2 |
+
# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
|
3 |
+
# 2019.9 - 2022 Jiayu DU
|
4 |
+
#copy from https://github.com/coqui-ai/TTS/blob/dbf1a08a0d4e47fdad6172e433eeb34bc6b13b4e/TTS/tts/layers/xtts/zh_num2words.py
|
5 |
+
import argparse
|
6 |
+
import csv
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import string
|
10 |
+
import sys
|
11 |
+
|
12 |
+
# fmt: off
|
13 |
+
|
14 |
+
# ================================================================================ #
|
15 |
+
# basic constant
|
16 |
+
# ================================================================================ #
|
17 |
+
CHINESE_DIGIS = "零一二三四五六七八九"
|
18 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
|
19 |
+
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
|
20 |
+
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
|
21 |
+
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
|
22 |
+
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
|
23 |
+
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
|
24 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
|
25 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
|
26 |
+
|
27 |
+
ZERO_ALT = "〇"
|
28 |
+
ONE_ALT = "幺"
|
29 |
+
TWO_ALTS = ["两", "兩"]
|
30 |
+
|
31 |
+
POSITIVE = ["正", "正"]
|
32 |
+
NEGATIVE = ["负", "負"]
|
33 |
+
POINT = ["点", "點"]
|
34 |
+
# PLUS = [u'加', u'加']
|
35 |
+
# SIL = [u'杠', u'槓']
|
36 |
+
|
37 |
+
FILLER_CHARS = ["呃", "啊"]
|
38 |
+
|
39 |
+
ER_WHITELIST = (
|
40 |
+
"(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
|
41 |
+
"胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
|
42 |
+
"儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
|
43 |
+
"佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
|
44 |
+
)
|
45 |
+
ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
|
46 |
+
|
47 |
+
# 中文数字系统类型
|
48 |
+
NUMBERING_TYPES = ["low", "mid", "high"]
|
49 |
+
|
50 |
+
CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
|
51 |
+
CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
52 |
+
COM_QUANTIFIERS = (
|
53 |
+
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
|
54 |
+
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
|
55 |
+
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
|
56 |
+
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
|
57 |
+
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
|
58 |
+
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
|
63 |
+
CN_PUNCS_STOP = "!?。。"
|
64 |
+
CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-"
|
65 |
+
CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
|
66 |
+
|
67 |
+
PUNCS = CN_PUNCS + string.punctuation
|
68 |
+
PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma
|
69 |
+
|
70 |
+
|
71 |
+
# https://zh.wikipedia.org/wiki/全行和半行
|
72 |
+
QJ2BJ = {
|
73 |
+
" ": " ",
|
74 |
+
"!": "!",
|
75 |
+
""": '"',
|
76 |
+
"#": "#",
|
77 |
+
"$": "$",
|
78 |
+
"%": "%",
|
79 |
+
"&": "&",
|
80 |
+
"'": "'",
|
81 |
+
"(": "(",
|
82 |
+
")": ")",
|
83 |
+
"*": "*",
|
84 |
+
"+": "+",
|
85 |
+
",": ",",
|
86 |
+
"-": "-",
|
87 |
+
".": ".",
|
88 |
+
"/": "/",
|
89 |
+
"0": "0",
|
90 |
+
"1": "1",
|
91 |
+
"2": "2",
|
92 |
+
"3": "3",
|
93 |
+
"4": "4",
|
94 |
+
"5": "5",
|
95 |
+
"6": "6",
|
96 |
+
"7": "7",
|
97 |
+
"8": "8",
|
98 |
+
"9": "9",
|
99 |
+
":": ":",
|
100 |
+
";": ";",
|
101 |
+
"<": "<",
|
102 |
+
"=": "=",
|
103 |
+
">": ">",
|
104 |
+
"?": "?",
|
105 |
+
"@": "@",
|
106 |
+
"A": "A",
|
107 |
+
"B": "B",
|
108 |
+
"C": "C",
|
109 |
+
"D": "D",
|
110 |
+
"E": "E",
|
111 |
+
"F": "F",
|
112 |
+
"G": "G",
|
113 |
+
"H": "H",
|
114 |
+
"I": "I",
|
115 |
+
"J": "J",
|
116 |
+
"K": "K",
|
117 |
+
"L": "L",
|
118 |
+
"M": "M",
|
119 |
+
"N": "N",
|
120 |
+
"O": "O",
|
121 |
+
"P": "P",
|
122 |
+
"Q": "Q",
|
123 |
+
"R": "R",
|
124 |
+
"S": "S",
|
125 |
+
"T": "T",
|
126 |
+
"U": "U",
|
127 |
+
"V": "V",
|
128 |
+
"W": "W",
|
129 |
+
"X": "X",
|
130 |
+
"Y": "Y",
|
131 |
+
"Z": "Z",
|
132 |
+
"[": "[",
|
133 |
+
"\": "\\",
|
134 |
+
"]": "]",
|
135 |
+
"^": "^",
|
136 |
+
"_": "_",
|
137 |
+
"`": "`",
|
138 |
+
"a": "a",
|
139 |
+
"b": "b",
|
140 |
+
"c": "c",
|
141 |
+
"d": "d",
|
142 |
+
"e": "e",
|
143 |
+
"f": "f",
|
144 |
+
"g": "g",
|
145 |
+
"h": "h",
|
146 |
+
"i": "i",
|
147 |
+
"j": "j",
|
148 |
+
"k": "k",
|
149 |
+
"l": "l",
|
150 |
+
"m": "m",
|
151 |
+
"n": "n",
|
152 |
+
"o": "o",
|
153 |
+
"p": "p",
|
154 |
+
"q": "q",
|
155 |
+
"r": "r",
|
156 |
+
"s": "s",
|
157 |
+
"t": "t",
|
158 |
+
"u": "u",
|
159 |
+
"v": "v",
|
160 |
+
"w": "w",
|
161 |
+
"x": "x",
|
162 |
+
"y": "y",
|
163 |
+
"z": "z",
|
164 |
+
"{": "{",
|
165 |
+
"|": "|",
|
166 |
+
"}": "}",
|
167 |
+
"~": "~",
|
168 |
+
}
|
169 |
+
QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "")
|
170 |
+
|
171 |
+
|
172 |
+
# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources:
|
173 |
+
# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
|
174 |
+
CN_CHARS_COMMON = (
|
175 |
+
"一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举"
|
176 |
+
"乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互"
|
177 |
+
"亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从"
|
178 |
+
"仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优"
|
179 |
+
"伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚"
|
180 |
+
"佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣"
|
181 |
+
"侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯"
|
182 |
+
"俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌"
|
183 |
+
"偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚"
|
184 |
+
"僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六"
|
185 |
+
"兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况"
|
186 |
+
"冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈"
|
187 |
+
"刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐"
|
188 |
+
"剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼"
|
189 |
+
"劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹"
|
190 |
+
"区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵"
|
191 |
+
"卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔"
|
192 |
+
"叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊"
|
193 |
+
"同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆"
|
194 |
+
"呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎"
|
195 |
+
"咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌"
|
196 |
+
"响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛"
|
197 |
+
"唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴"
|
198 |
+
"啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌"
|
199 |
+
"嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡"
|
200 |
+
"嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓"
|
201 |
+
"嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢"
|
202 |
+
"圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥"
|
203 |
+
"坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩"
|
204 |
+
"垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基"
|
205 |
+
"埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填"
|
206 |
+
"塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复"
|
207 |
+
"夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖"
|
208 |
+
"套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮"
|
209 |
+
"妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈"
|
210 |
+
"娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺���"
|
211 |
+
"婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱"
|
212 |
+
"嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽"
|
213 |
+
"宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾"
|
214 |
+
"宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝"
|
215 |
+
"尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山"
|
216 |
+
"屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃"
|
217 |
+
"峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧"
|
218 |
+
"崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉"
|
219 |
+
"巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡"
|
220 |
+
"带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐"
|
221 |
+
"庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷"
|
222 |
+
"建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖"
|
223 |
+
"彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循"
|
224 |
+
"徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀"
|
225 |
+
"态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓"
|
226 |
+
"恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟"
|
227 |
+
"悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭"
|
228 |
+
"惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥"
|
229 |
+
"慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我"
|
230 |
+
"戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔"
|
231 |
+
"托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡"
|
232 |
+
"抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥"
|
233 |
+
"拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫"
|
234 |
+
"振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎"
|
235 |
+
"掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭"
|
236 |
+
"揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘"
|
237 |
+
"摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢"
|
238 |
+
"擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整"
|
239 |
+
"敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗"
|
240 |
+
"旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝"
|
241 |
+
"星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡"
|
242 |
+
"晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜"
|
243 |
+
"曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽"
|
244 |
+
"杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅"
|
245 |
+
"枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔"
|
246 |
+
"柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩"
|
247 |
+
"株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯"
|
248 |
+
"桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘"
|
249 |
+
"棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂"
|
250 |
+
"楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰���"
|
251 |
+
"榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞"
|
252 |
+
"橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正"
|
253 |
+
"此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒"
|
254 |
+
"毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮"
|
255 |
+
"氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽"
|
256 |
+
"汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽"
|
257 |
+
"沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼"
|
258 |
+
"泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿"
|
259 |
+
"流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸"
|
260 |
+
"浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅"
|
261 |
+
"淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟"
|
262 |
+
"渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆"
|
263 |
+
"溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕"
|
264 |
+
"滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶"
|
265 |
+
"漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧"
|
266 |
+
"澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵"
|
267 |
+
"灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈"
|
268 |
+
"烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯"
|
269 |
+
"焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵"
|
270 |
+
"熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍"
|
271 |
+
"牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷"
|
272 |
+
"犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎"
|
273 |
+
"猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎"
|
274 |
+
"玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊"
|
275 |
+
"珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊"
|
276 |
+
"琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙"
|
277 |
+
"瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱"
|
278 |
+
"璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯"
|
279 |
+
"田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐"
|
280 |
+
"疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒"
|
281 |
+
"痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩"
|
282 |
+
"瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙"
|
283 |
+
"皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾"
|
284 |
+
"省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦"
|
285 |
+
"睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知"
|
286 |
+
"矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮"
|
287 |
+
"砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍"
|
288 |
+
"碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷"
|
289 |
+
"磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭"
|
290 |
+
"祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租���"
|
291 |
+
"秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗"
|
292 |
+
"穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立"
|
293 |
+
"竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯"
|
294 |
+
"笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅"
|
295 |
+
"箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾"
|
296 |
+
"簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮"
|
297 |
+
"粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢"
|
298 |
+
"縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁"
|
299 |
+
"绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩"
|
300 |
+
"绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕"
|
301 |
+
"编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐"
|
302 |
+
"网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲"
|
303 |
+
"羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者"
|
304 |
+
"耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋"
|
305 |
+
"职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸"
|
306 |
+
"肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳"
|
307 |
+
"胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒"
|
308 |
+
"腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻"
|
309 |
+
"臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般"
|
310 |
+
"舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊"
|
311 |
+
"芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈"
|
312 |
+
"苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆"
|
313 |
+
"茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐"
|
314 |
+
"荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛"
|
315 |
+
"莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥"
|
316 |
+
"菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著"
|
317 |
+
"葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺"
|
318 |
+
"蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷"
|
319 |
+
"蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸"
|
320 |
+
"薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱"
|
321 |
+
"虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆"
|
322 |
+
"蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗"
|
323 |
+
"蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃"
|
324 |
+
"螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡"
|
325 |
+
"蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒"
|
326 |
+
"袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂"
|
327 |
+
"褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉"
|
328 |
+
"觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认"
|
329 |
+
"讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词"
|
330 |
+
"诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵���"
|
331 |
+
"诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡"
|
332 |
+
"谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹"
|
333 |
+
"豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼"
|
334 |
+
"贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤"
|
335 |
+
"赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑"
|
336 |
+
"跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮"
|
337 |
+
"踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇"
|
338 |
+
"躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较"
|
339 |
+
"辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄"
|
340 |
+
"迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆"
|
341 |
+
"选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒"
|
342 |
+
"道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱"
|
343 |
+
"邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴"
|
344 |
+
"郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝"
|
345 |
+
"酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭"
|
346 |
+
"醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒"
|
347 |
+
"钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼"
|
348 |
+
"钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨"
|
349 |
+
"铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐"
|
350 |
+
"锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹"
|
351 |
+
"锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣"
|
352 |
+
"镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼"
|
353 |
+
"闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶"
|
354 |
+
"阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈"
|
355 |
+
"隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳"
|
356 |
+
"零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰"
|
357 |
+
"靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵"
|
358 |
+
"韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额"
|
359 |
+
"颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰"
|
360 |
+
"饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥"
|
361 |
+
"馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑"
|
362 |
+
"骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高"
|
363 |
+
"髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾"
|
364 |
+
"鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨"
|
365 |
+
"鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓"
|
366 |
+
"鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶"
|
367 |
+
"鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟"
|
368 |
+
"鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄"
|
369 |
+
"黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷"
|
370 |
+
"鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦���"
|
371 |
+
"㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡"
|
372 |
+
"䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽"
|
373 |
+
"𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯"
|
374 |
+
"𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟"
|
375 |
+
"𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟"
|
376 |
+
"𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟"
|
377 |
+
"𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓"
|
378 |
+
)
|
379 |
+
CN_CHARS_EXT = "吶诶屌囧飚屄"
|
380 |
+
|
381 |
+
CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
|
382 |
+
IN_CH_CHARS = {c: True for c in CN_CHARS}
|
383 |
+
|
384 |
+
EN_CHARS = string.ascii_letters + string.digits
|
385 |
+
IN_EN_CHARS = {c: True for c in EN_CHARS}
|
386 |
+
|
387 |
+
VALID_CHARS = CN_CHARS + EN_CHARS + " "
|
388 |
+
IN_VALID_CHARS = {c: True for c in VALID_CHARS}
|
389 |
+
|
390 |
+
|
391 |
+
# ================================================================================ #
|
392 |
+
# basic class
|
393 |
+
# ================================================================================ #
|
394 |
+
class ChineseChar(object):
|
395 |
+
"""
|
396 |
+
中文字符
|
397 |
+
每个字符对应简体和繁体,
|
398 |
+
e.g. 简体 = '负', 繁体 = '負'
|
399 |
+
转换时可转换为简体或繁体
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self, simplified, traditional):
|
403 |
+
self.simplified = simplified
|
404 |
+
self.traditional = traditional
|
405 |
+
# self.__repr__ = self.__str__
|
406 |
+
|
407 |
+
def __str__(self):
|
408 |
+
return self.simplified or self.traditional or None
|
409 |
+
|
410 |
+
def __repr__(self):
|
411 |
+
return self.__str__()
|
412 |
+
|
413 |
+
|
414 |
+
class ChineseNumberUnit(ChineseChar):
|
415 |
+
"""
|
416 |
+
中文数字/数位字符
|
417 |
+
每个字符除繁简体外还有一个额外的大写字符
|
418 |
+
e.g. '陆' 和 '陸'
|
419 |
+
"""
|
420 |
+
|
421 |
+
def __init__(self, power, simplified, traditional, big_s, big_t):
|
422 |
+
super(ChineseNumberUnit, self).__init__(simplified, traditional)
|
423 |
+
self.power = power
|
424 |
+
self.big_s = big_s
|
425 |
+
self.big_t = big_t
|
426 |
+
|
427 |
+
def __str__(self):
|
428 |
+
return "10^{}".format(self.power)
|
429 |
+
|
430 |
+
@classmethod
|
431 |
+
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
432 |
+
if small_unit:
|
433 |
+
return ChineseNumberUnit(
|
434 |
+
power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]
|
435 |
+
)
|
436 |
+
elif numbering_type == NUMBERING_TYPES[0]:
|
437 |
+
return ChineseNumberUnit(
|
438 |
+
power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
|
439 |
+
)
|
440 |
+
elif numbering_type == NUMBERING_TYPES[1]:
|
441 |
+
return ChineseNumberUnit(
|
442 |
+
power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
|
443 |
+
)
|
444 |
+
elif numbering_type == NUMBERING_TYPES[2]:
|
445 |
+
return ChineseNumberUnit(
|
446 |
+
power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type))
|
450 |
+
|
451 |
+
|
452 |
+
class ChineseNumberDigit(ChineseChar):
|
453 |
+
"""
|
454 |
+
中文数字字符
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
|
458 |
+
super(ChineseNumberDigit, self).__init__(simplified, traditional)
|
459 |
+
self.value = value
|
460 |
+
self.big_s = big_s
|
461 |
+
self.big_t = big_t
|
462 |
+
self.alt_s = alt_s
|
463 |
+
self.alt_t = alt_t
|
464 |
+
|
465 |
+
def __str__(self):
|
466 |
+
return str(self.value)
|
467 |
+
|
468 |
+
@classmethod
|
469 |
+
def create(cls, i, v):
|
470 |
+
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
|
471 |
+
|
472 |
+
|
473 |
+
class ChineseMath(ChineseChar):
|
474 |
+
"""
|
475 |
+
中文数位字符
|
476 |
+
"""
|
477 |
+
|
478 |
+
def __init__(self, simplified, traditional, symbol, expression=None):
|
479 |
+
super(ChineseMath, self).__init__(simplified, traditional)
|
480 |
+
self.symbol = symbol
|
481 |
+
self.expression = expression
|
482 |
+
self.big_s = simplified
|
483 |
+
self.big_t = traditional
|
484 |
+
|
485 |
+
|
486 |
+
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
|
487 |
+
|
488 |
+
|
489 |
+
class NumberSystem(object):
|
490 |
+
"""
|
491 |
+
中文数字系统
|
492 |
+
"""
|
493 |
+
|
494 |
+
pass
|
495 |
+
|
496 |
+
|
497 |
+
class MathSymbol(object):
|
498 |
+
"""
|
499 |
+
用于中文数字系统的数学符号 (繁/简体), e.g.
|
500 |
+
positive = ['正', '正']
|
501 |
+
negative = ['负', '負']
|
502 |
+
point = ['点', '點']
|
503 |
+
"""
|
504 |
+
|
505 |
+
def __init__(self, positive, negative, point):
|
506 |
+
self.positive = positive
|
507 |
+
self.negative = negative
|
508 |
+
self.point = point
|
509 |
+
|
510 |
+
def __iter__(self):
|
511 |
+
for v in self.__dict__.values():
|
512 |
+
yield v
|
513 |
+
|
514 |
+
|
515 |
+
# class OtherSymbol(object):
|
516 |
+
# """
|
517 |
+
# 其他符号
|
518 |
+
# """
|
519 |
+
#
|
520 |
+
# def __init__(self, sil):
|
521 |
+
# self.sil = sil
|
522 |
+
#
|
523 |
+
# def __iter__(self):
|
524 |
+
# for v in self.__dict__.values():
|
525 |
+
# yield v
|
526 |
+
|
527 |
+
|
528 |
+
# ================================================================================ #
|
529 |
+
# basic utils
|
530 |
+
# ================================================================================ #
|
531 |
+
def create_system(numbering_type=NUMBERING_TYPES[1]):
|
532 |
+
"""
|
533 |
+
根据数字系统类型返回创建相应的数字系统,默认为 mid
|
534 |
+
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
|
535 |
+
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
|
536 |
+
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
|
537 |
+
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
|
538 |
+
返回对应的数字系统
|
539 |
+
"""
|
540 |
+
|
541 |
+
# chinese number units of '亿' and larger
|
542 |
+
all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
543 |
+
larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
|
544 |
+
# chinese number units of '十, 百, 千, 万'
|
545 |
+
all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
|
546 |
+
smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
|
547 |
+
# digis
|
548 |
+
chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
|
549 |
+
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
|
550 |
+
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
|
551 |
+
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
|
552 |
+
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
|
553 |
+
|
554 |
+
# symbols
|
555 |
+
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
|
556 |
+
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
|
557 |
+
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
|
558 |
+
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
|
559 |
+
system = NumberSystem()
|
560 |
+
system.units = smaller_units + larger_units
|
561 |
+
system.digits = digits
|
562 |
+
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
|
563 |
+
# system.symbols = OtherSymbol(sil_cn)
|
564 |
+
return system
|
565 |
+
|
566 |
+
|
567 |
+
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
568 |
+
def get_symbol(char, system):
|
569 |
+
for u in system.units:
|
570 |
+
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
|
571 |
+
return u
|
572 |
+
for d in system.digits:
|
573 |
+
if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
|
574 |
+
return d
|
575 |
+
for m in system.math:
|
576 |
+
if char in [m.traditional, m.simplified]:
|
577 |
+
return m
|
578 |
+
|
579 |
+
def string2symbols(chinese_string, system):
|
580 |
+
int_string, dec_string = chinese_string, ""
|
581 |
+
for p in [system.math.point.simplified, system.math.point.traditional]:
|
582 |
+
if p in chinese_string:
|
583 |
+
int_string, dec_string = chinese_string.split(p)
|
584 |
+
break
|
585 |
+
return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string]
|
586 |
+
|
587 |
+
def correct_symbols(integer_symbols, system):
|
588 |
+
"""
|
589 |
+
一百八 to 一百八十
|
590 |
+
一亿一千三百万 to 一亿 一千万 三百万
|
591 |
+
"""
|
592 |
+
|
593 |
+
if integer_symbols and isinstance(integer_symbols[0], CNU):
|
594 |
+
if integer_symbols[0].power == 1:
|
595 |
+
integer_symbols = [system.digits[1]] + integer_symbols
|
596 |
+
|
597 |
+
if len(integer_symbols) > 1:
|
598 |
+
if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
|
599 |
+
integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
|
600 |
+
|
601 |
+
result = []
|
602 |
+
unit_count = 0
|
603 |
+
for s in integer_symbols:
|
604 |
+
if isinstance(s, CND):
|
605 |
+
result.append(s)
|
606 |
+
unit_count = 0
|
607 |
+
elif isinstance(s, CNU):
|
608 |
+
current_unit = CNU(s.power, None, None, None, None)
|
609 |
+
unit_count += 1
|
610 |
+
|
611 |
+
if unit_count == 1:
|
612 |
+
result.append(current_unit)
|
613 |
+
elif unit_count > 1:
|
614 |
+
for i in range(len(result)):
|
615 |
+
if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
|
616 |
+
result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None)
|
617 |
+
return result
|
618 |
+
|
619 |
+
def compute_value(integer_symbols):
|
620 |
+
"""
|
621 |
+
Compute the value.
|
622 |
+
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
|
623 |
+
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
|
624 |
+
"""
|
625 |
+
value = [0]
|
626 |
+
last_power = 0
|
627 |
+
for s in integer_symbols:
|
628 |
+
if isinstance(s, CND):
|
629 |
+
value[-1] = s.value
|
630 |
+
elif isinstance(s, CNU):
|
631 |
+
value[-1] *= pow(10, s.power)
|
632 |
+
if s.power > last_power:
|
633 |
+
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
|
634 |
+
last_power = s.power
|
635 |
+
value.append(0)
|
636 |
+
return sum(value)
|
637 |
+
|
638 |
+
system = create_system(numbering_type)
|
639 |
+
int_part, dec_part = string2symbols(chinese_string, system)
|
640 |
+
int_part = correct_symbols(int_part, system)
|
641 |
+
int_str = str(compute_value(int_part))
|
642 |
+
dec_str = "".join([str(d.value) for d in dec_part])
|
643 |
+
if dec_part:
|
644 |
+
return "{0}.{1}".format(int_str, dec_str)
|
645 |
+
else:
|
646 |
+
return int_str
|
647 |
+
|
648 |
+
|
649 |
+
def num2chn(
|
650 |
+
number_string,
|
651 |
+
numbering_type=NUMBERING_TYPES[1],
|
652 |
+
big=False,
|
653 |
+
traditional=False,
|
654 |
+
alt_zero=False,
|
655 |
+
alt_one=False,
|
656 |
+
alt_two=True,
|
657 |
+
use_zeros=True,
|
658 |
+
use_units=True,
|
659 |
+
):
|
660 |
+
def get_value(value_string, use_zeros=True):
|
661 |
+
striped_string = value_string.lstrip("0")
|
662 |
+
|
663 |
+
# record nothing if all zeros
|
664 |
+
if not striped_string:
|
665 |
+
return []
|
666 |
+
|
667 |
+
# record one digits
|
668 |
+
elif len(striped_string) == 1:
|
669 |
+
if use_zeros and len(value_string) != len(striped_string):
|
670 |
+
return [system.digits[0], system.digits[int(striped_string)]]
|
671 |
+
else:
|
672 |
+
return [system.digits[int(striped_string)]]
|
673 |
+
|
674 |
+
# recursively record multiple digits
|
675 |
+
else:
|
676 |
+
result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
|
677 |
+
result_string = value_string[: -result_unit.power]
|
678 |
+
return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :])
|
679 |
+
|
680 |
+
system = create_system(numbering_type)
|
681 |
+
|
682 |
+
int_dec = number_string.split(".")
|
683 |
+
if len(int_dec) == 1:
|
684 |
+
int_string = int_dec[0]
|
685 |
+
dec_string = ""
|
686 |
+
elif len(int_dec) == 2:
|
687 |
+
int_string = int_dec[0]
|
688 |
+
dec_string = int_dec[1]
|
689 |
+
else:
|
690 |
+
raise ValueError("invalid input num string with more than one dot: {}".format(number_string))
|
691 |
+
|
692 |
+
if use_units and len(int_string) > 1:
|
693 |
+
result_symbols = get_value(int_string)
|
694 |
+
else:
|
695 |
+
result_symbols = [system.digits[int(c)] for c in int_string]
|
696 |
+
dec_symbols = [system.digits[int(c)] for c in dec_string]
|
697 |
+
if dec_string:
|
698 |
+
result_symbols += [system.math.point] + dec_symbols
|
699 |
+
|
700 |
+
if alt_two:
|
701 |
+
liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t)
|
702 |
+
for i, v in enumerate(result_symbols):
|
703 |
+
if isinstance(v, CND) and v.value == 2:
|
704 |
+
next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
|
705 |
+
previous_symbol = result_symbols[i - 1] if i > 0 else None
|
706 |
+
if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
|
707 |
+
if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
|
708 |
+
result_symbols[i] = liang
|
709 |
+
|
710 |
+
# if big is True, '两' will not be used and `alt_two` has no impact on output
|
711 |
+
if big:
|
712 |
+
attr_name = "big_"
|
713 |
+
if traditional:
|
714 |
+
attr_name += "t"
|
715 |
+
else:
|
716 |
+
attr_name += "s"
|
717 |
+
else:
|
718 |
+
if traditional:
|
719 |
+
attr_name = "traditional"
|
720 |
+
else:
|
721 |
+
attr_name = "simplified"
|
722 |
+
|
723 |
+
result = "".join([getattr(s, attr_name) for s in result_symbols])
|
724 |
+
|
725 |
+
# if not use_zeros:
|
726 |
+
# result = result.strip(getattr(system.digits[0], attr_name))
|
727 |
+
|
728 |
+
if alt_zero:
|
729 |
+
result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
|
730 |
+
|
731 |
+
if alt_one:
|
732 |
+
result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
|
733 |
+
|
734 |
+
for i, p in enumerate(POINT):
|
735 |
+
if result.startswith(p):
|
736 |
+
return CHINESE_DIGIS[0] + result
|
737 |
+
|
738 |
+
# ^10, 11, .., 19
|
739 |
+
if (
|
740 |
+
len(result) >= 2
|
741 |
+
and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]]
|
742 |
+
and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
|
743 |
+
):
|
744 |
+
result = result[1:]
|
745 |
+
|
746 |
+
return result
|
747 |
+
|
748 |
+
|
749 |
+
# ================================================================================ #
|
750 |
+
# different types of rewriters
|
751 |
+
# ================================================================================ #
|
752 |
+
class Cardinal:
|
753 |
+
"""
|
754 |
+
CARDINAL类
|
755 |
+
"""
|
756 |
+
|
757 |
+
def __init__(self, cardinal=None, chntext=None):
|
758 |
+
self.cardinal = cardinal
|
759 |
+
self.chntext = chntext
|
760 |
+
|
761 |
+
def chntext2cardinal(self):
|
762 |
+
return chn2num(self.chntext)
|
763 |
+
|
764 |
+
def cardinal2chntext(self):
|
765 |
+
return num2chn(self.cardinal)
|
766 |
+
|
767 |
+
|
768 |
+
class Digit:
|
769 |
+
"""
|
770 |
+
DIGIT类
|
771 |
+
"""
|
772 |
+
|
773 |
+
def __init__(self, digit=None, chntext=None):
|
774 |
+
self.digit = digit
|
775 |
+
self.chntext = chntext
|
776 |
+
|
777 |
+
# def chntext2digit(self):
|
778 |
+
# return chn2num(self.chntext)
|
779 |
+
|
780 |
+
def digit2chntext(self):
|
781 |
+
return num2chn(self.digit, alt_two=False, use_units=False)
|
782 |
+
|
783 |
+
|
784 |
+
class TelePhone:
|
785 |
+
"""
|
786 |
+
TELEPHONE类
|
787 |
+
"""
|
788 |
+
|
789 |
+
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
|
790 |
+
self.telephone = telephone
|
791 |
+
self.raw_chntext = raw_chntext
|
792 |
+
self.chntext = chntext
|
793 |
+
|
794 |
+
# def chntext2telephone(self):
|
795 |
+
# sil_parts = self.raw_chntext.split('<SIL>')
|
796 |
+
# self.telephone = '-'.join([
|
797 |
+
# str(chn2num(p)) for p in sil_parts
|
798 |
+
# ])
|
799 |
+
# return self.telephone
|
800 |
+
|
801 |
+
def telephone2chntext(self, fixed=False):
|
802 |
+
if fixed:
|
803 |
+
sil_parts = self.telephone.split("-")
|
804 |
+
self.raw_chntext = "<SIL>".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts])
|
805 |
+
self.chntext = self.raw_chntext.replace("<SIL>", "")
|
806 |
+
else:
|
807 |
+
sp_parts = self.telephone.strip("+").split()
|
808 |
+
self.raw_chntext = "<SP>".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts])
|
809 |
+
self.chntext = self.raw_chntext.replace("<SP>", "")
|
810 |
+
return self.chntext
|
811 |
+
|
812 |
+
|
813 |
+
class Fraction:
|
814 |
+
"""
|
815 |
+
FRACTION类
|
816 |
+
"""
|
817 |
+
|
818 |
+
def __init__(self, fraction=None, chntext=None):
|
819 |
+
self.fraction = fraction
|
820 |
+
self.chntext = chntext
|
821 |
+
|
822 |
+
def chntext2fraction(self):
|
823 |
+
denominator, numerator = self.chntext.split("分之")
|
824 |
+
return chn2num(numerator) + "/" + chn2num(denominator)
|
825 |
+
|
826 |
+
def fraction2chntext(self):
|
827 |
+
numerator, denominator = self.fraction.split("/")
|
828 |
+
return num2chn(denominator) + "分之" + num2chn(numerator)
|
829 |
+
|
830 |
+
|
831 |
+
class Date:
|
832 |
+
"""
|
833 |
+
DATE类
|
834 |
+
"""
|
835 |
+
|
836 |
+
def __init__(self, date=None, chntext=None):
|
837 |
+
self.date = date
|
838 |
+
self.chntext = chntext
|
839 |
+
|
840 |
+
# def chntext2date(self):
|
841 |
+
# chntext = self.chntext
|
842 |
+
# try:
|
843 |
+
# year, other = chntext.strip().split('年', maxsplit=1)
|
844 |
+
# year = Digit(chntext=year).digit2chntext() + '年'
|
845 |
+
# except ValueError:
|
846 |
+
# other = chntext
|
847 |
+
# year = ''
|
848 |
+
# if other:
|
849 |
+
# try:
|
850 |
+
# month, day = other.strip().split('月', maxsplit=1)
|
851 |
+
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
|
852 |
+
# except ValueError:
|
853 |
+
# day = chntext
|
854 |
+
# month = ''
|
855 |
+
# if day:
|
856 |
+
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
|
857 |
+
# else:
|
858 |
+
# month = ''
|
859 |
+
# day = ''
|
860 |
+
# date = year + month + day
|
861 |
+
# self.date = date
|
862 |
+
# return self.date
|
863 |
+
|
864 |
+
def date2chntext(self):
|
865 |
+
date = self.date
|
866 |
+
try:
|
867 |
+
year, other = date.strip().split("年", 1)
|
868 |
+
year = Digit(digit=year).digit2chntext() + "年"
|
869 |
+
except ValueError:
|
870 |
+
other = date
|
871 |
+
year = ""
|
872 |
+
if other:
|
873 |
+
try:
|
874 |
+
month, day = other.strip().split("月", 1)
|
875 |
+
month = Cardinal(cardinal=month).cardinal2chntext() + "月"
|
876 |
+
except ValueError:
|
877 |
+
day = date
|
878 |
+
month = ""
|
879 |
+
if day:
|
880 |
+
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
|
881 |
+
else:
|
882 |
+
month = ""
|
883 |
+
day = ""
|
884 |
+
chntext = year + month + day
|
885 |
+
self.chntext = chntext
|
886 |
+
return self.chntext
|
887 |
+
|
888 |
+
|
889 |
+
class Money:
|
890 |
+
"""
|
891 |
+
MONEY类
|
892 |
+
"""
|
893 |
+
|
894 |
+
def __init__(self, money=None, chntext=None):
|
895 |
+
self.money = money
|
896 |
+
self.chntext = chntext
|
897 |
+
|
898 |
+
# def chntext2money(self):
|
899 |
+
# return self.money
|
900 |
+
|
901 |
+
def money2chntext(self):
|
902 |
+
money = self.money
|
903 |
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
904 |
+
matchers = pattern.findall(money)
|
905 |
+
if matchers:
|
906 |
+
for matcher in matchers:
|
907 |
+
money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
|
908 |
+
self.chntext = money
|
909 |
+
return self.chntext
|
910 |
+
|
911 |
+
|
912 |
+
class Percentage:
|
913 |
+
"""
|
914 |
+
PERCENTAGE类
|
915 |
+
"""
|
916 |
+
|
917 |
+
def __init__(self, percentage=None, chntext=None):
|
918 |
+
self.percentage = percentage
|
919 |
+
self.chntext = chntext
|
920 |
+
|
921 |
+
def chntext2percentage(self):
|
922 |
+
return chn2num(self.chntext.strip().strip("百分之")) + "%"
|
923 |
+
|
924 |
+
def percentage2chntext(self):
|
925 |
+
return "百分之" + num2chn(self.percentage.strip().strip("%"))
|
926 |
+
|
927 |
+
|
928 |
+
def normalize_nsw(raw_text):
|
929 |
+
text = "^" + raw_text + "$"
|
930 |
+
|
931 |
+
# 规范化日期
|
932 |
+
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
|
933 |
+
matchers = pattern.findall(text)
|
934 |
+
if matchers:
|
935 |
+
# print('date')
|
936 |
+
for matcher in matchers:
|
937 |
+
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
938 |
+
|
939 |
+
# 规范化金钱
|
940 |
+
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
|
941 |
+
matchers = pattern.findall(text)
|
942 |
+
if matchers:
|
943 |
+
# print('money')
|
944 |
+
for matcher in matchers:
|
945 |
+
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
|
946 |
+
|
947 |
+
# 规范化固话/手机号码
|
948 |
+
# 手机
|
949 |
+
# http://www.jihaoba.com/news/show/13680
|
950 |
+
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
951 |
+
# 联通:130、131、132、156、155、186、185、176
|
952 |
+
# 电信:133、153、189、180、181、177
|
953 |
+
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
954 |
+
matchers = pattern.findall(text)
|
955 |
+
if matchers:
|
956 |
+
# print('telephone')
|
957 |
+
for matcher in matchers:
|
958 |
+
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
|
959 |
+
# 固话
|
960 |
+
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
961 |
+
matchers = pattern.findall(text)
|
962 |
+
if matchers:
|
963 |
+
# print('fixed telephone')
|
964 |
+
for matcher in matchers:
|
965 |
+
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
|
966 |
+
|
967 |
+
# 规范化分数
|
968 |
+
pattern = re.compile(r"(\d+/\d+)")
|
969 |
+
matchers = pattern.findall(text)
|
970 |
+
if matchers:
|
971 |
+
# print('fraction')
|
972 |
+
for matcher in matchers:
|
973 |
+
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
|
974 |
+
|
975 |
+
# 规范化百分数
|
976 |
+
text = text.replace("%", "%")
|
977 |
+
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
978 |
+
matchers = pattern.findall(text)
|
979 |
+
if matchers:
|
980 |
+
# print('percentage')
|
981 |
+
for matcher in matchers:
|
982 |
+
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
|
983 |
+
|
984 |
+
# 规范化纯数+量词
|
985 |
+
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
986 |
+
matchers = pattern.findall(text)
|
987 |
+
if matchers:
|
988 |
+
# print('cardinal+quantifier')
|
989 |
+
for matcher in matchers:
|
990 |
+
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
991 |
+
|
992 |
+
# 规范化数字编号
|
993 |
+
pattern = re.compile(r"(\d{4,32})")
|
994 |
+
matchers = pattern.findall(text)
|
995 |
+
if matchers:
|
996 |
+
# print('digit')
|
997 |
+
for matcher in matchers:
|
998 |
+
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
999 |
+
|
1000 |
+
# 规范化纯数
|
1001 |
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
1002 |
+
matchers = pattern.findall(text)
|
1003 |
+
if matchers:
|
1004 |
+
# print('cardinal')
|
1005 |
+
for matcher in matchers:
|
1006 |
+
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
1007 |
+
|
1008 |
+
# restore P2P, O2O, B2C, B2B etc
|
1009 |
+
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
|
1010 |
+
matchers = pattern.findall(text)
|
1011 |
+
if matchers:
|
1012 |
+
# print('particular')
|
1013 |
+
for matcher in matchers:
|
1014 |
+
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
|
1015 |
+
|
1016 |
+
return text.lstrip("^").rstrip("$")
|
1017 |
+
|
1018 |
+
|
1019 |
+
def remove_erhua(text):
|
1020 |
+
"""
|
1021 |
+
去除儿化音词中的儿:
|
1022 |
+
他女儿在那边儿 -> 他女儿在那边
|
1023 |
+
"""
|
1024 |
+
|
1025 |
+
new_str = ""
|
1026 |
+
while re.search("儿", text):
|
1027 |
+
a = re.search("儿", text).span()
|
1028 |
+
remove_er_flag = 0
|
1029 |
+
|
1030 |
+
if ER_WHITELIST_PATTERN.search(text):
|
1031 |
+
b = ER_WHITELIST_PATTERN.search(text).span()
|
1032 |
+
if b[0] <= a[0]:
|
1033 |
+
remove_er_flag = 1
|
1034 |
+
|
1035 |
+
if remove_er_flag == 0:
|
1036 |
+
new_str = new_str + text[0 : a[0]]
|
1037 |
+
text = text[a[1] :]
|
1038 |
+
else:
|
1039 |
+
new_str = new_str + text[0 : b[1]]
|
1040 |
+
text = text[b[1] :]
|
1041 |
+
|
1042 |
+
text = new_str + text
|
1043 |
+
return text
|
1044 |
+
|
1045 |
+
|
1046 |
+
def remove_space(text):
|
1047 |
+
tokens = text.split()
|
1048 |
+
new = []
|
1049 |
+
for k, t in enumerate(tokens):
|
1050 |
+
if k != 0:
|
1051 |
+
if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]):
|
1052 |
+
new.append(" ")
|
1053 |
+
new.append(t)
|
1054 |
+
return "".join(new)
|
1055 |
+
|
1056 |
+
|
1057 |
+
class TextNorm:
|
1058 |
+
def __init__(
|
1059 |
+
self,
|
1060 |
+
to_banjiao: bool = False,
|
1061 |
+
to_upper: bool = False,
|
1062 |
+
to_lower: bool = False,
|
1063 |
+
remove_fillers: bool = False,
|
1064 |
+
remove_erhua: bool = False,
|
1065 |
+
check_chars: bool = False,
|
1066 |
+
remove_space: bool = False,
|
1067 |
+
cc_mode: str = "",
|
1068 |
+
):
|
1069 |
+
self.to_banjiao = to_banjiao
|
1070 |
+
self.to_upper = to_upper
|
1071 |
+
self.to_lower = to_lower
|
1072 |
+
self.remove_fillers = remove_fillers
|
1073 |
+
self.remove_erhua = remove_erhua
|
1074 |
+
self.check_chars = check_chars
|
1075 |
+
self.remove_space = remove_space
|
1076 |
+
|
1077 |
+
self.cc = None
|
1078 |
+
if cc_mode:
|
1079 |
+
from opencc import OpenCC # Open Chinese Convert: pip install opencc
|
1080 |
+
|
1081 |
+
self.cc = OpenCC(cc_mode)
|
1082 |
+
|
1083 |
+
def __call__(self, text):
|
1084 |
+
if self.cc:
|
1085 |
+
text = self.cc.convert(text)
|
1086 |
+
|
1087 |
+
if self.to_banjiao:
|
1088 |
+
text = text.translate(QJ2BJ_TRANSFORM)
|
1089 |
+
|
1090 |
+
if self.to_upper:
|
1091 |
+
text = text.upper()
|
1092 |
+
|
1093 |
+
if self.to_lower:
|
1094 |
+
text = text.lower()
|
1095 |
+
|
1096 |
+
if self.remove_fillers:
|
1097 |
+
for c in FILLER_CHARS:
|
1098 |
+
text = text.replace(c, "")
|
1099 |
+
|
1100 |
+
if self.remove_erhua:
|
1101 |
+
text = remove_erhua(text)
|
1102 |
+
|
1103 |
+
text = normalize_nsw(text)
|
1104 |
+
|
1105 |
+
text = text.translate(PUNCS_TRANSFORM)
|
1106 |
+
|
1107 |
+
if self.check_chars:
|
1108 |
+
for c in text:
|
1109 |
+
if not IN_VALID_CHARS.get(c):
|
1110 |
+
print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
|
1111 |
+
return ""
|
1112 |
+
|
1113 |
+
if self.remove_space:
|
1114 |
+
text = remove_space(text)
|
1115 |
+
|
1116 |
+
return text
|
1117 |
+
|
1118 |
+
|
1119 |
+
if __name__ == "__main__":
|
1120 |
+
p = argparse.ArgumentParser()
|
1121 |
+
|
1122 |
+
# normalizer options
|
1123 |
+
p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao")
|
1124 |
+
p.add_argument("--to_upper", action="store_true", help="convert to upper case")
|
1125 |
+
p.add_argument("--to_lower", action="store_true", help="convert to lower case")
|
1126 |
+
p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "呃, 啊"')
|
1127 |
+
p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"')
|
1128 |
+
p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars")
|
1129 |
+
p.add_argument("--remove_space", action="store_true", help="remove whitespace")
|
1130 |
+
p.add_argument(
|
1131 |
+
"--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified"
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
# I/O options
|
1135 |
+
p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines")
|
1136 |
+
p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead")
|
1137 |
+
p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format")
|
1138 |
+
p.add_argument("ifile", help="input filename, assume utf-8 encoding")
|
1139 |
+
p.add_argument("ofile", help="output filename")
|
1140 |
+
|
1141 |
+
args = p.parse_args()
|
1142 |
+
|
1143 |
+
if args.has_key:
|
1144 |
+
args.format = "ark"
|
1145 |
+
|
1146 |
+
normalizer = TextNorm(
|
1147 |
+
to_banjiao=args.to_banjiao,
|
1148 |
+
to_upper=args.to_upper,
|
1149 |
+
to_lower=args.to_lower,
|
1150 |
+
remove_fillers=args.remove_fillers,
|
1151 |
+
remove_erhua=args.remove_erhua,
|
1152 |
+
check_chars=args.check_chars,
|
1153 |
+
remove_space=args.remove_space,
|
1154 |
+
cc_mode=args.cc_mode,
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
normalizer = TextNorm(
|
1158 |
+
to_banjiao=args.to_banjiao,
|
1159 |
+
to_upper=args.to_upper,
|
1160 |
+
to_lower=args.to_lower,
|
1161 |
+
remove_fillers=args.remove_fillers,
|
1162 |
+
remove_erhua=args.remove_erhua,
|
1163 |
+
check_chars=args.check_chars,
|
1164 |
+
remove_space=args.remove_space,
|
1165 |
+
cc_mode=args.cc_mode,
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
ndone = 0
|
1169 |
+
with open(args.ifile, "r", encoding="utf-8") as istream, open(args.ofile, "w+", encoding="utf-8") as ostream:
|
1170 |
+
if args.format == "tsv":
|
1171 |
+
reader = csv.DictReader(istream, delimiter="\t")
|
1172 |
+
assert "TEXT" in reader.fieldnames
|
1173 |
+
print("\t".join(reader.fieldnames), file=ostream)
|
1174 |
+
|
1175 |
+
for item in reader:
|
1176 |
+
text = item["TEXT"]
|
1177 |
+
|
1178 |
+
if text:
|
1179 |
+
text = normalizer(text)
|
1180 |
+
|
1181 |
+
if text:
|
1182 |
+
item["TEXT"] = text
|
1183 |
+
print("\t".join([item[f] for f in reader.fieldnames]), file=ostream)
|
1184 |
+
|
1185 |
+
ndone += 1
|
1186 |
+
if ndone % args.log_interval == 0:
|
1187 |
+
print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
|
1188 |
+
else:
|
1189 |
+
for l in istream:
|
1190 |
+
key, text = "", ""
|
1191 |
+
if args.format == "ark": # KALDI archive, line format: "key text"
|
1192 |
+
cols = l.strip().split(maxsplit=1)
|
1193 |
+
key, text = cols[0], cols[1] if len(cols) == 2 else ""
|
1194 |
+
else:
|
1195 |
+
text = l.strip()
|
1196 |
+
|
1197 |
+
if text:
|
1198 |
+
text = normalizer(text)
|
1199 |
+
|
1200 |
+
if text:
|
1201 |
+
if args.format == "ark":
|
1202 |
+
print(key + "\t" + text, file=ostream)
|
1203 |
+
else:
|
1204 |
+
print(text, file=ostream)
|
1205 |
+
|
1206 |
+
ndone += 1
|
1207 |
+
if ndone % args.log_interval == 0:
|
1208 |
+
print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
|
1209 |
+
print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True)
|
music_dcae/__init__.py
ADDED
File without changes
|
music_dcae/music_dcae_pipeline.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from diffusers import AutoencoderDC
|
4 |
+
import torchaudio
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import torchaudio
|
7 |
+
from diffusers.models.modeling_utils import ModelMixin
|
8 |
+
from diffusers.loaders import FromOriginalModelMixin
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
|
11 |
+
|
12 |
+
try:
|
13 |
+
from .music_vocoder import ADaMoSHiFiGANV1
|
14 |
+
except ImportError:
|
15 |
+
from music_vocoder import ADaMoSHiFiGANV1
|
16 |
+
|
17 |
+
|
18 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
19 |
+
DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
|
20 |
+
VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
|
21 |
+
|
22 |
+
|
23 |
+
class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
24 |
+
@register_to_config
|
25 |
+
def __init__(self, source_sample_rate=None, dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH):
|
26 |
+
super(MusicDCAE, self).__init__()
|
27 |
+
|
28 |
+
self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
|
29 |
+
self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
|
30 |
+
|
31 |
+
if source_sample_rate is None:
|
32 |
+
source_sample_rate = 48000
|
33 |
+
|
34 |
+
self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
35 |
+
|
36 |
+
self.transform = transforms.Compose([
|
37 |
+
transforms.Normalize(0.5, 0.5),
|
38 |
+
])
|
39 |
+
self.min_mel_value = -11.0
|
40 |
+
self.max_mel_value = 3.0
|
41 |
+
self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
|
42 |
+
self.mel_chunk_size = 1024
|
43 |
+
self.time_dimention_multiple = 8
|
44 |
+
self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
|
45 |
+
self.scale_factor = 0.1786
|
46 |
+
self.shift_factor = -1.9091
|
47 |
+
|
48 |
+
def load_audio(self, audio_path):
|
49 |
+
audio, sr = torchaudio.load(audio_path)
|
50 |
+
return audio, sr
|
51 |
+
|
52 |
+
def forward_mel(self, audios):
|
53 |
+
mels = []
|
54 |
+
for i in range(len(audios)):
|
55 |
+
image = self.vocoder.mel_transform(audios[i])
|
56 |
+
mels.append(image)
|
57 |
+
mels = torch.stack(mels)
|
58 |
+
return mels
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def encode(self, audios, audio_lengths=None, sr=None):
|
62 |
+
if audio_lengths is None:
|
63 |
+
audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
|
64 |
+
audio_lengths = audio_lengths.to(audios.device)
|
65 |
+
|
66 |
+
# audios: N x 2 x T, 48kHz
|
67 |
+
device = audios.device
|
68 |
+
dtype = audios.dtype
|
69 |
+
|
70 |
+
if sr is None:
|
71 |
+
sr = 48000
|
72 |
+
resampler = self.resampler
|
73 |
+
else:
|
74 |
+
resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
|
75 |
+
|
76 |
+
audio = resampler(audios)
|
77 |
+
|
78 |
+
max_audio_len = audio.shape[-1]
|
79 |
+
if max_audio_len % (8 * 512) != 0:
|
80 |
+
audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
|
81 |
+
|
82 |
+
mels = self.forward_mel(audio)
|
83 |
+
mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
|
84 |
+
mels = self.transform(mels)
|
85 |
+
latents = []
|
86 |
+
for mel in mels:
|
87 |
+
latent = self.dcae.encoder(mel.unsqueeze(0))
|
88 |
+
latents.append(latent)
|
89 |
+
latents = torch.cat(latents, dim=0)
|
90 |
+
latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
91 |
+
latents = (latents - self.shift_factor) * self.scale_factor
|
92 |
+
return latents, latent_lengths
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def decode(self, latents, audio_lengths=None, sr=None):
|
96 |
+
latents = latents / self.scale_factor + self.shift_factor
|
97 |
+
|
98 |
+
mels = []
|
99 |
+
|
100 |
+
for latent in latents:
|
101 |
+
mel = self.dcae.decoder(latent.unsqueeze(0))
|
102 |
+
mels.append(mel)
|
103 |
+
mels = torch.cat(mels, dim=0)
|
104 |
+
|
105 |
+
mels = mels * 0.5 + 0.5
|
106 |
+
mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
|
107 |
+
bsz, channels, num_mel, mel_width = mels.shape
|
108 |
+
pred_wavs = []
|
109 |
+
for i in range(bsz):
|
110 |
+
mel = mels[i]
|
111 |
+
wav = self.vocoder.decode(mel).squeeze(1)
|
112 |
+
pred_wavs.append(wav)
|
113 |
+
|
114 |
+
pred_wavs = torch.stack(pred_wavs)
|
115 |
+
|
116 |
+
if sr is not None:
|
117 |
+
resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
118 |
+
pred_wavs = [resampler(wav) for wav in pred_wavs]
|
119 |
+
else:
|
120 |
+
sr = 44100
|
121 |
+
if audio_lengths is not None:
|
122 |
+
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
123 |
+
return sr, pred_wavs
|
124 |
+
|
125 |
+
def forward(self, audios, audio_lengths=None, sr=None):
|
126 |
+
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
127 |
+
sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
|
128 |
+
return sr, pred_wavs, latents, latent_lengths
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
|
133 |
+
audio, sr = torchaudio.load("test.wav")
|
134 |
+
audio_lengths = torch.tensor([audio.shape[1]])
|
135 |
+
audios = audio.unsqueeze(0)
|
136 |
+
|
137 |
+
# test encode only
|
138 |
+
model = MusicDCAE()
|
139 |
+
# latents, latent_lengths = model.encode(audios, audio_lengths)
|
140 |
+
# print("latents shape: ", latents.shape)
|
141 |
+
# print("latent_lengths: ", latent_lengths)
|
142 |
+
|
143 |
+
# test encode and decode
|
144 |
+
sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
|
145 |
+
print("reconstructed wavs: ", pred_wavs[0].shape)
|
146 |
+
print("latents shape: ", latents.shape)
|
147 |
+
print("latent_lengths: ", latent_lengths)
|
148 |
+
print("sr: ", sr)
|
149 |
+
torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
|
150 |
+
print("test_reconstructed.flac")
|
music_dcae/music_log_mel.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
from torchaudio.transforms import MelScale
|
5 |
+
|
6 |
+
|
7 |
+
class LinearSpectrogram(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
n_fft=2048,
|
11 |
+
win_length=2048,
|
12 |
+
hop_length=512,
|
13 |
+
center=False,
|
14 |
+
mode="pow2_sqrt",
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.n_fft = n_fft
|
19 |
+
self.win_length = win_length
|
20 |
+
self.hop_length = hop_length
|
21 |
+
self.center = center
|
22 |
+
self.mode = mode
|
23 |
+
|
24 |
+
self.register_buffer("window", torch.hann_window(win_length))
|
25 |
+
|
26 |
+
def forward(self, y: Tensor) -> Tensor:
|
27 |
+
if y.ndim == 3:
|
28 |
+
y = y.squeeze(1)
|
29 |
+
|
30 |
+
y = torch.nn.functional.pad(
|
31 |
+
y.unsqueeze(1),
|
32 |
+
(
|
33 |
+
(self.win_length - self.hop_length) // 2,
|
34 |
+
(self.win_length - self.hop_length + 1) // 2,
|
35 |
+
),
|
36 |
+
mode="reflect",
|
37 |
+
).squeeze(1)
|
38 |
+
dtype = y.dtype
|
39 |
+
spec = torch.stft(
|
40 |
+
y.float(),
|
41 |
+
self.n_fft,
|
42 |
+
hop_length=self.hop_length,
|
43 |
+
win_length=self.win_length,
|
44 |
+
window=self.window,
|
45 |
+
center=self.center,
|
46 |
+
pad_mode="reflect",
|
47 |
+
normalized=False,
|
48 |
+
onesided=True,
|
49 |
+
return_complex=True,
|
50 |
+
)
|
51 |
+
spec = torch.view_as_real(spec)
|
52 |
+
|
53 |
+
if self.mode == "pow2_sqrt":
|
54 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
55 |
+
spec = spec.to(dtype)
|
56 |
+
return spec
|
57 |
+
|
58 |
+
|
59 |
+
class LogMelSpectrogram(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
sample_rate=44100,
|
63 |
+
n_fft=2048,
|
64 |
+
win_length=2048,
|
65 |
+
hop_length=512,
|
66 |
+
n_mels=128,
|
67 |
+
center=False,
|
68 |
+
f_min=0.0,
|
69 |
+
f_max=None,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.sample_rate = sample_rate
|
74 |
+
self.n_fft = n_fft
|
75 |
+
self.win_length = win_length
|
76 |
+
self.hop_length = hop_length
|
77 |
+
self.center = center
|
78 |
+
self.n_mels = n_mels
|
79 |
+
self.f_min = f_min
|
80 |
+
self.f_max = f_max or sample_rate // 2
|
81 |
+
|
82 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
83 |
+
self.mel_scale = MelScale(
|
84 |
+
self.n_mels,
|
85 |
+
self.sample_rate,
|
86 |
+
self.f_min,
|
87 |
+
self.f_max,
|
88 |
+
self.n_fft // 2 + 1,
|
89 |
+
"slaney",
|
90 |
+
"slaney",
|
91 |
+
)
|
92 |
+
|
93 |
+
def compress(self, x: Tensor) -> Tensor:
|
94 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
95 |
+
|
96 |
+
def decompress(self, x: Tensor) -> Tensor:
|
97 |
+
return torch.exp(x)
|
98 |
+
|
99 |
+
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
100 |
+
linear = self.spectrogram(x)
|
101 |
+
x = self.mel_scale(linear)
|
102 |
+
x = self.compress(x)
|
103 |
+
# print(x.shape)
|
104 |
+
if return_linear:
|
105 |
+
return x, self.compress(linear)
|
106 |
+
|
107 |
+
return x
|
music_dcae/music_vocoder.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from math import prod
|
7 |
+
from typing import Callable, Tuple, List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.nn import Conv1d
|
12 |
+
from torch.nn.utils import weight_norm
|
13 |
+
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
|
14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
15 |
+
from diffusers.loaders import FromOriginalModelMixin
|
16 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
17 |
+
|
18 |
+
|
19 |
+
try:
|
20 |
+
from music_log_mel import LogMelSpectrogram
|
21 |
+
except ImportError:
|
22 |
+
from .music_log_mel import LogMelSpectrogram
|
23 |
+
|
24 |
+
|
25 |
+
def drop_path(
|
26 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
27 |
+
):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
29 |
+
|
30 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
31 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
32 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
33 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
34 |
+
'survival rate' as the argument.
|
35 |
+
|
36 |
+
""" # noqa: E501
|
37 |
+
|
38 |
+
if drop_prob == 0.0 or not training:
|
39 |
+
return x
|
40 |
+
keep_prob = 1 - drop_prob
|
41 |
+
shape = (x.shape[0],) + (1,) * (
|
42 |
+
x.ndim - 1
|
43 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
44 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
45 |
+
if keep_prob > 0.0 and scale_by_keep:
|
46 |
+
random_tensor.div_(keep_prob)
|
47 |
+
return x * random_tensor
|
48 |
+
|
49 |
+
|
50 |
+
class DropPath(nn.Module):
|
51 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
|
52 |
+
|
53 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
54 |
+
super(DropPath, self).__init__()
|
55 |
+
self.drop_prob = drop_prob
|
56 |
+
self.scale_by_keep = scale_by_keep
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
60 |
+
|
61 |
+
def extra_repr(self):
|
62 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
63 |
+
|
64 |
+
|
65 |
+
class LayerNorm(nn.Module):
|
66 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
67 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
68 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
69 |
+
with shape (batch_size, channels, height, width).
|
70 |
+
""" # noqa: E501
|
71 |
+
|
72 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
73 |
+
super().__init__()
|
74 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
75 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
76 |
+
self.eps = eps
|
77 |
+
self.data_format = data_format
|
78 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
79 |
+
raise NotImplementedError
|
80 |
+
self.normalized_shape = (normalized_shape,)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
if self.data_format == "channels_last":
|
84 |
+
return F.layer_norm(
|
85 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
86 |
+
)
|
87 |
+
elif self.data_format == "channels_first":
|
88 |
+
u = x.mean(1, keepdim=True)
|
89 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
90 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
91 |
+
x = self.weight[:, None] * x + self.bias[:, None]
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class ConvNeXtBlock(nn.Module):
|
96 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
97 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
98 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
99 |
+
We use (2) as we find it slightly faster in PyTorch
|
100 |
+
|
101 |
+
Args:
|
102 |
+
dim (int): Number of input channels.
|
103 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
104 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
105 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
106 |
+
kernel_size (int): Kernel size for depthwise conv. Default: 7.
|
107 |
+
dilation (int): Dilation for depthwise conv. Default: 1.
|
108 |
+
""" # noqa: E501
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
dim: int,
|
113 |
+
drop_path: float = 0.0,
|
114 |
+
layer_scale_init_value: float = 1e-6,
|
115 |
+
mlp_ratio: float = 4.0,
|
116 |
+
kernel_size: int = 7,
|
117 |
+
dilation: int = 1,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
|
121 |
+
self.dwconv = nn.Conv1d(
|
122 |
+
dim,
|
123 |
+
dim,
|
124 |
+
kernel_size=kernel_size,
|
125 |
+
padding=int(dilation * (kernel_size - 1) / 2),
|
126 |
+
groups=dim,
|
127 |
+
) # depthwise conv
|
128 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
129 |
+
self.pwconv1 = nn.Linear(
|
130 |
+
dim, int(mlp_ratio * dim)
|
131 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
132 |
+
self.act = nn.GELU()
|
133 |
+
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
|
134 |
+
self.gamma = (
|
135 |
+
nn.Parameter(layer_scale_init_value *
|
136 |
+
torch.ones((dim)), requires_grad=True)
|
137 |
+
if layer_scale_init_value > 0
|
138 |
+
else None
|
139 |
+
)
|
140 |
+
self.drop_path = DropPath(
|
141 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
142 |
+
|
143 |
+
def forward(self, x, apply_residual: bool = True):
|
144 |
+
input = x
|
145 |
+
|
146 |
+
x = self.dwconv(x)
|
147 |
+
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
|
148 |
+
x = self.norm(x)
|
149 |
+
x = self.pwconv1(x)
|
150 |
+
x = self.act(x)
|
151 |
+
x = self.pwconv2(x)
|
152 |
+
|
153 |
+
if self.gamma is not None:
|
154 |
+
x = self.gamma * x
|
155 |
+
|
156 |
+
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
157 |
+
x = self.drop_path(x)
|
158 |
+
|
159 |
+
if apply_residual:
|
160 |
+
x = input + x
|
161 |
+
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class ParallelConvNeXtBlock(nn.Module):
|
166 |
+
def __init__(self, kernel_sizes: List[int], *args, **kwargs):
|
167 |
+
super().__init__()
|
168 |
+
self.blocks = nn.ModuleList(
|
169 |
+
[
|
170 |
+
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
|
171 |
+
for kernel_size in kernel_sizes
|
172 |
+
]
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
176 |
+
return torch.stack(
|
177 |
+
[block(x, apply_residual=False) for block in self.blocks] + [x],
|
178 |
+
dim=1,
|
179 |
+
).sum(dim=1)
|
180 |
+
|
181 |
+
|
182 |
+
class ConvNeXtEncoder(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
input_channels=3,
|
186 |
+
depths=[3, 3, 9, 3],
|
187 |
+
dims=[96, 192, 384, 768],
|
188 |
+
drop_path_rate=0.0,
|
189 |
+
layer_scale_init_value=1e-6,
|
190 |
+
kernel_sizes: Tuple[int] = (7,),
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
assert len(depths) == len(dims)
|
194 |
+
|
195 |
+
self.channel_layers = nn.ModuleList()
|
196 |
+
stem = nn.Sequential(
|
197 |
+
nn.Conv1d(
|
198 |
+
input_channels,
|
199 |
+
dims[0],
|
200 |
+
kernel_size=7,
|
201 |
+
padding=3,
|
202 |
+
padding_mode="replicate",
|
203 |
+
),
|
204 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
205 |
+
)
|
206 |
+
self.channel_layers.append(stem)
|
207 |
+
|
208 |
+
for i in range(len(depths) - 1):
|
209 |
+
mid_layer = nn.Sequential(
|
210 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
211 |
+
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
|
212 |
+
)
|
213 |
+
self.channel_layers.append(mid_layer)
|
214 |
+
|
215 |
+
block_fn = (
|
216 |
+
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
|
217 |
+
if len(kernel_sizes) == 1
|
218 |
+
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
|
219 |
+
)
|
220 |
+
|
221 |
+
self.stages = nn.ModuleList()
|
222 |
+
drop_path_rates = [
|
223 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
224 |
+
]
|
225 |
+
|
226 |
+
cur = 0
|
227 |
+
for i in range(len(depths)):
|
228 |
+
stage = nn.Sequential(
|
229 |
+
*[
|
230 |
+
block_fn(
|
231 |
+
dim=dims[i],
|
232 |
+
drop_path=drop_path_rates[cur + j],
|
233 |
+
layer_scale_init_value=layer_scale_init_value,
|
234 |
+
)
|
235 |
+
for j in range(depths[i])
|
236 |
+
]
|
237 |
+
)
|
238 |
+
self.stages.append(stage)
|
239 |
+
cur += depths[i]
|
240 |
+
|
241 |
+
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
|
242 |
+
self.apply(self._init_weights)
|
243 |
+
|
244 |
+
def _init_weights(self, m):
|
245 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
246 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
x: torch.Tensor,
|
252 |
+
) -> torch.Tensor:
|
253 |
+
for channel_layer, stage in zip(self.channel_layers, self.stages):
|
254 |
+
x = channel_layer(x)
|
255 |
+
x = stage(x)
|
256 |
+
|
257 |
+
return self.norm(x)
|
258 |
+
|
259 |
+
|
260 |
+
def init_weights(m, mean=0.0, std=0.01):
|
261 |
+
classname = m.__class__.__name__
|
262 |
+
if classname.find("Conv") != -1:
|
263 |
+
m.weight.data.normal_(mean, std)
|
264 |
+
|
265 |
+
|
266 |
+
def get_padding(kernel_size, dilation=1):
|
267 |
+
return (kernel_size * dilation - dilation) // 2
|
268 |
+
|
269 |
+
|
270 |
+
class ResBlock1(torch.nn.Module):
|
271 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
272 |
+
super().__init__()
|
273 |
+
|
274 |
+
self.convs1 = nn.ModuleList(
|
275 |
+
[
|
276 |
+
weight_norm(
|
277 |
+
Conv1d(
|
278 |
+
channels,
|
279 |
+
channels,
|
280 |
+
kernel_size,
|
281 |
+
1,
|
282 |
+
dilation=dilation[0],
|
283 |
+
padding=get_padding(kernel_size, dilation[0]),
|
284 |
+
)
|
285 |
+
),
|
286 |
+
weight_norm(
|
287 |
+
Conv1d(
|
288 |
+
channels,
|
289 |
+
channels,
|
290 |
+
kernel_size,
|
291 |
+
1,
|
292 |
+
dilation=dilation[1],
|
293 |
+
padding=get_padding(kernel_size, dilation[1]),
|
294 |
+
)
|
295 |
+
),
|
296 |
+
weight_norm(
|
297 |
+
Conv1d(
|
298 |
+
channels,
|
299 |
+
channels,
|
300 |
+
kernel_size,
|
301 |
+
1,
|
302 |
+
dilation=dilation[2],
|
303 |
+
padding=get_padding(kernel_size, dilation[2]),
|
304 |
+
)
|
305 |
+
),
|
306 |
+
]
|
307 |
+
)
|
308 |
+
self.convs1.apply(init_weights)
|
309 |
+
|
310 |
+
self.convs2 = nn.ModuleList(
|
311 |
+
[
|
312 |
+
weight_norm(
|
313 |
+
Conv1d(
|
314 |
+
channels,
|
315 |
+
channels,
|
316 |
+
kernel_size,
|
317 |
+
1,
|
318 |
+
dilation=1,
|
319 |
+
padding=get_padding(kernel_size, 1),
|
320 |
+
)
|
321 |
+
),
|
322 |
+
weight_norm(
|
323 |
+
Conv1d(
|
324 |
+
channels,
|
325 |
+
channels,
|
326 |
+
kernel_size,
|
327 |
+
1,
|
328 |
+
dilation=1,
|
329 |
+
padding=get_padding(kernel_size, 1),
|
330 |
+
)
|
331 |
+
),
|
332 |
+
weight_norm(
|
333 |
+
Conv1d(
|
334 |
+
channels,
|
335 |
+
channels,
|
336 |
+
kernel_size,
|
337 |
+
1,
|
338 |
+
dilation=1,
|
339 |
+
padding=get_padding(kernel_size, 1),
|
340 |
+
)
|
341 |
+
),
|
342 |
+
]
|
343 |
+
)
|
344 |
+
self.convs2.apply(init_weights)
|
345 |
+
|
346 |
+
def forward(self, x):
|
347 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
348 |
+
xt = F.silu(x)
|
349 |
+
xt = c1(xt)
|
350 |
+
xt = F.silu(xt)
|
351 |
+
xt = c2(xt)
|
352 |
+
x = xt + x
|
353 |
+
return x
|
354 |
+
|
355 |
+
def remove_weight_norm(self):
|
356 |
+
for conv in self.convs1:
|
357 |
+
remove_weight_norm(conv)
|
358 |
+
for conv in self.convs2:
|
359 |
+
remove_weight_norm(conv)
|
360 |
+
|
361 |
+
|
362 |
+
class HiFiGANGenerator(nn.Module):
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
*,
|
366 |
+
hop_length: int = 512,
|
367 |
+
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
|
368 |
+
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
|
369 |
+
resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
|
370 |
+
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
371 |
+
(1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
372 |
+
num_mels: int = 128,
|
373 |
+
upsample_initial_channel: int = 512,
|
374 |
+
use_template: bool = True,
|
375 |
+
pre_conv_kernel_size: int = 7,
|
376 |
+
post_conv_kernel_size: int = 7,
|
377 |
+
post_activation: Callable = partial(nn.SiLU, inplace=True),
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
|
381 |
+
assert (
|
382 |
+
prod(upsample_rates) == hop_length
|
383 |
+
), f"hop_length must be {prod(upsample_rates)}"
|
384 |
+
|
385 |
+
self.conv_pre = weight_norm(
|
386 |
+
nn.Conv1d(
|
387 |
+
num_mels,
|
388 |
+
upsample_initial_channel,
|
389 |
+
pre_conv_kernel_size,
|
390 |
+
1,
|
391 |
+
padding=get_padding(pre_conv_kernel_size),
|
392 |
+
)
|
393 |
+
)
|
394 |
+
|
395 |
+
self.num_upsamples = len(upsample_rates)
|
396 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
397 |
+
|
398 |
+
self.noise_convs = nn.ModuleList()
|
399 |
+
self.use_template = use_template
|
400 |
+
self.ups = nn.ModuleList()
|
401 |
+
|
402 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
403 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
404 |
+
self.ups.append(
|
405 |
+
weight_norm(
|
406 |
+
nn.ConvTranspose1d(
|
407 |
+
upsample_initial_channel // (2**i),
|
408 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
409 |
+
k,
|
410 |
+
u,
|
411 |
+
padding=(k - u) // 2,
|
412 |
+
)
|
413 |
+
)
|
414 |
+
)
|
415 |
+
|
416 |
+
if not use_template:
|
417 |
+
continue
|
418 |
+
|
419 |
+
if i + 1 < len(upsample_rates):
|
420 |
+
stride_f0 = np.prod(upsample_rates[i + 1:])
|
421 |
+
self.noise_convs.append(
|
422 |
+
Conv1d(
|
423 |
+
1,
|
424 |
+
c_cur,
|
425 |
+
kernel_size=stride_f0 * 2,
|
426 |
+
stride=stride_f0,
|
427 |
+
padding=stride_f0 // 2,
|
428 |
+
)
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
432 |
+
|
433 |
+
self.resblocks = nn.ModuleList()
|
434 |
+
for i in range(len(self.ups)):
|
435 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
436 |
+
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
437 |
+
self.resblocks.append(ResBlock1(ch, k, d))
|
438 |
+
|
439 |
+
self.activation_post = post_activation()
|
440 |
+
self.conv_post = weight_norm(
|
441 |
+
nn.Conv1d(
|
442 |
+
ch,
|
443 |
+
1,
|
444 |
+
post_conv_kernel_size,
|
445 |
+
1,
|
446 |
+
padding=get_padding(post_conv_kernel_size),
|
447 |
+
)
|
448 |
+
)
|
449 |
+
self.ups.apply(init_weights)
|
450 |
+
self.conv_post.apply(init_weights)
|
451 |
+
|
452 |
+
def forward(self, x, template=None):
|
453 |
+
x = self.conv_pre(x)
|
454 |
+
|
455 |
+
for i in range(self.num_upsamples):
|
456 |
+
x = F.silu(x, inplace=True)
|
457 |
+
x = self.ups[i](x)
|
458 |
+
|
459 |
+
if self.use_template:
|
460 |
+
x = x + self.noise_convs[i](template)
|
461 |
+
|
462 |
+
xs = None
|
463 |
+
|
464 |
+
for j in range(self.num_kernels):
|
465 |
+
if xs is None:
|
466 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
467 |
+
else:
|
468 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
469 |
+
|
470 |
+
x = xs / self.num_kernels
|
471 |
+
|
472 |
+
x = self.activation_post(x)
|
473 |
+
x = self.conv_post(x)
|
474 |
+
x = torch.tanh(x)
|
475 |
+
|
476 |
+
return x
|
477 |
+
|
478 |
+
def remove_weight_norm(self):
|
479 |
+
for up in self.ups:
|
480 |
+
remove_weight_norm(up)
|
481 |
+
for block in self.resblocks:
|
482 |
+
block.remove_weight_norm()
|
483 |
+
remove_weight_norm(self.conv_pre)
|
484 |
+
remove_weight_norm(self.conv_post)
|
485 |
+
|
486 |
+
|
487 |
+
class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
488 |
+
|
489 |
+
@register_to_config
|
490 |
+
def __init__(
|
491 |
+
self,
|
492 |
+
input_channels: int = 128,
|
493 |
+
depths: List[int] = [3, 3, 9, 3],
|
494 |
+
dims: List[int] = [128, 256, 384, 512],
|
495 |
+
drop_path_rate: float = 0.0,
|
496 |
+
kernel_sizes: Tuple[int] = (7,),
|
497 |
+
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
|
498 |
+
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
|
499 |
+
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
|
500 |
+
resblock_dilation_sizes: Tuple[Tuple[int]] = (
|
501 |
+
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
502 |
+
num_mels: int = 512,
|
503 |
+
upsample_initial_channel: int = 1024,
|
504 |
+
use_template: bool = False,
|
505 |
+
pre_conv_kernel_size: int = 13,
|
506 |
+
post_conv_kernel_size: int = 13,
|
507 |
+
sampling_rate: int = 44100,
|
508 |
+
n_fft: int = 2048,
|
509 |
+
win_length: int = 2048,
|
510 |
+
hop_length: int = 512,
|
511 |
+
f_min: int = 40,
|
512 |
+
f_max: int = 16000,
|
513 |
+
n_mels: int = 128,
|
514 |
+
):
|
515 |
+
super().__init__()
|
516 |
+
|
517 |
+
self.backbone = ConvNeXtEncoder(
|
518 |
+
input_channels=input_channels,
|
519 |
+
depths=depths,
|
520 |
+
dims=dims,
|
521 |
+
drop_path_rate=drop_path_rate,
|
522 |
+
kernel_sizes=kernel_sizes,
|
523 |
+
)
|
524 |
+
|
525 |
+
self.head = HiFiGANGenerator(
|
526 |
+
hop_length=hop_length,
|
527 |
+
upsample_rates=upsample_rates,
|
528 |
+
upsample_kernel_sizes=upsample_kernel_sizes,
|
529 |
+
resblock_kernel_sizes=resblock_kernel_sizes,
|
530 |
+
resblock_dilation_sizes=resblock_dilation_sizes,
|
531 |
+
num_mels=num_mels,
|
532 |
+
upsample_initial_channel=upsample_initial_channel,
|
533 |
+
use_template=use_template,
|
534 |
+
pre_conv_kernel_size=pre_conv_kernel_size,
|
535 |
+
post_conv_kernel_size=post_conv_kernel_size,
|
536 |
+
)
|
537 |
+
self.sampling_rate = sampling_rate
|
538 |
+
self.mel_transform = LogMelSpectrogram(
|
539 |
+
sample_rate=sampling_rate,
|
540 |
+
n_fft=n_fft,
|
541 |
+
win_length=win_length,
|
542 |
+
hop_length=hop_length,
|
543 |
+
f_min=f_min,
|
544 |
+
f_max=f_max,
|
545 |
+
n_mels=n_mels,
|
546 |
+
)
|
547 |
+
self.eval()
|
548 |
+
|
549 |
+
@torch.no_grad()
|
550 |
+
def decode(self, mel):
|
551 |
+
y = self.backbone(mel)
|
552 |
+
y = self.head(y)
|
553 |
+
return y
|
554 |
+
|
555 |
+
@torch.no_grad()
|
556 |
+
def encode(self, x):
|
557 |
+
return self.mel_transform(x)
|
558 |
+
|
559 |
+
def forward(self, mel):
|
560 |
+
y = self.backbone(mel)
|
561 |
+
y = self.head(y)
|
562 |
+
return y
|
563 |
+
|
564 |
+
|
565 |
+
if __name__ == "__main__":
|
566 |
+
import soundfile as sf
|
567 |
+
|
568 |
+
x = "test_audio.flac"
|
569 |
+
model = ADaMoSHiFiGANV1.from_pretrained("./checkpoints/music_vocoder", local_files_only=True)
|
570 |
+
|
571 |
+
wav, sr = librosa.load(x, sr=44100, mono=True)
|
572 |
+
wav = torch.from_numpy(wav).float()[None]
|
573 |
+
mel = model.encode(wav)
|
574 |
+
|
575 |
+
wav = model.decode(mel)[0].mT
|
576 |
+
sf.write("test_audio_vocoder_rec.flac", wav.cpu().numpy(), 44100)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
pipeline_ace_step.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import glob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from loguru import logger
|
10 |
+
from tqdm import tqdm
|
11 |
+
import json
|
12 |
+
import math
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
# from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
17 |
+
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
18 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
19 |
+
from diffusers.utils.torch_utils import randn_tensor
|
20 |
+
from transformers import UMT5EncoderModel, AutoTokenizer
|
21 |
+
|
22 |
+
from language_segmentation import LangSegment
|
23 |
+
from music_dcae.music_dcae_pipeline import MusicDCAE
|
24 |
+
from models.ace_step_transformer import ACEStepTransformer2DModel
|
25 |
+
from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
|
26 |
+
from apg_guidance import apg_forward, MomentumBuffer, cfg_forward, cfg_zero_star, cfg_double_condition_forward
|
27 |
+
import torchaudio
|
28 |
+
|
29 |
+
|
30 |
+
SUPPORT_LANGUAGES = {
|
31 |
+
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
32 |
+
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
33 |
+
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
34 |
+
"ko": 6152, "hi": 6680
|
35 |
+
}
|
36 |
+
|
37 |
+
structure_pattern = re.compile(r"\[.*?\]")
|
38 |
+
|
39 |
+
|
40 |
+
def ensure_directory_exists(directory):
|
41 |
+
directory = str(directory)
|
42 |
+
if not os.path.exists(directory):
|
43 |
+
os.makedirs(directory)
|
44 |
+
|
45 |
+
|
46 |
+
REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
|
47 |
+
|
48 |
+
|
49 |
+
class ACEStepPipeline:
|
50 |
+
|
51 |
+
def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, **kwargs):
|
52 |
+
if checkpoint_dir is None:
|
53 |
+
if persistent_storage_path is None:
|
54 |
+
checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
|
55 |
+
else:
|
56 |
+
checkpoint_dir = os.path.join(persistent_storage_path, "checkpoints")
|
57 |
+
ensure_directory_exists(checkpoint_dir)
|
58 |
+
|
59 |
+
self.checkpoint_dir = checkpoint_dir
|
60 |
+
device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu")
|
61 |
+
self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
|
62 |
+
self.device = device
|
63 |
+
self.loaded = False
|
64 |
+
|
65 |
+
def load_checkpoint(self, checkpoint_dir=None):
|
66 |
+
device = self.device
|
67 |
+
|
68 |
+
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
69 |
+
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
70 |
+
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
71 |
+
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
72 |
+
|
73 |
+
files_exist = (
|
74 |
+
os.path.exists(os.path.join(dcae_model_path, "config.json")) and
|
75 |
+
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
|
76 |
+
os.path.exists(os.path.join(vocoder_model_path, "config.json")) and
|
77 |
+
os.path.exists(os.path.join(vocoder_model_path, "diffusion_pytorch_model.safetensors")) and
|
78 |
+
os.path.exists(os.path.join(ace_step_model_path, "config.json")) and
|
79 |
+
os.path.exists(os.path.join(ace_step_model_path, "diffusion_pytorch_model.safetensors")) and
|
80 |
+
os.path.exists(os.path.join(text_encoder_model_path, "config.json")) and
|
81 |
+
os.path.exists(os.path.join(text_encoder_model_path, "model.safetensors")) and
|
82 |
+
os.path.exists(os.path.join(text_encoder_model_path, "special_tokens_map.json")) and
|
83 |
+
os.path.exists(os.path.join(text_encoder_model_path, "tokenizer_config.json")) and
|
84 |
+
os.path.exists(os.path.join(text_encoder_model_path, "tokenizer.json"))
|
85 |
+
)
|
86 |
+
|
87 |
+
if not files_exist:
|
88 |
+
logger.info(f"Checkpoint directory {checkpoint_dir} is not complete, downloading from Hugging Face Hub")
|
89 |
+
|
90 |
+
# download music dcae model
|
91 |
+
os.makedirs(dcae_model_path, exist_ok=True)
|
92 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
|
93 |
+
filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
94 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="music_dcae_f8c8",
|
95 |
+
filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
96 |
+
|
97 |
+
# download vocoder model
|
98 |
+
os.makedirs(vocoder_model_path, exist_ok=True)
|
99 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
|
100 |
+
filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
101 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="music_vocoder",
|
102 |
+
filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
103 |
+
|
104 |
+
# download ace_step transformer model
|
105 |
+
os.makedirs(ace_step_model_path, exist_ok=True)
|
106 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
|
107 |
+
filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
108 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="ace_step_transformer",
|
109 |
+
filename="diffusion_pytorch_model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
110 |
+
|
111 |
+
# download text encoder model
|
112 |
+
os.makedirs(text_encoder_model_path, exist_ok=True)
|
113 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
|
114 |
+
filename="config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
115 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
|
116 |
+
filename="model.safetensors", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
117 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
|
118 |
+
filename="special_tokens_map.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
119 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
|
120 |
+
filename="tokenizer_config.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
121 |
+
hf_hub_download(repo_id=REPO_ID, subfolder="umt5-base",
|
122 |
+
filename="tokenizer.json", local_dir=checkpoint_dir, local_dir_use_symlinks=False)
|
123 |
+
|
124 |
+
logger.info("Models downloaded")
|
125 |
+
|
126 |
+
dcae_checkpoint_path = dcae_model_path
|
127 |
+
vocoder_checkpoint_path = vocoder_model_path
|
128 |
+
ace_step_checkpoint_path = ace_step_model_path
|
129 |
+
text_encoder_checkpoint_path = text_encoder_model_path
|
130 |
+
|
131 |
+
self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
|
132 |
+
self.music_dcae.to(device).eval().to(self.dtype)
|
133 |
+
|
134 |
+
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
|
135 |
+
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
136 |
+
|
137 |
+
lang_segment = LangSegment()
|
138 |
+
|
139 |
+
lang_segment.setfilters([
|
140 |
+
'af', 'am', 'an', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'dz', 'el',
|
141 |
+
'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr', 'ga', 'gl', 'gu', 'he', 'hi', 'hr', 'ht', 'hu', 'hy',
|
142 |
+
'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lb', 'lo', 'lt', 'lv', 'mg',
|
143 |
+
'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'nb', 'ne', 'nl', 'nn', 'no', 'oc', 'or', 'pa', 'pl', 'ps', 'pt', 'qu',
|
144 |
+
'ro', 'ru', 'rw', 'se', 'si', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ug', 'uk',
|
145 |
+
'ur', 'vi', 'vo', 'wa', 'xh', 'zh', 'zu'
|
146 |
+
])
|
147 |
+
self.lang_segment = lang_segment
|
148 |
+
self.lyric_tokenizer = VoiceBpeTokenizer()
|
149 |
+
text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path).eval()
|
150 |
+
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
151 |
+
text_encoder_model.requires_grad_(False)
|
152 |
+
self.text_encoder_model = text_encoder_model
|
153 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_checkpoint_path)
|
154 |
+
self.loaded = True
|
155 |
+
|
156 |
+
# compile
|
157 |
+
# self.music_dcae = torch.compile(self.music_dcae)
|
158 |
+
# self.ace_step_transformer = torch.compile(self.ace_step_transformer)
|
159 |
+
# self.text_encoder_model = torch.compile(self.text_encoder_model)
|
160 |
+
|
161 |
+
def get_text_embeddings(self, texts, device, text_max_length=256):
|
162 |
+
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
163 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
164 |
+
if self.text_encoder_model.device != device:
|
165 |
+
self.text_encoder_model.to(device)
|
166 |
+
with torch.no_grad():
|
167 |
+
outputs = self.text_encoder_model(**inputs)
|
168 |
+
last_hidden_states = outputs.last_hidden_state
|
169 |
+
attention_mask = inputs["attention_mask"]
|
170 |
+
return last_hidden_states, attention_mask
|
171 |
+
|
172 |
+
def get_text_embeddings_null(self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10):
|
173 |
+
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
174 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
175 |
+
if self.text_encoder_model.device != device:
|
176 |
+
self.text_encoder_model.to(device)
|
177 |
+
|
178 |
+
def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
|
179 |
+
handlers = []
|
180 |
+
|
181 |
+
def hook(module, input, output):
|
182 |
+
output[:] *= tau
|
183 |
+
return output
|
184 |
+
|
185 |
+
for i in range(l_min, l_max):
|
186 |
+
handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
|
187 |
+
handlers.append(handler)
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
outputs = self.text_encoder_model(**inputs)
|
191 |
+
last_hidden_states = outputs.last_hidden_state
|
192 |
+
|
193 |
+
for hook in handlers:
|
194 |
+
hook.remove()
|
195 |
+
|
196 |
+
return last_hidden_states
|
197 |
+
|
198 |
+
last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
|
199 |
+
return last_hidden_states
|
200 |
+
|
201 |
+
def set_seeds(self, batch_size, manual_seeds=None):
|
202 |
+
seeds = None
|
203 |
+
if manual_seeds is not None:
|
204 |
+
if isinstance(manual_seeds, str):
|
205 |
+
if "," in manual_seeds:
|
206 |
+
seeds = list(map(int, manual_seeds.split(",")))
|
207 |
+
elif manual_seeds.isdigit():
|
208 |
+
seeds = int(manual_seeds)
|
209 |
+
|
210 |
+
random_generators = [torch.Generator(device=self.device) for _ in range(batch_size)]
|
211 |
+
actual_seeds = []
|
212 |
+
for i in range(batch_size):
|
213 |
+
seed = None
|
214 |
+
if seeds is None:
|
215 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
216 |
+
if isinstance(seeds, int):
|
217 |
+
seed = seeds
|
218 |
+
if isinstance(seeds, list):
|
219 |
+
seed = seeds[i]
|
220 |
+
random_generators[i].manual_seed(seed)
|
221 |
+
actual_seeds.append(seed)
|
222 |
+
return random_generators, actual_seeds
|
223 |
+
|
224 |
+
def get_lang(self, text):
|
225 |
+
language = "en"
|
226 |
+
try:
|
227 |
+
_ = self.lang_segment.getTexts(text)
|
228 |
+
langCounts = self.lang_segment.getCounts()
|
229 |
+
language = langCounts[0][0]
|
230 |
+
if len(langCounts) > 1 and language == "en":
|
231 |
+
language = langCounts[1][0]
|
232 |
+
except Exception as err:
|
233 |
+
language = "en"
|
234 |
+
return language
|
235 |
+
|
236 |
+
def tokenize_lyrics(self, lyrics, debug=False):
|
237 |
+
lines = lyrics.split("\n")
|
238 |
+
lyric_token_idx = [261]
|
239 |
+
for line in lines:
|
240 |
+
line = line.strip()
|
241 |
+
if not line:
|
242 |
+
lyric_token_idx += [2]
|
243 |
+
continue
|
244 |
+
|
245 |
+
lang = self.get_lang(line)
|
246 |
+
|
247 |
+
if lang not in SUPPORT_LANGUAGES:
|
248 |
+
lang = "en"
|
249 |
+
if "zh" in lang:
|
250 |
+
lang = "zh"
|
251 |
+
if "spa" in lang:
|
252 |
+
lang = "es"
|
253 |
+
|
254 |
+
try:
|
255 |
+
if structure_pattern.match(line):
|
256 |
+
token_idx = self.lyric_tokenizer.encode(line, "en")
|
257 |
+
else:
|
258 |
+
token_idx = self.lyric_tokenizer.encode(line, lang)
|
259 |
+
if debug:
|
260 |
+
toks = self.lyric_tokenizer.batch_decode([[tok_id] for tok_id in token_idx])
|
261 |
+
logger.info(f"debbug {line} --> {lang} --> {toks}")
|
262 |
+
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
263 |
+
except Exception as e:
|
264 |
+
print("tokenize error", e, "for line", line, "major_language", lang)
|
265 |
+
return lyric_token_idx
|
266 |
+
|
267 |
+
@torch.no_grad()
|
268 |
+
def text2music_diffusion_process(
|
269 |
+
self,
|
270 |
+
duration,
|
271 |
+
encoder_text_hidden_states,
|
272 |
+
text_attention_mask,
|
273 |
+
speaker_embds,
|
274 |
+
lyric_token_ids,
|
275 |
+
lyric_mask,
|
276 |
+
random_generators=None,
|
277 |
+
infer_steps=60,
|
278 |
+
guidance_scale=15.0,
|
279 |
+
omega_scale=10.0,
|
280 |
+
scheduler_type="euler",
|
281 |
+
cfg_type="apg",
|
282 |
+
zero_steps=1,
|
283 |
+
use_zero_init=True,
|
284 |
+
guidance_interval=0.5,
|
285 |
+
guidance_interval_decay=1.0,
|
286 |
+
min_guidance_scale=3.0,
|
287 |
+
oss_steps=[],
|
288 |
+
encoder_text_hidden_states_null=None,
|
289 |
+
use_erg_lyric=False,
|
290 |
+
use_erg_diffusion=False,
|
291 |
+
retake_random_generators=None,
|
292 |
+
retake_variance=0.5,
|
293 |
+
add_retake_noise=False,
|
294 |
+
guidance_scale_text=0.0,
|
295 |
+
guidance_scale_lyric=0.0,
|
296 |
+
):
|
297 |
+
|
298 |
+
logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
|
299 |
+
do_classifier_free_guidance = True
|
300 |
+
if guidance_scale == 0.0 or guidance_scale == 1.0:
|
301 |
+
do_classifier_free_guidance = False
|
302 |
+
|
303 |
+
do_double_condition_guidance = False
|
304 |
+
if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
|
305 |
+
do_double_condition_guidance = True
|
306 |
+
logger.info("do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format(do_double_condition_guidance, guidance_scale_text, guidance_scale_lyric))
|
307 |
+
|
308 |
+
device = encoder_text_hidden_states.device
|
309 |
+
dtype = encoder_text_hidden_states.dtype
|
310 |
+
bsz = encoder_text_hidden_states.shape[0]
|
311 |
+
|
312 |
+
if scheduler_type == "euler":
|
313 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
314 |
+
num_train_timesteps=1000,
|
315 |
+
shift=3.0,
|
316 |
+
)
|
317 |
+
elif scheduler_type == "heun":
|
318 |
+
scheduler = FlowMatchHeunDiscreteScheduler(
|
319 |
+
num_train_timesteps=1000,
|
320 |
+
shift=3.0,
|
321 |
+
)
|
322 |
+
frame_length = int(duration * 44100 / 512 / 8)
|
323 |
+
|
324 |
+
if len(oss_steps) > 0:
|
325 |
+
infer_steps = max(oss_steps)
|
326 |
+
scheduler.set_timesteps
|
327 |
+
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
|
328 |
+
new_timesteps = torch.zeros(len(oss_steps), dtype=dtype, device=device)
|
329 |
+
for idx in range(len(oss_steps)):
|
330 |
+
new_timesteps[idx] = timesteps[oss_steps[idx]-1]
|
331 |
+
num_inference_steps = len(oss_steps)
|
332 |
+
sigmas = (new_timesteps / 1000).float().cpu().numpy()
|
333 |
+
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=num_inference_steps, device=device, sigmas=sigmas)
|
334 |
+
logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
|
335 |
+
else:
|
336 |
+
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
|
337 |
+
|
338 |
+
target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
|
339 |
+
if add_retake_noise:
|
340 |
+
retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
|
341 |
+
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
342 |
+
# to make sure mean = 0, std = 1
|
343 |
+
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
344 |
+
|
345 |
+
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
346 |
+
|
347 |
+
# guidance interval逻辑
|
348 |
+
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
|
349 |
+
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
|
350 |
+
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
|
351 |
+
|
352 |
+
momentum_buffer = MomentumBuffer()
|
353 |
+
|
354 |
+
def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
|
355 |
+
handlers = []
|
356 |
+
|
357 |
+
def hook(module, input, output):
|
358 |
+
output[:] *= tau
|
359 |
+
return output
|
360 |
+
|
361 |
+
for i in range(l_min, l_max):
|
362 |
+
handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
|
363 |
+
handlers.append(handler)
|
364 |
+
|
365 |
+
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
|
366 |
+
|
367 |
+
for hook in handlers:
|
368 |
+
hook.remove()
|
369 |
+
|
370 |
+
return encoder_hidden_states
|
371 |
+
|
372 |
+
# P(speaker, text, lyric)
|
373 |
+
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(
|
374 |
+
encoder_text_hidden_states,
|
375 |
+
text_attention_mask,
|
376 |
+
speaker_embds,
|
377 |
+
lyric_token_ids,
|
378 |
+
lyric_mask,
|
379 |
+
)
|
380 |
+
|
381 |
+
if use_erg_lyric:
|
382 |
+
# P(null_speaker, text_weaker, lyric_weaker)
|
383 |
+
encoder_hidden_states_null = forward_encoder_with_temperature(
|
384 |
+
self,
|
385 |
+
inputs={
|
386 |
+
"encoder_text_hidden_states": encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states),
|
387 |
+
"text_attention_mask": text_attention_mask,
|
388 |
+
"speaker_embeds": torch.zeros_like(speaker_embds),
|
389 |
+
"lyric_token_idx": lyric_token_ids,
|
390 |
+
"lyric_mask": lyric_mask,
|
391 |
+
}
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
# P(null_speaker, null_text, null_lyric)
|
395 |
+
encoder_hidden_states_null, _ = self.ace_step_transformer.encode(
|
396 |
+
torch.zeros_like(encoder_text_hidden_states),
|
397 |
+
text_attention_mask,
|
398 |
+
torch.zeros_like(speaker_embds),
|
399 |
+
torch.zeros_like(lyric_token_ids),
|
400 |
+
lyric_mask,
|
401 |
+
)
|
402 |
+
|
403 |
+
encoder_hidden_states_no_lyric = None
|
404 |
+
if do_double_condition_guidance:
|
405 |
+
# P(null_speaker, text, lyric_weaker)
|
406 |
+
if use_erg_lyric:
|
407 |
+
encoder_hidden_states_no_lyric = forward_encoder_with_temperature(
|
408 |
+
self,
|
409 |
+
inputs={
|
410 |
+
"encoder_text_hidden_states": encoder_text_hidden_states,
|
411 |
+
"text_attention_mask": text_attention_mask,
|
412 |
+
"speaker_embeds": torch.zeros_like(speaker_embds),
|
413 |
+
"lyric_token_idx": lyric_token_ids,
|
414 |
+
"lyric_mask": lyric_mask,
|
415 |
+
}
|
416 |
+
)
|
417 |
+
# P(null_speaker, text, no_lyric)
|
418 |
+
else:
|
419 |
+
encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode(
|
420 |
+
encoder_text_hidden_states,
|
421 |
+
text_attention_mask,
|
422 |
+
torch.zeros_like(speaker_embds),
|
423 |
+
torch.zeros_like(lyric_token_ids),
|
424 |
+
lyric_mask,
|
425 |
+
)
|
426 |
+
|
427 |
+
def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
|
428 |
+
handlers = []
|
429 |
+
|
430 |
+
def hook(module, input, output):
|
431 |
+
output[:] *= tau
|
432 |
+
return output
|
433 |
+
|
434 |
+
for i in range(l_min, l_max):
|
435 |
+
handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
|
436 |
+
handlers.append(handler)
|
437 |
+
handler = self.ace_step_transformer.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook)
|
438 |
+
handlers.append(handler)
|
439 |
+
|
440 |
+
sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
|
441 |
+
|
442 |
+
for hook in handlers:
|
443 |
+
hook.remove()
|
444 |
+
|
445 |
+
return sample
|
446 |
+
|
447 |
+
|
448 |
+
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
449 |
+
# expand the latents if we are doing classifier free guidance
|
450 |
+
latents = target_latents
|
451 |
+
|
452 |
+
is_in_guidance_interval = start_idx <= i < end_idx
|
453 |
+
if is_in_guidance_interval and do_classifier_free_guidance:
|
454 |
+
# compute current guidance scale
|
455 |
+
if guidance_interval_decay > 0:
|
456 |
+
# Linearly interpolate to calculate the current guidance scale
|
457 |
+
progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1]
|
458 |
+
current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay
|
459 |
+
else:
|
460 |
+
current_guidance_scale = guidance_scale
|
461 |
+
|
462 |
+
latent_model_input = latents
|
463 |
+
timestep = t.expand(latent_model_input.shape[0])
|
464 |
+
output_length = latent_model_input.shape[-1]
|
465 |
+
# P(x|speaker, text, lyric)
|
466 |
+
noise_pred_with_cond = self.ace_step_transformer.decode(
|
467 |
+
hidden_states=latent_model_input,
|
468 |
+
attention_mask=attention_mask,
|
469 |
+
encoder_hidden_states=encoder_hidden_states,
|
470 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
471 |
+
output_length=output_length,
|
472 |
+
timestep=timestep,
|
473 |
+
).sample
|
474 |
+
|
475 |
+
noise_pred_with_only_text_cond = None
|
476 |
+
if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None:
|
477 |
+
noise_pred_with_only_text_cond = self.ace_step_transformer.decode(
|
478 |
+
hidden_states=latent_model_input,
|
479 |
+
attention_mask=attention_mask,
|
480 |
+
encoder_hidden_states=encoder_hidden_states_no_lyric,
|
481 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
482 |
+
output_length=output_length,
|
483 |
+
timestep=timestep,
|
484 |
+
).sample
|
485 |
+
|
486 |
+
if use_erg_diffusion:
|
487 |
+
noise_pred_uncond = forward_diffusion_with_temperature(
|
488 |
+
self,
|
489 |
+
hidden_states=latent_model_input,
|
490 |
+
timestep=timestep,
|
491 |
+
inputs={
|
492 |
+
"encoder_hidden_states": encoder_hidden_states_null,
|
493 |
+
"encoder_hidden_mask": encoder_hidden_mask,
|
494 |
+
"output_length": output_length,
|
495 |
+
"attention_mask": attention_mask,
|
496 |
+
},
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
noise_pred_uncond = self.ace_step_transformer.decode(
|
500 |
+
hidden_states=latent_model_input,
|
501 |
+
attention_mask=attention_mask,
|
502 |
+
encoder_hidden_states=encoder_hidden_states_null,
|
503 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
504 |
+
output_length=output_length,
|
505 |
+
timestep=timestep,
|
506 |
+
).sample
|
507 |
+
|
508 |
+
if do_double_condition_guidance and noise_pred_with_only_text_cond is not None:
|
509 |
+
noise_pred = cfg_double_condition_forward(
|
510 |
+
cond_output=noise_pred_with_cond,
|
511 |
+
uncond_output=noise_pred_uncond,
|
512 |
+
only_text_cond_output=noise_pred_with_only_text_cond,
|
513 |
+
guidance_scale_text=guidance_scale_text,
|
514 |
+
guidance_scale_lyric=guidance_scale_lyric,
|
515 |
+
)
|
516 |
+
|
517 |
+
elif cfg_type == "apg":
|
518 |
+
noise_pred = apg_forward(
|
519 |
+
pred_cond=noise_pred_with_cond,
|
520 |
+
pred_uncond=noise_pred_uncond,
|
521 |
+
guidance_scale=current_guidance_scale,
|
522 |
+
momentum_buffer=momentum_buffer,
|
523 |
+
)
|
524 |
+
elif cfg_type == "cfg":
|
525 |
+
noise_pred = cfg_forward(
|
526 |
+
cond_output=noise_pred_with_cond,
|
527 |
+
uncond_output=noise_pred_uncond,
|
528 |
+
cfg_strength=current_guidance_scale,
|
529 |
+
)
|
530 |
+
elif cfg_type == "cfg_star":
|
531 |
+
noise_pred = cfg_zero_star(
|
532 |
+
noise_pred_with_cond=noise_pred_with_cond,
|
533 |
+
noise_pred_uncond=noise_pred_uncond,
|
534 |
+
guidance_scale=current_guidance_scale,
|
535 |
+
i=i,
|
536 |
+
zero_steps=zero_steps,
|
537 |
+
use_zero_init=use_zero_init
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
latent_model_input = latents
|
541 |
+
timestep = t.expand(latent_model_input.shape[0])
|
542 |
+
noise_pred = self.ace_step_transformer.decode(
|
543 |
+
hidden_states=latent_model_input,
|
544 |
+
attention_mask=attention_mask,
|
545 |
+
encoder_hidden_states=encoder_hidden_states,
|
546 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
547 |
+
output_length=latent_model_input.shape[-1],
|
548 |
+
timestep=timestep,
|
549 |
+
).sample
|
550 |
+
|
551 |
+
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
552 |
+
|
553 |
+
return target_latents
|
554 |
+
|
555 |
+
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
556 |
+
output_audio_paths = []
|
557 |
+
bs = latents.shape[0]
|
558 |
+
audio_lengths = [target_wav_duration_second * sample_rate] * bs
|
559 |
+
pred_latents = latents
|
560 |
+
with torch.no_grad():
|
561 |
+
_, pred_wavs = self.music_dcae.decode(pred_latents, sr=sample_rate)
|
562 |
+
pred_wavs = [pred_wav.cpu().float() for pred_wav in pred_wavs]
|
563 |
+
for i in tqdm(range(bs)):
|
564 |
+
output_audio_path = self.save_wav_file(pred_wavs[i], i, sample_rate=sample_rate)
|
565 |
+
output_audio_paths.append(output_audio_path)
|
566 |
+
return output_audio_paths
|
567 |
+
|
568 |
+
def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="flac"):
|
569 |
+
if save_path is None:
|
570 |
+
logger.warning("save_path is None, using default path ./outputs/")
|
571 |
+
base_path = f"./outputs/"
|
572 |
+
ensure_directory_exists(base_path)
|
573 |
+
else:
|
574 |
+
base_path = save_path
|
575 |
+
ensure_directory_exists(base_path)
|
576 |
+
|
577 |
+
output_path_flac = f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
|
578 |
+
target_wav = target_wav.float()
|
579 |
+
torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format=format)
|
580 |
+
return output_path_flac
|
581 |
+
|
582 |
+
def __call__(
|
583 |
+
self,
|
584 |
+
audio_duration: float = 60.0,
|
585 |
+
prompt: str = None,
|
586 |
+
lyrics: str = None,
|
587 |
+
infer_step: int = 60,
|
588 |
+
guidance_scale: float = 15.0,
|
589 |
+
scheduler_type: str = "euler",
|
590 |
+
cfg_type: str = "apg",
|
591 |
+
omega_scale: int = 10.0,
|
592 |
+
manual_seeds: list = None,
|
593 |
+
guidance_interval: float = 0.5,
|
594 |
+
guidance_interval_decay: float = 0.,
|
595 |
+
min_guidance_scale: float = 3.0,
|
596 |
+
use_erg_tag: bool = True,
|
597 |
+
use_erg_lyric: bool = True,
|
598 |
+
use_erg_diffusion: bool = True,
|
599 |
+
oss_steps: str = None,
|
600 |
+
guidance_scale_text: float = 0.0,
|
601 |
+
guidance_scale_lyric: float = 0.0,
|
602 |
+
retake_seeds: list = None,
|
603 |
+
retake_variance: float = 0.5,
|
604 |
+
task: str = "text2music",
|
605 |
+
save_path: str = None,
|
606 |
+
format: str = "flac",
|
607 |
+
batch_size: int = 1,
|
608 |
+
):
|
609 |
+
|
610 |
+
start_time = time.time()
|
611 |
+
|
612 |
+
if not self.loaded:
|
613 |
+
logger.warning("Checkpoint not loaded, loading checkpoint...")
|
614 |
+
self.load_checkpoint(self.checkpoint_dir)
|
615 |
+
load_model_cost = time.time() - start_time
|
616 |
+
logger.info(f"Model loaded in {load_model_cost:.2f} seconds.")
|
617 |
+
|
618 |
+
start_time = time.time()
|
619 |
+
|
620 |
+
random_generators, actual_seeds = self.set_seeds(batch_size, manual_seeds)
|
621 |
+
retake_random_generators, actual_retake_seeds = self.set_seeds(batch_size, retake_seeds)
|
622 |
+
|
623 |
+
if isinstance(oss_steps, str) and len(oss_steps) > 0:
|
624 |
+
oss_steps = list(map(int, oss_steps.split(",")))
|
625 |
+
else:
|
626 |
+
oss_steps = []
|
627 |
+
|
628 |
+
texts = [prompt]
|
629 |
+
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
|
630 |
+
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|
631 |
+
text_attention_mask = text_attention_mask.repeat(batch_size, 1)
|
632 |
+
|
633 |
+
encoder_text_hidden_states_null = None
|
634 |
+
if use_erg_tag:
|
635 |
+
encoder_text_hidden_states_null = self.get_text_embeddings_null(texts, self.device)
|
636 |
+
encoder_text_hidden_states_null = encoder_text_hidden_states_null.repeat(batch_size, 1, 1)
|
637 |
+
|
638 |
+
# not support for released checkpoint
|
639 |
+
speaker_embeds = torch.zeros(batch_size, 512).to(self.device).to(self.dtype)
|
640 |
+
|
641 |
+
# 6 lyric
|
642 |
+
lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
643 |
+
lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
|
644 |
+
if len(lyrics) > 0:
|
645 |
+
lyric_token_idx = self.tokenize_lyrics(lyrics, debug=True)
|
646 |
+
lyric_mask = [1] * len(lyric_token_idx)
|
647 |
+
lyric_token_idx = torch.tensor(lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
648 |
+
lyric_mask = torch.tensor(lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
|
649 |
+
|
650 |
+
if audio_duration <= 0:
|
651 |
+
audio_duration = random.uniform(30.0, 240.0)
|
652 |
+
logger.info(f"random audio duration: {audio_duration}")
|
653 |
+
|
654 |
+
end_time = time.time()
|
655 |
+
preprocess_time_cost = end_time - start_time
|
656 |
+
start_time = end_time
|
657 |
+
|
658 |
+
target_latents = self.text2music_diffusion_process(
|
659 |
+
duration=audio_duration,
|
660 |
+
encoder_text_hidden_states=encoder_text_hidden_states,
|
661 |
+
text_attention_mask=text_attention_mask,
|
662 |
+
speaker_embds=speaker_embeds,
|
663 |
+
lyric_token_ids=lyric_token_idx,
|
664 |
+
lyric_mask=lyric_mask,
|
665 |
+
guidance_scale=guidance_scale,
|
666 |
+
omega_scale=omega_scale,
|
667 |
+
infer_steps=infer_step,
|
668 |
+
random_generators=random_generators,
|
669 |
+
scheduler_type=scheduler_type,
|
670 |
+
cfg_type=cfg_type,
|
671 |
+
guidance_interval=guidance_interval,
|
672 |
+
guidance_interval_decay=guidance_interval_decay,
|
673 |
+
min_guidance_scale=min_guidance_scale,
|
674 |
+
oss_steps=oss_steps,
|
675 |
+
encoder_text_hidden_states_null=encoder_text_hidden_states_null,
|
676 |
+
use_erg_lyric=use_erg_lyric,
|
677 |
+
use_erg_diffusion=use_erg_diffusion,
|
678 |
+
retake_random_generators=retake_random_generators,
|
679 |
+
retake_variance=retake_variance,
|
680 |
+
add_retake_noise=task == "retake",
|
681 |
+
guidance_scale_text=guidance_scale_text,
|
682 |
+
guidance_scale_lyric=guidance_scale_lyric,
|
683 |
+
)
|
684 |
+
|
685 |
+
end_time = time.time()
|
686 |
+
diffusion_time_cost = end_time - start_time
|
687 |
+
start_time = end_time
|
688 |
+
|
689 |
+
output_paths = self.latents2audio(
|
690 |
+
latents=target_latents,
|
691 |
+
target_wav_duration_second=audio_duration,
|
692 |
+
save_path=save_path,
|
693 |
+
format=format,
|
694 |
+
)
|
695 |
+
|
696 |
+
end_time = time.time()
|
697 |
+
latent2audio_time_cost = end_time - start_time
|
698 |
+
timecosts = {
|
699 |
+
"preprocess": preprocess_time_cost,
|
700 |
+
"diffusion": diffusion_time_cost,
|
701 |
+
"latent2audio": latent2audio_time_cost,
|
702 |
+
}
|
703 |
+
|
704 |
+
input_params_json = {
|
705 |
+
"task": task,
|
706 |
+
"prompt": prompt,
|
707 |
+
"lyrics": lyrics,
|
708 |
+
"audio_duration": audio_duration,
|
709 |
+
"infer_step": infer_step,
|
710 |
+
"guidance_scale": guidance_scale,
|
711 |
+
"scheduler_type": scheduler_type,
|
712 |
+
"cfg_type": cfg_type,
|
713 |
+
"omega_scale": omega_scale,
|
714 |
+
"guidance_interval": guidance_interval,
|
715 |
+
"guidance_interval_decay": guidance_interval_decay,
|
716 |
+
"min_guidance_scale": min_guidance_scale,
|
717 |
+
"use_erg_tag": use_erg_tag,
|
718 |
+
"use_erg_lyric": use_erg_lyric,
|
719 |
+
"use_erg_diffusion": use_erg_diffusion,
|
720 |
+
"oss_steps": oss_steps,
|
721 |
+
"timecosts": timecosts,
|
722 |
+
"actual_seeds": actual_seeds,
|
723 |
+
"retake_seeds": actual_retake_seeds,
|
724 |
+
"retake_variance": retake_variance,
|
725 |
+
"guidance_scale_text": guidance_scale_text,
|
726 |
+
"guidance_scale_lyric": guidance_scale_lyric,
|
727 |
+
}
|
728 |
+
# save input_params_json
|
729 |
+
for output_audio_path in output_paths:
|
730 |
+
input_params_json_save_path = output_audio_path.replace(f".{format}", "_input_params.json")
|
731 |
+
input_params_json["audio_path"] = output_audio_path
|
732 |
+
with open(input_params_json_save_path, "w", encoding="utf-8") as f:
|
733 |
+
json.dump(input_params_json, f, indent=4, ensure_ascii=False)
|
734 |
+
|
735 |
+
return output_paths + [input_params_json]
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==3.4.1
|
2 |
+
diffusers==0.32.2
|
3 |
+
gradio==5.23.3
|
4 |
+
librosa==0.11.0
|
5 |
+
loguru==0.7.3
|
6 |
+
matplotlib==3.10.1
|
7 |
+
numpy
|
8 |
+
pypinyin==0.53.0
|
9 |
+
pytorch_lightning==2.5.1
|
10 |
+
soundfile==0.13.1
|
11 |
+
torch
|
12 |
+
torchaudio
|
13 |
+
torchvision
|
14 |
+
tqdm==4.67.1
|
15 |
+
transformers==4.50.0
|
16 |
+
py3langid==0.3.0
|
17 |
+
hangul-romanize==0.1.0
|
18 |
+
num2words==0.5.14
|
19 |
+
spacy==3.8.4
|
20 |
+
accelerate==1.6.0
|
21 |
+
cutlet
|
22 |
+
fugashi[unidic-lite]
|
schedulers/scheduling_flow_match_euler_discrete.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Output class for the scheduler's `step` function output.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
"""
|
40 |
+
|
41 |
+
prev_sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
Euler scheduler.
|
47 |
+
|
48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
49 |
+
methods the library implements for all schedulers such as loading and saving.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
53 |
+
The number of diffusion steps to train the model.
|
54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
57 |
+
shift (`float`, defaults to 1.0):
|
58 |
+
The shift value for the timestep schedule.
|
59 |
+
"""
|
60 |
+
|
61 |
+
_compatibles = []
|
62 |
+
order = 1
|
63 |
+
|
64 |
+
@register_to_config
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
num_train_timesteps: int = 1000,
|
68 |
+
shift: float = 1.0,
|
69 |
+
use_dynamic_shifting=False,
|
70 |
+
base_shift: Optional[float] = 0.5,
|
71 |
+
max_shift: Optional[float] = 1.15,
|
72 |
+
base_image_seq_len: Optional[int] = 256,
|
73 |
+
max_image_seq_len: Optional[int] = 4096,
|
74 |
+
):
|
75 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
76 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
77 |
+
|
78 |
+
sigmas = timesteps / num_train_timesteps
|
79 |
+
if not use_dynamic_shifting:
|
80 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
81 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
82 |
+
|
83 |
+
self.timesteps = sigmas * num_train_timesteps
|
84 |
+
|
85 |
+
self._step_index = None
|
86 |
+
self._begin_index = None
|
87 |
+
|
88 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
89 |
+
self.sigma_min = self.sigmas[-1].item()
|
90 |
+
self.sigma_max = self.sigmas[0].item()
|
91 |
+
|
92 |
+
@property
|
93 |
+
def step_index(self):
|
94 |
+
"""
|
95 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
96 |
+
"""
|
97 |
+
return self._step_index
|
98 |
+
|
99 |
+
@property
|
100 |
+
def begin_index(self):
|
101 |
+
"""
|
102 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
103 |
+
"""
|
104 |
+
return self._begin_index
|
105 |
+
|
106 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
107 |
+
def set_begin_index(self, begin_index: int = 0):
|
108 |
+
"""
|
109 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
begin_index (`int`):
|
113 |
+
The begin index for the scheduler.
|
114 |
+
"""
|
115 |
+
self._begin_index = begin_index
|
116 |
+
|
117 |
+
def scale_noise(
|
118 |
+
self,
|
119 |
+
sample: torch.FloatTensor,
|
120 |
+
timestep: Union[float, torch.FloatTensor],
|
121 |
+
noise: Optional[torch.FloatTensor] = None,
|
122 |
+
) -> torch.FloatTensor:
|
123 |
+
"""
|
124 |
+
Forward process in flow-matching
|
125 |
+
|
126 |
+
Args:
|
127 |
+
sample (`torch.FloatTensor`):
|
128 |
+
The input sample.
|
129 |
+
timestep (`int`, *optional*):
|
130 |
+
The current timestep in the diffusion chain.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
`torch.FloatTensor`:
|
134 |
+
A scaled input sample.
|
135 |
+
"""
|
136 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
137 |
+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
138 |
+
|
139 |
+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
140 |
+
# mps does not support float64
|
141 |
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
142 |
+
timestep = timestep.to(sample.device, dtype=torch.float32)
|
143 |
+
else:
|
144 |
+
schedule_timesteps = self.timesteps.to(sample.device)
|
145 |
+
timestep = timestep.to(sample.device)
|
146 |
+
|
147 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
148 |
+
if self.begin_index is None:
|
149 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
150 |
+
elif self.step_index is not None:
|
151 |
+
# add_noise is called after first denoising step (for inpainting)
|
152 |
+
step_indices = [self.step_index] * timestep.shape[0]
|
153 |
+
else:
|
154 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
155 |
+
step_indices = [self.begin_index] * timestep.shape[0]
|
156 |
+
|
157 |
+
sigma = sigmas[step_indices].flatten()
|
158 |
+
while len(sigma.shape) < len(sample.shape):
|
159 |
+
sigma = sigma.unsqueeze(-1)
|
160 |
+
|
161 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
162 |
+
|
163 |
+
return sample
|
164 |
+
|
165 |
+
def _sigma_to_t(self, sigma):
|
166 |
+
return sigma * self.config.num_train_timesteps
|
167 |
+
|
168 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
169 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
170 |
+
|
171 |
+
def set_timesteps(
|
172 |
+
self,
|
173 |
+
num_inference_steps: int = None,
|
174 |
+
device: Union[str, torch.device] = None,
|
175 |
+
sigmas: Optional[List[float]] = None,
|
176 |
+
mu: Optional[float] = None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
180 |
+
|
181 |
+
Args:
|
182 |
+
num_inference_steps (`int`):
|
183 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
184 |
+
device (`str` or `torch.device`, *optional*):
|
185 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
186 |
+
"""
|
187 |
+
|
188 |
+
if self.config.use_dynamic_shifting and mu is None:
|
189 |
+
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
190 |
+
|
191 |
+
if sigmas is None:
|
192 |
+
self.num_inference_steps = num_inference_steps
|
193 |
+
timesteps = np.linspace(
|
194 |
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
195 |
+
)
|
196 |
+
|
197 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
198 |
+
|
199 |
+
if self.config.use_dynamic_shifting:
|
200 |
+
sigmas = self.time_shift(mu, 1.0, sigmas)
|
201 |
+
else:
|
202 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
203 |
+
|
204 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
205 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
206 |
+
|
207 |
+
self.timesteps = timesteps.to(device=device)
|
208 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
209 |
+
|
210 |
+
self._step_index = None
|
211 |
+
self._begin_index = None
|
212 |
+
|
213 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
214 |
+
if schedule_timesteps is None:
|
215 |
+
schedule_timesteps = self.timesteps
|
216 |
+
|
217 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
218 |
+
|
219 |
+
# The sigma index that is taken for the **very** first `step`
|
220 |
+
# is always the second index (or the last index if there is only 1)
|
221 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
222 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
223 |
+
pos = 1 if len(indices) > 1 else 0
|
224 |
+
|
225 |
+
return indices[pos].item()
|
226 |
+
|
227 |
+
def _init_step_index(self, timestep):
|
228 |
+
if self.begin_index is None:
|
229 |
+
if isinstance(timestep, torch.Tensor):
|
230 |
+
timestep = timestep.to(self.timesteps.device)
|
231 |
+
self._step_index = self.index_for_timestep(timestep)
|
232 |
+
else:
|
233 |
+
self._step_index = self._begin_index
|
234 |
+
|
235 |
+
def step(
|
236 |
+
self,
|
237 |
+
model_output: torch.FloatTensor,
|
238 |
+
timestep: Union[float, torch.FloatTensor],
|
239 |
+
sample: torch.FloatTensor,
|
240 |
+
s_churn: float = 0.0,
|
241 |
+
s_tmin: float = 0.0,
|
242 |
+
s_tmax: float = float("inf"),
|
243 |
+
s_noise: float = 1.0,
|
244 |
+
generator: Optional[torch.Generator] = None,
|
245 |
+
return_dict: bool = True,
|
246 |
+
omega: Union[float, np.array] = 0.0
|
247 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
248 |
+
"""
|
249 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
250 |
+
process from the learned model outputs (most often the predicted noise).
|
251 |
+
|
252 |
+
Args:
|
253 |
+
model_output (`torch.FloatTensor`):
|
254 |
+
The direct output from learned diffusion model.
|
255 |
+
timestep (`float`):
|
256 |
+
The current discrete timestep in the diffusion chain.
|
257 |
+
sample (`torch.FloatTensor`):
|
258 |
+
A current instance of a sample created by the diffusion process.
|
259 |
+
s_churn (`float`):
|
260 |
+
s_tmin (`float`):
|
261 |
+
s_tmax (`float`):
|
262 |
+
s_noise (`float`, defaults to 1.0):
|
263 |
+
Scaling factor for noise added to the sample.
|
264 |
+
generator (`torch.Generator`, *optional*):
|
265 |
+
A random number generator.
|
266 |
+
return_dict (`bool`):
|
267 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
268 |
+
tuple.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
272 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
273 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
274 |
+
"""
|
275 |
+
|
276 |
+
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
|
277 |
+
# L = Lower bound
|
278 |
+
# U = Upper bound
|
279 |
+
# x_0 = Midpoint (x corresponding to y = 1.0)
|
280 |
+
# k = Steepness, can adjust based on preference
|
281 |
+
|
282 |
+
if isinstance(x, torch.Tensor):
|
283 |
+
device_ = x.device
|
284 |
+
x = x.to(torch.float).cpu().numpy()
|
285 |
+
|
286 |
+
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
|
287 |
+
|
288 |
+
if isinstance(new_x, np.ndarray):
|
289 |
+
new_x = torch.from_numpy(new_x).to(device_)
|
290 |
+
return new_x
|
291 |
+
|
292 |
+
self.omega_bef_rescale = omega
|
293 |
+
omega = logistic_function(omega, k=0.1)
|
294 |
+
self.omega_aft_rescale = omega
|
295 |
+
|
296 |
+
if (
|
297 |
+
isinstance(timestep, int)
|
298 |
+
or isinstance(timestep, torch.IntTensor)
|
299 |
+
or isinstance(timestep, torch.LongTensor)
|
300 |
+
):
|
301 |
+
raise ValueError(
|
302 |
+
(
|
303 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
304 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
305 |
+
" one of the `scheduler.timesteps` as a timestep."
|
306 |
+
),
|
307 |
+
)
|
308 |
+
|
309 |
+
if self.step_index is None:
|
310 |
+
self._init_step_index(timestep)
|
311 |
+
|
312 |
+
# Upcast to avoid precision issues when computing prev_sample
|
313 |
+
sample = sample.to(torch.float32)
|
314 |
+
|
315 |
+
sigma = self.sigmas[self.step_index]
|
316 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
317 |
+
|
318 |
+
## --
|
319 |
+
## mean shift 1
|
320 |
+
dx = (sigma_next - sigma) * model_output
|
321 |
+
m = dx.mean()
|
322 |
+
# print(dx.shape) # torch.Size([1, 16, 128, 128])
|
323 |
+
# print(f'm: {m}') # m: -0.0014209747314453125
|
324 |
+
# raise NotImplementedError
|
325 |
+
dx_ = (dx - m) * omega + m
|
326 |
+
prev_sample = sample + dx_
|
327 |
+
|
328 |
+
# ## --
|
329 |
+
# ## mean shift 2
|
330 |
+
# m = model_output.mean()
|
331 |
+
# model_output_ = (model_output - m) * omega + m
|
332 |
+
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
333 |
+
|
334 |
+
# ## --
|
335 |
+
# ## original
|
336 |
+
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
|
337 |
+
|
338 |
+
# ## --
|
339 |
+
# ## spatial mean 1
|
340 |
+
# dx = (sigma_next - sigma) * model_output
|
341 |
+
# m = dx.mean(dim=(0, 1), keepdim=True)
|
342 |
+
# # print(dx.shape) # torch.Size([1, 16, 128, 128])
|
343 |
+
# # print(m.shape) # torch.Size([1, 1, 128, 128])
|
344 |
+
# # raise NotImplementedError
|
345 |
+
# dx_ = (dx - m) * omega + m
|
346 |
+
# prev_sample = sample + dx_
|
347 |
+
|
348 |
+
# ## --
|
349 |
+
# ## spatial mean 2
|
350 |
+
# m = model_output.mean(dim=(0, 1), keepdim=True)
|
351 |
+
# model_output_ = (model_output - m) * omega + m
|
352 |
+
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
353 |
+
|
354 |
+
# ## --
|
355 |
+
# ## channel mean 1
|
356 |
+
# m = model_output.mean(dim=(2, 3), keepdim=True)
|
357 |
+
# # print(m.shape) # torch.Size([1, 16, 1, 1])
|
358 |
+
# model_output_ = (model_output - m) * omega + m
|
359 |
+
# prev_sample = sample + (sigma_next - sigma) * model_output_
|
360 |
+
|
361 |
+
# ## --
|
362 |
+
# ## channel mean 2
|
363 |
+
# dx = (sigma_next - sigma) * model_output
|
364 |
+
# m = dx.mean(dim=(2, 3), keepdim=True)
|
365 |
+
# # print(m.shape) # torch.Size([1, 16, 1, 1])
|
366 |
+
# dx_ = (dx - m) * omega + m
|
367 |
+
# prev_sample = sample + dx_
|
368 |
+
|
369 |
+
# ## --
|
370 |
+
# ## keep sample mean
|
371 |
+
# m_tgt = sample.mean()
|
372 |
+
# prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
|
373 |
+
# m_src = prev_sample_.mean()
|
374 |
+
# prev_sample = prev_sample_ - m_src + m_tgt
|
375 |
+
|
376 |
+
# ## --
|
377 |
+
# ## test
|
378 |
+
# # print(sample.mean())
|
379 |
+
# prev_sample = sample + (sigma_next - sigma) * model_output * omega
|
380 |
+
# # raise NotImplementedError
|
381 |
+
|
382 |
+
# Cast sample back to model compatible dtype
|
383 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
384 |
+
|
385 |
+
# upon completion increase step index by one
|
386 |
+
self._step_index += 1
|
387 |
+
|
388 |
+
if not return_dict:
|
389 |
+
return (prev_sample,)
|
390 |
+
|
391 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
392 |
+
|
393 |
+
def __len__(self):
|
394 |
+
return self.config.num_train_timesteps
|
schedulers/scheduling_flow_match_heun_discrete.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.utils import BaseOutput, logging
|
23 |
+
from diffusers.utils.torch_utils import randn_tensor
|
24 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
Output class for the scheduler's `step` function output.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38 |
+
denoising loop.
|
39 |
+
"""
|
40 |
+
|
41 |
+
prev_sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
Heun scheduler.
|
47 |
+
|
48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
49 |
+
methods the library implements for all schedulers such as loading and saving.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
53 |
+
The number of diffusion steps to train the model.
|
54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
57 |
+
shift (`float`, defaults to 1.0):
|
58 |
+
The shift value for the timestep schedule.
|
59 |
+
"""
|
60 |
+
|
61 |
+
_compatibles = []
|
62 |
+
order = 2
|
63 |
+
|
64 |
+
@register_to_config
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
num_train_timesteps: int = 1000,
|
68 |
+
shift: float = 1.0,
|
69 |
+
):
|
70 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
71 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
72 |
+
|
73 |
+
sigmas = timesteps / num_train_timesteps
|
74 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
75 |
+
|
76 |
+
self.timesteps = sigmas * num_train_timesteps
|
77 |
+
|
78 |
+
self._step_index = None
|
79 |
+
self._begin_index = None
|
80 |
+
|
81 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
82 |
+
self.sigma_min = self.sigmas[-1].item()
|
83 |
+
self.sigma_max = self.sigmas[0].item()
|
84 |
+
|
85 |
+
@property
|
86 |
+
def step_index(self):
|
87 |
+
"""
|
88 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
89 |
+
"""
|
90 |
+
return self._step_index
|
91 |
+
|
92 |
+
@property
|
93 |
+
def begin_index(self):
|
94 |
+
"""
|
95 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
96 |
+
"""
|
97 |
+
return self._begin_index
|
98 |
+
|
99 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
100 |
+
def set_begin_index(self, begin_index: int = 0):
|
101 |
+
"""
|
102 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
begin_index (`int`):
|
106 |
+
The begin index for the scheduler.
|
107 |
+
"""
|
108 |
+
self._begin_index = begin_index
|
109 |
+
|
110 |
+
def scale_noise(
|
111 |
+
self,
|
112 |
+
sample: torch.FloatTensor,
|
113 |
+
timestep: Union[float, torch.FloatTensor],
|
114 |
+
noise: Optional[torch.FloatTensor] = None,
|
115 |
+
) -> torch.FloatTensor:
|
116 |
+
"""
|
117 |
+
Forward process in flow-matching
|
118 |
+
|
119 |
+
Args:
|
120 |
+
sample (`torch.FloatTensor`):
|
121 |
+
The input sample.
|
122 |
+
timestep (`int`, *optional*):
|
123 |
+
The current timestep in the diffusion chain.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
`torch.FloatTensor`:
|
127 |
+
A scaled input sample.
|
128 |
+
"""
|
129 |
+
if self.step_index is None:
|
130 |
+
self._init_step_index(timestep)
|
131 |
+
|
132 |
+
sigma = self.sigmas[self.step_index]
|
133 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
134 |
+
|
135 |
+
return sample
|
136 |
+
|
137 |
+
def _sigma_to_t(self, sigma):
|
138 |
+
return sigma * self.config.num_train_timesteps
|
139 |
+
|
140 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
141 |
+
"""
|
142 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
143 |
+
|
144 |
+
Args:
|
145 |
+
num_inference_steps (`int`):
|
146 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
147 |
+
device (`str` or `torch.device`, *optional*):
|
148 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
149 |
+
"""
|
150 |
+
self.num_inference_steps = num_inference_steps
|
151 |
+
|
152 |
+
timesteps = np.linspace(
|
153 |
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
154 |
+
)
|
155 |
+
|
156 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
157 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
158 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
159 |
+
|
160 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
161 |
+
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
162 |
+
self.timesteps = timesteps.to(device=device)
|
163 |
+
|
164 |
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
165 |
+
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
166 |
+
|
167 |
+
# empty dt and derivative
|
168 |
+
self.prev_derivative = None
|
169 |
+
self.dt = None
|
170 |
+
|
171 |
+
self._step_index = None
|
172 |
+
self._begin_index = None
|
173 |
+
|
174 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
175 |
+
if schedule_timesteps is None:
|
176 |
+
schedule_timesteps = self.timesteps
|
177 |
+
|
178 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
179 |
+
|
180 |
+
# The sigma index that is taken for the **very** first `step`
|
181 |
+
# is always the second index (or the last index if there is only 1)
|
182 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
183 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
184 |
+
pos = 1 if len(indices) > 1 else 0
|
185 |
+
|
186 |
+
return indices[pos].item()
|
187 |
+
|
188 |
+
def _init_step_index(self, timestep):
|
189 |
+
if self.begin_index is None:
|
190 |
+
if isinstance(timestep, torch.Tensor):
|
191 |
+
timestep = timestep.to(self.timesteps.device)
|
192 |
+
self._step_index = self.index_for_timestep(timestep)
|
193 |
+
else:
|
194 |
+
self._step_index = self._begin_index
|
195 |
+
|
196 |
+
@property
|
197 |
+
def state_in_first_order(self):
|
198 |
+
return self.dt is None
|
199 |
+
|
200 |
+
def step(
|
201 |
+
self,
|
202 |
+
model_output: torch.FloatTensor,
|
203 |
+
timestep: Union[float, torch.FloatTensor],
|
204 |
+
sample: torch.FloatTensor,
|
205 |
+
s_churn: float = 0.0,
|
206 |
+
s_tmin: float = 0.0,
|
207 |
+
s_tmax: float = float("inf"),
|
208 |
+
s_noise: float = 1.0,
|
209 |
+
generator: Optional[torch.Generator] = None,
|
210 |
+
return_dict: bool = True,
|
211 |
+
omega: Union[float, np.array] = 0.0
|
212 |
+
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
|
213 |
+
"""
|
214 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
215 |
+
process from the learned model outputs (most often the predicted noise).
|
216 |
+
|
217 |
+
Args:
|
218 |
+
model_output (`torch.FloatTensor`):
|
219 |
+
The direct output from learned diffusion model.
|
220 |
+
timestep (`float`):
|
221 |
+
The current discrete timestep in the diffusion chain.
|
222 |
+
sample (`torch.FloatTensor`):
|
223 |
+
A current instance of a sample created by the diffusion process.
|
224 |
+
s_churn (`float`):
|
225 |
+
s_tmin (`float`):
|
226 |
+
s_tmax (`float`):
|
227 |
+
s_noise (`float`, defaults to 1.0):
|
228 |
+
Scaling factor for noise added to the sample.
|
229 |
+
generator (`torch.Generator`, *optional*):
|
230 |
+
A random number generator.
|
231 |
+
return_dict (`bool`):
|
232 |
+
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
|
233 |
+
tuple.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
|
237 |
+
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
|
238 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
|
242 |
+
# L = Lower bound
|
243 |
+
# U = Upper bound
|
244 |
+
# x_0 = Midpoint (x corresponding to y = 1.0)
|
245 |
+
# k = Steepness, can adjust based on preference
|
246 |
+
|
247 |
+
if isinstance(x, torch.Tensor):
|
248 |
+
device_ = x.device
|
249 |
+
x = x.to(torch.float).cpu().numpy()
|
250 |
+
|
251 |
+
new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
|
252 |
+
|
253 |
+
if isinstance(new_x, np.ndarray):
|
254 |
+
new_x = torch.from_numpy(new_x).to(device_)
|
255 |
+
return new_x
|
256 |
+
|
257 |
+
self.omega_bef_rescale = omega
|
258 |
+
omega = logistic_function(omega, k=0.1)
|
259 |
+
self.omega_aft_rescale = omega
|
260 |
+
|
261 |
+
if (
|
262 |
+
isinstance(timestep, int)
|
263 |
+
or isinstance(timestep, torch.IntTensor)
|
264 |
+
or isinstance(timestep, torch.LongTensor)
|
265 |
+
):
|
266 |
+
raise ValueError(
|
267 |
+
(
|
268 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
269 |
+
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
270 |
+
" one of the `scheduler.timesteps` as a timestep."
|
271 |
+
),
|
272 |
+
)
|
273 |
+
|
274 |
+
if self.step_index is None:
|
275 |
+
self._init_step_index(timestep)
|
276 |
+
|
277 |
+
# Upcast to avoid precision issues when computing prev_sample
|
278 |
+
sample = sample.to(torch.float32)
|
279 |
+
|
280 |
+
if self.state_in_first_order:
|
281 |
+
sigma = self.sigmas[self.step_index]
|
282 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
283 |
+
else:
|
284 |
+
# 2nd order / Heun's method
|
285 |
+
sigma = self.sigmas[self.step_index - 1]
|
286 |
+
sigma_next = self.sigmas[self.step_index]
|
287 |
+
|
288 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
289 |
+
|
290 |
+
sigma_hat = sigma * (gamma + 1)
|
291 |
+
|
292 |
+
if gamma > 0:
|
293 |
+
noise = randn_tensor(
|
294 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
295 |
+
)
|
296 |
+
eps = noise * s_noise
|
297 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
298 |
+
|
299 |
+
if self.state_in_first_order:
|
300 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
301 |
+
denoised = sample - model_output * sigma
|
302 |
+
# 2. convert to an ODE derivative for 1st order
|
303 |
+
derivative = (sample - denoised) / sigma_hat
|
304 |
+
# 3. Delta timestep
|
305 |
+
dt = sigma_next - sigma_hat
|
306 |
+
|
307 |
+
# store for 2nd order step
|
308 |
+
self.prev_derivative = derivative
|
309 |
+
self.dt = dt
|
310 |
+
self.sample = sample
|
311 |
+
else:
|
312 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
313 |
+
denoised = sample - model_output * sigma_next
|
314 |
+
# 2. 2nd order / Heun's method
|
315 |
+
derivative = (sample - denoised) / sigma_next
|
316 |
+
derivative = 0.5 * (self.prev_derivative + derivative)
|
317 |
+
|
318 |
+
# 3. take prev timestep & sample
|
319 |
+
dt = self.dt
|
320 |
+
sample = self.sample
|
321 |
+
|
322 |
+
# free dt and derivative
|
323 |
+
# Note, this puts the scheduler in "first order mode"
|
324 |
+
self.prev_derivative = None
|
325 |
+
self.dt = None
|
326 |
+
self.sample = None
|
327 |
+
|
328 |
+
# original sample way
|
329 |
+
# prev_sample = sample + derivative * dt
|
330 |
+
|
331 |
+
dx = derivative * dt
|
332 |
+
m = dx.mean()
|
333 |
+
dx_ = (dx - m) * omega + m
|
334 |
+
prev_sample = sample + dx_
|
335 |
+
|
336 |
+
# Cast sample back to model compatible dtype
|
337 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
338 |
+
|
339 |
+
# upon completion increase step index by one
|
340 |
+
self._step_index += 1
|
341 |
+
|
342 |
+
if not return_dict:
|
343 |
+
return (prev_sample,)
|
344 |
+
|
345 |
+
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
346 |
+
|
347 |
+
def __len__(self):
|
348 |
+
return self.config.num_train_timesteps
|
ui/components.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
TAG_PLACEHOLDER = "funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic"
|
5 |
+
LYRIC_PLACEHOLDER = """[verse]
|
6 |
+
Neon lights they flicker bright
|
7 |
+
City hums in dead of night
|
8 |
+
Rhythms pulse through concrete veins
|
9 |
+
Lost in echoes of refrains
|
10 |
+
|
11 |
+
[verse]
|
12 |
+
Bassline groovin' in my chest
|
13 |
+
Heartbeats match the city's zest
|
14 |
+
Electric whispers fill the air
|
15 |
+
Synthesized dreams everywhere
|
16 |
+
|
17 |
+
[chorus]
|
18 |
+
Turn it up and let it flow
|
19 |
+
Feel the fire let it grow
|
20 |
+
In this rhythm we belong
|
21 |
+
Hear the night sing out our song
|
22 |
+
|
23 |
+
[verse]
|
24 |
+
Guitar strings they start to weep
|
25 |
+
Wake the soul from silent sleep
|
26 |
+
Every note a story told
|
27 |
+
In this night we’re bold and gold
|
28 |
+
|
29 |
+
[bridge]
|
30 |
+
Voices blend in harmony
|
31 |
+
Lost in pure cacophony
|
32 |
+
Timeless echoes timeless cries
|
33 |
+
Soulful shouts beneath the skies
|
34 |
+
|
35 |
+
[verse]
|
36 |
+
Keyboard dances on the keys
|
37 |
+
Melodies on evening breeze
|
38 |
+
Catch the tune and hold it tight
|
39 |
+
In this moment we take flight
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
def create_output_ui(task_name="Text2Music"):
|
44 |
+
# For many consumer-grade GPU devices, only one batch can be run
|
45 |
+
output_audio1 = gr.Audio(type="filepath", label=f"{task_name} Generated Audio 1")
|
46 |
+
# output_audio2 = gr.Audio(type="filepath", label="Generated Audio 2")
|
47 |
+
with gr.Accordion(f"{task_name} Parameters", open=False):
|
48 |
+
input_params_json = gr.JSON(label=f"{task_name} Parameters")
|
49 |
+
# outputs = [output_audio1, output_audio2]
|
50 |
+
outputs = [output_audio1]
|
51 |
+
return outputs, input_params_json
|
52 |
+
|
53 |
+
|
54 |
+
def dump_func(*args):
|
55 |
+
print(args)
|
56 |
+
return []
|
57 |
+
|
58 |
+
|
59 |
+
def create_text2music_ui(
|
60 |
+
gr,
|
61 |
+
text2music_process_func,
|
62 |
+
sample_data_func=None,
|
63 |
+
):
|
64 |
+
with gr.Row():
|
65 |
+
with gr.Column():
|
66 |
+
|
67 |
+
with gr.Row(equal_height=True):
|
68 |
+
audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=180, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
|
69 |
+
sample_bnt = gr.Button("Sample", variant="primary", scale=1)
|
70 |
+
|
71 |
+
prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.")
|
72 |
+
lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
|
73 |
+
|
74 |
+
with gr.Accordion("Basic Settings", open=True):
|
75 |
+
infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
|
76 |
+
guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
|
77 |
+
guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=5.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
|
78 |
+
guidance_scale_lyric = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=1.5, label="Guidance Scale Lyric", interactive=True)
|
79 |
+
|
80 |
+
manual_seeds = gr.Textbox(label="manual seeds (default None)", placeholder="1,2,3,4", value=None, info="Seed for the generation")
|
81 |
+
|
82 |
+
with gr.Accordion("Advanced Settings", open=False):
|
83 |
+
scheduler_type = gr.Radio(["euler", "heun"], value="euler", label="Scheduler Type", elem_id="scheduler_type", info="Scheduler type for the generation. euler is recommended. heun will take more time.")
|
84 |
+
cfg_type = gr.Radio(["cfg", "apg", "cfg_star"], value="apg", label="CFG Type", elem_id="cfg_type", info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.")
|
85 |
+
use_erg_tag = gr.Checkbox(label="use ERG for tag", value=True, info="Use Entropy Rectifying Guidance for tag. It will multiple a temperature to the attention to make a weaker tag condition and make better diversity.")
|
86 |
+
use_erg_lyric = gr.Checkbox(label="use ERG for lyric", value=True, info="The same but apply to lyric encoder's attention.")
|
87 |
+
use_erg_diffusion = gr.Checkbox(label="use ERG for diffusion", value=True, info="The same but apply to diffusion model's attention.")
|
88 |
+
|
89 |
+
omega_scale = gr.Slider(minimum=-100.0, maximum=100.0, step=0.1, value=10.0, label="Granularity Scale", interactive=True, info="Granularity scale for the generation. Higher values can reduce artifacts")
|
90 |
+
|
91 |
+
guidance_interval = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Guidance Interval", interactive=True, info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)")
|
92 |
+
guidance_interval_decay = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Guidance Interval Decay", interactive=True, info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.")
|
93 |
+
min_guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=3.0, label="Min Guidance Scale", interactive=True, info="Min guidance scale for guidance interval decay's end scale")
|
94 |
+
oss_steps = gr.Textbox(label="OSS Steps", placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200", value=None, info="Optimal Steps for the generation. But not test well")
|
95 |
+
|
96 |
+
text2music_bnt = gr.Button(variant="primary")
|
97 |
+
|
98 |
+
with gr.Column():
|
99 |
+
outputs, input_params_json = create_output_ui()
|
100 |
+
with gr.Tab("retake"):
|
101 |
+
retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance", info="Variance for the retake. 0.0 means no variance. 1.0 means full variance.")
|
102 |
+
retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None, info="Seed for the retake.")
|
103 |
+
retake_bnt = gr.Button(variant="primary")
|
104 |
+
retake_outputs, retake_input_params_json = create_output_ui("Retake")
|
105 |
+
|
106 |
+
def retake_process_func(json_data, retake_variance, retake_seeds):
|
107 |
+
return text2music_process_func(
|
108 |
+
json_data["audio_duration"],
|
109 |
+
json_data["prompt"],
|
110 |
+
json_data["lyrics"],
|
111 |
+
json_data["infer_step"],
|
112 |
+
json_data["guidance_scale"],
|
113 |
+
json_data["scheduler_type"],
|
114 |
+
json_data["cfg_type"],
|
115 |
+
json_data["omega_scale"],
|
116 |
+
", ".join(map(str, json_data["actual_seeds"])),
|
117 |
+
json_data["guidance_interval"],
|
118 |
+
json_data["guidance_interval_decay"],
|
119 |
+
json_data["min_guidance_scale"],
|
120 |
+
json_data["use_erg_tag"],
|
121 |
+
json_data["use_erg_lyric"],
|
122 |
+
json_data["use_erg_diffusion"],
|
123 |
+
", ".join(map(str, json_data["oss_steps"])),
|
124 |
+
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
125 |
+
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
126 |
+
retake_seeds=retake_seeds,
|
127 |
+
retake_variance=retake_variance,
|
128 |
+
task="retake",
|
129 |
+
)
|
130 |
+
|
131 |
+
retake_bnt.click(
|
132 |
+
fn=retake_process_func,
|
133 |
+
inputs=[
|
134 |
+
input_params_json,
|
135 |
+
retake_variance,
|
136 |
+
retake_seeds,
|
137 |
+
],
|
138 |
+
outputs=retake_outputs + [retake_input_params_json],
|
139 |
+
)
|
140 |
+
with gr.Tab("repainting"):
|
141 |
+
pass
|
142 |
+
with gr.Tab("edit"):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def sample_data():
|
146 |
+
json_data = sample_data_func()
|
147 |
+
return (
|
148 |
+
json_data["audio_duration"],
|
149 |
+
json_data["prompt"],
|
150 |
+
json_data["lyrics"],
|
151 |
+
json_data["infer_step"],
|
152 |
+
json_data["guidance_scale"],
|
153 |
+
json_data["scheduler_type"],
|
154 |
+
json_data["cfg_type"],
|
155 |
+
json_data["omega_scale"],
|
156 |
+
", ".join(map(str, json_data["actual_seeds"])),
|
157 |
+
json_data["guidance_interval"],
|
158 |
+
json_data["guidance_interval_decay"],
|
159 |
+
json_data["min_guidance_scale"],
|
160 |
+
json_data["use_erg_tag"],
|
161 |
+
json_data["use_erg_lyric"],
|
162 |
+
json_data["use_erg_diffusion"],
|
163 |
+
", ".join(map(str, json_data["oss_steps"])),
|
164 |
+
json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0,
|
165 |
+
json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0,
|
166 |
+
)
|
167 |
+
|
168 |
+
sample_bnt.click(
|
169 |
+
sample_data,
|
170 |
+
outputs=[
|
171 |
+
audio_duration,
|
172 |
+
prompt,
|
173 |
+
lyrics,
|
174 |
+
infer_step,
|
175 |
+
guidance_scale,
|
176 |
+
scheduler_type,
|
177 |
+
cfg_type,
|
178 |
+
omega_scale,
|
179 |
+
manual_seeds,
|
180 |
+
guidance_interval,
|
181 |
+
guidance_interval_decay,
|
182 |
+
min_guidance_scale,
|
183 |
+
use_erg_tag,
|
184 |
+
use_erg_lyric,
|
185 |
+
use_erg_diffusion,
|
186 |
+
oss_steps,
|
187 |
+
guidance_scale_text,
|
188 |
+
guidance_scale_lyric,
|
189 |
+
],
|
190 |
+
)
|
191 |
+
|
192 |
+
text2music_bnt.click(
|
193 |
+
fn=text2music_process_func,
|
194 |
+
inputs=[
|
195 |
+
audio_duration,
|
196 |
+
prompt,
|
197 |
+
lyrics,
|
198 |
+
infer_step,
|
199 |
+
guidance_scale,
|
200 |
+
scheduler_type,
|
201 |
+
cfg_type,
|
202 |
+
omega_scale,
|
203 |
+
manual_seeds,
|
204 |
+
guidance_interval,
|
205 |
+
guidance_interval_decay,
|
206 |
+
min_guidance_scale,
|
207 |
+
use_erg_tag,
|
208 |
+
use_erg_lyric,
|
209 |
+
use_erg_diffusion,
|
210 |
+
oss_steps,
|
211 |
+
guidance_scale_text,
|
212 |
+
guidance_scale_lyric,
|
213 |
+
], outputs=outputs + [input_params_json]
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
def create_main_demo_ui(
|
218 |
+
text2music_process_func=dump_func,
|
219 |
+
sample_data_func=dump_func,
|
220 |
+
):
|
221 |
+
with gr.Blocks(
|
222 |
+
title="FusicModel 1.0 DEMO",
|
223 |
+
) as demo:
|
224 |
+
gr.Markdown(
|
225 |
+
"""
|
226 |
+
<h1 style="text-align: center;">FusicModel 1.0 DEMO</h1>
|
227 |
+
"""
|
228 |
+
)
|
229 |
+
|
230 |
+
with gr.Tab("text2music"):
|
231 |
+
create_text2music_ui(
|
232 |
+
gr=gr,
|
233 |
+
text2music_process_func=text2music_process_func,
|
234 |
+
sample_data_func=sample_data_func,
|
235 |
+
)
|
236 |
+
return demo
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
demo = create_main_demo_ui()
|
241 |
+
demo.launch(
|
242 |
+
server_name="0.0.0.0",
|
243 |
+
server_port=7860,
|
244 |
+
)
|