Erlanggaa commited on
Commit
93d3f62
·
verified ·
1 Parent(s): b0257db

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ myenv/bin/python filter=lfs diff=lfs merge=lfs -text
37
+ myenv/bin/python3 filter=lfs diff=lfs merge=lfs -text
38
+ myenv/bin/python3.10 filter=lfs diff=lfs merge=lfs -text
.github/workflows/sync-hf.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to HF Space
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ trigger_curl:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Send cURL POST request
14
+ run: |
15
+ curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
16
+ -s \
17
+ -H "Content-Type: application/json" \
18
+ -d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customed
2
+ .vscode/
3
+ tests/
4
+ runs/
5
+ data/
6
+ ckpts/
7
+ wandb/
8
+ results/
9
+
10
+
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
+
3
+ USER root
4
+
5
+ ARG DEBIAN_FRONTEND=noninteractive
6
+
7
+ LABEL github_repo="https://github.com/SWivid/F5-TTS"
8
+
9
+ RUN set -x \
10
+ && apt-get update \
11
+ && apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
12
+ && apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
13
+ && rm -rf /var/lib/apt/lists/* \
14
+ && apt-get clean
15
+
16
+ WORKDIR /workspace
17
+
18
+ RUN git clone https://github.com/SWivid/F5-TTS.git \
19
+ && cd F5-TTS \
20
+ && pip install --no-cache-dir -r requirements.txt \
21
+ && pip install --no-cache-dir -r requirements_eval.txt
22
+
23
+ ENV SHELL=/bin/bash
24
+
25
+ WORKDIR /workspace/F5-TTS
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yushen CHEN
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
3
+
4
+ [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
6
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
7
+ [![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
8
+ [![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
9
+ [![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
10
+ <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto">
11
+
12
+ **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
13
+
14
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction from [paper](https://arxiv.org/abs/2406.18009).
15
+
16
+ **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
17
+
18
+ ### Thanks to all the contributors !
19
+
20
+ ## Installation
21
+
22
+ Clone the repository:
23
+
24
+ ```bash
25
+ git clone https://github.com/SWivid/F5-TTS.git
26
+ cd F5-TTS
27
+ ```
28
+
29
+ Install torch with your CUDA version, e.g. :
30
+
31
+ ```bash
32
+ pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
33
+ pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
34
+ ```
35
+
36
+ Install other packages:
37
+
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ **[Optional]**: We provide [Dockerfile](https://github.com/SWivid/F5-TTS/blob/main/Dockerfile) and you can use the following command to build it.
43
+ ```bash
44
+ docker build -t f5tts:v1 .
45
+ ```
46
+
47
+ ## Prepare Dataset
48
+
49
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
50
+
51
+ ```bash
52
+ # prepare custom dataset up to your need
53
+ # download corresponding dataset first, and fill in the path in scripts
54
+
55
+ # Prepare the Emilia dataset
56
+ python scripts/prepare_emilia.py
57
+
58
+ # Prepare the Wenetspeech4TTS dataset
59
+ python scripts/prepare_wenetspeech4tts.py
60
+ ```
61
+
62
+ ## Training & Finetuning
63
+
64
+ Once your datasets are prepared, you can start the training process.
65
+
66
+ ```bash
67
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
68
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
69
+ accelerate config
70
+ accelerate launch train.py
71
+ ```
72
+ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
73
+
74
+ Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
75
+
76
+ ## Inference
77
+
78
+ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.
79
+
80
+ Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
81
+ - To avoid possible inference failures, make sure you have seen through the following instructions.
82
+ - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
83
+ - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
84
+ - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
85
+
86
+ ### CLI Inference
87
+
88
+ Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
89
+
90
+ for change model use --ckpt_file to specify the model you want to load,
91
+ for change vocab.txt use --vocab_file to provide your vocab.txt file.
92
+
93
+ ```bash
94
+ python inference-cli.py \
95
+ --model "F5-TTS" \
96
+ --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
97
+ --ref_text "Some call me nature, others call me mother nature." \
98
+ --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
99
+
100
+ python inference-cli.py \
101
+ --model "E2-TTS" \
102
+ --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
103
+ --ref_text "对,这就是我,万人敬仰的太乙真人。" \
104
+ --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
105
+
106
+ # Multi voice
107
+ python inference-cli.py -c samples/story.toml
108
+ ```
109
+
110
+ ### Gradio App
111
+ Currently supported features:
112
+ - Chunk inference
113
+ - Podcast Generation
114
+ - Multiple Speech-Type Generation
115
+
116
+ You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
117
+
118
+ ```bash
119
+ python gradio_app.py
120
+ ```
121
+
122
+ You can specify the port/host:
123
+
124
+ ```bash
125
+ python gradio_app.py --port 7860 --host 0.0.0.0
126
+ ```
127
+
128
+ Or launch a share link:
129
+
130
+ ```bash
131
+ python gradio_app.py --share
132
+ ```
133
+
134
+ ### Speech Editing
135
+
136
+ To test speech editing capabilities, use the following command.
137
+
138
+ ```bash
139
+ python speech_edit.py
140
+ ```
141
+
142
+ ## Evaluation
143
+
144
+ ### Prepare Test Datasets
145
+
146
+ 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
147
+ 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
148
+ 3. Unzip the downloaded datasets and place them in the data/ directory.
149
+ 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
150
+ 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
151
+
152
+ ### Batch Inference for Test Set
153
+
154
+ To run batch inference for evaluations, execute the following commands:
155
+
156
+ ```bash
157
+ # batch inference for evaluations
158
+ accelerate config # if not set before
159
+ bash scripts/eval_infer_batch.sh
160
+ ```
161
+
162
+ ### Download Evaluation Model Checkpoints
163
+
164
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
165
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
166
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
167
+
168
+ ### Objective Evaluation
169
+
170
+ Install packages for evaluation:
171
+
172
+ ```bash
173
+ pip install -r requirements_eval.txt
174
+ ```
175
+
176
+ **Some Notes**
177
+
178
+ For faster-whisper with CUDA 11:
179
+
180
+ ```bash
181
+ pip install --force-reinstall ctranslate2==3.24.0
182
+ ```
183
+
184
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
185
+
186
+ ```bash
187
+ pip install faster-whisper==0.10.1
188
+ ```
189
+
190
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
191
+ ```bash
192
+ # Evaluation for Seed-TTS test set
193
+ python scripts/eval_seedtts_testset.py
194
+
195
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
196
+ python scripts/eval_librispeech_test_clean.py
197
+ ```
198
+
199
+ ## Acknowledgements
200
+
201
+ - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
202
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
203
+ - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
204
+ - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
205
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
206
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
207
+ - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
208
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
209
+ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation of F5-TTS, with the MLX framework.
210
+
211
+ ## Citation
212
+ If our work and codebase is useful for you, please cite as:
213
+ ```
214
+ @article{chen-etal-2024-f5tts,
215
+ title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
216
+ author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
217
+ journal={arXiv preprint arXiv:2410.06885},
218
+ year={2024},
219
+ }
220
+ ```
221
+ ## License
222
+
223
+ Our code is released under MIT License. The pre-trained models are licensed under the CC-BY-NC license due to the training data Emilia, which is an in-the-wild dataset. Sorry for any inconvenience this may cause.
224
+ =======
225
+ # TTS
226
+ >>>>>>> e88b7d3df9854aa4bbc3db7b64fabc8be3e82f6a
finetune-cli.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
3
+ from model.utils import get_tokenizer
4
+ from model.dataset import load_dataset
5
+ from cached_path import cached_path
6
+ import shutil,os
7
+ # -------------------------- Dataset Settings --------------------------- #
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
+
15
+ # -------------------------- Argument Parsing --------------------------- #
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser(description='Train CFM Model')
18
+
19
+ parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
20
+ parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
21
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
22
+ parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
23
+ parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
24
+ parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
25
+ parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
26
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
27
+ parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
28
+ parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
+ parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
+ parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
+ parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
+
33
+ return parser.parse_args()
34
+
35
+ # -------------------------- Training Settings -------------------------- #
36
+
37
+ def main():
38
+ args = parse_args()
39
+
40
+
41
+ # Model parameters based on experiment name
42
+ if args.exp_name == "F5TTS_Base":
43
+ wandb_resume_id = None
44
+ model_cls = DiT
45
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
+ if args.finetune:
47
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
+ elif args.exp_name == "E2TTS_Base":
49
+ wandb_resume_id = None
50
+ model_cls = UNetT
51
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ if args.finetune:
53
+ ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
+
55
+ if args.finetune:
56
+ path_ckpt = os.path.join("ckpts",args.dataset_name)
57
+ if os.path.isdir(path_ckpt)==False:
58
+ os.makedirs(path_ckpt,exist_ok=True)
59
+ shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
+
61
+ checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
+
63
+ # Use the dataset_name provided in the command line
64
+ tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
65
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
66
+
67
+ mel_spec_kwargs = dict(
68
+ target_sample_rate=target_sample_rate,
69
+ n_mel_channels=n_mel_channels,
70
+ hop_length=hop_length,
71
+ )
72
+
73
+ e2tts = CFM(
74
+ transformer=model_cls(
75
+ **model_cfg,
76
+ text_num_embeds=vocab_size,
77
+ mel_dim=n_mel_channels
78
+ ),
79
+ mel_spec_kwargs=mel_spec_kwargs,
80
+ vocab_char_map=vocab_char_map,
81
+ )
82
+
83
+ trainer = Trainer(
84
+ e2tts,
85
+ args.epochs,
86
+ args.learning_rate,
87
+ num_warmup_updates=args.num_warmup_updates,
88
+ save_per_updates=args.save_per_updates,
89
+ checkpoint_path=checkpoint_path,
90
+ batch_size=args.batch_size_per_gpu,
91
+ batch_size_type=args.batch_size_type,
92
+ max_samples=args.max_samples,
93
+ grad_accumulation_steps=args.grad_accumulation_steps,
94
+ max_grad_norm=args.max_grad_norm,
95
+ wandb_project="CFM-TTS",
96
+ wandb_run_name=args.exp_name,
97
+ wandb_resume_id=wandb_resume_id,
98
+ last_per_steps=args.last_per_steps,
99
+ )
100
+
101
+ train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
102
+ trainer.train(train_dataset,
103
+ resumable_with_seed=666 # seed for shuffling dataset
104
+ )
105
+
106
+
107
+ if __name__ == '__main__':
108
+ main()
finetune_gradio.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ from transformers import pipeline
4
+ import gradio as gr
5
+ import torch
6
+ import click
7
+ import torchaudio
8
+ from glob import glob
9
+ import librosa
10
+ import numpy as np
11
+ from scipy.io import wavfile
12
+ import shutil
13
+ import time
14
+
15
+ import json
16
+ from model.utils import convert_char_to_pinyin
17
+ import signal
18
+ import psutil
19
+ import platform
20
+ import subprocess
21
+ from datasets.arrow_writer import ArrowWriter
22
+
23
+ import json
24
+
25
+ training_process = None
26
+ system = platform.system()
27
+ python_executable = sys.executable or "python"
28
+
29
+ path_data="data"
30
+
31
+ device = (
32
+ "cuda"
33
+ if torch.cuda.is_available()
34
+ else "mps" if torch.backends.mps.is_available() else "cpu"
35
+ )
36
+
37
+ pipe = None
38
+
39
+ # Load metadata
40
+ def get_audio_duration(audio_path):
41
+ """Calculate the duration of an audio file."""
42
+ audio, sample_rate = torchaudio.load(audio_path)
43
+ num_channels = audio.shape[0]
44
+ return audio.shape[1] / (sample_rate * num_channels)
45
+
46
+ def clear_text(text):
47
+ """Clean and prepare text by lowering the case and stripping whitespace."""
48
+ return text.lower().strip()
49
+
50
+ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
51
+ padding = (int(frame_length // 2), int(frame_length // 2))
52
+ y = np.pad(y, padding, mode=pad_mode)
53
+
54
+ axis = -1
55
+ # put our new within-frame axis at the end for now
56
+ out_strides = y.strides + tuple([y.strides[axis]])
57
+ # Reduce the shape on the framing axis
58
+ x_shape_trimmed = list(y.shape)
59
+ x_shape_trimmed[axis] -= frame_length - 1
60
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
61
+ xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
62
+ if axis < 0:
63
+ target_axis = axis - 1
64
+ else:
65
+ target_axis = axis + 1
66
+ xw = np.moveaxis(xw, -1, target_axis)
67
+ # Downsample along the target axis
68
+ slices = [slice(None)] * xw.ndim
69
+ slices[axis] = slice(0, None, hop_length)
70
+ x = xw[tuple(slices)]
71
+
72
+ # Calculate power
73
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
74
+
75
+ return np.sqrt(power)
76
+
77
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
78
+ def __init__(
79
+ self,
80
+ sr: int,
81
+ threshold: float = -40.0,
82
+ min_length: int = 2000,
83
+ min_interval: int = 300,
84
+ hop_size: int = 20,
85
+ max_sil_kept: int = 2000,
86
+ ):
87
+ if not min_length >= min_interval >= hop_size:
88
+ raise ValueError(
89
+ "The following condition must be satisfied: min_length >= min_interval >= hop_size"
90
+ )
91
+ if not max_sil_kept >= hop_size:
92
+ raise ValueError(
93
+ "The following condition must be satisfied: max_sil_kept >= hop_size"
94
+ )
95
+ min_interval = sr * min_interval / 1000
96
+ self.threshold = 10 ** (threshold / 20.0)
97
+ self.hop_size = round(sr * hop_size / 1000)
98
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
99
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
100
+ self.min_interval = round(min_interval / self.hop_size)
101
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
102
+
103
+ def _apply_slice(self, waveform, begin, end):
104
+ if len(waveform.shape) > 1:
105
+ return waveform[
106
+ :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
107
+ ]
108
+ else:
109
+ return waveform[
110
+ begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
111
+ ]
112
+
113
+ # @timeit
114
+ def slice(self, waveform):
115
+ if len(waveform.shape) > 1:
116
+ samples = waveform.mean(axis=0)
117
+ else:
118
+ samples = waveform
119
+ if samples.shape[0] <= self.min_length:
120
+ return [waveform]
121
+ rms_list = get_rms(
122
+ y=samples, frame_length=self.win_size, hop_length=self.hop_size
123
+ ).squeeze(0)
124
+ sil_tags = []
125
+ silence_start = None
126
+ clip_start = 0
127
+ for i, rms in enumerate(rms_list):
128
+ # Keep looping while frame is silent.
129
+ if rms < self.threshold:
130
+ # Record start of silent frames.
131
+ if silence_start is None:
132
+ silence_start = i
133
+ continue
134
+ # Keep looping while frame is not silent and silence start has not been recorded.
135
+ if silence_start is None:
136
+ continue
137
+ # Clear recorded silence start if interval is not enough or clip is too short
138
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
139
+ need_slice_middle = (
140
+ i - silence_start >= self.min_interval
141
+ and i - clip_start >= self.min_length
142
+ )
143
+ if not is_leading_silence and not need_slice_middle:
144
+ silence_start = None
145
+ continue
146
+ # Need slicing. Record the range of silent frames to be removed.
147
+ if i - silence_start <= self.max_sil_kept:
148
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
149
+ if silence_start == 0:
150
+ sil_tags.append((0, pos))
151
+ else:
152
+ sil_tags.append((pos, pos))
153
+ clip_start = pos
154
+ elif i - silence_start <= self.max_sil_kept * 2:
155
+ pos = rms_list[
156
+ i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
157
+ ].argmin()
158
+ pos += i - self.max_sil_kept
159
+ pos_l = (
160
+ rms_list[
161
+ silence_start : silence_start + self.max_sil_kept + 1
162
+ ].argmin()
163
+ + silence_start
164
+ )
165
+ pos_r = (
166
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
167
+ + i
168
+ - self.max_sil_kept
169
+ )
170
+ if silence_start == 0:
171
+ sil_tags.append((0, pos_r))
172
+ clip_start = pos_r
173
+ else:
174
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
175
+ clip_start = max(pos_r, pos)
176
+ else:
177
+ pos_l = (
178
+ rms_list[
179
+ silence_start : silence_start + self.max_sil_kept + 1
180
+ ].argmin()
181
+ + silence_start
182
+ )
183
+ pos_r = (
184
+ rms_list[i - self.max_sil_kept : i + 1].argmin()
185
+ + i
186
+ - self.max_sil_kept
187
+ )
188
+ if silence_start == 0:
189
+ sil_tags.append((0, pos_r))
190
+ else:
191
+ sil_tags.append((pos_l, pos_r))
192
+ clip_start = pos_r
193
+ silence_start = None
194
+ # Deal with trailing silence.
195
+ total_frames = rms_list.shape[0]
196
+ if (
197
+ silence_start is not None
198
+ and total_frames - silence_start >= self.min_interval
199
+ ):
200
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
201
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
202
+ sil_tags.append((pos, total_frames + 1))
203
+ # Apply and return slices.
204
+ ####音频+起始时间+终止时间
205
+ if len(sil_tags) == 0:
206
+ return [[waveform,0,int(total_frames*self.hop_size)]]
207
+ else:
208
+ chunks = []
209
+ if sil_tags[0][0] > 0:
210
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
211
+ for i in range(len(sil_tags) - 1):
212
+ chunks.append(
213
+ [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
214
+ )
215
+ if sil_tags[-1][1] < total_frames:
216
+ chunks.append(
217
+ [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
218
+ )
219
+ return chunks
220
+
221
+ #terminal
222
+ def terminate_process_tree(pid, including_parent=True):
223
+ try:
224
+ parent = psutil.Process(pid)
225
+ except psutil.NoSuchProcess:
226
+ # Process already terminated
227
+ return
228
+
229
+ children = parent.children(recursive=True)
230
+ for child in children:
231
+ try:
232
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
233
+ except OSError:
234
+ pass
235
+ if including_parent:
236
+ try:
237
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
238
+ except OSError:
239
+ pass
240
+
241
+ def terminate_process(pid):
242
+ if system == "Windows":
243
+ cmd = f"taskkill /t /f /pid {pid}"
244
+ os.system(cmd)
245
+ else:
246
+ terminate_process_tree(pid)
247
+
248
+ def start_training(dataset_name="",
249
+ exp_name="F5TTS_Base",
250
+ learning_rate=1e-4,
251
+ batch_size_per_gpu=400,
252
+ batch_size_type="frame",
253
+ max_samples=64,
254
+ grad_accumulation_steps=1,
255
+ max_grad_norm=1.0,
256
+ epochs=11,
257
+ num_warmup_updates=200,
258
+ save_per_updates=400,
259
+ last_per_steps=800,
260
+ finetune=True,
261
+ ):
262
+
263
+
264
+ global training_process
265
+
266
+ path_project = os.path.join(path_data, dataset_name + "_pinyin")
267
+
268
+ if os.path.isdir(path_project)==False:
269
+ yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
270
+ return
271
+
272
+ file_raw = os.path.join(path_project,"raw.arrow")
273
+ if os.path.isfile(file_raw)==False:
274
+ yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275
+ return
276
+
277
+ # Check if a training process is already running
278
+ if training_process is not None:
279
+ return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
280
+
281
+ yield "start train",gr.update(interactive=False),gr.update(interactive=False)
282
+
283
+ # Command to run the training script with the specified arguments
284
+ cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
285
+ f"--learning_rate {learning_rate} " \
286
+ f"--batch_size_per_gpu {batch_size_per_gpu} " \
287
+ f"--batch_size_type {batch_size_type} " \
288
+ f"--max_samples {max_samples} " \
289
+ f"--grad_accumulation_steps {grad_accumulation_steps} " \
290
+ f"--max_grad_norm {max_grad_norm} " \
291
+ f"--epochs {epochs} " \
292
+ f"--num_warmup_updates {num_warmup_updates} " \
293
+ f"--save_per_updates {save_per_updates} " \
294
+ f"--last_per_steps {last_per_steps} " \
295
+ f"--dataset_name {dataset_name}"
296
+ if finetune:cmd += f" --finetune {finetune}"
297
+
298
+ print(cmd)
299
+
300
+ try:
301
+ # Start the training process
302
+ training_process = subprocess.Popen(cmd, shell=True)
303
+
304
+ time.sleep(5)
305
+ yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True)
306
+
307
+ # Wait for the training process to finish
308
+ training_process.wait()
309
+ time.sleep(1)
310
+
311
+ if training_process is None:
312
+ text_info = 'train stop'
313
+ else:
314
+ text_info = "train complete !"
315
+
316
+ except Exception as e: # Catch all exceptions
317
+ # Ensure that we reset the training process variable in case of an error
318
+ text_info=f"An error occurred: {str(e)}"
319
+
320
+ training_process=None
321
+
322
+ yield text_info,gr.update(interactive=True),gr.update(interactive=False)
323
+
324
+ def stop_training():
325
+ global training_process
326
+ if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False)
327
+ terminate_process_tree(training_process.pid)
328
+ training_process = None
329
+ return 'train stop',gr.update(interactive=True),gr.update(interactive=False)
330
+
331
+ def create_data_project(name):
332
+ name+="_pinyin"
333
+ os.makedirs(os.path.join(path_data,name),exist_ok=True)
334
+ os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
335
+
336
+ def transcribe(file_audio,language="indonesian"):
337
+ global pipe
338
+
339
+ if pipe is None:
340
+ pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device)
341
+
342
+ text_transcribe = pipe(
343
+ file_audio,
344
+ chunk_length_s=30,
345
+ batch_size=128,
346
+ generate_kwargs={"task": "transcribe","language": language},
347
+ return_timestamps=False,
348
+ )["text"].strip()
349
+ return text_transcribe
350
+
351
+ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
352
+ name_project+="_pinyin"
353
+ path_project= os.path.join(path_data,name_project)
354
+ path_dataset = os.path.join(path_project,"dataset")
355
+ path_project_wavs = os.path.join(path_project,"wavs")
356
+ file_metadata = os.path.join(path_project,"metadata.csv")
357
+
358
+ if audio_files is None:return "You need to load an audio file."
359
+
360
+ if os.path.isdir(path_project_wavs):
361
+ shutil.rmtree(path_project_wavs)
362
+
363
+ if os.path.isfile(file_metadata):
364
+ os.remove(file_metadata)
365
+
366
+ os.makedirs(path_project_wavs,exist_ok=True)
367
+
368
+ if user:
369
+ file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
370
+ if file_audios==[]:return "No audio file was found in the dataset."
371
+ else:
372
+ file_audios = audio_files
373
+
374
+
375
+ alpha = 0.5
376
+ _max = 1.0
377
+ slicer = Slicer(24000)
378
+
379
+ num = 0
380
+ error_num = 0
381
+ data=""
382
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
383
+
384
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
385
+
386
+ list_slicer=slicer.slice(audio)
387
+ for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
388
+
389
+ name_segment = os.path.join(f"segment_{num}")
390
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
391
+
392
+ tmp_max = np.abs(chunk).max()
393
+ if(tmp_max>1):chunk/=tmp_max
394
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
395
+ wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
396
+
397
+ try:
398
+ text=transcribe(file_segment,language)
399
+ text = text.lower().strip().replace('"',"")
400
+
401
+ data+= f"{name_segment}|{text}\n"
402
+
403
+ num+=1
404
+ except:
405
+ error_num +=1
406
+
407
+ with open(file_metadata,"w",encoding="utf-8") as f:
408
+ f.write(data)
409
+
410
+ if error_num!=[]:
411
+ error_text=f"\nerror files : {error_num}"
412
+ else:
413
+ error_text=""
414
+
415
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
416
+
417
+ def format_seconds_to_hms(seconds):
418
+ hours = int(seconds / 3600)
419
+ minutes = int((seconds % 3600) / 60)
420
+ seconds = seconds % 60
421
+ return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
422
+
423
+ def create_metadata(name_project,progress=gr.Progress()):
424
+ name_project+="_pinyin"
425
+ path_project= os.path.join(path_data,name_project)
426
+ path_project_wavs = os.path.join(path_project,"wavs")
427
+ file_metadata = os.path.join(path_project,"metadata.csv")
428
+ file_raw = os.path.join(path_project,"raw.arrow")
429
+ file_duration = os.path.join(path_project,"duration.json")
430
+ file_vocab = os.path.join(path_project,"vocab.txt")
431
+
432
+ if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
433
+
434
+ with open(file_metadata,"r",encoding="utf-8") as f:
435
+ data=f.read()
436
+
437
+ audio_path_list=[]
438
+ text_list=[]
439
+ duration_list=[]
440
+
441
+ count=data.split("\n")
442
+ lenght=0
443
+ result=[]
444
+ error_files=[]
445
+ for line in progress.tqdm(data.split("\n"),total=count):
446
+ sp_line=line.split("|")
447
+ if len(sp_line)!=2:continue
448
+ name_audio,text = sp_line[:2]
449
+
450
+ file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
451
+
452
+ if os.path.isfile(file_audio)==False:
453
+ error_files.append(file_audio)
454
+ continue
455
+
456
+ duraction = get_audio_duration(file_audio)
457
+ if duraction<2 and duraction>15:continue
458
+ if len(text)<4:continue
459
+
460
+ text = clear_text(text)
461
+ text = convert_char_to_pinyin([text], polyphone = True)[0]
462
+
463
+ audio_path_list.append(file_audio)
464
+ duration_list.append(duraction)
465
+ text_list.append(text)
466
+
467
+ result.append({"audio_path": file_audio, "text": text, "duration": duraction})
468
+
469
+ lenght+=duraction
470
+
471
+ if duration_list==[]:
472
+ error_files_text="\n".join(error_files)
473
+ return f"Error: No audio files found in the specified path : \n{error_files_text}"
474
+
475
+ min_second = round(min(duration_list),2)
476
+ max_second = round(max(duration_list),2)
477
+
478
+ with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
479
+ for line in progress.tqdm(result,total=len(result), desc=f"prepare data"):
480
+ writer.write(line)
481
+
482
+ with open(file_duration, 'w', encoding='utf-8') as f:
483
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
484
+
485
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
486
+ if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
487
+ shutil.copy2(file_vocab_finetune, file_vocab)
488
+
489
+ if error_files!=[]:
490
+ error_text="error files\n" + "\n".join(error_files)
491
+ else:
492
+ error_text=""
493
+
494
+ return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
495
+
496
+ def check_user(value):
497
+ return gr.update(visible=not value),gr.update(visible=value)
498
+
499
+ def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune):
500
+ name_project+="_pinyin"
501
+ path_project= os.path.join(path_data,name_project)
502
+ file_duraction = os.path.join(path_project,"duration.json")
503
+
504
+ with open(file_duraction, 'r') as file:
505
+ data = json.load(file)
506
+
507
+ duration_list = data['duration']
508
+
509
+ samples = len(duration_list)
510
+
511
+ if torch.cuda.is_available():
512
+ gpu_properties = torch.cuda.get_device_properties(0)
513
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
514
+ elif torch.backends.mps.is_available():
515
+ total_memory = psutil.virtual_memory().available / (1024 ** 3)
516
+
517
+ if batch_size_type=="frame":
518
+ batch = int(total_memory * 0.5)
519
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
520
+ batch_size_per_gpu = int(38400 / batch )
521
+ else:
522
+ batch_size_per_gpu = int(total_memory / 8)
523
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
524
+ batch = batch_size_per_gpu
525
+
526
+ if batch_size_per_gpu<=0:batch_size_per_gpu=1
527
+
528
+ if samples<64:
529
+ max_samples = int(samples * 0.25)
530
+ else:
531
+ max_samples = 64
532
+
533
+ num_warmup_updates = int(samples * 0.10)
534
+ save_per_updates = int(samples * 0.25)
535
+ last_per_steps =int(save_per_updates * 5)
536
+
537
+ max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
538
+ num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
539
+ save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
540
+ last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
541
+
542
+ if finetune:learning_rate=1e-4
543
+ else:learning_rate=7.5e-5
544
+
545
+ return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
546
+
547
+ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
548
+ try:
549
+ checkpoint = torch.load(checkpoint_path)
550
+ print("Original Checkpoint Keys:", checkpoint.keys())
551
+
552
+ ema_model_state_dict = checkpoint.get('ema_model_state_dict', None)
553
+
554
+ if ema_model_state_dict is not None:
555
+ new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
556
+ torch.save(new_checkpoint, new_checkpoint_path)
557
+ return f"New checkpoint saved at: {new_checkpoint_path}"
558
+ else:
559
+ return "No 'ema_model_state_dict' found in the checkpoint."
560
+
561
+ except Exception as e:
562
+ return f"An error occurred: {e}"
563
+
564
+ def vocab_check(project_name):
565
+ name_project = project_name + "_pinyin"
566
+ path_project = os.path.join(path_data, name_project)
567
+
568
+ file_metadata = os.path.join(path_project, "metadata.csv")
569
+
570
+ file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
571
+ if os.path.isfile(file_vocab)==False:
572
+ return f"the file {file_vocab} not found !"
573
+
574
+ with open(file_vocab,"r",encoding="utf-8") as f:
575
+ data=f.read()
576
+
577
+ vocab = data.split("\n")
578
+
579
+ if os.path.isfile(file_metadata)==False:
580
+ return f"the file {file_metadata} not found !"
581
+
582
+ with open(file_metadata,"r",encoding="utf-8") as f:
583
+ data=f.read()
584
+
585
+ miss_symbols=[]
586
+ miss_symbols_keep={}
587
+ for item in data.split("\n"):
588
+ sp=item.split("|")
589
+ if len(sp)!=2:continue
590
+ text=sp[1].lower().strip()
591
+
592
+ for t in text:
593
+ if (t in vocab)==False and (t in miss_symbols_keep)==False:
594
+ miss_symbols.append(t)
595
+ miss_symbols_keep[t]=t
596
+
597
+
598
+ if miss_symbols==[]:info ="You can train using your language !"
599
+ else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
600
+
601
+ return info
602
+
603
+
604
+
605
+ with gr.Blocks() as app:
606
+
607
+ with gr.Row():
608
+ project_name=gr.Textbox(label="project name",value="my_speak")
609
+ bt_create=gr.Button("create new project")
610
+
611
+ bt_create.click(fn=create_data_project,inputs=[project_name])
612
+
613
+ with gr.Tabs():
614
+
615
+
616
+ with gr.TabItem("transcribe Data"):
617
+
618
+
619
+ ch_manual = gr.Checkbox(label="user",value=False)
620
+
621
+ mark_info_transcribe=gr.Markdown(
622
+ """```plaintext
623
+ Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
624
+
625
+ my_speak/
626
+
627
+ └── dataset/
628
+ ├── audio1.wav
629
+ └── audio2.wav
630
+ ...
631
+ ```""",visible=False)
632
+
633
+ audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple")
634
+ txt_lang = gr.Text(label="Language",value="indonesian")
635
+ bt_transcribe=bt_create=gr.Button("transcribe")
636
+ txt_info_transcribe=gr.Text(label="info",value="")
637
+ bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe])
638
+ ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe])
639
+
640
+ with gr.TabItem("prepare Data"):
641
+ gr.Markdown(
642
+ """```plaintext
643
+ place all your wavs folder and your metadata.csv file in {your name project}
644
+ my_speak/
645
+
646
+ ├── wavs/
647
+ │ ├── audio1.wav
648
+ │ └── audio2.wav
649
+ | ...
650
+
651
+ └── metadata.csv
652
+
653
+ file format metadata.csv
654
+
655
+ audio1|text1
656
+ audio2|text1
657
+ ...
658
+
659
+ ```""")
660
+
661
+ bt_prepare=bt_create=gr.Button("prepare")
662
+ txt_info_prepare=gr.Text(label="info",value="")
663
+ bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare])
664
+
665
+ with gr.TabItem("train Data"):
666
+
667
+ with gr.Row():
668
+ bt_calculate=bt_create=gr.Button("Auto Settings")
669
+ ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True)
670
+ lb_samples = gr.Label(label="samples")
671
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
672
+
673
+ with gr.Row():
674
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
675
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
676
+
677
+ with gr.Row():
678
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
679
+ max_samples = gr.Number(label="Max Samples", value=16)
680
+
681
+ with gr.Row():
682
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
683
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
684
+
685
+ with gr.Row():
686
+ epochs = gr.Number(label="Epochs", value=10)
687
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
688
+
689
+ with gr.Row():
690
+ save_per_updates = gr.Number(label="Save per Updates", value=10)
691
+ last_per_steps = gr.Number(label="Last per Steps", value=50)
692
+
693
+ with gr.Row():
694
+ start_button = gr.Button("Start Training")
695
+ stop_button = gr.Button("Stop Training",interactive=False)
696
+
697
+ txt_info_train=gr.Text(label="info",value="")
698
+ start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button])
699
+ stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
700
+ bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate])
701
+
702
+ with gr.TabItem("reduse checkpoint"):
703
+ txt_path_checkpoint = gr.Text(label="path checkpoint :")
704
+ txt_path_checkpoint_small = gr.Text(label="path output :")
705
+ txt_info_reduse = gr.Text(label="info",value="")
706
+ reduse_button = gr.Button("reduse")
707
+ reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
708
+
709
+ with gr.TabItem("vocab check experiment"):
710
+ check_button = gr.Button("check vocab")
711
+ txt_info_check=gr.Text(label="info",value="")
712
+ check_button.click(fn=vocab_check,inputs=[project_name],outputs=[txt_info_check])
713
+
714
+
715
+ @click.command()
716
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
717
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
718
+ @click.option(
719
+ "--share",
720
+ "-s",
721
+ default=False,
722
+ is_flag=True,
723
+ help="Share the app via Gradio share link",
724
+ )
725
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
726
+ def main(port, host, share, api):
727
+ global app
728
+ print(f"Starting app...")
729
+ app.queue(api_open=api).launch(
730
+ server_name=host, server_port=port, share=share, show_api=api
731
+ )
732
+
733
+ if __name__ == "__main__":
734
+ main()
gradio_app.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torchaudio
4
+ import gradio as gr
5
+ import numpy as np
6
+ import tempfile
7
+ from einops import rearrange
8
+ from vocos import Vocos
9
+ from pydub import AudioSegment, silence
10
+ from model import CFM, UNetT, DiT, MMDiT
11
+ from cached_path import cached_path
12
+ from model.utils import (
13
+ load_checkpoint,
14
+ get_tokenizer,
15
+ convert_char_to_pinyin,
16
+ save_spectrogram,
17
+ )
18
+ from transformers import pipeline
19
+ import click
20
+ import soundfile as sf
21
+
22
+ try:
23
+ import spaces
24
+ USING_SPACES = True
25
+ except ImportError:
26
+ USING_SPACES = False
27
+
28
+ def gpu_decorator(func):
29
+ if USING_SPACES:
30
+ return spaces.GPU(func)
31
+ else:
32
+ return func
33
+
34
+ device = (
35
+ "cuda"
36
+ if torch.cuda.is_available()
37
+ else "mps" if torch.backends.mps.is_available() else "cpu"
38
+ )
39
+
40
+ print(f"Using {device} device")
41
+
42
+ pipe = pipeline(
43
+ "automatic-speech-recognition",
44
+ model="openai/whisper-large-v3-turbo",
45
+ torch_dtype=torch.float16,
46
+ device=device,
47
+ )
48
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
49
+
50
+ # --------------------- Settings -------------------- #
51
+
52
+ target_sample_rate = 24000
53
+ n_mel_channels = 100
54
+ hop_length = 256
55
+ target_rms = 0.1
56
+ nfe_step = 32 # 16, 32
57
+ cfg_strength = 2.0
58
+ ode_method = "euler"
59
+ sway_sampling_coef = -1.0
60
+ speed = 1.0
61
+ fix_duration = None
62
+
63
+
64
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
65
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
66
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
67
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
68
+ model = CFM(
69
+ transformer=model_cls(
70
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
71
+ ),
72
+ mel_spec_kwargs=dict(
73
+ target_sample_rate=target_sample_rate,
74
+ n_mel_channels=n_mel_channels,
75
+ hop_length=hop_length,
76
+ ),
77
+ odeint_kwargs=dict(
78
+ method=ode_method,
79
+ ),
80
+ vocab_char_map=vocab_char_map,
81
+ ).to(device)
82
+
83
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
84
+
85
+ return model
86
+
87
+
88
+ # load models
89
+ F5TTS_model_cfg = dict(
90
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
91
+ )
92
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
93
+
94
+ F5TTS_ema_model = load_model(
95
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
96
+ )
97
+ E2TTS_ema_model = load_model(
98
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
99
+ )
100
+
101
+ def chunk_text(text, max_chars=135):
102
+ """
103
+ Splits the input text into chunks, each with a maximum number of characters.
104
+
105
+ Args:
106
+ text (str): The text to be split.
107
+ max_chars (int): The maximum number of characters per chunk.
108
+
109
+ Returns:
110
+ List[str]: A list of text chunks.
111
+ """
112
+ chunks = []
113
+ current_chunk = ""
114
+ # Split the text into sentences based on punctuation followed by whitespace
115
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
116
+
117
+ for sentence in sentences:
118
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
119
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
120
+ else:
121
+ if current_chunk:
122
+ chunks.append(current_chunk.strip())
123
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
124
+
125
+ if current_chunk:
126
+ chunks.append(current_chunk.strip())
127
+
128
+ return chunks
129
+
130
+ @gpu_decorator
131
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
132
+ if exp_name == "F5-TTS":
133
+ ema_model = F5TTS_ema_model
134
+ elif exp_name == "E2-TTS":
135
+ ema_model = E2TTS_ema_model
136
+
137
+ audio, sr = ref_audio
138
+ if audio.shape[0] > 1:
139
+ audio = torch.mean(audio, dim=0, keepdim=True)
140
+
141
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
142
+ if rms < target_rms:
143
+ audio = audio * target_rms / rms
144
+ if sr != target_sample_rate:
145
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
146
+ audio = resampler(audio)
147
+ audio = audio.to(device)
148
+
149
+ generated_waves = []
150
+ spectrograms = []
151
+
152
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
153
+ # Prepare the text
154
+ if len(ref_text[-1].encode('utf-8')) == 1:
155
+ ref_text = ref_text + " "
156
+ text_list = [ref_text + gen_text]
157
+ final_text_list = convert_char_to_pinyin(text_list)
158
+
159
+ # Calculate duration
160
+ ref_audio_len = audio.shape[-1] // hop_length
161
+ zh_pause_punc = r"。,、;:?!"
162
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
163
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
164
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
165
+
166
+ # inference
167
+ with torch.inference_mode():
168
+ generated, _ = ema_model.sample(
169
+ cond=audio,
170
+ text=final_text_list,
171
+ duration=duration,
172
+ steps=nfe_step,
173
+ cfg_strength=cfg_strength,
174
+ sway_sampling_coef=sway_sampling_coef,
175
+ )
176
+
177
+ generated = generated[:, ref_audio_len:, :]
178
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
179
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
180
+ if rms < target_rms:
181
+ generated_wave = generated_wave * rms / target_rms
182
+
183
+ # wav -> numpy
184
+ generated_wave = generated_wave.squeeze().cpu().numpy()
185
+
186
+ generated_waves.append(generated_wave)
187
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
188
+
189
+ # Combine all generated waves with cross-fading
190
+ if cross_fade_duration <= 0:
191
+ # Simply concatenate
192
+ final_wave = np.concatenate(generated_waves)
193
+ else:
194
+ final_wave = generated_waves[0]
195
+ for i in range(1, len(generated_waves)):
196
+ prev_wave = final_wave
197
+ next_wave = generated_waves[i]
198
+
199
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
200
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
201
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
202
+
203
+ if cross_fade_samples <= 0:
204
+ # No overlap possible, concatenate
205
+ final_wave = np.concatenate([prev_wave, next_wave])
206
+ continue
207
+
208
+ # Overlapping parts
209
+ prev_overlap = prev_wave[-cross_fade_samples:]
210
+ next_overlap = next_wave[:cross_fade_samples]
211
+
212
+ # Fade out and fade in
213
+ fade_out = np.linspace(1, 0, cross_fade_samples)
214
+ fade_in = np.linspace(0, 1, cross_fade_samples)
215
+
216
+ # Cross-faded overlap
217
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
218
+
219
+ # Combine
220
+ new_wave = np.concatenate([
221
+ prev_wave[:-cross_fade_samples],
222
+ cross_faded_overlap,
223
+ next_wave[cross_fade_samples:]
224
+ ])
225
+
226
+ final_wave = new_wave
227
+
228
+ # Remove silence
229
+ if remove_silence:
230
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
231
+ sf.write(f.name, final_wave, target_sample_rate)
232
+ aseg = AudioSegment.from_file(f.name)
233
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
234
+ non_silent_wave = AudioSegment.silent(duration=0)
235
+ for non_silent_seg in non_silent_segs:
236
+ non_silent_wave += non_silent_seg
237
+ aseg = non_silent_wave
238
+ aseg.export(f.name, format="wav")
239
+ final_wave, _ = torchaudio.load(f.name)
240
+ final_wave = final_wave.squeeze().cpu().numpy()
241
+
242
+ # Create a combined spectrogram
243
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
244
+
245
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
246
+ spectrogram_path = tmp_spectrogram.name
247
+ save_spectrogram(combined_spectrogram, spectrogram_path)
248
+
249
+ return (target_sample_rate, final_wave), spectrogram_path
250
+
251
+ @gpu_decorator
252
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
253
+
254
+ print(gen_text)
255
+
256
+ gr.Info("Converting audio...")
257
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
258
+ aseg = AudioSegment.from_file(ref_audio_orig)
259
+
260
+ non_silent_segs = silence.split_on_silence(
261
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
262
+ )
263
+ non_silent_wave = AudioSegment.silent(duration=0)
264
+ for non_silent_seg in non_silent_segs:
265
+ non_silent_wave += non_silent_seg
266
+ aseg = non_silent_wave
267
+
268
+ audio_duration = len(aseg)
269
+ if audio_duration > 15000:
270
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
271
+ aseg = aseg[:15000]
272
+ aseg.export(f.name, format="wav")
273
+ ref_audio = f.name
274
+
275
+ if not ref_text.strip():
276
+ gr.Info("No reference text provided, transcribing reference audio...")
277
+ ref_text = pipe(
278
+ ref_audio,
279
+ chunk_length_s=30,
280
+ batch_size=128,
281
+ generate_kwargs={"task": "transcribe"},
282
+ return_timestamps=False,
283
+ )["text"].strip()
284
+ gr.Info("Finished transcription")
285
+ else:
286
+ gr.Info("Using custom reference text...")
287
+
288
+ # Add the functionality to ensure it ends with ". "
289
+ if not ref_text.endswith(". "):
290
+ if ref_text.endswith("."):
291
+ ref_text += " "
292
+ else:
293
+ ref_text += ". "
294
+
295
+ audio, sr = torchaudio.load(ref_audio)
296
+
297
+ # Use the new chunk_text function to split gen_text
298
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
299
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
300
+ print('ref_text', ref_text)
301
+ for i, batch_text in enumerate(gen_text_batches):
302
+ print(f'gen_text {i}', batch_text)
303
+
304
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
305
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
306
+
307
+
308
+ @gpu_decorator
309
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
310
+ # Split the script into speaker blocks
311
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
312
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
313
+
314
+ generated_audio_segments = []
315
+
316
+ for i in range(0, len(speaker_blocks), 2):
317
+ speaker = speaker_blocks[i]
318
+ text = speaker_blocks[i+1].strip()
319
+
320
+ # Determine which speaker is talking
321
+ if speaker == speaker1_name:
322
+ ref_audio = ref_audio1
323
+ ref_text = ref_text1
324
+ elif speaker == speaker2_name:
325
+ ref_audio = ref_audio2
326
+ ref_text = ref_text2
327
+ else:
328
+ continue # Skip if the speaker is neither speaker1 nor speaker2
329
+
330
+ # Generate audio for this block
331
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
332
+
333
+ # Convert the generated audio to a numpy array
334
+ sr, audio_data = audio
335
+
336
+ # Save the audio data as a WAV file
337
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
338
+ sf.write(temp_file.name, audio_data, sr)
339
+ audio_segment = AudioSegment.from_wav(temp_file.name)
340
+
341
+ generated_audio_segments.append(audio_segment)
342
+
343
+ # Add a short pause between speakers
344
+ pause = AudioSegment.silent(duration=500) # 500ms pause
345
+ generated_audio_segments.append(pause)
346
+
347
+ # Concatenate all audio segments
348
+ final_podcast = sum(generated_audio_segments)
349
+
350
+ # Export the final podcast
351
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
352
+ podcast_path = temp_file.name
353
+ final_podcast.export(podcast_path, format="wav")
354
+
355
+ return podcast_path
356
+
357
+ def parse_speechtypes_text(gen_text):
358
+ # Pattern to find (Emotion)
359
+ pattern = r'\((.*?)\)'
360
+
361
+ # Split the text by the pattern
362
+ tokens = re.split(pattern, gen_text)
363
+
364
+ segments = []
365
+
366
+ current_emotion = 'Regular'
367
+
368
+ for i in range(len(tokens)):
369
+ if i % 2 == 0:
370
+ # This is text
371
+ text = tokens[i].strip()
372
+ if text:
373
+ segments.append({'emotion': current_emotion, 'text': text})
374
+ else:
375
+ # This is emotion
376
+ emotion = tokens[i].strip()
377
+ current_emotion = emotion
378
+
379
+ return segments
380
+
381
+ def update_speed(new_speed):
382
+ global speed
383
+ speed = new_speed
384
+ return f"Speed set to: {speed}"
385
+
386
+ with gr.Blocks() as app_credits:
387
+ gr.Markdown("""
388
+ # Credits
389
+
390
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
391
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
392
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation
393
+ """)
394
+ with gr.Blocks() as app_tts:
395
+ gr.Markdown("# Batched TTS")
396
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
397
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
398
+ model_choice = gr.Radio(
399
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
400
+ )
401
+ generate_btn = gr.Button("Synthesize", variant="primary")
402
+ with gr.Accordion("Advanced Settings", open=False):
403
+ ref_text_input = gr.Textbox(
404
+ label="Reference Text",
405
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
406
+ lines=2,
407
+ )
408
+ remove_silence = gr.Checkbox(
409
+ label="Remove Silences",
410
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
411
+ value=False,
412
+ )
413
+ speed_slider = gr.Slider(
414
+ label="Speed",
415
+ minimum=0.3,
416
+ maximum=2.0,
417
+ value=speed,
418
+ step=0.1,
419
+ info="Adjust the speed of the audio.",
420
+ )
421
+ cross_fade_duration_slider = gr.Slider(
422
+ label="Cross-Fade Duration (s)",
423
+ minimum=0.0,
424
+ maximum=1.0,
425
+ value=0.15,
426
+ step=0.01,
427
+ info="Set the duration of the cross-fade between audio clips.",
428
+ )
429
+ speed_slider.change(update_speed, inputs=speed_slider)
430
+
431
+ audio_output = gr.Audio(label="Synthesized Audio")
432
+ spectrogram_output = gr.Image(label="Spectrogram")
433
+
434
+ generate_btn.click(
435
+ infer,
436
+ inputs=[
437
+ ref_audio_input,
438
+ ref_text_input,
439
+ gen_text_input,
440
+ model_choice,
441
+ remove_silence,
442
+ cross_fade_duration_slider,
443
+ ],
444
+ outputs=[audio_output, spectrogram_output],
445
+ )
446
+
447
+ with gr.Blocks() as app_podcast:
448
+ gr.Markdown("# Podcast Generation")
449
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
450
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
451
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
452
+
453
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
454
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
455
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
456
+
457
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
458
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
459
+
460
+ podcast_model_choice = gr.Radio(
461
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
462
+ )
463
+ podcast_remove_silence = gr.Checkbox(
464
+ label="Remove Silences",
465
+ value=True,
466
+ )
467
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
468
+ podcast_output = gr.Audio(label="Generated Podcast")
469
+
470
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
471
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
472
+
473
+ generate_podcast_btn.click(
474
+ podcast_generation,
475
+ inputs=[
476
+ script_input,
477
+ speaker1_name,
478
+ ref_audio_input1,
479
+ ref_text_input1,
480
+ speaker2_name,
481
+ ref_audio_input2,
482
+ ref_text_input2,
483
+ podcast_model_choice,
484
+ podcast_remove_silence,
485
+ ],
486
+ outputs=podcast_output,
487
+ )
488
+
489
+ def parse_emotional_text(gen_text):
490
+ # Pattern to find (Emotion)
491
+ pattern = r'\((.*?)\)'
492
+
493
+ # Split the text by the pattern
494
+ tokens = re.split(pattern, gen_text)
495
+
496
+ segments = []
497
+
498
+ current_emotion = 'Regular'
499
+
500
+ for i in range(len(tokens)):
501
+ if i % 2 == 0:
502
+ # This is text
503
+ text = tokens[i].strip()
504
+ if text:
505
+ segments.append({'emotion': current_emotion, 'text': text})
506
+ else:
507
+ # This is emotion
508
+ emotion = tokens[i].strip()
509
+ current_emotion = emotion
510
+
511
+ return segments
512
+
513
+ with gr.Blocks() as app_emotional:
514
+ # New section for emotional generation
515
+ gr.Markdown(
516
+ """
517
+ # Multiple Speech-Type Generation
518
+
519
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
520
+
521
+ **Example Input:**
522
+
523
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
524
+ """
525
+ )
526
+
527
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
528
+
529
+ # Regular speech type (mandatory)
530
+ with gr.Row():
531
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
532
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
533
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
534
+
535
+ # Additional speech types (up to 99 more)
536
+ max_speech_types = 100
537
+ speech_type_names = []
538
+ speech_type_audios = []
539
+ speech_type_ref_texts = []
540
+ speech_type_delete_btns = []
541
+
542
+ for i in range(max_speech_types - 1):
543
+ with gr.Row():
544
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
545
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
546
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
547
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
548
+ speech_type_names.append(name_input)
549
+ speech_type_audios.append(audio_input)
550
+ speech_type_ref_texts.append(ref_text_input)
551
+ speech_type_delete_btns.append(delete_btn)
552
+
553
+ # Button to add speech type
554
+ add_speech_type_btn = gr.Button("Add Speech Type")
555
+
556
+ # Keep track of current number of speech types
557
+ speech_type_count = gr.State(value=0)
558
+
559
+ # Function to add a speech type
560
+ def add_speech_type_fn(speech_type_count):
561
+ if speech_type_count < max_speech_types - 1:
562
+ speech_type_count += 1
563
+ # Prepare updates for the components
564
+ name_updates = []
565
+ audio_updates = []
566
+ ref_text_updates = []
567
+ delete_btn_updates = []
568
+ for i in range(max_speech_types - 1):
569
+ if i < speech_type_count:
570
+ name_updates.append(gr.update(visible=True))
571
+ audio_updates.append(gr.update(visible=True))
572
+ ref_text_updates.append(gr.update(visible=True))
573
+ delete_btn_updates.append(gr.update(visible=True))
574
+ else:
575
+ name_updates.append(gr.update())
576
+ audio_updates.append(gr.update())
577
+ ref_text_updates.append(gr.update())
578
+ delete_btn_updates.append(gr.update())
579
+ else:
580
+ # Optionally, show a warning
581
+ # gr.Warning("Maximum number of speech types reached.")
582
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
583
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
584
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
585
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
586
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
587
+
588
+ add_speech_type_btn.click(
589
+ add_speech_type_fn,
590
+ inputs=speech_type_count,
591
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
592
+ )
593
+
594
+ # Function to delete a speech type
595
+ def make_delete_speech_type_fn(index):
596
+ def delete_speech_type_fn(speech_type_count):
597
+ # Prepare updates
598
+ name_updates = []
599
+ audio_updates = []
600
+ ref_text_updates = []
601
+ delete_btn_updates = []
602
+
603
+ for i in range(max_speech_types - 1):
604
+ if i == index:
605
+ name_updates.append(gr.update(visible=False, value=''))
606
+ audio_updates.append(gr.update(visible=False, value=None))
607
+ ref_text_updates.append(gr.update(visible=False, value=''))
608
+ delete_btn_updates.append(gr.update(visible=False))
609
+ else:
610
+ name_updates.append(gr.update())
611
+ audio_updates.append(gr.update())
612
+ ref_text_updates.append(gr.update())
613
+ delete_btn_updates.append(gr.update())
614
+
615
+ speech_type_count = max(0, speech_type_count - 1)
616
+
617
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
618
+
619
+ return delete_speech_type_fn
620
+
621
+ for i, delete_btn in enumerate(speech_type_delete_btns):
622
+ delete_fn = make_delete_speech_type_fn(i)
623
+ delete_btn.click(
624
+ delete_fn,
625
+ inputs=speech_type_count,
626
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
627
+ )
628
+
629
+ # Text input for the prompt
630
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
631
+
632
+ # Model choice
633
+ model_choice_emotional = gr.Radio(
634
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
635
+ )
636
+
637
+ with gr.Accordion("Advanced Settings", open=False):
638
+ remove_silence_emotional = gr.Checkbox(
639
+ label="Remove Silences",
640
+ value=True,
641
+ )
642
+
643
+ # Generate button
644
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
645
+
646
+ # Output audio
647
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
648
+ @gpu_decorator
649
+ def generate_emotional_speech(
650
+ regular_audio,
651
+ regular_ref_text,
652
+ gen_text,
653
+ *args,
654
+ ):
655
+ num_additional_speech_types = max_speech_types - 1
656
+ speech_type_names_list = args[:num_additional_speech_types]
657
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
658
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
659
+ model_choice = args[3 * num_additional_speech_types]
660
+ remove_silence = args[3 * num_additional_speech_types + 1]
661
+
662
+ # Collect the speech types and their audios into a dict
663
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
664
+
665
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
666
+ if name_input and audio_input:
667
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
668
+
669
+ # Parse the gen_text into segments
670
+ segments = parse_speechtypes_text(gen_text)
671
+
672
+ # For each segment, generate speech
673
+ generated_audio_segments = []
674
+ current_emotion = 'Regular'
675
+
676
+ for segment in segments:
677
+ emotion = segment['emotion']
678
+ text = segment['text']
679
+
680
+ if emotion in speech_types:
681
+ current_emotion = emotion
682
+ else:
683
+ # If emotion not available, default to Regular
684
+ current_emotion = 'Regular'
685
+
686
+ ref_audio = speech_types[current_emotion]['audio']
687
+ ref_text = speech_types[current_emotion].get('ref_text', '')
688
+
689
+ # Generate speech for this segment
690
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
691
+ sr, audio_data = audio
692
+
693
+ generated_audio_segments.append(audio_data)
694
+
695
+ # Concatenate all audio segments
696
+ if generated_audio_segments:
697
+ final_audio_data = np.concatenate(generated_audio_segments)
698
+ return (sr, final_audio_data)
699
+ else:
700
+ gr.Warning("No audio generated.")
701
+ return None
702
+
703
+ generate_emotional_btn.click(
704
+ generate_emotional_speech,
705
+ inputs=[
706
+ regular_audio,
707
+ regular_ref_text,
708
+ gen_text_input_emotional,
709
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
710
+ model_choice_emotional,
711
+ remove_silence_emotional,
712
+ ],
713
+ outputs=audio_output_emotional,
714
+ )
715
+
716
+ # Validation function to disable Generate button if speech types are missing
717
+ def validate_speech_types(
718
+ gen_text,
719
+ regular_name,
720
+ *args
721
+ ):
722
+ num_additional_speech_types = max_speech_types - 1
723
+ speech_type_names_list = args[:num_additional_speech_types]
724
+
725
+ # Collect the speech types names
726
+ speech_types_available = set()
727
+ if regular_name:
728
+ speech_types_available.add(regular_name)
729
+ for name_input in speech_type_names_list:
730
+ if name_input:
731
+ speech_types_available.add(name_input)
732
+
733
+ # Parse the gen_text to get the speech types used
734
+ segments = parse_emotional_text(gen_text)
735
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
736
+
737
+ # Check if all speech types in text are available
738
+ missing_speech_types = speech_types_in_text - speech_types_available
739
+
740
+ if missing_speech_types:
741
+ # Disable the generate button
742
+ return gr.update(interactive=False)
743
+ else:
744
+ # Enable the generate button
745
+ return gr.update(interactive=True)
746
+
747
+ gen_text_input_emotional.change(
748
+ validate_speech_types,
749
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
750
+ outputs=generate_emotional_btn
751
+ )
752
+ with gr.Blocks() as app:
753
+ gr.Markdown(
754
+ """
755
+ # E2/F5 TTS
756
+
757
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
758
+
759
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
760
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
761
+
762
+ The checkpoints support English and Chinese.
763
+
764
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
765
+
766
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
767
+ """
768
+ )
769
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
770
+
771
+ @click.command()
772
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
773
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
774
+ @click.option(
775
+ "--share",
776
+ "-s",
777
+ default=False,
778
+ is_flag=True,
779
+ help="Share the app via Gradio share link",
780
+ )
781
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
782
+ def main(port, host, share, api):
783
+ global app
784
+ print(f"Starting app...")
785
+ app.queue(api_open=api).launch(
786
+ server_name=host, server_port=port, share=share, show_api=api
787
+ )
788
+
789
+
790
+ if __name__ == "__main__":
791
+ if not USING_SPACES:
792
+ main()
793
+ else:
794
+ app.queue().launch()
inference-cli.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs
3
+ import re
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import tomli
10
+ import torch
11
+ import torchaudio
12
+ import tqdm
13
+ from cached_path import cached_path
14
+ from einops import rearrange
15
+ from pydub import AudioSegment, silence
16
+ from transformers import pipeline
17
+ from vocos import Vocos
18
+
19
+ from model import CFM, DiT, MMDiT, UNetT
20
+ from model.utils import (convert_char_to_pinyin, get_tokenizer,
21
+ load_checkpoint, save_spectrogram)
22
+
23
+ parser = argparse.ArgumentParser(
24
+ prog="python3 inference-cli.py",
25
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
26
+ epilog="Specify options above to override one or more settings from config.",
27
+ )
28
+ parser.add_argument(
29
+ "-c",
30
+ "--config",
31
+ help="Configuration file. Default=cli-config.toml",
32
+ default="inference-cli.toml",
33
+ )
34
+ parser.add_argument(
35
+ "-m",
36
+ "--model",
37
+ help="F5-TTS | E2-TTS",
38
+ )
39
+ parser.add_argument(
40
+ "-p",
41
+ "--ckpt_file",
42
+ help="The Checkpoint .pt",
43
+ )
44
+ parser.add_argument(
45
+ "-v",
46
+ "--vocab_file",
47
+ help="The vocab .txt",
48
+ )
49
+ parser.add_argument(
50
+ "-r",
51
+ "--ref_audio",
52
+ type=str,
53
+ help="Reference audio file < 15 seconds."
54
+ )
55
+ parser.add_argument(
56
+ "-s",
57
+ "--ref_text",
58
+ type=str,
59
+ default="666",
60
+ help="Subtitle for the reference audio."
61
+ )
62
+ parser.add_argument(
63
+ "-t",
64
+ "--gen_text",
65
+ type=str,
66
+ help="Text to generate.",
67
+ )
68
+ parser.add_argument(
69
+ "-f",
70
+ "--gen_file",
71
+ type=str,
72
+ help="File with text to generate. Ignores --text",
73
+ )
74
+ parser.add_argument(
75
+ "-o",
76
+ "--output_dir",
77
+ type=str,
78
+ help="Path to output folder..",
79
+ )
80
+ parser.add_argument(
81
+ "--remove_silence",
82
+ help="Remove silence.",
83
+ )
84
+ parser.add_argument(
85
+ "--load_vocoder_from_local",
86
+ action="store_true",
87
+ help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
88
+ )
89
+ args = parser.parse_args()
90
+
91
+ config = tomli.load(open(args.config, "rb"))
92
+
93
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
94
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
95
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
96
+ gen_file = args.gen_file if args.gen_file else config["gen_file"]
97
+ if gen_file:
98
+ gen_text = codecs.open(gen_file, "r", "utf-8").read()
99
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
100
+ model = args.model if args.model else config["model"]
101
+ ckpt_file = args.ckpt_file if args.ckpt_file else ""
102
+ vocab_file = args.vocab_file if args.vocab_file else ""
103
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
104
+ wave_path = Path(output_dir)/"out.wav"
105
+ spectrogram_path = Path(output_dir)/"out.png"
106
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
107
+
108
+ device = (
109
+ "cuda"
110
+ if torch.cuda.is_available()
111
+ else "mps" if torch.backends.mps.is_available() else "cpu"
112
+ )
113
+
114
+ if args.load_vocoder_from_local:
115
+ print(f"Load vocos from local path {vocos_local_path}")
116
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
117
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
118
+ vocos.load_state_dict(state_dict)
119
+ vocos.eval()
120
+ else:
121
+ print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
122
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
123
+
124
+ print(f"Using {device} device")
125
+
126
+ # --------------------- Settings -------------------- #
127
+
128
+ target_sample_rate = 24000
129
+ n_mel_channels = 100
130
+ hop_length = 256
131
+ target_rms = 0.1
132
+ nfe_step = 32 # 16, 32
133
+ cfg_strength = 2.0
134
+ ode_method = "euler"
135
+ sway_sampling_coef = -1.0
136
+ speed = 1.0
137
+ # fix_duration = 27 # None or float (duration in seconds)
138
+ fix_duration = None
139
+
140
+ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
141
+
142
+ if file_vocab=="":
143
+ file_vocab="Emilia_ZH_EN"
144
+ tokenizer="pinyin"
145
+ else:
146
+ tokenizer="custom"
147
+
148
+ print("\nvocab : ",vocab_file,tokenizer)
149
+ print("tokenizer : ",tokenizer)
150
+ print("model : ",ckpt_path,"\n")
151
+
152
+ vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
153
+ model = CFM(
154
+ transformer=model_cls(
155
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
156
+ ),
157
+ mel_spec_kwargs=dict(
158
+ target_sample_rate=target_sample_rate,
159
+ n_mel_channels=n_mel_channels,
160
+ hop_length=hop_length,
161
+ ),
162
+ odeint_kwargs=dict(
163
+ method=ode_method,
164
+ ),
165
+ vocab_char_map=vocab_char_map,
166
+ ).to(device)
167
+
168
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
169
+
170
+ return model
171
+
172
+ # load models
173
+ F5TTS_model_cfg = dict(
174
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
175
+ )
176
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
177
+
178
+ def chunk_text(text, max_chars=135):
179
+ """
180
+ Splits the input text into chunks, each with a maximum number of characters.
181
+ Args:
182
+ text (str): The text to be split.
183
+ max_chars (int): The maximum number of characters per chunk.
184
+ Returns:
185
+ List[str]: A list of text chunks.
186
+ """
187
+ chunks = []
188
+ current_chunk = ""
189
+ # Split the text into sentences based on punctuation followed by whitespace
190
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
191
+
192
+ for sentence in sentences:
193
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
194
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
195
+ else:
196
+ if current_chunk:
197
+ chunks.append(current_chunk.strip())
198
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
199
+
200
+ if current_chunk:
201
+ chunks.append(current_chunk.strip())
202
+
203
+ return chunks
204
+
205
+ #ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
206
+ #if not Path(ckpt_path).exists():
207
+ #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
208
+
209
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
210
+ if model == "F5-TTS":
211
+
212
+ if ckpt_file == "":
213
+ repo_name= "F5-TTS"
214
+ exp_name = "F5TTS_Base"
215
+ ckpt_step= 1200000
216
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
217
+
218
+ ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
219
+
220
+ elif model == "E2-TTS":
221
+ if ckpt_file == "":
222
+ repo_name= "E2-TTS"
223
+ exp_name = "E2TTS_Base"
224
+ ckpt_step= 1200000
225
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
226
+
227
+ ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
228
+
229
+ audio, sr = ref_audio
230
+ if audio.shape[0] > 1:
231
+ audio = torch.mean(audio, dim=0, keepdim=True)
232
+
233
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
234
+ if rms < target_rms:
235
+ audio = audio * target_rms / rms
236
+ if sr != target_sample_rate:
237
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
238
+ audio = resampler(audio)
239
+ audio = audio.to(device)
240
+
241
+ generated_waves = []
242
+ spectrograms = []
243
+
244
+ for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
245
+ # Prepare the text
246
+ if len(ref_text[-1].encode('utf-8')) == 1:
247
+ ref_text = ref_text + " "
248
+ text_list = [ref_text + gen_text]
249
+ final_text_list = convert_char_to_pinyin(text_list)
250
+
251
+ # Calculate duration
252
+ ref_audio_len = audio.shape[-1] // hop_length
253
+ zh_pause_punc = r"。,、;:?!"
254
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
255
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
256
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
257
+
258
+ # inference
259
+ with torch.inference_mode():
260
+ generated, _ = ema_model.sample(
261
+ cond=audio,
262
+ text=final_text_list,
263
+ duration=duration,
264
+ steps=nfe_step,
265
+ cfg_strength=cfg_strength,
266
+ sway_sampling_coef=sway_sampling_coef,
267
+ )
268
+
269
+ generated = generated[:, ref_audio_len:, :]
270
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
271
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
272
+ if rms < target_rms:
273
+ generated_wave = generated_wave * rms / target_rms
274
+
275
+ # wav -> numpy
276
+ generated_wave = generated_wave.squeeze().cpu().numpy()
277
+
278
+ generated_waves.append(generated_wave)
279
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
280
+
281
+ # Combine all generated waves with cross-fading
282
+ if cross_fade_duration <= 0:
283
+ # Simply concatenate
284
+ final_wave = np.concatenate(generated_waves)
285
+ else:
286
+ final_wave = generated_waves[0]
287
+ for i in range(1, len(generated_waves)):
288
+ prev_wave = final_wave
289
+ next_wave = generated_waves[i]
290
+
291
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
292
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
293
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
294
+
295
+ if cross_fade_samples <= 0:
296
+ # No overlap possible, concatenate
297
+ final_wave = np.concatenate([prev_wave, next_wave])
298
+ continue
299
+
300
+ # Overlapping parts
301
+ prev_overlap = prev_wave[-cross_fade_samples:]
302
+ next_overlap = next_wave[:cross_fade_samples]
303
+
304
+ # Fade out and fade in
305
+ fade_out = np.linspace(1, 0, cross_fade_samples)
306
+ fade_in = np.linspace(0, 1, cross_fade_samples)
307
+
308
+ # Cross-faded overlap
309
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
310
+
311
+ # Combine
312
+ new_wave = np.concatenate([
313
+ prev_wave[:-cross_fade_samples],
314
+ cross_faded_overlap,
315
+ next_wave[cross_fade_samples:]
316
+ ])
317
+
318
+ final_wave = new_wave
319
+
320
+ # Create a combined spectrogram
321
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
322
+
323
+ return final_wave, combined_spectrogram
324
+
325
+ def process_voice(ref_audio_orig, ref_text):
326
+ print("Converting audio...")
327
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
328
+ aseg = AudioSegment.from_file(ref_audio_orig)
329
+
330
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
331
+ non_silent_wave = AudioSegment.silent(duration=0)
332
+ for non_silent_seg in non_silent_segs:
333
+ non_silent_wave += non_silent_seg
334
+ aseg = non_silent_wave
335
+
336
+ audio_duration = len(aseg)
337
+ if audio_duration > 15000:
338
+ print("Audio is over 15s, clipping to only first 15s.")
339
+ aseg = aseg[:15000]
340
+ aseg.export(f.name, format="wav")
341
+ ref_audio = f.name
342
+
343
+ if not ref_text.strip():
344
+ print("No reference text provided, transcribing reference audio...")
345
+ pipe = pipeline(
346
+ "automatic-speech-recognition",
347
+ model="openai/whisper-large-v3-turbo",
348
+ torch_dtype=torch.float16,
349
+ device=device,
350
+ )
351
+ ref_text = pipe(
352
+ ref_audio,
353
+ chunk_length_s=30,
354
+ batch_size=128,
355
+ generate_kwargs={"task": "transcribe"},
356
+ return_timestamps=False,
357
+ )["text"].strip()
358
+ print("Finished transcription")
359
+ else:
360
+ print("Using custom reference text...")
361
+ return ref_audio, ref_text
362
+
363
+ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
364
+ print(gen_text)
365
+ # Add the functionality to ensure it ends with ". "
366
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
367
+ if ref_text.endswith("."):
368
+ ref_text += " "
369
+ else:
370
+ ref_text += ". "
371
+
372
+ # Split the input text into batches
373
+ audio, sr = torchaudio.load(ref_audio)
374
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
375
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
376
+ print('ref_text', ref_text)
377
+ for i, gen_text in enumerate(gen_text_batches):
378
+ print(f'gen_text {i}', gen_text)
379
+
380
+ print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
381
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
382
+
383
+
384
+ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
385
+ main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
386
+ if "voices" not in config:
387
+ voices = {"main": main_voice}
388
+ else:
389
+ voices = config["voices"]
390
+ voices["main"] = main_voice
391
+ for voice in voices:
392
+ voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])
393
+
394
+ generated_audio_segments = []
395
+ reg1 = r'(?=\[\w+\])'
396
+ chunks = re.split(reg1, text_gen)
397
+ reg2 = r'\[(\w+)\]'
398
+ for text in chunks:
399
+ match = re.match(reg2, text)
400
+ if not match or voice not in voices:
401
+ voice = "main"
402
+ else:
403
+ voice = match[1]
404
+ text = re.sub(reg2, "", text)
405
+ gen_text = text.strip()
406
+ ref_audio = voices[voice]['ref_audio']
407
+ ref_text = voices[voice]['ref_text']
408
+ print(f"Voice: {voice}")
409
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
410
+ generated_audio_segments.append(audio)
411
+
412
+ if generated_audio_segments:
413
+ final_wave = np.concatenate(generated_audio_segments)
414
+ with open(wave_path, "wb") as f:
415
+ sf.write(f.name, final_wave, target_sample_rate)
416
+ # Remove silence
417
+ if remove_silence:
418
+ aseg = AudioSegment.from_file(f.name)
419
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
420
+ non_silent_wave = AudioSegment.silent(duration=0)
421
+ for non_silent_seg in non_silent_segs:
422
+ non_silent_wave += non_silent_seg
423
+ aseg = non_silent_wave
424
+ aseg.export(f.name, format="wav")
425
+ print(f.name)
426
+
427
+
428
+ process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
inference-cli.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = "Some call me nature, others call me mother nature."
6
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
+ # File with text to generate. Ignores the text above.
8
+ gen_file = ""
9
+ remove_silence = false
10
+ output_dir = "tests"
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from model.cfm import CFM
2
+
3
+ from model.backbones.unett import UNetT
4
+ from model.backbones.dit import DiT
5
+ from model.backbones.mmdit import MMDiT
6
+
7
+ from model.trainer import Trainer
model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
model/backbones/dit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from einops import repeat
17
+
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ DiTBlock,
25
+ AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
+ else:
43
+ self.extra_modeling = False
44
+
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
50
+
51
+ if drop_text: # cfg for text
52
+ text = torch.zeros_like(text)
53
+
54
+ text = self.text_embed(text) # b n -> b n d
55
+
56
+ # possible extra modeling
57
+ if self.extra_modeling:
58
+ # sinus pos emb
59
+ batch_start = torch.zeros((batch,), dtype=torch.long)
60
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
+ text_pos_embed = self.freqs_cis[pos_idx]
62
+ text = text + text_pos_embed
63
+
64
+ # convnextv2 blocks
65
+ text = self.text_blocks(text)
66
+
67
+ return text
68
+
69
+
70
+ # noised input audio and context mixing embedding
71
+
72
+ class InputEmbedding(nn.Module):
73
+ def __init__(self, mel_dim, text_dim, out_dim):
74
+ super().__init__()
75
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
+
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
+ if drop_audio_cond: # cfg for cond audio
80
+ cond = torch.zeros_like(cond)
81
+
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
+ x = self.conv_pos_embed(x) + x
84
+ return x
85
+
86
+
87
+ # Transformer backbone using DiT blocks
88
+
89
+ class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.time_embed = TimestepEmbedding(dim)
98
+ if text_dim is None:
99
+ text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
+
103
+ self.rotary_embed = RotaryEmbedding(dim_head)
104
+
105
+ self.dim = dim
106
+ self.depth = depth
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def forward(
126
+ self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
+ drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
+ ):
135
+ batch, seq_len = x.shape[0], x.shape[1]
136
+ if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
+ t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
+
146
+ if self.long_skip_connection is not None:
147
+ residual = x
148
+
149
+ for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
+
152
+ if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
+
155
+ x = self.norm_out(x, t)
156
+ output = self.proj_out(x)
157
+
158
+ return output
model/backbones/mmdit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from einops import repeat
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from model.modules import (
20
+ TimestepEmbedding,
21
+ ConvPositionEmbedding,
22
+ MMDiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
25
+ )
26
+
27
+
28
+ # text embedding
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+ class AudioEmbedding(nn.Module):
58
+ def __init__(self, in_dim, out_dim):
59
+ super().__init__()
60
+ self.linear = nn.Linear(2 * in_dim, out_dim)
61
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
+
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
+ if drop_audio_cond:
65
+ cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
+ x = self.linear(x)
68
+ x = self.conv_pos_embed(x) + x
69
+ return x
70
+
71
+
72
+ # Transformer backbone using MM-DiT blocks
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.time_embed = TimestepEmbedding(dim)
82
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
83
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
84
+
85
+ self.rotary_embed = RotaryEmbedding(dim_head)
86
+
87
+ self.dim = dim
88
+ self.depth = depth
89
+
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
+ )
100
+ for i in range(depth)
101
+ ]
102
+ )
103
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
+ self.proj_out = nn.Linear(dim, mel_dim)
105
+
106
+ def forward(
107
+ self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
+ drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
+ ):
116
+ batch = x.shape[0]
117
+ if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
+
120
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
+ t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
+
125
+ seq_len = x.shape[1]
126
+ text_len = text.shape[1]
127
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
+ for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
+
133
+ x = self.norm_out(x, t)
134
+ output = self.proj_out(x)
135
+
136
+ return output
model/backbones/unett.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from einops import repeat, pack, unpack
18
+
19
+ from x_transformers import RMSNorm
20
+ from x_transformers.x_transformers import RotaryEmbedding
21
+
22
+ from model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ Attention,
27
+ AttnProcessor,
28
+ FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
+
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Flat UNet Transformer backbone
91
+
92
+ class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
+ ):
98
+ super().__init__()
99
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
+
101
+ self.time_embed = TimestepEmbedding(dim)
102
+ if text_dim is None:
103
+ text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
+
107
+ self.rotary_embed = RotaryEmbedding(dim_head)
108
+
109
+ # transformer layers & skip connections
110
+
111
+ self.dim = dim
112
+ self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
+
115
+ self.depth = depth
116
+ self.layers = nn.ModuleList([])
117
+
118
+ for idx in range(depth):
119
+ is_later_half = idx >= (depth // 2)
120
+
121
+ attn_norm = RMSNorm(dim)
122
+ attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
+
130
+ ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
142
+
143
+ self.norm_out = RMSNorm(dim)
144
+ self.proj_out = nn.Linear(dim, mel_dim)
145
+
146
+ def forward(
147
+ self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
+ drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
+ ):
156
+ batch, seq_len = x.shape[0], x.shape[1]
157
+ if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
+ t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
+
165
+ # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
+ if mask is not None:
168
+ mask = F.pad(mask, (1, 0), value=1)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
+
172
+ # flat unet transformer
173
+ skip_connect_type = self.skip_connect_type
174
+ skips = []
175
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
+ layer = idx + 1
177
+
178
+ # skip connection logic
179
+ is_first_half = layer <= (self.depth // 2)
180
+ is_later_half = not is_first_half
181
+
182
+ if is_first_half:
183
+ skips.append(x)
184
+
185
+ if is_later_half:
186
+ skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
+ x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
+ x = x + skip
192
+
193
+ # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
+ x = ff(ff_norm(x)) + x
196
+
197
+ assert len(skips) == 0
198
+
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
+
201
+ return self.proj_out(x)
model/cfm.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+ from torchdiffeq import odeint
20
+
21
+ from einops import rearrange
22
+
23
+ from model.modules import MelSpec
24
+
25
+ from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma = 0.,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
+ ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
+ *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
+ seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ edit_mask = None,
99
+ ):
100
+ self.eval()
101
+
102
+ # raw wave
103
+
104
+ if cond.ndim == 2:
105
+ cond = self.mel_spec(cond)
106
+ cond = rearrange(cond, 'b d n -> b n d')
107
+ assert cond.shape[-1] == self.num_channels
108
+
109
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
+ if not exists(lens):
111
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
+
113
+ # text
114
+
115
+ if isinstance(text, list):
116
+ if exists(self.vocab_char_map):
117
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
118
+ else:
119
+ text = list_str_to_tensor(text).to(device)
120
+ assert text.shape[0] == batch
121
+
122
+ if exists(text):
123
+ text_lens = (text != -1).sum(dim = -1)
124
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
+
126
+ # duration
127
+
128
+ cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
+
132
+ if isinstance(duration, int):
133
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
+
135
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
+ duration = duration.clamp(max = max_duration)
137
+ max_duration = duration.amax()
138
+
139
+ # duplicate test corner for inner time step oberservation
140
+ if duplicate_test:
141
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
+
143
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
+
148
+ if batch > 1:
149
+ mask = lens_to_mask(duration)
150
+ else: # save memory and speed up, as single inference need no mask currently
151
+ mask = None
152
+
153
+ # test for no ref audio
154
+ if no_ref_audio:
155
+ cond = torch.zeros_like(cond)
156
+
157
+ # neural ode
158
+
159
+ def fn(t, x):
160
+ # at each step, conditioning is fixed
161
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
+
163
+ # predict flow
164
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
165
+ if cfg_strength < 1e-5:
166
+ return pred
167
+
168
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
169
+ return pred + (pred - null_pred) * cfg_strength
170
+
171
+ # noise input
172
+ # to make sure batch inference result is same with different batch size, and for sure single inference
173
+ # still some difference maybe due to convolutional layers
174
+ y0 = []
175
+ for dur in duration:
176
+ if exists(seed):
177
+ torch.manual_seed(seed)
178
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
+
181
+ t_start = 0
182
+
183
+ # duplicate test corner for inner time step oberservation
184
+ if duplicate_test:
185
+ t_start = t_inter
186
+ y0 = (1 - t_start) * y0 + t_start * test_cond
187
+ steps = int(steps * (1 - t_start))
188
+
189
+ t = torch.linspace(t_start, 1, steps, device = self.device)
190
+ if sway_sampling_coef is not None:
191
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
+
193
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
+
195
+ sampled = trajectory[-1]
196
+ out = sampled
197
+ out = torch.where(cond_mask, cond, out)
198
+
199
+ if exists(vocoder):
200
+ out = rearrange(out, 'b n d -> b d n')
201
+ out = vocoder(out)
202
+
203
+ return out, trajectory
204
+
205
+ def forward(
206
+ self,
207
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
208
+ text: int['b nt'] | list[str],
209
+ *,
210
+ lens: int['b'] | None = None,
211
+ noise_scheduler: str | None = None,
212
+ ):
213
+ # handle raw wave
214
+ if inp.ndim == 2:
215
+ inp = self.mel_spec(inp)
216
+ inp = rearrange(inp, 'b d n -> b n d')
217
+ assert inp.shape[-1] == self.num_channels
218
+
219
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
+
221
+ # handle text as string
222
+ if isinstance(text, list):
223
+ if exists(self.vocab_char_map):
224
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
225
+ else:
226
+ text = list_str_to_tensor(text).to(device)
227
+ assert text.shape[0] == batch
228
+
229
+ # lens and mask
230
+ if not exists(lens):
231
+ lens = torch.full((batch,), seq_len, device = device)
232
+
233
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
+
235
+ # get a random span to mask out for training conditionally
236
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
+
239
+ if exists(mask):
240
+ rand_span_mask &= mask
241
+
242
+ # mel is x1
243
+ x1 = inp
244
+
245
+ # x0 is gaussian noise
246
+ x0 = torch.randn_like(x1)
247
+
248
+ # time step
249
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
250
+ # TODO. noise_scheduler
251
+
252
+ # sample xt (φ_t(x) in the paper)
253
+ t = rearrange(time, 'b -> b 1 1')
254
+ φ = (1 - t) * x0 + t * x1
255
+ flow = x1 - x0
256
+
257
+ # only predict what is within the random mask span for infilling
258
+ cond = torch.where(
259
+ rand_span_mask[..., None],
260
+ torch.zeros_like(x1), x1
261
+ )
262
+
263
+ # transformer and cfg training with a drop rate
264
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
265
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
266
+ drop_audio_cond = True
267
+ drop_text = True
268
+ else:
269
+ drop_text = False
270
+
271
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
274
+
275
+ # flow matching loss
276
+ loss = F.mse_loss(pred, flow, reduction = 'none')
277
+ loss = loss[rand_span_mask]
278
+
279
+ return loss.mean(), cond, pred
model/dataset.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, Sampler
8
+ import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
+ from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
+
14
+ from model.modules import MelSpec
15
+
16
+
17
+ class HFDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
+ ):
25
+ self.data = hf_dataset
26
+ self.target_sample_rate = target_sample_rate
27
+ self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
30
+ def get_frame_len(self, index):
31
+ row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, index):
40
+ row = self.data[index]
41
+ audio = row['audio']['array']
42
+
43
+ # logger.info(f"Audio shape: {audio.shape}")
44
+
45
+ sample_rate = row['audio']['sampling_rate']
46
+ duration = audio.shape[-1] / sample_rate
47
+
48
+ if duration > 30 or duration < 0.3:
49
+ return self.__getitem__((index + 1) % len(self.data))
50
+
51
+ audio_tensor = torch.from_numpy(audio).float()
52
+
53
+ if sample_rate != self.target_sample_rate:
54
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
+ audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
+ mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
+ return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
+ )
69
+
70
+
71
+ class CustomDataset(Dataset):
72
+ def __init__(
73
+ self,
74
+ custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
80
+ ):
81
+ self.data = custom_dataset
82
+ self.durations = durations
83
+ self.target_sample_rate = target_sample_rate
84
+ self.hop_length = hop_length
85
+ self.preprocessed_mel = preprocessed_mel
86
+ if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
+
89
+ def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
+ return self.durations[index] * self.target_sample_rate / self.hop_length
92
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, index):
98
+ row = self.data[index]
99
+ audio_path = row["audio_path"]
100
+ text = row["text"]
101
+ duration = row["duration"]
102
+
103
+ if self.preprocessed_mel:
104
+ mel_spec = torch.tensor(row["mel_spec"])
105
+
106
+ else:
107
+ audio, source_sample_rate = torchaudio.load(audio_path)
108
+
109
+ if duration > 30 or duration < 0.3:
110
+ return self.__getitem__((index + 1) % len(self.data))
111
+
112
+ if source_sample_rate != self.target_sample_rate:
113
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
+ audio = resampler(audio)
115
+
116
+ mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
+ return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
+ )
123
+
124
+
125
+ # Dynamic Batch Sampler
126
+
127
+ class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
+ """
134
+
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
+ self.sampler = sampler
137
+ self.frames_threshold = frames_threshold
138
+ self.max_samples = max_samples
139
+
140
+ indices, batches = [], []
141
+ data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
+ indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
+
147
+ batch = []
148
+ batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
+ batch.append(idx)
152
+ batch_frames += frame_len
153
+ else:
154
+ if len(batch) > 0:
155
+ batches.append(batch)
156
+ if frame_len <= self.frames_threshold:
157
+ batch = [idx]
158
+ batch_frames = frame_len
159
+ else:
160
+ batch = []
161
+ batch_frames = 0
162
+
163
+ if not drop_last and len(batch) > 0:
164
+ batches.append(batch)
165
+
166
+ del indices
167
+
168
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
+ # e.g. for epoch n, use (random_seed + n)
171
+ random.seed(random_seed)
172
+ random.shuffle(batches)
173
+
174
+ self.batches = batches
175
+
176
+ def __iter__(self):
177
+ return iter(self.batches)
178
+
179
+ def __len__(self):
180
+ return len(self.batches)
181
+
182
+
183
+ # Load dataset
184
+
185
+ def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str = "pinyin",
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+ '''
193
+ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
+ - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
+ '''
196
+
197
+ print("Loading dataset ...")
198
+
199
+ if dataset_type == "CustomDataset":
200
+ if audio_type == "raw":
201
+ try:
202
+ train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
203
+ except:
204
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
205
+ preprocessed_mel = False
206
+ elif audio_type == "mel":
207
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
208
+ preprocessed_mel = True
209
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
210
+ data_dict = json.load(f)
211
+ durations = data_dict["duration"]
212
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
+
214
+ elif dataset_type == "CustomDatasetPath":
215
+ try:
216
+ train_dataset = load_from_disk(f"{dataset_name}/raw")
217
+ except:
218
+ train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
+
220
+ with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
+ data_dict = json.load(f)
222
+ durations = data_dict["duration"]
223
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
+
225
+ elif dataset_type == "HFDataset":
226
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
227
+ "May also the corresponding script cuz different dataset may have different format.")
228
+ pre, post = dataset_name.split("_")
229
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
230
+
231
+ return train_dataset
232
+
233
+
234
+ # collation
235
+
236
+ def collate_fn(batch):
237
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
238
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
239
+ max_mel_length = mel_lengths.amax()
240
+
241
+ padded_mel_specs = []
242
+ for spec in mel_specs: # TODO. maybe records mask for attention here
243
+ padding = (0, max_mel_length - spec.size(-1))
244
+ padded_spec = F.pad(spec, padding, value = 0)
245
+ padded_mel_specs.append(padded_spec)
246
+
247
+ mel_specs = torch.stack(padded_mel_specs)
248
+
249
+ text = [item['text'] for item in batch]
250
+ text_lengths = torch.LongTensor([len(item) for item in text])
251
+
252
+ return dict(
253
+ mel = mel_specs,
254
+ mel_lengths = mel_lengths,
255
+ text = text,
256
+ text_lengths = text_lengths,
257
+ )
model/ecapa_tdnn.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
+ in_channels == out_channels == channels
18
+ '''
19
+
20
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
+ super().__init__()
22
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
+ self.scale = scale
24
+ self.width = channels // scale
25
+ self.nums = scale if scale == 1 else scale - 1
26
+
27
+ self.convs = []
28
+ self.bns = []
29
+ for i in range(self.nums):
30
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
+ self.bns.append(nn.BatchNorm1d(self.width))
32
+ self.convs = nn.ModuleList(self.convs)
33
+ self.bns = nn.ModuleList(self.bns)
34
+
35
+ def forward(self, x):
36
+ out = []
37
+ spx = torch.split(x, self.width, 1)
38
+ for i in range(self.nums):
39
+ if i == 0:
40
+ sp = spx[i]
41
+ else:
42
+ sp = sp + spx[i]
43
+ # Order: conv -> relu -> bn
44
+ sp = self.convs[i](sp)
45
+ sp = self.bns[i](F.relu(sp))
46
+ out.append(sp)
47
+ if self.scale != 1:
48
+ out.append(spx[self.nums])
49
+ out = torch.cat(out, dim=1)
50
+
51
+ return out
52
+
53
+
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+ class SE_Connect(nn.Module):
71
+ def __init__(self, channels, se_bottleneck_dim=128):
72
+ super().__init__()
73
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
+
76
+ def forward(self, x):
77
+ out = x.mean(dim=2)
78
+ out = F.relu(self.linear1(out))
79
+ out = torch.sigmoid(self.linear2(out))
80
+ out = x * out.unsqueeze(2)
81
+
82
+ return out
83
+
84
+
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
+
88
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
+ # return nn.Sequential(
90
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
+ # SE_Connect(channels)
94
+ # )
95
+
96
+ class SE_Res2Block(nn.Module):
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
+ super().__init__()
99
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
+
104
+ self.shortcut = None
105
+ if in_channels != out_channels:
106
+ self.shortcut = nn.Conv1d(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ kernel_size=1,
110
+ )
111
+
112
+ def forward(self, x):
113
+ residual = x
114
+ if self.shortcut:
115
+ residual = self.shortcut(x)
116
+
117
+ x = self.Conv1dReluBn1(x)
118
+ x = self.Res2Conv1dReluBn(x)
119
+ x = self.Conv1dReluBn2(x)
120
+ x = self.SE_Connect(x)
121
+
122
+ return x + residual
123
+
124
+
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
127
+
128
+ class AttentiveStatsPool(nn.Module):
129
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
+ super().__init__()
131
+ self.global_context_att = global_context_att
132
+
133
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
+ if global_context_att:
135
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
+ else:
137
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
+
140
+ def forward(self, x):
141
+
142
+ if self.global_context_att:
143
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
146
+ else:
147
+ x_in = x
148
+
149
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
+ alpha = torch.tanh(self.linear1(x_in))
151
+ # alpha = F.relu(self.linear1(x_in))
152
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
153
+ mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
+ std = torch.sqrt(residuals.clamp(min=1e-9))
156
+ return torch.cat([mean, std], dim=1)
157
+
158
+
159
+ class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
+ super().__init__()
163
+
164
+ self.feat_type = feat_type
165
+ self.feature_selection = feature_selection
166
+ self.update_extract = update_extract
167
+ self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
+ try:
171
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
+
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
+
181
+ self.feat_num = self.get_feat_num()
182
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
+
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
+ for name, param in self.feature_extract.named_parameters():
187
+ for freeze_val in freeze_list:
188
+ if freeze_val in name:
189
+ param.requires_grad = False
190
+ break
191
+
192
+ if not self.update_extract:
193
+ for param in self.feature_extract.parameters():
194
+ param.requires_grad = False
195
+
196
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
+ # self.channels = [channels] * 4 + [channels * 3]
198
+ self.channels = [channels] * 4 + [1536]
199
+
200
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
+
205
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
+ cat_channels = channels * 3
207
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
+
212
+
213
+ def get_feat_num(self):
214
+ self.feature_extract.eval()
215
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
+ with torch.no_grad():
217
+ features = self.feature_extract(wav)
218
+ select_feature = features[self.feature_selection]
219
+ if isinstance(select_feature, (list, tuple)):
220
+ return len(select_feature)
221
+ else:
222
+ return 1
223
+
224
+ def get_feat(self, x):
225
+ if self.update_extract:
226
+ x = self.feature_extract([sample for sample in x])
227
+ else:
228
+ with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
+ else:
232
+ x = self.feature_extract([sample for sample in x])
233
+
234
+ if self.feat_type == 'fbank':
235
+ x = x.log()
236
+
237
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
+ x = x[self.feature_selection]
239
+ if isinstance(x, (list, tuple)):
240
+ x = torch.stack(x, dim=0)
241
+ else:
242
+ x = x.unsqueeze(0)
243
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
+ x = (norm_weights * x).sum(dim=0)
245
+ x = torch.transpose(x, 1, 2) + 1e-6
246
+
247
+ x = self.instance_norm(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.get_feat(x)
252
+
253
+ out1 = self.layer1(x)
254
+ out2 = self.layer2(out1)
255
+ out3 = self.layer3(out2)
256
+ out4 = self.layer4(out3)
257
+
258
+ out = torch.cat([out2, out3, out4], dim=1)
259
+ out = F.relu(self.conv(out))
260
+ out = self.bn(self.pooling(out))
261
+ out = self.linear(out)
262
+
263
+ return out
264
+
265
+
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
model/modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from einops import rearrange
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+ class MelSpec(nn.Module):
26
+ def __init__(
27
+ self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
+ ):
38
+ super().__init__()
39
+ self.n_mel_channels = n_mel_channels
40
+
41
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
+ )
52
+
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
+
55
+ def forward(self, inp):
56
+ if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
+
59
+ assert len(inp.shape) == 2
60
+
61
+ if self.dummy.device != inp.device:
62
+ self.to(inp.device)
63
+
64
+ mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
+ return mel
67
+
68
+
69
+ # sinusoidal position embedding
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
+
104
+ x = rearrange(x, 'b n d -> b d n')
105
+ x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
+
111
+ return out
112
+
113
+
114
+ # rotary positional embedding related
115
+
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
+ # has some connection to NTK literature
119
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
122
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
+ t = torch.arange(end, device=freqs.device) # type: ignore
124
+ freqs = torch.outer(t, freqs).float() # type: ignore
125
+ freqs_cos = torch.cos(freqs) # real part
126
+ freqs_sin = torch.sin(freqs) # imaginary part
127
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
+
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
+ # length = length if isinstance(length, int) else length.max()
131
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
135
+ # avoid extra long error.
136
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
+ return pos
138
+
139
+
140
+ # Global Response Normalization layer (Instance Normalization ?)
141
+
142
+ class GRN(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
+
148
+ def forward(self, x):
149
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
+ return self.gamma * (x * Nx) + self.beta + x
152
+
153
+
154
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
+
157
+ class ConvNeXtV2Block(nn.Module):
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ intermediate_dim: int,
162
+ dilation: int = 1,
163
+ ):
164
+ super().__init__()
165
+ padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
168
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
+ self.act = nn.GELU()
170
+ self.grn = GRN(intermediate_dim)
171
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ residual = x
175
+ x = x.transpose(1, 2) # b n d -> b d n
176
+ x = self.dwconv(x)
177
+ x = x.transpose(1, 2) # b d n -> b n d
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.grn(x)
182
+ x = self.pwconv2(x)
183
+ return residual + x
184
+
185
+
186
+ # AdaLayerNormZero
187
+ # return with modulated x for attn input, and params for later mlp modulation
188
+
189
+ class AdaLayerNormZero(nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+
193
+ self.silu = nn.SiLU()
194
+ self.linear = nn.Linear(dim, dim * 6)
195
+
196
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
+
198
+ def forward(self, x, emb = None):
199
+ emb = self.linear(self.silu(emb))
200
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
+
202
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
+
205
+
206
+ # AdaLayerNormZero for final layer
207
+ # return only with modulated x for attn input, cuz no more mlp modulation
208
+
209
+ class AdaLayerNormZero_Final(nn.Module):
210
+ def __init__(self, dim):
211
+ super().__init__()
212
+
213
+ self.silu = nn.SiLU()
214
+ self.linear = nn.Linear(dim, dim * 2)
215
+
216
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
+
218
+ def forward(self, x, emb):
219
+ emb = self.linear(self.silu(emb))
220
+ scale, shift = torch.chunk(emb, 2, dim=1)
221
+
222
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
+ return x
224
+
225
+
226
+ # FeedForward
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
+ super().__init__()
231
+ inner_dim = int(dim * mult)
232
+ dim_out = dim_out if dim_out is not None else dim
233
+
234
+ activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
+
245
+ def forward(self, x):
246
+ return self.ff(x)
247
+
248
+
249
+ # Attention with possible joint part
250
+ # modified from diffusers/src/diffusers/models/attention_processor.py
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ processor: JointAttnProcessor | AttnProcessor,
256
+ dim: int,
257
+ heads: int = 8,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
+
268
+ self.processor = processor
269
+
270
+ self.dim = dim
271
+ self.heads = heads
272
+ self.inner_dim = dim_head * heads
273
+ self.dropout = dropout
274
+
275
+ self.context_dim = context_dim
276
+ self.context_pre_only = context_pre_only
277
+
278
+ self.to_q = nn.Linear(dim, self.inner_dim)
279
+ self.to_k = nn.Linear(dim, self.inner_dim)
280
+ self.to_v = nn.Linear(dim, self.inner_dim)
281
+
282
+ if self.context_dim is not None:
283
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
+ if self.context_pre_only is not None:
286
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
+
288
+ self.to_out = nn.ModuleList([])
289
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
290
+ self.to_out.append(nn.Dropout(dropout))
291
+
292
+ if self.context_pre_only is not None and not self.context_pre_only:
293
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
294
+
295
+ def forward(
296
+ self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
+ ) -> torch.Tensor:
303
+ if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
+ else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
+
308
+
309
+ # Attention processor
310
+
311
+ class AttnProcessor:
312
+ def __init__(self):
313
+ pass
314
+
315
+ def __call__(
316
+ self,
317
+ attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
+ ) -> torch.FloatTensor:
322
+
323
+ batch_size = x.shape[0]
324
+
325
+ # `sample` projections.
326
+ query = attn.to_q(x)
327
+ key = attn.to_k(x)
328
+ value = attn.to_v(x)
329
+
330
+ # apply rotary position embedding
331
+ if rope is not None:
332
+ freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
+
335
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
+
338
+ # attention
339
+ inner_dim = key.shape[-1]
340
+ head_dim = inner_dim // attn.heads
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
346
+ if mask is not None:
347
+ attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
+ else:
351
+ attn_mask = None
352
+
353
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ x = x.to(query.dtype)
356
+
357
+ # linear proj
358
+ x = attn.to_out[0](x)
359
+ # dropout
360
+ x = attn.to_out[1](x)
361
+
362
+ if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
+
366
+ return x
367
+
368
+
369
+ # Joint Attention processor for MM-DiT
370
+ # modified from diffusers/src/diffusers/models/attention_processor.py
371
+
372
+ class JointAttnProcessor:
373
+ def __init__(self):
374
+ pass
375
+
376
+ def __call__(
377
+ self,
378
+ attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
+ ) -> torch.FloatTensor:
385
+ residual = x
386
+
387
+ batch_size = c.shape[0]
388
+
389
+ # `sample` projections.
390
+ query = attn.to_q(x)
391
+ key = attn.to_k(x)
392
+ value = attn.to_v(x)
393
+
394
+ # `context` projections.
395
+ c_query = attn.to_q_c(c)
396
+ c_key = attn.to_k_c(c)
397
+ c_value = attn.to_v_c(c)
398
+
399
+ # apply rope for context and noised input independently
400
+ if rope is not None:
401
+ freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
+ if c_rope is not None:
406
+ freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
+
411
+ # attention
412
+ query = torch.cat([query, c_query], dim=1)
413
+ key = torch.cat([key, c_key], dim=1)
414
+ value = torch.cat([value, c_value], dim=1)
415
+
416
+ inner_dim = key.shape[-1]
417
+ head_dim = inner_dim // attn.heads
418
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
+
422
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
423
+ if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
+ else:
428
+ attn_mask = None
429
+
430
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
+ x = x.to(query.dtype)
433
+
434
+ # Split the attention outputs.
435
+ x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
+ )
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+ if not attn.context_pre_only:
445
+ c = attn.to_out_c(c)
446
+
447
+ if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
+
452
+ return x, c
453
+
454
+
455
+ # DiT Block
456
+
457
+ class DiTBlock(nn.Module):
458
+
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
+ self.attn_norm = AdaLayerNormZero(dim)
463
+ self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
+
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
+ # pre-norm & modulation for attention input
476
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
+
478
+ # attention
479
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
+
481
+ # process attention output for input x
482
+ x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
+ ff_output = self.ff(norm)
486
+ x = x + gate_mlp.unsqueeze(1) * ff_output
487
+
488
+ return x
489
+
490
+
491
+ # MMDiT Block https://arxiv.org/abs/2403.03206
492
+
493
+ class MMDiTBlock(nn.Module):
494
+ r"""
495
+ modified from diffusers/src/diffusers/models/attention.py
496
+
497
+ notes.
498
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
+ _x: noised input related. (right part)
500
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
+ """
502
+
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
+ super().__init__()
505
+
506
+ self.context_pre_only = context_pre_only
507
+
508
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
+ self.attn_norm_x = AdaLayerNormZero(dim)
510
+ self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
+
520
+ if not context_pre_only:
521
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
+ else:
524
+ self.ff_norm_c = None
525
+ self.ff_c = None
526
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
+
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
+ # pre-norm & modulation for attention input
531
+ if self.context_pre_only:
532
+ norm_c = self.attn_norm_c(c, t)
533
+ else:
534
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
+
537
+ # attention
538
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
+
540
+ # process attention output for context c
541
+ if self.context_pre_only:
542
+ c = None
543
+ else: # if not last layer
544
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
+
546
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
+ c_ff_output = self.ff_c(norm_c)
548
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
+
550
+ # process attention output for input x
551
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
+ x_ff_output = self.ff_x(norm_x)
555
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
+
557
+ return c, x
558
+
559
+
560
+ # time step conditioning embedding
561
+
562
+ class TimestepEmbedding(nn.Module):
563
+ def __init__(self, dim, freq_embed_dim=256):
564
+ super().__init__()
565
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
+
572
+ def forward(self, timestep: float['b']):
573
+ time_hidden = self.time_embed(timestep)
574
+ time = self.time_mlp(time_hidden) # b d
575
+ return time
model/trainer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
+
145
+ if self.is_main:
146
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
+
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
+ if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
+ else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
+ step = 0
158
+
159
+ del checkpoint; gc.collect()
160
+ return step
161
+
162
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
+
164
+ if exists(resumable_with_seed):
165
+ generator = torch.Generator()
166
+ generator.manual_seed(resumable_with_seed)
167
+ else:
168
+ generator = None
169
+
170
+ if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
+ batch_size=self.batch_size, shuffle=True, generator=generator)
173
+ elif self.batch_size_type == "frame":
174
+ self.accelerator.even_batches = False
175
+ sampler = SequentialSampler(train_dataset)
176
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
+ batch_sampler=batch_sampler)
179
+ else:
180
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
+
182
+ # accelerator.prepare() dispatches batches to devices;
183
+ # which means the length of dataloader calculated before, should consider the number of devices
184
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
186
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
+ decay_steps = total_steps - warmup_steps
188
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
+ self.scheduler = SequentialLR(self.optimizer,
191
+ schedulers=[warmup_scheduler, decay_scheduler],
192
+ milestones=[warmup_steps])
193
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
+ start_step = self.load_checkpoint()
195
+ global_step = start_step
196
+
197
+ if exists(resumable_with_seed):
198
+ orig_epoch_step = len(train_dataloader)
199
+ skipped_epoch = int(start_step // orig_epoch_step)
200
+ skipped_batch = start_step % orig_epoch_step
201
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
+ else:
203
+ skipped_epoch = 0
204
+
205
+ for epoch in range(skipped_epoch, self.epochs):
206
+ self.model.train()
207
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
208
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
+ initial=skipped_batch, total=orig_epoch_step)
210
+ else:
211
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
+
213
+ for batch in progress_bar:
214
+ with self.accelerator.accumulate(self.model):
215
+ text_inputs = batch['text']
216
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
+ mel_lengths = batch["mel_lengths"]
218
+
219
+ # TODO. add duration predictor training
220
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
+
224
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
+ self.accelerator.backward(loss)
226
+
227
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
+
230
+ self.optimizer.step()
231
+ self.scheduler.step()
232
+ self.optimizer.zero_grad()
233
+
234
+ if self.is_main:
235
+ self.ema_model.update()
236
+
237
+ global_step += 1
238
+
239
+ if self.accelerator.is_local_main_process:
240
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
+
242
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
+
244
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
+ self.save_checkpoint(global_step)
246
+
247
+ if global_step % self.last_per_steps == 0:
248
+ self.save_checkpoint(global_step, last=True)
249
+
250
+ self.accelerator.end_training()
model/utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+
26
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
+ from model.modules import MelSpec
28
+
29
+
30
+ # seed everything
31
+
32
+ def seed_everything(seed = 0):
33
+ random.seed(seed)
34
+ os.environ['PYTHONHASHSEED'] = str(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+ # helpers
42
+
43
+ def exists(v):
44
+ return v is not None
45
+
46
+ def default(v, d):
47
+ return v if exists(v) else d
48
+
49
+ # tensor helpers
50
+
51
+ def lens_to_mask(
52
+ t: int['b'],
53
+ length: int | None = None
54
+ ) -> bool['b n']:
55
+
56
+ if not exists(length):
57
+ length = t.amax()
58
+
59
+ seq = torch.arange(length, device = t.device)
60
+ return einx.less('n, b -> b n', seq, t)
61
+
62
+ def mask_from_start_end_indices(
63
+ seq_len: int['b'],
64
+ start: int['b'],
65
+ end: int['b']
66
+ ):
67
+ max_seq_len = seq_len.max().item()
68
+ seq = torch.arange(max_seq_len, device = start.device).long()
69
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
+
71
+ def mask_from_frac_lengths(
72
+ seq_len: int['b'],
73
+ frac_lengths: float['b']
74
+ ):
75
+ lengths = (frac_lengths * seq_len).long()
76
+ max_start = seq_len - lengths
77
+
78
+ rand = torch.rand_like(frac_lengths)
79
+ start = (max_start * rand).long().clamp(min = 0)
80
+ end = start + lengths
81
+
82
+ return mask_from_start_end_indices(seq_len, start, end)
83
+
84
+ def maybe_masked_mean(
85
+ t: float['b n d'],
86
+ mask: bool['b n'] = None
87
+ ) -> float['b d']:
88
+
89
+ if not exists(mask):
90
+ return t.mean(dim = 1)
91
+
92
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
+ num = reduce(t, 'b n d -> b d', 'sum')
94
+ den = reduce(mask.float(), 'b n -> b', 'sum')
95
+
96
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
+
98
+
99
+ # simple utf-8 tokenizer, since paper went character based
100
+ def list_str_to_tensor(
101
+ text: list[str],
102
+ padding_value = -1
103
+ ) -> int['b nt']:
104
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
+ return text
107
+
108
+ # char tokenizer, based on custom dataset's extracted .txt file
109
+ def list_str_to_idx(
110
+ text: list[str] | list[list[str]],
111
+ vocab_char_map: dict[str, int], # {char: idx}
112
+ padding_value = -1
113
+ ) -> int['b nt']:
114
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
+ return text
117
+
118
+
119
+ # Get tokenizer
120
+
121
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
+ '''
123
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
+ - "char" for char-wise tokenizer, need .txt vocab_file
125
+ - "byte" for utf-8 tokenizer
126
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
+ - if use "byte", set to 256 (unicode byte range)
130
+ '''
131
+ if tokenizer in ["pinyin", "char"]:
132
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
+ vocab_char_map = {}
134
+ for i, char in enumerate(f):
135
+ vocab_char_map[char[:-1]] = i
136
+ vocab_size = len(vocab_char_map)
137
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
138
+
139
+ elif tokenizer == "byte":
140
+ vocab_char_map = None
141
+ vocab_size = 256
142
+ elif tokenizer == "custom":
143
+ with open (dataset_name, "r", encoding="utf-8") as f:
144
+ vocab_char_map = {}
145
+ for i, char in enumerate(f):
146
+ vocab_char_map[char[:-1]] = i
147
+ vocab_size = len(vocab_char_map)
148
+
149
+ return vocab_char_map, vocab_size
150
+
151
+
152
+ # convert char to pinyin
153
+
154
+ def convert_char_to_pinyin(text_list, polyphone = True):
155
+ final_text_list = []
156
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
158
+ for text in text_list:
159
+ char_list = []
160
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
+ text = text.translate(custom_trans)
162
+ for seg in jieba.cut(text):
163
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
164
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
165
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
+ char_list.append(" ")
167
+ char_list.extend(seg)
168
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
169
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
170
+ for c in seg:
171
+ if c not in "。,、;:?!《》【】—…":
172
+ char_list.append(" ")
173
+ char_list.append(c)
174
+ else: # if mixed chinese characters, alphabets and symbols
175
+ for c in seg:
176
+ if ord(c) < 256:
177
+ char_list.extend(c)
178
+ else:
179
+ if c not in "。,、;:?!《》【】—…":
180
+ char_list.append(" ")
181
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
182
+ else: # if is zh punc
183
+ char_list.append(c)
184
+ final_text_list.append(char_list)
185
+
186
+ return final_text_list
187
+
188
+
189
+ # save spectrogram
190
+ def save_spectrogram(spectrogram, path):
191
+ plt.figure(figsize=(12, 4))
192
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
193
+ plt.colorbar()
194
+ plt.savefig(path)
195
+ plt.close()
196
+
197
+
198
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
+ def get_seedtts_testset_metainfo(metalst):
200
+ f = open(metalst); lines = f.readlines(); f.close()
201
+ metainfo = []
202
+ for line in lines:
203
+ if len(line.strip().split('|')) == 5:
204
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
+ elif len(line.strip().split('|')) == 4:
206
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
+ if not os.path.isabs(prompt_wav):
209
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
+ return metainfo
212
+
213
+
214
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
+ f = open(metalst); lines = f.readlines(); f.close()
217
+ metainfo = []
218
+ for line in lines:
219
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
+
221
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
+
225
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
+
229
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
+
231
+ return metainfo
232
+
233
+
234
+ # padded to max length mel batch
235
+ def padded_mel_batch(ref_mels):
236
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
+ padded_ref_mels = []
238
+ for mel in ref_mels:
239
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
+ padded_ref_mels.append(padded_ref_mel)
241
+ padded_ref_mels = torch.stack(padded_ref_mels)
242
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
+ return padded_ref_mels
244
+
245
+
246
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
+
248
+ def get_inference_prompt(
249
+ metainfo,
250
+ speed = 1., tokenizer = "pinyin", polyphone = True,
251
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
+ use_truth_duration = False,
253
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
+ ):
255
+ prompts_all = []
256
+
257
+ min_tokens = min_secs * target_sample_rate // hop_length
258
+ max_tokens = max_secs * target_sample_rate // hop_length
259
+
260
+ batch_accum = [0] * num_buckets
261
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
+ ([[] for _ in range(num_buckets)] for _ in range(6))
263
+
264
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
+
266
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
+
268
+ # Audio
269
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
+ if ref_rms < target_rms:
272
+ ref_audio = ref_audio * target_rms / ref_rms
273
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
+ if ref_sr != target_sample_rate:
275
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
+ ref_audio = resampler(ref_audio)
277
+
278
+ # Text
279
+ if len(prompt_text[-1].encode('utf-8')) == 1:
280
+ prompt_text = prompt_text + " "
281
+ text = [prompt_text + gt_text]
282
+ if tokenizer == "pinyin":
283
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
+ else:
285
+ text_list = text
286
+
287
+ # Duration, mel frame length
288
+ ref_mel_len = ref_audio.shape[-1] // hop_length
289
+ if use_truth_duration:
290
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
291
+ if gt_sr != target_sample_rate:
292
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
+ gt_audio = resampler(gt_audio)
294
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
+
296
+ # # test vocoder resynthesis
297
+ # ref_audio = gt_audio
298
+ else:
299
+ zh_pause_punc = r"。,、;:?!"
300
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
+
304
+ # to mel spectrogram
305
+ ref_mel = mel_spectrogram(ref_audio)
306
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
+
308
+ # deal with batch
309
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
+ assert min_tokens <= total_mel_len <= max_tokens, \
311
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
+
314
+ utts[bucket_i].append(utt)
315
+ ref_rms_list[bucket_i].append(ref_rms)
316
+ ref_mels[bucket_i].append(ref_mel)
317
+ ref_mel_lens[bucket_i].append(ref_mel_len)
318
+ total_mel_lens[bucket_i].append(total_mel_len)
319
+ final_text_list[bucket_i].extend(text_list)
320
+
321
+ batch_accum[bucket_i] += total_mel_len
322
+
323
+ if batch_accum[bucket_i] >= infer_batch_size:
324
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
+ prompts_all.append((
326
+ utts[bucket_i],
327
+ ref_rms_list[bucket_i],
328
+ padded_mel_batch(ref_mels[bucket_i]),
329
+ ref_mel_lens[bucket_i],
330
+ total_mel_lens[bucket_i],
331
+ final_text_list[bucket_i]
332
+ ))
333
+ batch_accum[bucket_i] = 0
334
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
+
336
+ # add residual
337
+ for bucket_i, bucket_frames in enumerate(batch_accum):
338
+ if bucket_frames > 0:
339
+ prompts_all.append((
340
+ utts[bucket_i],
341
+ ref_rms_list[bucket_i],
342
+ padded_mel_batch(ref_mels[bucket_i]),
343
+ ref_mel_lens[bucket_i],
344
+ total_mel_lens[bucket_i],
345
+ final_text_list[bucket_i]
346
+ ))
347
+ # not only leave easy work for last workers
348
+ random.seed(666)
349
+ random.shuffle(prompts_all)
350
+
351
+ return prompts_all
352
+
353
+
354
+ # get wav_res_ref_text of seed-tts test metalst
355
+ # https://github.com/BytedanceSpeech/seed-tts-eval
356
+
357
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
+ f = open(metalst)
359
+ lines = f.readlines()
360
+ f.close()
361
+
362
+ test_set_ = []
363
+ for line in tqdm(lines):
364
+ if len(line.strip().split('|')) == 5:
365
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
+ elif len(line.strip().split('|')) == 4:
367
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
+
369
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
+ continue
371
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
+ if not os.path.isabs(prompt_wav):
373
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
+
375
+ test_set_.append((gen_wav, prompt_wav, gt_text))
376
+
377
+ num_jobs = len(gpus)
378
+ if num_jobs == 1:
379
+ return [(gpus[0], test_set_)]
380
+
381
+ wav_per_job = len(test_set_) // num_jobs + 1
382
+ test_set = []
383
+ for i in range(num_jobs):
384
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
+
386
+ return test_set
387
+
388
+
389
+ # get librispeech test-clean cross sentence test
390
+
391
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
+ f = open(metalst)
393
+ lines = f.readlines()
394
+ f.close()
395
+
396
+ test_set_ = []
397
+ for line in tqdm(lines):
398
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
+
400
+ if eval_ground_truth:
401
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
+ else:
404
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
+
408
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
+
411
+ test_set_.append((gen_wav, ref_wav, gen_txt))
412
+
413
+ num_jobs = len(gpus)
414
+ if num_jobs == 1:
415
+ return [(gpus[0], test_set_)]
416
+
417
+ wav_per_job = len(test_set_) // num_jobs + 1
418
+ test_set = []
419
+ for i in range(num_jobs):
420
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
+
422
+ return test_set
423
+
424
+
425
+ # load asr model
426
+
427
+ def load_asr_model(lang, ckpt_dir = ""):
428
+ if lang == "zh":
429
+ from funasr import AutoModel
430
+ model = AutoModel(
431
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
432
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
435
+ disable_update=True,
436
+ ) # following seed-tts setting
437
+ elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
+ return model
442
+
443
+
444
+ # WER Evaluation, the way Seed-TTS does
445
+
446
+ def run_asr_wer(args):
447
+ rank, lang, test_set, ckpt_dir = args
448
+
449
+ if lang == "zh":
450
+ import zhconv
451
+ torch.cuda.set_device(rank)
452
+ elif lang == "en":
453
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
+ else:
455
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
+
457
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
+ punctuation_all = punctuation + string.punctuation
461
+ wers = []
462
+
463
+ from jiwer import compute_measures
464
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
465
+ if lang == "zh":
466
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
+ hypo = res[0]["text"]
468
+ hypo = zhconv.convert(hypo, 'zh-cn')
469
+ elif lang == "en":
470
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
+ hypo = ''
472
+ for segment in segments:
473
+ hypo = hypo + ' ' + segment.text
474
+
475
+ # raw_truth = truth
476
+ # raw_hypo = hypo
477
+
478
+ for x in punctuation_all:
479
+ truth = truth.replace(x, '')
480
+ hypo = hypo.replace(x, '')
481
+
482
+ truth = truth.replace(' ', ' ')
483
+ hypo = hypo.replace(' ', ' ')
484
+
485
+ if lang == "zh":
486
+ truth = " ".join([x for x in truth])
487
+ hypo = " ".join([x for x in hypo])
488
+ elif lang == "en":
489
+ truth = truth.lower()
490
+ hypo = hypo.lower()
491
+
492
+ measures = compute_measures(truth, hypo)
493
+ wer = measures["wer"]
494
+
495
+ # ref_list = truth.split(" ")
496
+ # subs = measures["substitutions"] / len(ref_list)
497
+ # dele = measures["deletions"] / len(ref_list)
498
+ # inse = measures["insertions"] / len(ref_list)
499
+
500
+ wers.append(wer)
501
+
502
+ return wers
503
+
504
+
505
+ # SIM Evaluation
506
+
507
+ def run_sim(args):
508
+ rank, test_set, ckpt_dir = args
509
+ device = f"cuda:{rank}"
510
+
511
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
+ model.load_state_dict(state_dict['model'], strict=False)
514
+
515
+ use_gpu=True if torch.cuda.is_available() else False
516
+ if use_gpu:
517
+ model = model.cuda(device)
518
+ model.eval()
519
+
520
+ sim_list = []
521
+ for wav1, wav2, truth in tqdm(test_set):
522
+
523
+ wav1, sr1 = torchaudio.load(wav1)
524
+ wav2, sr2 = torchaudio.load(wav2)
525
+
526
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
+ wav1 = resample1(wav1)
529
+ wav2 = resample2(wav2)
530
+
531
+ if use_gpu:
532
+ wav1 = wav1.cuda(device)
533
+ wav2 = wav2.cuda(device)
534
+ with torch.no_grad():
535
+ emb1 = model(wav1)
536
+ emb2 = model(wav2)
537
+
538
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
539
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
+ sim_list.append(sim)
541
+
542
+ return sim_list
543
+
544
+
545
+ # filter func for dirty data with many repetitions
546
+
547
+ def repetition_found(text, length = 2, tolerance = 10):
548
+ pattern_count = defaultdict(int)
549
+ for i in range(len(text) - length + 1):
550
+ pattern = text[i:i + length]
551
+ pattern_count[pattern] += 1
552
+ for pattern, count in pattern_count.items():
553
+ if count > tolerance:
554
+ return True
555
+ return False
556
+
557
+
558
+ # load model checkpoint for inference
559
+
560
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
+ from ema_pytorch import EMA
562
+
563
+ ckpt_type = ckpt_path.split(".")[-1]
564
+ if ckpt_type == "safetensors":
565
+ from safetensors.torch import load_file
566
+ checkpoint = load_file(ckpt_path, device=device)
567
+ else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
+
570
+ if use_ema == True:
571
+ ema_model = EMA(model, include_online_model = False).to(device)
572
+ if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint)
574
+ else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
+ ema_model.copy_params_from_ema_to_model()
577
+ else:
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ return model
myenv/bin/python ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94e9b56bc1f96b18d36a0f1d14308575bb7b5960eda94ae0520f0376e95d12d
3
+ size 5909000
myenv/bin/python3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94e9b56bc1f96b18d36a0f1d14308575bb7b5960eda94ae0520f0376e95d12d
3
+ size 5909000
myenv/bin/python3.10 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94e9b56bc1f96b18d36a0f1d14308575bb7b5960eda94ae0520f0376e95d12d
3
+ size 5909000
myenv/pyvenv.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ home = /usr/bin
2
+ include-system-site-packages = false
3
+ version = 3.10.12
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ cached_path
3
+ click
4
+ datasets
5
+ einops>=0.8.0
6
+ einx>=0.3.0
7
+ ema_pytorch>=0.5.2
8
+ gradio
9
+ jieba
10
+ librosa
11
+ matplotlib
12
+ numpy<=1.26.4
13
+ pydub
14
+ pypinyin
15
+ safetensors
16
+ soundfile
17
+ tomli
18
+ torchdiffeq
19
+ tqdm>=4.65.0
20
+ transformers
21
+ vocos
22
+ wandb
23
+ x_transformers>=1.31.14
requirements_eval.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ faster_whisper
2
+ funasr
3
+ jiwer
4
+ zhconv
5
+ zhon
samples/country.flac ADDED
Binary file (180 kB). View file
 
samples/main.flac ADDED
Binary file (279 kB). View file
 
samples/story.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "samples/main.flac"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = ""
6
+ gen_text = ""
7
+ # File with text to generate. Ignores the text above.
8
+ gen_file = "samples/story.txt"
9
+ remove_silence = true
10
+ output_dir = "samples"
11
+
12
+ [voices.town]
13
+ ref_audio = "samples/town.flac"
14
+ ref_text = ""
15
+
16
+ [voices.country]
17
+ ref_audio = "samples/country.flac"
18
+ ref_text = ""
19
+
samples/story.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
samples/town.flac ADDED
Binary file (229 kB). View file
 
scripts/count_max_epoch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
4
+
5
+ # data
6
+ total_hours = 95282
7
+ mel_hop_length = 256
8
+ mel_sampling_rate = 24000
9
+
10
+ # target
11
+ wanted_max_updates = 1000000
12
+
13
+ # train params
14
+ gpus = 8
15
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
16
+ grad_accum = 1
17
+
18
+ # intermediate
19
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
20
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
21
+ updates_per_epoch = total_hours / mini_batch_hours
22
+ steps_per_epoch = updates_per_epoch * grad_accum
23
+
24
+ # result
25
+ epochs = wanted_max_updates / updates_per_epoch
26
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
27
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
28
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
29
+
30
+ # others
31
+ print(f"total {total_hours:.0f} hours")
32
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
scripts/count_params_gflops.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
+
6
+ import torch
7
+ import thop
8
+
9
+
10
+ ''' ~155M '''
11
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
14
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
+
18
+ ''' ~335M '''
19
+ # FLOPs: 622.1 G, Params: 333.2 M
20
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
+ # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
+
24
+
25
+ model = M2_TTS(transformer=transformer)
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ duration = 20
30
+ frame_length = int(duration * target_sample_rate / hop_length)
31
+ text_length = 150
32
+
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
34
+ print(f"FLOPs: {flops / 1e9} G")
35
+ print(f"Params: {params / 1e6} M")
scripts/eval_infer_batch.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ import time
5
+ import random
6
+ from tqdm import tqdm
7
+ import argparse
8
+
9
+ import torch
10
+ import torchaudio
11
+ from accelerate import Accelerator
12
+ from einops import rearrange
13
+ from vocos import Vocos
14
+
15
+ from model import CFM, UNetT, DiT
16
+ from model.utils import (
17
+ load_checkpoint,
18
+ get_tokenizer,
19
+ get_seedtts_testset_metainfo,
20
+ get_librispeech_test_clean_metainfo,
21
+ get_inference_prompt,
22
+ )
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ target_rms = 0.1
34
+
35
+ tokenizer = "pinyin"
36
+
37
+
38
+ # ---------------------- infer setting ---------------------- #
39
+
40
+ parser = argparse.ArgumentParser(description="batch inference")
41
+
42
+ parser.add_argument('-s', '--seed', default=None, type=int)
43
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
44
+ parser.add_argument('-n', '--expname', required=True)
45
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
46
+
47
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
48
+ parser.add_argument('-o', '--odemethod', default="euler")
49
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
50
+
51
+ parser.add_argument('-t', '--testset', required=True)
52
+
53
+ args = parser.parse_args()
54
+
55
+
56
+ seed = args.seed
57
+ dataset_name = args.dataset
58
+ exp_name = args.expname
59
+ ckpt_step = args.ckptstep
60
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
61
+
62
+ nfe_step = args.nfestep
63
+ ode_method = args.odemethod
64
+ sway_sampling_coef = args.swaysampling
65
+
66
+ testset = args.testset
67
+
68
+
69
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
70
+ cfg_strength = 2.
71
+ speed = 1.
72
+ use_truth_duration = False
73
+ no_ref_audio = False
74
+
75
+
76
+ if exp_name == "F5TTS_Base":
77
+ model_cls = DiT
78
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
79
+
80
+ elif exp_name == "E2TTS_Base":
81
+ model_cls = UNetT
82
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
83
+
84
+
85
+ if testset == "ls_pc_test_clean":
86
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
87
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
+
90
+ elif testset == "seedtts_test_zh":
91
+ metalst = "data/seedtts_testset/zh/meta.lst"
92
+ metainfo = get_seedtts_testset_metainfo(metalst)
93
+
94
+ elif testset == "seedtts_test_en":
95
+ metalst = "data/seedtts_testset/en/meta.lst"
96
+ metainfo = get_seedtts_testset_metainfo(metalst)
97
+
98
+
99
+ # path to save genereted wavs
100
+ if seed is None: seed = random.randint(-10000, 10000)
101
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
102
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
103
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
104
+ f"_cfg{cfg_strength}_speed{speed}" \
105
+ f"{'_gt-dur' if use_truth_duration else ''}" \
106
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
107
+
108
+
109
+ # -------------------------------------------------#
110
+
111
+ use_ema = True
112
+
113
+ prompts_all = get_inference_prompt(
114
+ metainfo,
115
+ speed = speed,
116
+ tokenizer = tokenizer,
117
+ target_sample_rate = target_sample_rate,
118
+ n_mel_channels = n_mel_channels,
119
+ hop_length = hop_length,
120
+ target_rms = target_rms,
121
+ use_truth_duration = use_truth_duration,
122
+ infer_batch_size = infer_batch_size,
123
+ )
124
+
125
+ # Vocoder model
126
+ local = False
127
+ if local:
128
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
131
+ vocos.load_state_dict(state_dict)
132
+ vocos.eval()
133
+ else:
134
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
135
+
136
+ # Tokenizer
137
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
138
+
139
+ # Model
140
+ model = CFM(
141
+ transformer = model_cls(
142
+ **model_cfg,
143
+ text_num_embeds = vocab_size,
144
+ mel_dim = n_mel_channels
145
+ ),
146
+ mel_spec_kwargs = dict(
147
+ target_sample_rate = target_sample_rate,
148
+ n_mel_channels = n_mel_channels,
149
+ hop_length = hop_length,
150
+ ),
151
+ odeint_kwargs = dict(
152
+ method = ode_method,
153
+ ),
154
+ vocab_char_map = vocab_char_map,
155
+ ).to(device)
156
+
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
158
+
159
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
160
+ os.makedirs(output_dir)
161
+
162
+ # start batch inference
163
+ accelerator.wait_for_everyone()
164
+ start = time.time()
165
+
166
+ with accelerator.split_between_processes(prompts_all) as prompts:
167
+
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond = ref_mels,
178
+ text = final_text_list,
179
+ duration = total_mel_lens,
180
+ lens = ref_mel_lens,
181
+ steps = nfe_step,
182
+ cfg_strength = cfg_strength,
183
+ sway_sampling_coef = sway_sampling_coef,
184
+ no_ref_audio = no_ref_audio,
185
+ seed = seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
192
+ if ref_rms_list[i] < target_rms:
193
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
194
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
195
+
196
+ accelerator.wait_for_everyone()
197
+ if accelerator.is_main_process:
198
+ timediff = time.time() - start
199
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
scripts/eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
scripts/eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_librispeech_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "en"
18
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
+
22
+ gpus = [0,1,2,3,4,5,6,7]
23
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
+
25
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
26
+ ## leading to a low similarity for the ground truth in some cases.
27
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
28
+
29
+ local = False
30
+ if local: # use local custom checkpoint dir
31
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
32
+ else:
33
+ asr_ckpt_dir = "" # auto download to cache dir
34
+
35
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
36
+
37
+
38
+ # --------------------------- WER ---------------------------
39
+
40
+ if eval_task == "wer":
41
+ wers = []
42
+
43
+ with mp.Pool(processes=len(gpus)) as pool:
44
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
45
+ results = pool.map(run_asr_wer, args)
46
+ for wers_ in results:
47
+ wers.extend(wers_)
48
+
49
+ wer = round(np.mean(wers)*100, 3)
50
+ print(f"\nTotal {len(wers)} samples")
51
+ print(f"WER : {wer}%")
52
+
53
+
54
+ # --------------------------- SIM ---------------------------
55
+
56
+ if eval_task == "sim":
57
+ sim_list = []
58
+
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_sim, args)
62
+ for sim_ in results:
63
+ sim_list.extend(sim_)
64
+
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
+ print(f"\nTotal {len(sim_list)} samples")
67
+ print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_seed_tts_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
+ metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
+ # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
+
22
+
23
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
+
28
+ local = False
29
+ if local: # use local custom checkpoint dir
30
+ if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
+ elif lang == "en":
33
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
+ else:
35
+ asr_ckpt_dir = "" # auto download to cache dir
36
+
37
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
38
+
39
+
40
+ # --------------------------- WER ---------------------------
41
+
42
+ if eval_task == "wer":
43
+ wers = []
44
+
45
+ with mp.Pool(processes=len(gpus)) as pool:
46
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
47
+ results = pool.map(run_asr_wer, args)
48
+ for wers_ in results:
49
+ wers.extend(wers_)
50
+
51
+ wer = round(np.mean(wers)*100, 3)
52
+ print(f"\nTotal {len(wers)} samples")
53
+ print(f"WER : {wer}%")
54
+
55
+
56
+ # --------------------------- SIM ---------------------------
57
+
58
+ if eval_task == "sim":
59
+ sim_list = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_sim, args)
64
+ for sim_ in results:
65
+ sim_list.extend(sim_)
66
+
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
+ print(f"\nTotal {len(sim_list)} samples")
69
+ print(f"SIM : {sim}")
scripts/prepare_csv_wavs.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from pathlib import Path
5
+ import json
6
+ import shutil
7
+ import argparse
8
+
9
+ import csv
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from datasets.arrow_writer import ArrowWriter
13
+
14
+ from model.utils import (
15
+ convert_char_to_pinyin,
16
+ )
17
+
18
+ PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
+
20
+ def is_csv_wavs_format(input_dataset_dir):
21
+ fpath = Path(input_dataset_dir)
22
+ metadata = fpath / "metadata.csv"
23
+ wavs = fpath / 'wavs'
24
+ return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
+
26
+
27
+ def prepare_csv_wavs_dir(input_dir):
28
+ assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
29
+ input_dir = Path(input_dir)
30
+ metadata_path = input_dir / "metadata.csv"
31
+ audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
32
+
33
+ sub_result, durations = [], []
34
+ vocab_set = set()
35
+ polyphone = True
36
+ for audio_path, text in audio_path_text_pairs:
37
+ if not Path(audio_path).exists():
38
+ print(f"audio {audio_path} not found, skipping")
39
+ continue
40
+ audio_duration = get_audio_duration(audio_path)
41
+ # assume tokenizer = "pinyin" ("pinyin" | "char")
42
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
43
+ sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
44
+ durations.append(audio_duration)
45
+ vocab_set.update(list(text))
46
+
47
+ return sub_result, durations, vocab_set
48
+
49
+ def get_audio_duration(audio_path):
50
+ audio, sample_rate = torchaudio.load(audio_path)
51
+ num_channels = audio.shape[0]
52
+ return audio.shape[1] / (sample_rate * num_channels)
53
+
54
+ def read_audio_text_pairs(csv_file_path):
55
+ audio_text_pairs = []
56
+
57
+ parent = Path(csv_file_path).parent
58
+ with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
+ reader = csv.reader(csvfile, delimiter='|')
60
+ next(reader) # Skip the header row
61
+ for row in reader:
62
+ if len(row) >= 2:
63
+ audio_file = row[0].strip() # First column: audio file path
64
+ text = row[1].strip() # Second column: text
65
+ audio_file_path = parent / audio_file
66
+ audio_text_pairs.append((audio_file_path.as_posix(), text))
67
+
68
+ return audio_text_pairs
69
+
70
+
71
+ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
72
+ out_dir = Path(out_dir)
73
+ # save preprocessed dataset to disk
74
+ out_dir.mkdir(exist_ok=True, parents=True)
75
+ print(f"\nSaving to {out_dir} ...")
76
+
77
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
78
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
+ raw_arrow_path = out_dir / "raw.arrow"
80
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
+ writer.write(line)
83
+
84
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
85
+ dur_json_path = out_dir / "duration.json"
86
+ with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
+
89
+ # vocab map, i.e. tokenizer
90
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
91
+ # if tokenizer == "pinyin":
92
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
93
+ voca_out_path = out_dir / "vocab.txt"
94
+ with open(voca_out_path.as_posix(), "w") as f:
95
+ for vocab in sorted(text_vocab_set):
96
+ f.write(vocab + "\n")
97
+
98
+ if is_finetune:
99
+ file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
100
+ shutil.copy2(file_vocab_finetune, voca_out_path)
101
+ else:
102
+ with open(voca_out_path, "w") as f:
103
+ for vocab in sorted(text_vocab_set):
104
+ f.write(vocab + "\n")
105
+
106
+ dataset_name = out_dir.stem
107
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
108
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
109
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
110
+
111
+
112
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
113
+ if is_finetune:
114
+ assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
115
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
116
+ save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
117
+
118
+
119
+ def cli():
120
+ # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
+ # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
+ parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
+ parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
+ parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
+ parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
+
127
+ args = parser.parse_args()
128
+
129
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
+
131
+ if __name__ == "__main__":
132
+ cli()
scripts/prepare_emilia.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2
+ # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3
+
4
+ # generate audio text map for Emilia ZH & EN
5
+ # evaluate for vocab size
6
+
7
+ import sys, os
8
+ sys.path.append(os.getcwd())
9
+
10
+ from pathlib import Path
11
+ import json
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from datasets import Dataset
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from model.utils import (
19
+ repetition_found,
20
+ convert_char_to_pinyin,
21
+ )
22
+
23
+
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
25
+ zh_filters = ["い", "て"]
26
+ # seems synthesized audios, or heavily code-switched
27
+ out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
31
+ }
32
+ en_filters = ["ا", "い", "て"]
33
+
34
+
35
+ def deal_with_audio_dir(audio_dir):
36
+ audio_jsonl = audio_dir.with_suffix(".jsonl")
37
+ sub_result, durations = [], []
38
+ vocab_set = set()
39
+ bad_case_zh = 0
40
+ bad_case_en = 0
41
+ with open(audio_jsonl, "r") as f:
42
+ lines = f.readlines()
43
+ for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
+ obj = json.loads(line)
45
+ text = obj["text"]
46
+ if obj['language'] == "zh":
47
+ if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
+ bad_case_zh += 1
49
+ continue
50
+ else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
54
+ bad_case_en += 1
55
+ continue
56
+ if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
+ duration = obj["duration"]
59
+ sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
+ durations.append(duration)
61
+ vocab_set.update(list(text))
62
+ return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
63
+
64
+
65
+ def main():
66
+ assert tokenizer in ["pinyin", "char"]
67
+ result = []
68
+ duration_list = []
69
+ text_vocab_set = set()
70
+ total_bad_case_zh = 0
71
+ total_bad_case_en = 0
72
+
73
+ # process raw data
74
+ executor = ProcessPoolExecutor(max_workers=max_workers)
75
+ futures = []
76
+ for lang in langs:
77
+ dataset_path = Path(os.path.join(dataset_dir, lang))
78
+ [
79
+ futures.append(executor.submit(deal_with_audio_dir, audio_dir))
80
+ for audio_dir in dataset_path.iterdir()
81
+ if audio_dir.is_dir()
82
+ ]
83
+ for futures in tqdm(futures, total=len(futures)):
84
+ sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
85
+ result.extend(sub_result)
86
+ duration_list.extend(durations)
87
+ text_vocab_set.update(vocab_set)
88
+ total_bad_case_zh += bad_case_zh
89
+ total_bad_case_en += bad_case_en
90
+ executor.shutdown()
91
+
92
+ # save preprocessed dataset to disk
93
+ if not os.path.exists(f"data/{dataset_name}"):
94
+ os.makedirs(f"data/{dataset_name}")
95
+ print(f"\nSaving to data/{dataset_name} ...")
96
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
+ with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
+ writer.write(line)
101
+
102
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
+
106
+ # vocab map, i.e. tokenizer
107
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
108
+ # if tokenizer == "pinyin":
109
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
110
+ with open(f"data/{dataset_name}/vocab.txt", "w") as f:
111
+ for vocab in sorted(text_vocab_set):
112
+ f.write(vocab + "\n")
113
+
114
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+
123
+ max_workers = 32
124
+
125
+ tokenizer = "pinyin" # "pinyin" | "char"
126
+ polyphone = True
127
+
128
+ langs = ["ZH", "EN"]
129
+ dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
130
+ dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
131
+ print(f"\nPrepare for {dataset_name}\n")
132
+
133
+ main()
134
+
135
+ # Emilia ZH & EN
136
+ # samples count 37837916 (after removal)
137
+ # pinyin vocab size 2543 (polyphone)
138
+ # total duration 95281.87 (hours)
139
+ # bad zh asr cnt 230435 (samples)
140
+ # bad eh asr cnt 37217 (samples)
141
+
142
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
143
+ # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/prepare_wenetspeech4tts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate audio text map for WenetSpeech4TTS
2
+ # evaluate for vocab size
3
+
4
+ import sys, os
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import torchaudio
12
+ from datasets import Dataset
13
+
14
+ from model.utils import convert_char_to_pinyin
15
+
16
+
17
+ def deal_with_sub_path_files(dataset_path, sub_path):
18
+ print(f"Dealing with: {sub_path}")
19
+
20
+ text_dir = os.path.join(dataset_path, sub_path, "txts")
21
+ audio_dir = os.path.join(dataset_path, sub_path, "wavs")
22
+ text_files = os.listdir(text_dir)
23
+
24
+ audio_paths, texts, durations = [], [], []
25
+ for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
+ first_line = file.readline().split("\t")
28
+ audio_nm = first_line[0]
29
+ audio_path = os.path.join(audio_dir, audio_nm + ".wav")
30
+ text = first_line[1].strip()
31
+
32
+ audio_paths.append(audio_path)
33
+
34
+ if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
+ elif tokenizer == "char":
37
+ texts.append(text)
38
+
39
+ audio, sample_rate = torchaudio.load(audio_path)
40
+ durations.append(audio.shape[-1] / sample_rate)
41
+
42
+ return audio_paths, texts, durations
43
+
44
+
45
+ def main():
46
+ assert tokenizer in ["pinyin", "char"]
47
+
48
+ audio_path_list, text_list, duration_list = [], [], []
49
+
50
+ executor = ProcessPoolExecutor(max_workers=max_workers)
51
+ futures = []
52
+ for dataset_path in dataset_paths:
53
+ sub_items = os.listdir(dataset_path)
54
+ sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
55
+ for sub_path in sub_paths:
56
+ futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
57
+ for future in tqdm(futures, total=len(futures)):
58
+ audio_paths, texts, durations = future.result()
59
+ audio_path_list.extend(audio_paths)
60
+ text_list.extend(texts)
61
+ duration_list.extend(durations)
62
+ executor.shutdown()
63
+
64
+ if not os.path.exists("data"):
65
+ os.makedirs("data")
66
+
67
+ print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
68
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
+ dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
+
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
73
+
74
+ print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
+ text_vocab_set = set()
76
+ for text in tqdm(text_list):
77
+ text_vocab_set.update(list(text))
78
+
79
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
80
+ if tokenizer == "pinyin":
81
+ text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
82
+
83
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
84
+ for vocab in sorted(text_vocab_set):
85
+ f.write(vocab + "\n")
86
+ print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ max_workers = 32
93
+
94
+ tokenizer = "pinyin" # "pinyin" | "char"
95
+ polyphone = True
96
+ dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
+
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
+ dataset_paths = [
100
+ "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
+ "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
+ "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
+ print(f"\nChoose Dataset: {dataset_name}\n")
105
+
106
+ main()
107
+
108
+ # Results (if adding alphabets with accents and symbols):
109
+ # WenetSpeech4TTS Basic Standard Premium
110
+ # samples count 3932473 1941220 407494
111
+ # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
+ # char vocab size 5264 5219 5042
114
+
115
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
+ # please be careful if using pretrained model, make sure the vocab.txt is same
speech_edit.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from model.utils import (
11
+ load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+
44
+ if exp_name == "F5TTS_Base":
45
+ model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
+
48
+ elif exp_name == "E2TTS_Base":
49
+ model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
+
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
+ output_dir = "tests"
54
+
55
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
56
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
57
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
58
+ # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
59
+ # [result will be saved at same path of audio file]
60
+ # [--language "zho" for Chinese, "eng" for English]
61
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
62
+
63
+ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
+ origin_text = "Some call me nature, others call me mother nature."
65
+ target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
68
+
69
+ # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
71
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
72
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
73
+ # fix_duration = None # use origin text duration
74
+
75
+
76
+ # -------------------------------------------------#
77
+
78
+ use_ema = True
79
+
80
+ if not os.path.exists(output_dir):
81
+ os.makedirs(output_dir)
82
+
83
+ # Vocoder model
84
+ local = False
85
+ if local:
86
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
89
+ vocos.load_state_dict(state_dict)
90
+
91
+ vocos.eval()
92
+ else:
93
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
94
+
95
+ # Tokenizer
96
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
97
+
98
+ # Model
99
+ model = CFM(
100
+ transformer = model_cls(
101
+ **model_cfg,
102
+ text_num_embeds = vocab_size,
103
+ mel_dim = n_mel_channels
104
+ ),
105
+ mel_spec_kwargs = dict(
106
+ target_sample_rate = target_sample_rate,
107
+ n_mel_channels = n_mel_channels,
108
+ hop_length = hop_length,
109
+ ),
110
+ odeint_kwargs = dict(
111
+ method = ode_method,
112
+ ),
113
+ vocab_char_map = vocab_char_map,
114
+ ).to(device)
115
+
116
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
117
+
118
+ # Audio
119
+ audio, sr = torchaudio.load(audio_to_edit)
120
+ if audio.shape[0] > 1:
121
+ audio = torch.mean(audio, dim=0, keepdim=True)
122
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
123
+ if rms < target_rms:
124
+ audio = audio * target_rms / rms
125
+ if sr != target_sample_rate:
126
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
127
+ audio = resampler(audio)
128
+ offset = 0
129
+ audio_ = torch.zeros(1, 0)
130
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
131
+ for part in parts_to_edit:
132
+ start, end = part
133
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
134
+ part_dur = part_dur * target_sample_rate
135
+ start = start * target_sample_rate
136
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
137
+ edit_mask = torch.cat((edit_mask,
138
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
139
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
140
+ ), dim = -1)
141
+ offset = end * target_sample_rate
142
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
143
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
144
+ audio = audio.to(device)
145
+ edit_mask = edit_mask.to(device)
146
+
147
+ # Text
148
+ text_list = [target_text]
149
+ if tokenizer == "pinyin":
150
+ final_text_list = convert_char_to_pinyin(text_list)
151
+ else:
152
+ final_text_list = [text_list]
153
+ print(f"text : {text_list}")
154
+ print(f"pinyin: {final_text_list}")
155
+
156
+ # Duration
157
+ ref_audio_len = 0
158
+ duration = audio.shape[-1] // hop_length
159
+
160
+ # Inference
161
+ with torch.inference_mode():
162
+ generated, trajectory = model.sample(
163
+ cond = audio,
164
+ text = final_text_list,
165
+ duration = duration,
166
+ steps = nfe_step,
167
+ cfg_strength = cfg_strength,
168
+ sway_sampling_coef = sway_sampling_coef,
169
+ seed = seed,
170
+ edit_mask = edit_mask,
171
+ )
172
+ print(f"Generated mel: {generated.shape}")
173
+
174
+ # Final result
175
+ generated = generated[:, ref_audio_len:, :]
176
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
177
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
178
+ if rms < target_rms:
179
+ generated_wave = generated_wave * rms / target_rms
180
+
181
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
182
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
183
+ print(f"Generated wav: {generated_wave.shape}")
train.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
+ dataset_name = "Emilia_ZH_EN"
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+ def main():
47
+ if tokenizer == "custom":
48
+ tokenizer_path = tokenizer_path
49
+ else:
50
+ tokenizer_path = dataset_name
51
+ vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
52
+
53
+ mel_spec_kwargs = dict(
54
+ target_sample_rate = target_sample_rate,
55
+ n_mel_channels = n_mel_channels,
56
+ hop_length = hop_length,
57
+ )
58
+
59
+ e2tts = CFM(
60
+ transformer = model_cls(
61
+ **model_cfg,
62
+ text_num_embeds = vocab_size,
63
+ mel_dim = n_mel_channels
64
+ ),
65
+ mel_spec_kwargs = mel_spec_kwargs,
66
+ vocab_char_map = vocab_char_map,
67
+ )
68
+
69
+ trainer = Trainer(
70
+ e2tts,
71
+ epochs,
72
+ learning_rate,
73
+ num_warmup_updates = num_warmup_updates,
74
+ save_per_updates = save_per_updates,
75
+ checkpoint_path = f'ckpts/{exp_name}',
76
+ batch_size = batch_size_per_gpu,
77
+ batch_size_type = batch_size_type,
78
+ max_samples = max_samples,
79
+ grad_accumulation_steps = grad_accumulation_steps,
80
+ max_grad_norm = max_grad_norm,
81
+ wandb_project = "CFM-TTS",
82
+ wandb_run_name = exp_name,
83
+ wandb_resume_id = wandb_resume_id,
84
+ last_per_steps = last_per_steps,
85
+ )
86
+
87
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
88
+ trainer.train(train_dataset,
89
+ resumable_with_seed = 666 # seed for shuffling dataset
90
+ )
91
+
92
+
93
+ if __name__ == '__main__':
94
+ main()
upload.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import Repository
2
+ import os
3
+
4
+ # Path folder lokal di Colab
5
+ local_dir = "/content/TTS"
6
+
7
+ # Nama repository Hugging Face (buat di situs Hugging Face jika belum ada)
8
+ repo_name = "Erlanggaa/TTSmodels"
9
+
10
+ # Buat instance repository dan clone repo di Google Colab
11
+ repo = Repository(local_dir=local_dir, clone_from=repo_name)
12
+
13
+ # Tambahkan file dari folder lokal ke repository
14
+ repo.git_add()
15
+
16
+ # Commit file ke repository
17
+ repo.git_commit("Models")
18
+
19
+ # Push perubahan ke Hugging Face Hub
20
+ repo.git_push()