kevinwang676 commited on
Commit
c120033
·
verified ·
1 Parent(s): fe5241b

Add files using upload-large-folder tool

Browse files
third_party/Matcha-TTS/configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
third_party/Matcha-TTS/configs/trainer/default.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: lightning.pytorch.trainer.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ max_epochs: -1
6
+
7
+ accelerator: gpu
8
+ devices: [0]
9
+
10
+ # mixed precision for extra speed-up
11
+ precision: 16-mixed
12
+
13
+ # perform a validation loop every N training epochs
14
+ check_val_every_n_epoch: 1
15
+
16
+ # set True to to ensure deterministic results
17
+ # makes training slower but gives more reproducibility than just setting seeds
18
+ deterministic: False
19
+
20
+ gradient_clip_val: 5.0
third_party/Matcha-TTS/matcha/app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from argparse import Namespace
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import torch
8
+
9
+ from matcha.cli import (
10
+ MATCHA_URLS,
11
+ VOCODER_URLS,
12
+ assert_model_downloaded,
13
+ get_device,
14
+ load_matcha,
15
+ load_vocoder,
16
+ process_text,
17
+ to_waveform,
18
+ )
19
+ from matcha.utils.utils import get_user_data_dir, plot_tensor
20
+
21
+ LOCATION = Path(get_user_data_dir())
22
+
23
+ args = Namespace(
24
+ cpu=False,
25
+ model="matcha_vctk",
26
+ vocoder="hifigan_univ_v1",
27
+ spk=0,
28
+ )
29
+
30
+ CURRENTLY_LOADED_MODEL = args.model
31
+
32
+
33
+ def MATCHA_TTS_LOC(x):
34
+ return LOCATION / f"{x}.ckpt"
35
+
36
+
37
+ def VOCODER_LOC(x):
38
+ return LOCATION / f"{x}"
39
+
40
+
41
+ LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
42
+ RADIO_OPTIONS = {
43
+ "Multi Speaker (VCTK)": {
44
+ "model": "matcha_vctk",
45
+ "vocoder": "hifigan_univ_v1",
46
+ },
47
+ "Single Speaker (LJ Speech)": {
48
+ "model": "matcha_ljspeech",
49
+ "vocoder": "hifigan_T2_v1",
50
+ },
51
+ }
52
+
53
+ # Ensure all the required models are downloaded
54
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
55
+ assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
56
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
57
+ assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
58
+
59
+ device = get_device(args)
60
+
61
+ # Load default model
62
+ model = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
63
+ vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
64
+
65
+
66
+ def load_model(model_name, vocoder_name):
67
+ model = load_matcha(model_name, MATCHA_TTS_LOC(model_name), device)
68
+ vocoder, denoiser = load_vocoder(vocoder_name, VOCODER_LOC(vocoder_name), device)
69
+ return model, vocoder, denoiser
70
+
71
+
72
+ def load_model_ui(model_type, textbox):
73
+ model_name, vocoder_name = RADIO_OPTIONS[model_type]["model"], RADIO_OPTIONS[model_type]["vocoder"]
74
+
75
+ global model, vocoder, denoiser, CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
76
+ if CURRENTLY_LOADED_MODEL != model_name:
77
+ model, vocoder, denoiser = load_model(model_name, vocoder_name)
78
+ CURRENTLY_LOADED_MODEL = model_name
79
+
80
+ if model_name == "matcha_ljspeech":
81
+ spk_slider = gr.update(visible=False, value=-1)
82
+ single_speaker_examples = gr.update(visible=True)
83
+ multi_speaker_examples = gr.update(visible=False)
84
+ length_scale = gr.update(value=0.95)
85
+ else:
86
+ spk_slider = gr.update(visible=True, value=0)
87
+ single_speaker_examples = gr.update(visible=False)
88
+ multi_speaker_examples = gr.update(visible=True)
89
+ length_scale = gr.update(value=0.85)
90
+
91
+ return (
92
+ textbox,
93
+ gr.update(interactive=True),
94
+ spk_slider,
95
+ single_speaker_examples,
96
+ multi_speaker_examples,
97
+ length_scale,
98
+ )
99
+
100
+
101
+ @torch.inference_mode()
102
+ def process_text_gradio(text):
103
+ output = process_text(1, text, device)
104
+ return output["x_phones"][1::2], output["x"], output["x_lengths"]
105
+
106
+
107
+ @torch.inference_mode()
108
+ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
109
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
110
+ output = model.synthesise(
111
+ text,
112
+ text_length,
113
+ n_timesteps=n_timesteps,
114
+ temperature=temperature,
115
+ spks=spk,
116
+ length_scale=length_scale,
117
+ )
118
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
119
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
120
+ sf.write(fp.name, output["waveform"], 22050, "PCM_24")
121
+
122
+ return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
123
+
124
+
125
+ def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
126
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
127
+ if CURRENTLY_LOADED_MODEL != "matcha_vctk":
128
+ global model, vocoder, denoiser # pylint: disable=global-statement
129
+ model, vocoder, denoiser = load_model("matcha_vctk", "hifigan_univ_v1")
130
+ CURRENTLY_LOADED_MODEL = "matcha_vctk"
131
+
132
+ phones, text, text_lengths = process_text_gradio(text)
133
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
134
+ return phones, audio, mel_spectrogram
135
+
136
+
137
+ def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
138
+ global CURRENTLY_LOADED_MODEL # pylint: disable=global-statement
139
+ if CURRENTLY_LOADED_MODEL != "matcha_ljspeech":
140
+ global model, vocoder, denoiser # pylint: disable=global-statement
141
+ model, vocoder, denoiser = load_model("matcha_ljspeech", "hifigan_T2_v1")
142
+ CURRENTLY_LOADED_MODEL = "matcha_ljspeech"
143
+
144
+ phones, text, text_lengths = process_text_gradio(text)
145
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
146
+ return phones, audio, mel_spectrogram
147
+
148
+
149
+ def main():
150
+ description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
151
+ ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
152
+ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
153
+
154
+
155
+ * Is probabilistic
156
+ * Has compact memory footprint
157
+ * Sounds highly natural
158
+ * Is very fast to synthesise from
159
+
160
+
161
+ Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
162
+ Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
163
+
164
+ Cached examples are available at the bottom of the page.
165
+ """
166
+
167
+ with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
168
+ processed_text = gr.State(value=None)
169
+ processed_text_len = gr.State(value=None)
170
+
171
+ with gr.Box():
172
+ with gr.Row():
173
+ gr.Markdown(description, scale=3)
174
+ with gr.Column():
175
+ gr.Image(LOGO_URL, label="Matcha-TTS logo", height=50, width=50, scale=1, show_label=False)
176
+ html = '<br><iframe width="560" height="315" src="https://www.youtube.com/embed/xmvJkz3bqw0?si=jN7ILyDsbPwJCGoa" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>'
177
+ gr.HTML(html)
178
+
179
+ with gr.Box():
180
+ radio_options = list(RADIO_OPTIONS.keys())
181
+ model_type = gr.Radio(
182
+ radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
183
+ )
184
+
185
+ with gr.Row():
186
+ gr.Markdown("# Text Input")
187
+ with gr.Row():
188
+ text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
189
+ spk_slider = gr.Slider(
190
+ minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
191
+ )
192
+
193
+ with gr.Row():
194
+ gr.Markdown("### Hyper parameters")
195
+ with gr.Row():
196
+ n_timesteps = gr.Slider(
197
+ label="Number of ODE steps",
198
+ minimum=1,
199
+ maximum=100,
200
+ step=1,
201
+ value=10,
202
+ interactive=True,
203
+ )
204
+ length_scale = gr.Slider(
205
+ label="Length scale (Speaking rate)",
206
+ minimum=0.5,
207
+ maximum=1.5,
208
+ step=0.05,
209
+ value=1.0,
210
+ interactive=True,
211
+ )
212
+ mel_temp = gr.Slider(
213
+ label="Sampling temperature",
214
+ minimum=0.00,
215
+ maximum=2.001,
216
+ step=0.16675,
217
+ value=0.667,
218
+ interactive=True,
219
+ )
220
+
221
+ synth_btn = gr.Button("Synthesise")
222
+
223
+ with gr.Box():
224
+ with gr.Row():
225
+ gr.Markdown("### Phonetised text")
226
+ phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
227
+
228
+ with gr.Box():
229
+ with gr.Row():
230
+ mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
231
+
232
+ # with gr.Row():
233
+ audio = gr.Audio(interactive=False, label="Audio")
234
+
235
+ with gr.Row(visible=False) as example_row_lj_speech:
236
+ examples = gr.Examples( # pylint: disable=unused-variable
237
+ examples=[
238
+ [
239
+ "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
240
+ 50,
241
+ 0.677,
242
+ 0.95,
243
+ ],
244
+ [
245
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
246
+ 2,
247
+ 0.677,
248
+ 0.95,
249
+ ],
250
+ [
251
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
252
+ 4,
253
+ 0.677,
254
+ 0.95,
255
+ ],
256
+ [
257
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
258
+ 10,
259
+ 0.677,
260
+ 0.95,
261
+ ],
262
+ [
263
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
264
+ 50,
265
+ 0.677,
266
+ 0.95,
267
+ ],
268
+ [
269
+ "The narrative of these events is based largely on the recollections of the participants.",
270
+ 10,
271
+ 0.677,
272
+ 0.95,
273
+ ],
274
+ [
275
+ "The jury did not believe him, and the verdict was for the defendants.",
276
+ 10,
277
+ 0.677,
278
+ 0.95,
279
+ ],
280
+ ],
281
+ fn=ljspeech_example_cacher,
282
+ inputs=[text, n_timesteps, mel_temp, length_scale],
283
+ outputs=[phonetised_text, audio, mel_spectrogram],
284
+ cache_examples=True,
285
+ )
286
+
287
+ with gr.Row() as example_row_multispeaker:
288
+ multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
289
+ examples=[
290
+ [
291
+ "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
292
+ 10,
293
+ 0.677,
294
+ 0.85,
295
+ 0,
296
+ ],
297
+ [
298
+ "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
299
+ 10,
300
+ 0.677,
301
+ 0.85,
302
+ 16,
303
+ ],
304
+ [
305
+ "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
306
+ 50,
307
+ 0.677,
308
+ 0.85,
309
+ 44,
310
+ ],
311
+ [
312
+ "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
313
+ 50,
314
+ 0.677,
315
+ 0.85,
316
+ 45,
317
+ ],
318
+ [
319
+ "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
320
+ 4,
321
+ 0.677,
322
+ 0.85,
323
+ 58,
324
+ ],
325
+ ],
326
+ fn=multispeaker_example_cacher,
327
+ inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
328
+ outputs=[phonetised_text, audio, mel_spectrogram],
329
+ cache_examples=True,
330
+ label="Multi Speaker Examples",
331
+ )
332
+
333
+ model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
334
+ load_model_ui,
335
+ inputs=[model_type, text],
336
+ outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
337
+ )
338
+
339
+ synth_btn.click(
340
+ fn=process_text_gradio,
341
+ inputs=[
342
+ text,
343
+ ],
344
+ outputs=[phonetised_text, processed_text, processed_text_len],
345
+ api_name="matcha_tts",
346
+ queue=True,
347
+ ).then(
348
+ fn=synthesise_mel,
349
+ inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
350
+ outputs=[audio, mel_spectrogram],
351
+ )
352
+
353
+ demo.queue().launch(share=True)
354
+
355
+
356
+ if __name__ == "__main__":
357
+ main()
third_party/Matcha-TTS/matcha/hifigan/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2
+
3
+ ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4
+
5
+ In our [paper](https://arxiv.org/abs/2010.05646),
6
+ we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
7
+ We provide our implementation and pretrained models as open source in this repository.
8
+
9
+ **Abstract :**
10
+ Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11
+ Although such methods improve the sampling efficiency and memory usage,
12
+ their sample quality has not yet reached that of autoregressive and flow-based generative models.
13
+ In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14
+ As speech audio consists of sinusoidal signals with various periods,
15
+ we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16
+ A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17
+ demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18
+ real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19
+ speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20
+ faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21
+
22
+ Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23
+
24
+ ## Pre-requisites
25
+
26
+ 1. Python >= 3.6
27
+ 2. Clone this repository.
28
+ 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29
+ 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30
+ And move all wav files to `LJSpeech-1.1/wavs`
31
+
32
+ ## Training
33
+
34
+ ```
35
+ python train.py --config config_v1.json
36
+ ```
37
+
38
+ To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
39
+ Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
40
+ You can change the path by adding `--checkpoint_path` option.
41
+
42
+ Validation loss during training with V1 generator.<br>
43
+ ![validation loss](./validation_loss.png)
44
+
45
+ ## Pretrained Model
46
+
47
+ You can also use pretrained models we provide.<br/>
48
+ [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
49
+ Details of each folder are as in follows:
50
+
51
+ | Folder Name | Generator | Dataset | Fine-Tuned |
52
+ | ------------ | --------- | --------- | ------------------------------------------------------ |
53
+ | LJ_V1 | V1 | LJSpeech | No |
54
+ | LJ_V2 | V2 | LJSpeech | No |
55
+ | LJ_V3 | V3 | LJSpeech | No |
56
+ | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
57
+ | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
58
+ | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
59
+ | VCTK_V1 | V1 | VCTK | No |
60
+ | VCTK_V2 | V2 | VCTK | No |
61
+ | VCTK_V3 | V3 | VCTK | No |
62
+ | UNIVERSAL_V1 | V1 | Universal | No |
63
+
64
+ We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
65
+
66
+ ## Fine-Tuning
67
+
68
+ 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
69
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
70
+ Example:
71
+ ` Audio File : LJ001-0001.wav
72
+ Mel-Spectrogram File : LJ001-0001.npy`
73
+ 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
74
+ 3. Run the following command.
75
+ ```
76
+ python train.py --fine_tuning True --config config_v1.json
77
+ ```
78
+ For other command line options, please refer to the training section.
79
+
80
+ ## Inference from wav file
81
+
82
+ 1. Make `test_files` directory and copy wav files into the directory.
83
+ 2. Run the following command.
84
+ ` python inference.py --checkpoint_file [generator checkpoint file path]`
85
+ Generated wav files are saved in `generated_files` by default.<br>
86
+ You can change the path by adding `--output_dir` option.
87
+
88
+ ## Inference for end-to-end speech synthesis
89
+
90
+ 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
91
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
92
+ [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
93
+ 2. Run the following command.
94
+ ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
95
+ Generated wav files are saved in `generated_files_from_mel` by default.<br>
96
+ You can change the path by adding `--output_dir` option.
97
+
98
+ ## Acknowledgements
99
+
100
+ We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
101
+ and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
third_party/Matcha-TTS/matcha/hifigan/__init__.py ADDED
File without changes
third_party/Matcha-TTS/matcha/hifigan/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ v1 = {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0004,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+ "upsample_rates": [8, 8, 2, 2],
11
+ "upsample_kernel_sizes": [16, 16, 4, 4],
12
+ "upsample_initial_channel": 512,
13
+ "resblock_kernel_sizes": [3, 7, 11],
14
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15
+ "resblock_initial_channel": 256,
16
+ "segment_size": 8192,
17
+ "num_mels": 80,
18
+ "num_freq": 1025,
19
+ "n_fft": 1024,
20
+ "hop_size": 256,
21
+ "win_size": 1024,
22
+ "sampling_rate": 22050,
23
+ "fmin": 0,
24
+ "fmax": 8000,
25
+ "fmax_loss": None,
26
+ "num_workers": 4,
27
+ "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
28
+ }
third_party/Matcha-TTS/matcha/hifigan/denoiser.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2
+
3
+ """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4
+ import torch
5
+
6
+
7
+ class Denoiser(torch.nn.Module):
8
+ """Removes model bias from audio produced with waveglow"""
9
+
10
+ def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
11
+ super().__init__()
12
+ self.filter_length = filter_length
13
+ self.hop_length = int(filter_length / n_overlap)
14
+ self.win_length = win_length
15
+
16
+ dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
17
+ self.device = device
18
+ if mode == "zeros":
19
+ mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
20
+ elif mode == "normal":
21
+ mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
22
+ else:
23
+ raise Exception(f"Mode {mode} if not supported")
24
+
25
+ def stft_fn(audio, n_fft, hop_length, win_length, window):
26
+ spec = torch.stft(
27
+ audio,
28
+ n_fft=n_fft,
29
+ hop_length=hop_length,
30
+ win_length=win_length,
31
+ window=window,
32
+ return_complex=True,
33
+ )
34
+ spec = torch.view_as_real(spec)
35
+ return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
36
+
37
+ self.stft = lambda x: stft_fn(
38
+ audio=x,
39
+ n_fft=self.filter_length,
40
+ hop_length=self.hop_length,
41
+ win_length=self.win_length,
42
+ window=torch.hann_window(self.win_length, device=device),
43
+ )
44
+ self.istft = lambda x, y: torch.istft(
45
+ torch.complex(x * torch.cos(y), x * torch.sin(y)),
46
+ n_fft=self.filter_length,
47
+ hop_length=self.hop_length,
48
+ win_length=self.win_length,
49
+ window=torch.hann_window(self.win_length, device=device),
50
+ )
51
+
52
+ with torch.no_grad():
53
+ bias_audio = vocoder(mel_input).float().squeeze(0)
54
+ bias_spec, _ = self.stft(bias_audio)
55
+
56
+ self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
57
+
58
+ @torch.inference_mode()
59
+ def forward(self, audio, strength=0.0005):
60
+ audio_spec, audio_angles = self.stft(audio)
61
+ audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
62
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
63
+ audio_denoised = self.istft(audio_spec_denoised, audio_angles)
64
+ return audio_denoised
third_party/Matcha-TTS/matcha/hifigan/env.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import os
4
+ import shutil
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def build_env(config, config_name, path):
14
+ t_path = os.path.join(path, config_name)
15
+ if config != t_path:
16
+ os.makedirs(path, exist_ok=True)
17
+ shutil.copyfile(config, os.path.join(path, config_name))
third_party/Matcha-TTS/matcha/hifigan/models.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+
9
+ from .xutils import get_padding, init_weights
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ class ResBlock1(torch.nn.Module):
15
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
+ super().__init__()
17
+ self.h = h
18
+ self.convs1 = nn.ModuleList(
19
+ [
20
+ weight_norm(
21
+ Conv1d(
22
+ channels,
23
+ channels,
24
+ kernel_size,
25
+ 1,
26
+ dilation=dilation[0],
27
+ padding=get_padding(kernel_size, dilation[0]),
28
+ )
29
+ ),
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=dilation[1],
37
+ padding=get_padding(kernel_size, dilation[1]),
38
+ )
39
+ ),
40
+ weight_norm(
41
+ Conv1d(
42
+ channels,
43
+ channels,
44
+ kernel_size,
45
+ 1,
46
+ dilation=dilation[2],
47
+ padding=get_padding(kernel_size, dilation[2]),
48
+ )
49
+ ),
50
+ ]
51
+ )
52
+ self.convs1.apply(init_weights)
53
+
54
+ self.convs2 = nn.ModuleList(
55
+ [
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=1,
63
+ padding=get_padding(kernel_size, 1),
64
+ )
65
+ ),
66
+ weight_norm(
67
+ Conv1d(
68
+ channels,
69
+ channels,
70
+ kernel_size,
71
+ 1,
72
+ dilation=1,
73
+ padding=get_padding(kernel_size, 1),
74
+ )
75
+ ),
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ 1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ ),
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ def forward(self, x):
91
+ for c1, c2 in zip(self.convs1, self.convs2):
92
+ xt = F.leaky_relu(x, LRELU_SLOPE)
93
+ xt = c1(xt)
94
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
95
+ xt = c2(xt)
96
+ x = xt + x
97
+ return x
98
+
99
+ def remove_weight_norm(self):
100
+ for l in self.convs1:
101
+ remove_weight_norm(l)
102
+ for l in self.convs2:
103
+ remove_weight_norm(l)
104
+
105
+
106
+ class ResBlock2(torch.nn.Module):
107
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
108
+ super().__init__()
109
+ self.h = h
110
+ self.convs = nn.ModuleList(
111
+ [
112
+ weight_norm(
113
+ Conv1d(
114
+ channels,
115
+ channels,
116
+ kernel_size,
117
+ 1,
118
+ dilation=dilation[0],
119
+ padding=get_padding(kernel_size, dilation[0]),
120
+ )
121
+ ),
122
+ weight_norm(
123
+ Conv1d(
124
+ channels,
125
+ channels,
126
+ kernel_size,
127
+ 1,
128
+ dilation=dilation[1],
129
+ padding=get_padding(kernel_size, dilation[1]),
130
+ )
131
+ ),
132
+ ]
133
+ )
134
+ self.convs.apply(init_weights)
135
+
136
+ def forward(self, x):
137
+ for c in self.convs:
138
+ xt = F.leaky_relu(x, LRELU_SLOPE)
139
+ xt = c(xt)
140
+ x = xt + x
141
+ return x
142
+
143
+ def remove_weight_norm(self):
144
+ for l in self.convs:
145
+ remove_weight_norm(l)
146
+
147
+
148
+ class Generator(torch.nn.Module):
149
+ def __init__(self, h):
150
+ super().__init__()
151
+ self.h = h
152
+ self.num_kernels = len(h.resblock_kernel_sizes)
153
+ self.num_upsamples = len(h.upsample_rates)
154
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
155
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
156
+
157
+ self.ups = nn.ModuleList()
158
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
159
+ self.ups.append(
160
+ weight_norm(
161
+ ConvTranspose1d(
162
+ h.upsample_initial_channel // (2**i),
163
+ h.upsample_initial_channel // (2 ** (i + 1)),
164
+ k,
165
+ u,
166
+ padding=(k - u) // 2,
167
+ )
168
+ )
169
+ )
170
+
171
+ self.resblocks = nn.ModuleList()
172
+ for i in range(len(self.ups)):
173
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
174
+ for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
175
+ self.resblocks.append(resblock(h, ch, k, d))
176
+
177
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
178
+ self.ups.apply(init_weights)
179
+ self.conv_post.apply(init_weights)
180
+
181
+ def forward(self, x):
182
+ x = self.conv_pre(x)
183
+ for i in range(self.num_upsamples):
184
+ x = F.leaky_relu(x, LRELU_SLOPE)
185
+ x = self.ups[i](x)
186
+ xs = None
187
+ for j in range(self.num_kernels):
188
+ if xs is None:
189
+ xs = self.resblocks[i * self.num_kernels + j](x)
190
+ else:
191
+ xs += self.resblocks[i * self.num_kernels + j](x)
192
+ x = xs / self.num_kernels
193
+ x = F.leaky_relu(x)
194
+ x = self.conv_post(x)
195
+ x = torch.tanh(x)
196
+
197
+ return x
198
+
199
+ def remove_weight_norm(self):
200
+ print("Removing weight norm...")
201
+ for l in self.ups:
202
+ remove_weight_norm(l)
203
+ for l in self.resblocks:
204
+ l.remove_weight_norm()
205
+ remove_weight_norm(self.conv_pre)
206
+ remove_weight_norm(self.conv_post)
207
+
208
+
209
+ class DiscriminatorP(torch.nn.Module):
210
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
211
+ super().__init__()
212
+ self.period = period
213
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
214
+ self.convs = nn.ModuleList(
215
+ [
216
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
217
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
218
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
219
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
220
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
221
+ ]
222
+ )
223
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
224
+
225
+ def forward(self, x):
226
+ fmap = []
227
+
228
+ # 1d to 2d
229
+ b, c, t = x.shape
230
+ if t % self.period != 0: # pad first
231
+ n_pad = self.period - (t % self.period)
232
+ x = F.pad(x, (0, n_pad), "reflect")
233
+ t = t + n_pad
234
+ x = x.view(b, c, t // self.period, self.period)
235
+
236
+ for l in self.convs:
237
+ x = l(x)
238
+ x = F.leaky_relu(x, LRELU_SLOPE)
239
+ fmap.append(x)
240
+ x = self.conv_post(x)
241
+ fmap.append(x)
242
+ x = torch.flatten(x, 1, -1)
243
+
244
+ return x, fmap
245
+
246
+
247
+ class MultiPeriodDiscriminator(torch.nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.discriminators = nn.ModuleList(
251
+ [
252
+ DiscriminatorP(2),
253
+ DiscriminatorP(3),
254
+ DiscriminatorP(5),
255
+ DiscriminatorP(7),
256
+ DiscriminatorP(11),
257
+ ]
258
+ )
259
+
260
+ def forward(self, y, y_hat):
261
+ y_d_rs = []
262
+ y_d_gs = []
263
+ fmap_rs = []
264
+ fmap_gs = []
265
+ for _, d in enumerate(self.discriminators):
266
+ y_d_r, fmap_r = d(y)
267
+ y_d_g, fmap_g = d(y_hat)
268
+ y_d_rs.append(y_d_r)
269
+ fmap_rs.append(fmap_r)
270
+ y_d_gs.append(y_d_g)
271
+ fmap_gs.append(fmap_g)
272
+
273
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
+
275
+
276
+ class DiscriminatorS(torch.nn.Module):
277
+ def __init__(self, use_spectral_norm=False):
278
+ super().__init__()
279
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
280
+ self.convs = nn.ModuleList(
281
+ [
282
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
283
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
284
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
285
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
286
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
287
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
288
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
289
+ ]
290
+ )
291
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
292
+
293
+ def forward(self, x):
294
+ fmap = []
295
+ for l in self.convs:
296
+ x = l(x)
297
+ x = F.leaky_relu(x, LRELU_SLOPE)
298
+ fmap.append(x)
299
+ x = self.conv_post(x)
300
+ fmap.append(x)
301
+ x = torch.flatten(x, 1, -1)
302
+
303
+ return x, fmap
304
+
305
+
306
+ class MultiScaleDiscriminator(torch.nn.Module):
307
+ def __init__(self):
308
+ super().__init__()
309
+ self.discriminators = nn.ModuleList(
310
+ [
311
+ DiscriminatorS(use_spectral_norm=True),
312
+ DiscriminatorS(),
313
+ DiscriminatorS(),
314
+ ]
315
+ )
316
+ self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
317
+
318
+ def forward(self, y, y_hat):
319
+ y_d_rs = []
320
+ y_d_gs = []
321
+ fmap_rs = []
322
+ fmap_gs = []
323
+ for i, d in enumerate(self.discriminators):
324
+ if i != 0:
325
+ y = self.meanpools[i - 1](y)
326
+ y_hat = self.meanpools[i - 1](y_hat)
327
+ y_d_r, fmap_r = d(y)
328
+ y_d_g, fmap_g = d(y_hat)
329
+ y_d_rs.append(y_d_r)
330
+ fmap_rs.append(fmap_r)
331
+ y_d_gs.append(y_d_g)
332
+ fmap_gs.append(fmap_g)
333
+
334
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
+
336
+
337
+ def feature_loss(fmap_r, fmap_g):
338
+ loss = 0
339
+ for dr, dg in zip(fmap_r, fmap_g):
340
+ for rl, gl in zip(dr, dg):
341
+ loss += torch.mean(torch.abs(rl - gl))
342
+
343
+ return loss * 2
344
+
345
+
346
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
347
+ loss = 0
348
+ r_losses = []
349
+ g_losses = []
350
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
351
+ r_loss = torch.mean((1 - dr) ** 2)
352
+ g_loss = torch.mean(dg**2)
353
+ loss += r_loss + g_loss
354
+ r_losses.append(r_loss.item())
355
+ g_losses.append(g_loss.item())
356
+
357
+ return loss, r_losses, g_losses
358
+
359
+
360
+ def generator_loss(disc_outputs):
361
+ loss = 0
362
+ gen_losses = []
363
+ for dg in disc_outputs:
364
+ l = torch.mean((1 - dg) ** 2)
365
+ gen_losses.append(l)
366
+ loss += l
367
+
368
+ return loss, gen_losses
third_party/Matcha-TTS/matcha/models/__init__.py ADDED
File without changes
third_party/Matcha-TTS/matcha/models/baselightningmodule.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a base lightning module that can be used to train a model.
3
+ The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
4
+ """
5
+ import inspect
6
+ from abc import ABC
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ from lightning import LightningModule
11
+ from lightning.pytorch.utilities import grad_norm
12
+
13
+ from matcha import utils
14
+ from matcha.utils.utils import plot_tensor
15
+
16
+ log = utils.get_pylogger(__name__)
17
+
18
+
19
+ class BaseLightningClass(LightningModule, ABC):
20
+ def update_data_statistics(self, data_statistics):
21
+ if data_statistics is None:
22
+ data_statistics = {
23
+ "mel_mean": 0.0,
24
+ "mel_std": 1.0,
25
+ }
26
+
27
+ self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
28
+ self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
29
+
30
+ def configure_optimizers(self) -> Any:
31
+ optimizer = self.hparams.optimizer(params=self.parameters())
32
+ if self.hparams.scheduler not in (None, {}):
33
+ scheduler_args = {}
34
+ # Manage last epoch for exponential schedulers
35
+ if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
36
+ if hasattr(self, "ckpt_loaded_epoch"):
37
+ current_epoch = self.ckpt_loaded_epoch - 1
38
+ else:
39
+ current_epoch = -1
40
+
41
+ scheduler_args.update({"optimizer": optimizer})
42
+ scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
43
+ scheduler.last_epoch = current_epoch
44
+ return {
45
+ "optimizer": optimizer,
46
+ "lr_scheduler": {
47
+ "scheduler": scheduler,
48
+ "interval": self.hparams.scheduler.lightning_args.interval,
49
+ "frequency": self.hparams.scheduler.lightning_args.frequency,
50
+ "name": "learning_rate",
51
+ },
52
+ }
53
+
54
+ return {"optimizer": optimizer}
55
+
56
+ def get_losses(self, batch):
57
+ x, x_lengths = batch["x"], batch["x_lengths"]
58
+ y, y_lengths = batch["y"], batch["y_lengths"]
59
+ spks = batch["spks"]
60
+
61
+ dur_loss, prior_loss, diff_loss = self(
62
+ x=x,
63
+ x_lengths=x_lengths,
64
+ y=y,
65
+ y_lengths=y_lengths,
66
+ spks=spks,
67
+ out_size=self.out_size,
68
+ )
69
+ return {
70
+ "dur_loss": dur_loss,
71
+ "prior_loss": prior_loss,
72
+ "diff_loss": diff_loss,
73
+ }
74
+
75
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
76
+ self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
77
+
78
+ def training_step(self, batch: Any, batch_idx: int):
79
+ loss_dict = self.get_losses(batch)
80
+ self.log(
81
+ "step",
82
+ float(self.global_step),
83
+ on_step=True,
84
+ prog_bar=True,
85
+ logger=True,
86
+ sync_dist=True,
87
+ )
88
+
89
+ self.log(
90
+ "sub_loss/train_dur_loss",
91
+ loss_dict["dur_loss"],
92
+ on_step=True,
93
+ on_epoch=True,
94
+ logger=True,
95
+ sync_dist=True,
96
+ )
97
+ self.log(
98
+ "sub_loss/train_prior_loss",
99
+ loss_dict["prior_loss"],
100
+ on_step=True,
101
+ on_epoch=True,
102
+ logger=True,
103
+ sync_dist=True,
104
+ )
105
+ self.log(
106
+ "sub_loss/train_diff_loss",
107
+ loss_dict["diff_loss"],
108
+ on_step=True,
109
+ on_epoch=True,
110
+ logger=True,
111
+ sync_dist=True,
112
+ )
113
+
114
+ total_loss = sum(loss_dict.values())
115
+ self.log(
116
+ "loss/train",
117
+ total_loss,
118
+ on_step=True,
119
+ on_epoch=True,
120
+ logger=True,
121
+ prog_bar=True,
122
+ sync_dist=True,
123
+ )
124
+
125
+ return {"loss": total_loss, "log": loss_dict}
126
+
127
+ def validation_step(self, batch: Any, batch_idx: int):
128
+ loss_dict = self.get_losses(batch)
129
+ self.log(
130
+ "sub_loss/val_dur_loss",
131
+ loss_dict["dur_loss"],
132
+ on_step=True,
133
+ on_epoch=True,
134
+ logger=True,
135
+ sync_dist=True,
136
+ )
137
+ self.log(
138
+ "sub_loss/val_prior_loss",
139
+ loss_dict["prior_loss"],
140
+ on_step=True,
141
+ on_epoch=True,
142
+ logger=True,
143
+ sync_dist=True,
144
+ )
145
+ self.log(
146
+ "sub_loss/val_diff_loss",
147
+ loss_dict["diff_loss"],
148
+ on_step=True,
149
+ on_epoch=True,
150
+ logger=True,
151
+ sync_dist=True,
152
+ )
153
+
154
+ total_loss = sum(loss_dict.values())
155
+ self.log(
156
+ "loss/val",
157
+ total_loss,
158
+ on_step=True,
159
+ on_epoch=True,
160
+ logger=True,
161
+ prog_bar=True,
162
+ sync_dist=True,
163
+ )
164
+
165
+ return total_loss
166
+
167
+ def on_validation_end(self) -> None:
168
+ if self.trainer.is_global_zero:
169
+ one_batch = next(iter(self.trainer.val_dataloaders))
170
+ if self.current_epoch == 0:
171
+ log.debug("Plotting original samples")
172
+ for i in range(2):
173
+ y = one_batch["y"][i].unsqueeze(0).to(self.device)
174
+ self.logger.experiment.add_image(
175
+ f"original/{i}",
176
+ plot_tensor(y.squeeze().cpu()),
177
+ self.current_epoch,
178
+ dataformats="HWC",
179
+ )
180
+
181
+ log.debug("Synthesising...")
182
+ for i in range(2):
183
+ x = one_batch["x"][i].unsqueeze(0).to(self.device)
184
+ x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
185
+ spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
186
+ output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
187
+ y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
188
+ attn = output["attn"]
189
+ self.logger.experiment.add_image(
190
+ f"generated_enc/{i}",
191
+ plot_tensor(y_enc.squeeze().cpu()),
192
+ self.current_epoch,
193
+ dataformats="HWC",
194
+ )
195
+ self.logger.experiment.add_image(
196
+ f"generated_dec/{i}",
197
+ plot_tensor(y_dec.squeeze().cpu()),
198
+ self.current_epoch,
199
+ dataformats="HWC",
200
+ )
201
+ self.logger.experiment.add_image(
202
+ f"alignment/{i}",
203
+ plot_tensor(attn.squeeze().cpu()),
204
+ self.current_epoch,
205
+ dataformats="HWC",
206
+ )
207
+
208
+ def on_before_optimizer_step(self, optimizer):
209
+ self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
third_party/Matcha-TTS/matcha/models/components/__init__.py ADDED
File without changes
third_party/Matcha-TTS/matcha/models/components/flow_matching.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from matcha.models.components.decoder import Decoder
7
+ from matcha.utils.pylogger import get_pylogger
8
+
9
+ log = get_pylogger(__name__)
10
+
11
+
12
+ class BASECFM(torch.nn.Module, ABC):
13
+ def __init__(
14
+ self,
15
+ n_feats,
16
+ cfm_params,
17
+ n_spks=1,
18
+ spk_emb_dim=128,
19
+ ):
20
+ super().__init__()
21
+ self.n_feats = n_feats
22
+ self.n_spks = n_spks
23
+ self.spk_emb_dim = spk_emb_dim
24
+ self.solver = cfm_params.solver
25
+ if hasattr(cfm_params, "sigma_min"):
26
+ self.sigma_min = cfm_params.sigma_min
27
+ else:
28
+ self.sigma_min = 1e-4
29
+
30
+ self.estimator = None
31
+
32
+ @torch.inference_mode()
33
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
34
+ """Forward diffusion
35
+
36
+ Args:
37
+ mu (torch.Tensor): output of encoder
38
+ shape: (batch_size, n_feats, mel_timesteps)
39
+ mask (torch.Tensor): output_mask
40
+ shape: (batch_size, 1, mel_timesteps)
41
+ n_timesteps (int): number of diffusion steps
42
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
43
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
44
+ shape: (batch_size, spk_emb_dim)
45
+ cond: Not used but kept for future purposes
46
+
47
+ Returns:
48
+ sample: generated mel-spectrogram
49
+ shape: (batch_size, n_feats, mel_timesteps)
50
+ """
51
+ z = torch.randn_like(mu) * temperature
52
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
53
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
54
+
55
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
56
+ """
57
+ Fixed euler solver for ODEs.
58
+ Args:
59
+ x (torch.Tensor): random noise
60
+ t_span (torch.Tensor): n_timesteps interpolated
61
+ shape: (n_timesteps + 1,)
62
+ mu (torch.Tensor): output of encoder
63
+ shape: (batch_size, n_feats, mel_timesteps)
64
+ mask (torch.Tensor): output_mask
65
+ shape: (batch_size, 1, mel_timesteps)
66
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
67
+ shape: (batch_size, spk_emb_dim)
68
+ cond: Not used but kept for future purposes
69
+ """
70
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
71
+
72
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
73
+ # Or in future might add like a return_all_steps flag
74
+ sol = []
75
+
76
+ for step in range(1, len(t_span)):
77
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
78
+
79
+ x = x + dt * dphi_dt
80
+ t = t + dt
81
+ sol.append(x)
82
+ if step < len(t_span) - 1:
83
+ dt = t_span[step + 1] - t
84
+
85
+ return sol[-1]
86
+
87
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
88
+ """Computes diffusion loss
89
+
90
+ Args:
91
+ x1 (torch.Tensor): Target
92
+ shape: (batch_size, n_feats, mel_timesteps)
93
+ mask (torch.Tensor): target mask
94
+ shape: (batch_size, 1, mel_timesteps)
95
+ mu (torch.Tensor): output of encoder
96
+ shape: (batch_size, n_feats, mel_timesteps)
97
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
98
+ shape: (batch_size, spk_emb_dim)
99
+
100
+ Returns:
101
+ loss: conditional flow matching loss
102
+ y: conditional flow
103
+ shape: (batch_size, n_feats, mel_timesteps)
104
+ """
105
+ b, _, t = mu.shape
106
+
107
+ # random timestep
108
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
109
+ # sample noise p(x_0)
110
+ z = torch.randn_like(x1)
111
+
112
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
113
+ u = x1 - (1 - self.sigma_min) * z
114
+
115
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
116
+ torch.sum(mask) * u.shape[1]
117
+ )
118
+ return loss, y
119
+
120
+
121
+ class CFM(BASECFM):
122
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
123
+ super().__init__(
124
+ n_feats=in_channels,
125
+ cfm_params=cfm_params,
126
+ n_spks=n_spks,
127
+ spk_emb_dim=spk_emb_dim,
128
+ )
129
+
130
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
131
+ # Just change the architecture of the estimator here
132
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
third_party/Matcha-TTS/matcha/models/matcha_tts.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import math
3
+ import random
4
+
5
+ import torch
6
+
7
+ import matcha.utils.monotonic_align as monotonic_align
8
+ from matcha import utils
9
+ from matcha.models.baselightningmodule import BaseLightningClass
10
+ from matcha.models.components.flow_matching import CFM
11
+ from matcha.models.components.text_encoder import TextEncoder
12
+ from matcha.utils.model import (
13
+ denormalize,
14
+ duration_loss,
15
+ fix_len_compatibility,
16
+ generate_path,
17
+ sequence_mask,
18
+ )
19
+
20
+ log = utils.get_pylogger(__name__)
21
+
22
+
23
+ class MatchaTTS(BaseLightningClass): # 🍵
24
+ def __init__(
25
+ self,
26
+ n_vocab,
27
+ n_spks,
28
+ spk_emb_dim,
29
+ n_feats,
30
+ encoder,
31
+ decoder,
32
+ cfm,
33
+ data_statistics,
34
+ out_size,
35
+ optimizer=None,
36
+ scheduler=None,
37
+ prior_loss=True,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.save_hyperparameters(logger=False)
42
+
43
+ self.n_vocab = n_vocab
44
+ self.n_spks = n_spks
45
+ self.spk_emb_dim = spk_emb_dim
46
+ self.n_feats = n_feats
47
+ self.out_size = out_size
48
+ self.prior_loss = prior_loss
49
+
50
+ if n_spks > 1:
51
+ self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
52
+
53
+ self.encoder = TextEncoder(
54
+ encoder.encoder_type,
55
+ encoder.encoder_params,
56
+ encoder.duration_predictor_params,
57
+ n_vocab,
58
+ n_spks,
59
+ spk_emb_dim,
60
+ )
61
+
62
+ self.decoder = CFM(
63
+ in_channels=2 * encoder.encoder_params.n_feats,
64
+ out_channel=encoder.encoder_params.n_feats,
65
+ cfm_params=cfm,
66
+ decoder_params=decoder,
67
+ n_spks=n_spks,
68
+ spk_emb_dim=spk_emb_dim,
69
+ )
70
+
71
+ self.update_data_statistics(data_statistics)
72
+
73
+ @torch.inference_mode()
74
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
75
+ """
76
+ Generates mel-spectrogram from text. Returns:
77
+ 1. encoder outputs
78
+ 2. decoder outputs
79
+ 3. generated alignment
80
+
81
+ Args:
82
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
83
+ shape: (batch_size, max_text_length)
84
+ x_lengths (torch.Tensor): lengths of texts in batch.
85
+ shape: (batch_size,)
86
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
87
+ temperature (float, optional): controls variance of terminal distribution.
88
+ spks (bool, optional): speaker ids.
89
+ shape: (batch_size,)
90
+ length_scale (float, optional): controls speech pace.
91
+ Increase value to slow down generated speech and vice versa.
92
+
93
+ Returns:
94
+ dict: {
95
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
96
+ # Average mel spectrogram generated by the encoder
97
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
98
+ # Refined mel spectrogram improved by the CFM
99
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
100
+ # Alignment map between text and mel spectrogram
101
+ "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
102
+ # Denormalized mel spectrogram
103
+ "mel_lengths": torch.Tensor, shape: (batch_size,),
104
+ # Lengths of mel spectrograms
105
+ "rtf": float,
106
+ # Real-time factor
107
+ """
108
+ # For RTF computation
109
+ t = dt.datetime.now()
110
+
111
+ if self.n_spks > 1:
112
+ # Get speaker embedding
113
+ spks = self.spk_emb(spks.long())
114
+
115
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
116
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
117
+
118
+ w = torch.exp(logw) * x_mask
119
+ w_ceil = torch.ceil(w) * length_scale
120
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
121
+ y_max_length = y_lengths.max()
122
+ y_max_length_ = fix_len_compatibility(y_max_length)
123
+
124
+ # Using obtained durations `w` construct alignment map `attn`
125
+ y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
126
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
127
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
128
+
129
+ # Align encoded text and get mu_y
130
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
131
+ mu_y = mu_y.transpose(1, 2)
132
+ encoder_outputs = mu_y[:, :, :y_max_length]
133
+
134
+ # Generate sample tracing the probability flow
135
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
136
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
137
+
138
+ t = (dt.datetime.now() - t).total_seconds()
139
+ rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
140
+
141
+ return {
142
+ "encoder_outputs": encoder_outputs,
143
+ "decoder_outputs": decoder_outputs,
144
+ "attn": attn[:, :, :y_max_length],
145
+ "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
146
+ "mel_lengths": y_lengths,
147
+ "rtf": rtf,
148
+ }
149
+
150
+ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None):
151
+ """
152
+ Computes 3 losses:
153
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
154
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
155
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
156
+
157
+ Args:
158
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
159
+ shape: (batch_size, max_text_length)
160
+ x_lengths (torch.Tensor): lengths of texts in batch.
161
+ shape: (batch_size,)
162
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
163
+ shape: (batch_size, n_feats, max_mel_length)
164
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
165
+ shape: (batch_size,)
166
+ out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
167
+ Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
168
+ spks (torch.Tensor, optional): speaker ids.
169
+ shape: (batch_size,)
170
+ """
171
+ if self.n_spks > 1:
172
+ # Get speaker embedding
173
+ spks = self.spk_emb(spks)
174
+
175
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
176
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
177
+ y_max_length = y.shape[-1]
178
+
179
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
180
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
181
+
182
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
183
+ with torch.no_grad():
184
+ const = -0.5 * math.log(2 * math.pi) * self.n_feats
185
+ factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
186
+ y_square = torch.matmul(factor.transpose(1, 2), y**2)
187
+ y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
188
+ mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
189
+ log_prior = y_square - y_mu_double + mu_square + const
190
+
191
+ attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
192
+ attn = attn.detach()
193
+
194
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
195
+ # refered to as prior loss in the paper
196
+ logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
197
+ dur_loss = duration_loss(logw, logw_, x_lengths)
198
+
199
+ # Cut a small segment of mel-spectrogram in order to increase batch size
200
+ # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
201
+ # - Do not need this hack for Matcha-TTS, but it works with it as well
202
+ if not isinstance(out_size, type(None)):
203
+ max_offset = (y_lengths - out_size).clamp(0)
204
+ offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
205
+ out_offset = torch.LongTensor(
206
+ [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
207
+ ).to(y_lengths)
208
+ attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
209
+ y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
210
+
211
+ y_cut_lengths = []
212
+ for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
213
+ y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
214
+ y_cut_lengths.append(y_cut_length)
215
+ cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
216
+ y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
217
+ attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
218
+
219
+ y_cut_lengths = torch.LongTensor(y_cut_lengths)
220
+ y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
221
+
222
+ attn = attn_cut
223
+ y = y_cut
224
+ y_mask = y_cut_mask
225
+
226
+ # Align encoded text with mel-spectrogram and get mu_y segment
227
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
228
+ mu_y = mu_y.transpose(1, 2)
229
+
230
+ # Compute loss of the decoder
231
+ diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
232
+
233
+ if self.prior_loss:
234
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
235
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
236
+ else:
237
+ prior_loss = 0
238
+
239
+ return dur_loss, prior_loss, diff_loss
third_party/Matcha-TTS/matcha/onnx/__init__.py ADDED
File without changes
third_party/Matcha-TTS/matcha/onnx/export.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from lightning import LightningModule
8
+
9
+ from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
10
+
11
+ DEFAULT_OPSET = 15
12
+
13
+ SEED = 1234
14
+ random.seed(SEED)
15
+ np.random.seed(SEED)
16
+ torch.manual_seed(SEED)
17
+ torch.cuda.manual_seed(SEED)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ class MatchaWithVocoder(LightningModule):
23
+ def __init__(self, matcha, vocoder):
24
+ super().__init__()
25
+ self.matcha = matcha
26
+ self.vocoder = vocoder
27
+
28
+ def forward(self, x, x_lengths, scales, spks=None):
29
+ mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
30
+ wavs = self.vocoder(mel).clamp(-1, 1)
31
+ lengths = mel_lengths * 256
32
+ return wavs.squeeze(1), lengths
33
+
34
+
35
+ def get_exportable_module(matcha, vocoder, n_timesteps):
36
+ """
37
+ Return an appropriate `LighteningModule` and output-node names
38
+ based on whether the vocoder is embedded in the final graph
39
+ """
40
+
41
+ def onnx_forward_func(x, x_lengths, scales, spks=None):
42
+ """
43
+ Custom forward function for accepting
44
+ scaler parameters as tensors
45
+ """
46
+ # Extract scaler parameters from tensors
47
+ temperature = scales[0]
48
+ length_scale = scales[1]
49
+ output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
50
+ return output["mel"], output["mel_lengths"]
51
+
52
+ # Monkey-patch Matcha's forward function
53
+ matcha.forward = onnx_forward_func
54
+
55
+ if vocoder is None:
56
+ model, output_names = matcha, ["mel", "mel_lengths"]
57
+ else:
58
+ model = MatchaWithVocoder(matcha, vocoder)
59
+ output_names = ["wav", "wav_lengths"]
60
+ return model, output_names
61
+
62
+
63
+ def get_inputs(is_multi_speaker):
64
+ """
65
+ Create dummy inputs for tracing
66
+ """
67
+ dummy_input_length = 50
68
+ x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
69
+ x_lengths = torch.LongTensor([dummy_input_length])
70
+
71
+ # Scales
72
+ temperature = 0.667
73
+ length_scale = 1.0
74
+ scales = torch.Tensor([temperature, length_scale])
75
+
76
+ model_inputs = [x, x_lengths, scales]
77
+ input_names = [
78
+ "x",
79
+ "x_lengths",
80
+ "scales",
81
+ ]
82
+
83
+ if is_multi_speaker:
84
+ spks = torch.LongTensor([1])
85
+ model_inputs.append(spks)
86
+ input_names.append("spks")
87
+
88
+ return tuple(model_inputs), input_names
89
+
90
+
91
+ def main():
92
+ parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
93
+
94
+ parser.add_argument(
95
+ "checkpoint_path",
96
+ type=str,
97
+ help="Path to the model checkpoint",
98
+ )
99
+ parser.add_argument("output", type=str, help="Path to output `.onnx` file")
100
+ parser.add_argument(
101
+ "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
102
+ )
103
+ parser.add_argument(
104
+ "--vocoder-name",
105
+ type=str,
106
+ choices=list(VOCODER_URLS.keys()),
107
+ default=None,
108
+ help="Name of the vocoder to embed in the ONNX graph",
109
+ )
110
+ parser.add_argument(
111
+ "--vocoder-checkpoint-path",
112
+ type=str,
113
+ default=None,
114
+ help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
115
+ )
116
+ parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
117
+
118
+ args = parser.parse_args()
119
+
120
+ print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
121
+ print(f"Setting n_timesteps to {args.n_timesteps}")
122
+
123
+ checkpoint_path = Path(args.checkpoint_path)
124
+ matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
125
+
126
+ if args.vocoder_name or args.vocoder_checkpoint_path:
127
+ assert (
128
+ args.vocoder_name and args.vocoder_checkpoint_path
129
+ ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
130
+ vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
131
+ else:
132
+ vocoder = None
133
+
134
+ is_multi_speaker = matcha.n_spks > 1
135
+
136
+ dummy_input, input_names = get_inputs(is_multi_speaker)
137
+ model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
138
+
139
+ # Set dynamic shape for inputs/outputs
140
+ dynamic_axes = {
141
+ "x": {0: "batch_size", 1: "time"},
142
+ "x_lengths": {0: "batch_size"},
143
+ }
144
+
145
+ if vocoder is None:
146
+ dynamic_axes.update(
147
+ {
148
+ "mel": {0: "batch_size", 2: "time"},
149
+ "mel_lengths": {0: "batch_size"},
150
+ }
151
+ )
152
+ else:
153
+ print("Embedding the vocoder in the ONNX graph")
154
+ dynamic_axes.update(
155
+ {
156
+ "wav": {0: "batch_size", 1: "time"},
157
+ "wav_lengths": {0: "batch_size"},
158
+ }
159
+ )
160
+
161
+ if is_multi_speaker:
162
+ dynamic_axes["spks"] = {0: "batch_size"}
163
+
164
+ # Create the output directory (if not exists)
165
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
166
+
167
+ model.to_onnx(
168
+ args.output,
169
+ dummy_input,
170
+ input_names=input_names,
171
+ output_names=output_names,
172
+ dynamic_axes=dynamic_axes,
173
+ opset_version=args.opset,
174
+ export_params=True,
175
+ do_constant_folding=True,
176
+ )
177
+ print(f"[🍵] ONNX model exported to {args.output}")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
third_party/Matcha-TTS/matcha/text/cleaners.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron
2
+
3
+ Cleaners are transformations that run over the input text at both training and eval time.
4
+
5
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7
+ 1. "english_cleaners" for English text
8
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11
+ the symbols in symbols.py to match your data).
12
+ """
13
+
14
+ import logging
15
+ import re
16
+
17
+ import phonemizer
18
+ import piper_phonemize
19
+ from unidecode import unidecode
20
+
21
+ # To avoid excessive logging we set the log level of the phonemizer package to Critical
22
+ critical_logger = logging.getLogger("phonemizer")
23
+ critical_logger.setLevel(logging.CRITICAL)
24
+
25
+ # Intializing the phonemizer globally significantly reduces the speed
26
+ # now the phonemizer is not initialising at every call
27
+ # Might be less flexible, but it is much-much faster
28
+ global_phonemizer = phonemizer.backend.EspeakBackend(
29
+ language="en-us",
30
+ preserve_punctuation=True,
31
+ with_stress=True,
32
+ language_switch="remove-flags",
33
+ logger=critical_logger,
34
+ )
35
+
36
+
37
+ # Regular expression matching whitespace:
38
+ _whitespace_re = re.compile(r"\s+")
39
+
40
+ # List of (regular expression, replacement) pairs for abbreviations:
41
+ _abbreviations = [
42
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
43
+ for x in [
44
+ ("mrs", "misess"),
45
+ ("mr", "mister"),
46
+ ("dr", "doctor"),
47
+ ("st", "saint"),
48
+ ("co", "company"),
49
+ ("jr", "junior"),
50
+ ("maj", "major"),
51
+ ("gen", "general"),
52
+ ("drs", "doctors"),
53
+ ("rev", "reverend"),
54
+ ("lt", "lieutenant"),
55
+ ("hon", "honorable"),
56
+ ("sgt", "sergeant"),
57
+ ("capt", "captain"),
58
+ ("esq", "esquire"),
59
+ ("ltd", "limited"),
60
+ ("col", "colonel"),
61
+ ("ft", "fort"),
62
+ ]
63
+ ]
64
+
65
+
66
+ def expand_abbreviations(text):
67
+ for regex, replacement in _abbreviations:
68
+ text = re.sub(regex, replacement, text)
69
+ return text
70
+
71
+
72
+ def lowercase(text):
73
+ return text.lower()
74
+
75
+
76
+ def collapse_whitespace(text):
77
+ return re.sub(_whitespace_re, " ", text)
78
+
79
+
80
+ def convert_to_ascii(text):
81
+ return unidecode(text)
82
+
83
+
84
+ def basic_cleaners(text):
85
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
86
+ text = lowercase(text)
87
+ text = collapse_whitespace(text)
88
+ return text
89
+
90
+
91
+ def transliteration_cleaners(text):
92
+ """Pipeline for non-English text that transliterates to ASCII."""
93
+ text = convert_to_ascii(text)
94
+ text = lowercase(text)
95
+ text = collapse_whitespace(text)
96
+ return text
97
+
98
+
99
+ def english_cleaners2(text):
100
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
101
+ text = convert_to_ascii(text)
102
+ text = lowercase(text)
103
+ text = expand_abbreviations(text)
104
+ phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
105
+ phonemes = collapse_whitespace(phonemes)
106
+ return phonemes
107
+
108
+
109
+ def english_cleaners_piper(text):
110
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
111
+ text = convert_to_ascii(text)
112
+ text = lowercase(text)
113
+ text = expand_abbreviations(text)
114
+ phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
115
+ phonemes = collapse_whitespace(phonemes)
116
+ return phonemes
third_party/Matcha-TTS/matcha/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
2
+ from matcha.utils.logging_utils import log_hyperparameters
3
+ from matcha.utils.pylogger import get_pylogger
4
+ from matcha.utils.rich_utils import enforce_tags, print_config_tree
5
+ from matcha.utils.utils import extras, get_metric_value, task_wrapper
third_party/Matcha-TTS/matcha/utils/instantiators.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import hydra
4
+ from lightning import Callback
5
+ from lightning.pytorch.loggers import Logger
6
+ from omegaconf import DictConfig
7
+
8
+ from matcha.utils import pylogger
9
+
10
+ log = pylogger.get_pylogger(__name__)
11
+
12
+
13
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14
+ """Instantiates callbacks from config.
15
+
16
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
17
+ :return: A list of instantiated callbacks.
18
+ """
19
+ callbacks: List[Callback] = []
20
+
21
+ if not callbacks_cfg:
22
+ log.warning("No callback configs found! Skipping..")
23
+ return callbacks
24
+
25
+ if not isinstance(callbacks_cfg, DictConfig):
26
+ raise TypeError("Callbacks config must be a DictConfig!")
27
+
28
+ for _, cb_conf in callbacks_cfg.items():
29
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
30
+ log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
31
+ callbacks.append(hydra.utils.instantiate(cb_conf))
32
+
33
+ return callbacks
34
+
35
+
36
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
37
+ """Instantiates loggers from config.
38
+
39
+ :param logger_cfg: A DictConfig object containing logger configurations.
40
+ :return: A list of instantiated loggers.
41
+ """
42
+ logger: List[Logger] = []
43
+
44
+ if not logger_cfg:
45
+ log.warning("No logger configs found! Skipping...")
46
+ return logger
47
+
48
+ if not isinstance(logger_cfg, DictConfig):
49
+ raise TypeError("Logger config must be a DictConfig!")
50
+
51
+ for _, lg_conf in logger_cfg.items():
52
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
53
+ log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
54
+ logger.append(hydra.utils.instantiate(lg_conf))
55
+
56
+ return logger
third_party/Matcha-TTS/matcha/utils/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def sequence_mask(length, max_length=None):
8
+ if max_length is None:
9
+ max_length = length.max()
10
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
+ return x.unsqueeze(0) < length.unsqueeze(1)
12
+
13
+
14
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
15
+ factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
16
+ length = (length / factor).ceil() * factor
17
+ if not torch.onnx.is_in_onnx_export():
18
+ return length.int().item()
19
+ else:
20
+ return length
21
+
22
+
23
+ def convert_pad_shape(pad_shape):
24
+ inverted_shape = pad_shape[::-1]
25
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
26
+ return pad_shape
27
+
28
+
29
+ def generate_path(duration, mask):
30
+ device = duration.device
31
+
32
+ b, t_x, t_y = mask.shape
33
+ cum_duration = torch.cumsum(duration, 1)
34
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
35
+
36
+ cum_duration_flat = cum_duration.view(b * t_x)
37
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
38
+ path = path.view(b, t_x, t_y)
39
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
40
+ path = path * mask
41
+ return path
42
+
43
+
44
+ def duration_loss(logw, logw_, lengths):
45
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
46
+ return loss
47
+
48
+
49
+ def normalize(data, mu, std):
50
+ if not isinstance(mu, (float, int)):
51
+ if isinstance(mu, list):
52
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
53
+ elif isinstance(mu, torch.Tensor):
54
+ mu = mu.to(data.device)
55
+ elif isinstance(mu, np.ndarray):
56
+ mu = torch.from_numpy(mu).to(data.device)
57
+ mu = mu.unsqueeze(-1)
58
+
59
+ if not isinstance(std, (float, int)):
60
+ if isinstance(std, list):
61
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
62
+ elif isinstance(std, torch.Tensor):
63
+ std = std.to(data.device)
64
+ elif isinstance(std, np.ndarray):
65
+ std = torch.from_numpy(std).to(data.device)
66
+ std = std.unsqueeze(-1)
67
+
68
+ return (data - mu) / std
69
+
70
+
71
+ def denormalize(data, mu, std):
72
+ if not isinstance(mu, float):
73
+ if isinstance(mu, list):
74
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
75
+ elif isinstance(mu, torch.Tensor):
76
+ mu = mu.to(data.device)
77
+ elif isinstance(mu, np.ndarray):
78
+ mu = torch.from_numpy(mu).to(data.device)
79
+ mu = mu.unsqueeze(-1)
80
+
81
+ if not isinstance(std, float):
82
+ if isinstance(std, list):
83
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
84
+ elif isinstance(std, torch.Tensor):
85
+ std = std.to(data.device)
86
+ elif isinstance(std, np.ndarray):
87
+ std = torch.from_numpy(std).to(data.device)
88
+ std = std.unsqueeze(-1)
89
+
90
+ return data * std + mu
third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # from distutils.core import setup
2
+ # from Cython.Build import cythonize
3
+ # import numpy
4
+
5
+ # setup(name='monotonic_align',
6
+ # ext_modules=cythonize("core.pyx"),
7
+ # include_dirs=[numpy.get_include()])
third_party/Matcha-TTS/matcha/utils/rich_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Sequence
3
+
4
+ import rich
5
+ import rich.syntax
6
+ import rich.tree
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+ from omegaconf import DictConfig, OmegaConf, open_dict
10
+ from rich.prompt import Prompt
11
+
12
+ from matcha.utils import pylogger
13
+
14
+ log = pylogger.get_pylogger(__name__)
15
+
16
+
17
+ @rank_zero_only
18
+ def print_config_tree(
19
+ cfg: DictConfig,
20
+ print_order: Sequence[str] = (
21
+ "data",
22
+ "model",
23
+ "callbacks",
24
+ "logger",
25
+ "trainer",
26
+ "paths",
27
+ "extras",
28
+ ),
29
+ resolve: bool = False,
30
+ save_to_file: bool = False,
31
+ ) -> None:
32
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
33
+
34
+ :param cfg: A DictConfig composed by Hydra.
35
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
36
+ "callbacks", "logger", "trainer", "paths", "extras")``.
37
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
38
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
39
+ """
40
+ style = "dim"
41
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
42
+
43
+ queue = []
44
+
45
+ # add fields from `print_order` to queue
46
+ for field in print_order:
47
+ _ = (
48
+ queue.append(field)
49
+ if field in cfg
50
+ else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
51
+ )
52
+
53
+ # add all the other fields to queue (not specified in `print_order`)
54
+ for field in cfg:
55
+ if field not in queue:
56
+ queue.append(field)
57
+
58
+ # generate config tree from queue
59
+ for field in queue:
60
+ branch = tree.add(field, style=style, guide_style=style)
61
+
62
+ config_group = cfg[field]
63
+ if isinstance(config_group, DictConfig):
64
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
65
+ else:
66
+ branch_content = str(config_group)
67
+
68
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
69
+
70
+ # print config tree
71
+ rich.print(tree)
72
+
73
+ # save config tree to file
74
+ if save_to_file:
75
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
76
+ rich.print(tree, file=file)
77
+
78
+
79
+ @rank_zero_only
80
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
81
+ """Prompts user to input tags from command line if no tags are provided in config.
82
+
83
+ :param cfg: A DictConfig composed by Hydra.
84
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
85
+ """
86
+ if not cfg.get("tags"):
87
+ if "id" in HydraConfig().cfg.hydra.job:
88
+ raise ValueError("Specify tags before launching a multirun!")
89
+
90
+ log.warning("No tags provided in config. Prompting user to input tags...")
91
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
92
+ tags = [t.strip() for t in tags.split(",") if t != ""]
93
+
94
+ with open_dict(cfg):
95
+ cfg.tags = tags
96
+
97
+ log.info(f"Tags: {cfg.tags}")
98
+
99
+ if save_to_file:
100
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
101
+ rich.print(cfg.tags, file=file)
third_party/Matcha-TTS/matcha/utils/utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from importlib.util import find_spec
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, Tuple
7
+
8
+ import gdown
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+ import wget
13
+ from omegaconf import DictConfig
14
+
15
+ from matcha.utils import pylogger, rich_utils
16
+
17
+ log = pylogger.get_pylogger(__name__)
18
+
19
+
20
+ def extras(cfg: DictConfig) -> None:
21
+ """Applies optional utilities before the task is started.
22
+
23
+ Utilities:
24
+ - Ignoring python warnings
25
+ - Setting tags from command line
26
+ - Rich config printing
27
+
28
+ :param cfg: A DictConfig object containing the config tree.
29
+ """
30
+ # return if no `extras` config
31
+ if not cfg.get("extras"):
32
+ log.warning("Extras config not found! <cfg.extras=null>")
33
+ return
34
+
35
+ # disable python warnings
36
+ if cfg.extras.get("ignore_warnings"):
37
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
38
+ warnings.filterwarnings("ignore")
39
+
40
+ # prompt user to input tags from command line if none are provided in the config
41
+ if cfg.extras.get("enforce_tags"):
42
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
43
+ rich_utils.enforce_tags(cfg, save_to_file=True)
44
+
45
+ # pretty print config tree using Rich library
46
+ if cfg.extras.get("print_config"):
47
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
48
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
49
+
50
+
51
+ def task_wrapper(task_func: Callable) -> Callable:
52
+ """Optional decorator that controls the failure behavior when executing the task function.
53
+
54
+ This wrapper can be used to:
55
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
56
+ - save the exception to a `.log` file
57
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
58
+ - etc. (adjust depending on your needs)
59
+
60
+ Example:
61
+ ```
62
+ @utils.task_wrapper
63
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
64
+ ...
65
+ return metric_dict, object_dict
66
+ ```
67
+
68
+ :param task_func: The task function to be wrapped.
69
+
70
+ :return: The wrapped task function.
71
+ """
72
+
73
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
74
+ # execute the task
75
+ try:
76
+ metric_dict, object_dict = task_func(cfg=cfg)
77
+
78
+ # things to do if exception occurs
79
+ except Exception as ex:
80
+ # save exception to `.log` file
81
+ log.exception("")
82
+
83
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
84
+ # so when using hparam search plugins like Optuna, you might want to disable
85
+ # raising the below exception to avoid multirun failure
86
+ raise ex
87
+
88
+ # things to always do after either success or exception
89
+ finally:
90
+ # display output dir path in terminal
91
+ log.info(f"Output dir: {cfg.paths.output_dir}")
92
+
93
+ # always close wandb run (even if exception occurs so multirun won't fail)
94
+ if find_spec("wandb"): # check if wandb is installed
95
+ import wandb
96
+
97
+ if wandb.run:
98
+ log.info("Closing wandb!")
99
+ wandb.finish()
100
+
101
+ return metric_dict, object_dict
102
+
103
+ return wrap
104
+
105
+
106
+ def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
107
+ """Safely retrieves value of the metric logged in LightningModule.
108
+
109
+ :param metric_dict: A dict containing metric values.
110
+ :param metric_name: The name of the metric to retrieve.
111
+ :return: The value of the metric.
112
+ """
113
+ if not metric_name:
114
+ log.info("Metric name is None! Skipping metric value retrieval...")
115
+ return None
116
+
117
+ if metric_name not in metric_dict:
118
+ raise ValueError(
119
+ f"Metric value not found! <metric_name={metric_name}>\n"
120
+ "Make sure metric name logged in LightningModule is correct!\n"
121
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
122
+ )
123
+
124
+ metric_value = metric_dict[metric_name].item()
125
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
126
+
127
+ return metric_value
128
+
129
+
130
+ def intersperse(lst, item):
131
+ # Adds blank symbol
132
+ result = [item] * (len(lst) * 2 + 1)
133
+ result[1::2] = lst
134
+ return result
135
+
136
+
137
+ def save_figure_to_numpy(fig):
138
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
139
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
140
+ return data
141
+
142
+
143
+ def plot_tensor(tensor):
144
+ plt.style.use("default")
145
+ fig, ax = plt.subplots(figsize=(12, 3))
146
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
147
+ plt.colorbar(im, ax=ax)
148
+ plt.tight_layout()
149
+ fig.canvas.draw()
150
+ data = save_figure_to_numpy(fig)
151
+ plt.close()
152
+ return data
153
+
154
+
155
+ def save_plot(tensor, savepath):
156
+ plt.style.use("default")
157
+ fig, ax = plt.subplots(figsize=(12, 3))
158
+ im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
159
+ plt.colorbar(im, ax=ax)
160
+ plt.tight_layout()
161
+ fig.canvas.draw()
162
+ plt.savefig(savepath)
163
+ plt.close()
164
+
165
+
166
+ def to_numpy(tensor):
167
+ if isinstance(tensor, np.ndarray):
168
+ return tensor
169
+ elif isinstance(tensor, torch.Tensor):
170
+ return tensor.detach().cpu().numpy()
171
+ elif isinstance(tensor, list):
172
+ return np.array(tensor)
173
+ else:
174
+ raise TypeError("Unsupported type for conversion to numpy array")
175
+
176
+
177
+ def get_user_data_dir(appname="matcha_tts"):
178
+ """
179
+ Args:
180
+ appname (str): Name of application
181
+
182
+ Returns:
183
+ Path: path to user data directory
184
+ """
185
+
186
+ MATCHA_HOME = os.environ.get("MATCHA_HOME")
187
+ if MATCHA_HOME is not None:
188
+ ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
189
+ elif sys.platform == "win32":
190
+ import winreg # pylint: disable=import-outside-toplevel
191
+
192
+ key = winreg.OpenKey(
193
+ winreg.HKEY_CURRENT_USER,
194
+ r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
195
+ )
196
+ dir_, _ = winreg.QueryValueEx(key, "Local AppData")
197
+ ans = Path(dir_).resolve(strict=False)
198
+ elif sys.platform == "darwin":
199
+ ans = Path("~/Library/Application Support/").expanduser()
200
+ else:
201
+ ans = Path.home().joinpath(".local/share")
202
+
203
+ final_path = ans.joinpath(appname)
204
+ final_path.mkdir(parents=True, exist_ok=True)
205
+ return final_path
206
+
207
+
208
+ def assert_model_downloaded(checkpoint_path, url, use_wget=True):
209
+ if Path(checkpoint_path).exists():
210
+ log.debug(f"[+] Model already present at {checkpoint_path}!")
211
+ print(f"[+] Model already present at {checkpoint_path}!")
212
+ return
213
+ log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
214
+ print(f"[-] Model not found at {checkpoint_path}! Will download it")
215
+ checkpoint_path = str(checkpoint_path)
216
+ if not use_wget:
217
+ gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
218
+ else:
219
+ wget.download(url=url, out=checkpoint_path)