abreza commited on
Commit
10e72d3
·
1 Parent(s): c61884d

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +2 -2
  3. app.py +33 -0
  4. config.py +18 -0
  5. interface.py +57 -0
  6. pmt2/.gitignore +31 -0
  7. pmt2/LICENSE +14 -0
  8. pmt2/README.md +97 -0
  9. pmt2/encoder/__init__.py +0 -0
  10. pmt2/encoder/audio.py +117 -0
  11. pmt2/encoder/config.py +53 -0
  12. pmt2/encoder/data_objects/__init__.py +2 -0
  13. pmt2/encoder/data_objects/random_cycler.py +37 -0
  14. pmt2/encoder/data_objects/speaker.py +40 -0
  15. pmt2/encoder/data_objects/speaker_batch.py +13 -0
  16. pmt2/encoder/data_objects/speaker_verification_dataset.py +56 -0
  17. pmt2/encoder/data_objects/utterance.py +26 -0
  18. pmt2/encoder/inference.py +179 -0
  19. pmt2/encoder/model.py +135 -0
  20. pmt2/encoder/params_data.py +29 -0
  21. pmt2/encoder/params_model.py +11 -0
  22. pmt2/encoder/preprocess.py +196 -0
  23. pmt2/encoder/train.py +125 -0
  24. pmt2/encoder/visualizations.py +179 -0
  25. pmt2/encoder_preprocess.py +69 -0
  26. pmt2/encoder_train.py +45 -0
  27. pmt2/inference.py +94 -0
  28. pmt2/prepare_data.py +96 -0
  29. pmt2/requirements.txt +0 -0
  30. pmt2/resources/model.JPG +0 -0
  31. pmt2/synthesizer/LICENSE.txt +24 -0
  32. pmt2/synthesizer/__init__.py +1 -0
  33. pmt2/synthesizer/audio.py +206 -0
  34. pmt2/synthesizer/audio_v2(support_hifigan).py +154 -0
  35. pmt2/synthesizer/english utils/__init__.py +45 -0
  36. pmt2/synthesizer/english utils/_cmudict.py +62 -0
  37. pmt2/synthesizer/english utils/cleaners.py +88 -0
  38. pmt2/synthesizer/english utils/numbers.py +69 -0
  39. pmt2/synthesizer/english utils/plot.py +82 -0
  40. pmt2/synthesizer/english utils/symbols.py +17 -0
  41. pmt2/synthesizer/english utils/text.py +75 -0
  42. pmt2/synthesizer/hparams.py +108 -0
  43. pmt2/synthesizer/hparams_new.py +108 -0
  44. pmt2/synthesizer/inference.py +168 -0
  45. pmt2/synthesizer/models/tacotron.py +519 -0
  46. pmt2/synthesizer/persian_utils/__init__.py +45 -0
  47. pmt2/synthesizer/persian_utils/plot.py +82 -0
  48. pmt2/synthesizer/persian_utils/symbols.py +17 -0
  49. pmt2/synthesizer/persian_utils/text.py +40 -0
  50. pmt2/synthesizer/preprocess.py +259 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ saved_models
2
+ train_nodev_all_vctk_hifigan.v1*
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Persian Tts Demo
3
- emoji: 📉
4
  colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
 
1
  ---
2
  title: Persian Tts Demo
3
+ emoji: 🏢
4
  colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.22.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from config import models_path, results_path, sample_path
4
+ from setup import setup_environment
5
+ from synthesis import load_models
6
+ from interface import create_interface
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ def main():
11
+ os.makedirs(models_path, exist_ok=True)
12
+ os.makedirs(results_path, exist_ok=True)
13
+
14
+ if (not os.path.exists(os.path.join(models_path, 'encoder.pt')) or
15
+ not os.path.exists(os.path.join(models_path, 'synthesizer.pt')) or
16
+ not os.path.exists(os.path.join(models_path, 'vocoder_HiFiGAN.pkl')) or
17
+ not os.path.exists(sample_path)):
18
+ setup_success = setup_environment()
19
+ if not setup_success:
20
+ print("Setup failed. Exiting.")
21
+ exit(1)
22
+ print("Setup completed successfully.")
23
+
24
+ load_success = load_models()
25
+ if not load_success:
26
+ print("Failed to load models. Exiting.")
27
+ exit(1)
28
+
29
+ demo = create_interface()
30
+ demo.launch()
31
+
32
+ if __name__ == "__main__":
33
+ main()
config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ models_path = os.path.join(BASE_DIR, 'saved_models', 'final_models')
5
+ results_path = os.path.join(BASE_DIR, 'results')
6
+ sample_path = os.path.join(BASE_DIR, 'sample.wav')
7
+
8
+ custom_css = """
9
+ .gradio-container {max-width: 900px !important}
10
+ .ethical-note {
11
+ background-color: #f8f9fa;
12
+ border-left: 4px solid #5c5c5c;
13
+ padding: 10px;
14
+ margin: 10px 0;
15
+ font-size: 0.9em;
16
+ }
17
+ footer {text-align: center; margin-top: 20px;}
18
+ """
interface.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from config import custom_css
3
+ from synthesis import generate_speech
4
+
5
+ def create_interface():
6
+ with gr.Blocks(title="Persian Text-to-Speech", css=custom_css) as demo:
7
+ gr.Markdown("# Persian Text-to-Speech with Tacotron2 and HiFiGAN")
8
+
9
+ with gr.Row():
10
+ with gr.Column(scale=2):
11
+ text_input = gr.Textbox(
12
+ label="Persian Text",
13
+ placeholder="مدل تولید گفتار با دادگان نسل مانا",
14
+ lines=5
15
+ )
16
+
17
+ generate_btn = gr.Button("Generate Speech", variant="primary")
18
+
19
+ with gr.Column(scale=2):
20
+ audio_output = gr.Audio(label="Generated Speech")
21
+
22
+ generate_btn.click(
23
+ fn=generate_speech,
24
+ inputs=[text_input],
25
+ outputs=[audio_output]
26
+ )
27
+
28
+ gr.Examples(
29
+ examples=[
30
+ ["سلام، چطور هستید؟"],
31
+ ["ایران سرزمین زیبایی‌ها و افتخارات است."],
32
+ ["فناوری هوش مصنوعی به سرعت در حال پیشرفت است."],
33
+ ["مدل تولید گفتار با دادگان نسل مانا"]
34
+ ],
35
+ inputs=[text_input]
36
+ )
37
+
38
+ gr.Markdown("""
39
+ ### Acknowledgments
40
+
41
+ - [**Nasl-e-Mana**](https://naslemana.com/), the monthly magazine of the blind community of Iran
42
+ - [ManaTTS Dataset](https://huggingface.co/datasets/MahtaFetrat/Mana-TTS)
43
+ - [Persian-MultiSpeaker-Tacotron2](https://github.com/MahtaFetrat/Persian-MultiSpeaker-Tacotron2/)
44
+
45
+ ### Citation
46
+
47
+ ```bibtex
48
+ @article{fetrat2024manatts,
49
+ title={ManaTTS Persian: A Recipe for Creating TTS Datasets for Lower-Resource Languages},
50
+ author={Mahta Fetrat Qharabagh and Zahra Dehghanian and Hamid R. Rabiee},
51
+ journal={arXiv preprint arXiv:2409.07259},
52
+ year={2024},
53
+ }
54
+ ```
55
+ """)
56
+
57
+ return demo
pmt2/.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.aux
3
+ *.log
4
+ *.out
5
+ *.synctex.gz
6
+ *.suo
7
+ *__pycache__
8
+ *.idea
9
+ *.ipynb_checkpoints
10
+ *.pickle
11
+ *.npy
12
+ *.blg
13
+ *.bbl
14
+ *.bcf
15
+ *.toc
16
+ *.sh
17
+
18
+ encoder/saved_models/*
19
+ synthesizer/saved_models/*
20
+ vocoder/saved_models/*
21
+ saved_models/my_run
22
+ saved_models/train_encoder
23
+ dataset/*
24
+ results/best_result
25
+ vocoder/hifigan/*.pkl
26
+ vocoder/hifigan2
27
+ evaluate_vocoder
28
+ features_check
29
+ auto_inference.py
30
+ start_instruction.txt
31
+ saved_models/final_models
pmt2/LICENSE ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
4
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
5
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
6
+ Original work Copyright (c) 2015 braindead (https://github.com/braindead)
7
+ Modified work Copyright (c) 2025 Majid Adibian (https://github.com/Adibian)
8
+ Modified work Copyright (c) 2025 Mahta Fetrat (https://github.com/MahtaFetrat)
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
pmt2/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiSpeaker Tacotron2 in Persian Language
2
+ This repository implements [Transfer Learning from Speaker Verification to
3
+ Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) (SV2TTS) for the Persian language. The core codebase is derived from [this repository](https://github.com/Adibian/Persian-MultiSpeaker-Tacotron2), which has been updated to address deprecated features and complete setup for Persian language compatibility. The original codebase, sourced from [this repository](https://github.com/CorentinJ/Real-Time-Voice-Cloning/tree/master), has been modified to support Persian language requirements.
4
+
5
+ <img src="https://github.com/majidAdibian77/persian-SV2TTS/blob/master/resources/model.JPG" width="800">
6
+
7
+ ---
8
+
9
+ ## Training
10
+ **1. Character-set definition:**
11
+
12
+ Open the `synthesizer/persian_utils/symbols.py` file and update the `_characters` variable to include all the characters that exist in your text files. Most of Persian characters and symbols are already included in this variable as follows:
13
+ ```
14
+ _characters = "ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآۀأؤإئًَُّ!(),-.:;? ̠،…؛؟‌٪#ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_–@+/\u200c"
15
+ ```
16
+
17
+ **2. Data structures:**
18
+ ```
19
+ dataset/persian_date/
20
+ train_data/
21
+ speaker1/book-1/
22
+ sample1.txt
23
+ sample1.wav
24
+ ...
25
+ ...
26
+ test_data/
27
+ ...
28
+ ```
29
+
30
+ **3. Preprocessing:**
31
+ ```
32
+ python3 synthesizer_preprocess_audio.py dataset --datasets_name persian_data --subfolders train_data --no_alignments --skip_existing --n_processes 4 --out_dir dataset/train/SV2TTS/synthesizer
33
+ python3 synthesizer_preprocess_audio.py dataset --datasets_name persian_data --subfolders test_data --no_alignments --skip_existing --n_processes 4 --out_dir dataset/test/SV2TTS/synthesizer
34
+ ```
35
+ 2. **Embedding Preprocessing**
36
+ ```
37
+ python3 synthesizer_preprocess_embeds.py dataset/train/SV2TTS/synthesizer
38
+ python3 synthesizer_preprocess_embeds.py dataset/test/SV2TTS/synthesizer
39
+ ```
40
+
41
+ **4. Train synthesizer:**
42
+ ```
43
+ python3 synthesizer_train.py my_run dataset/train/SV2TTS/synthesizer
44
+ ```
45
+
46
+ ## Inference
47
+
48
+ To generate a wav file, place all trained models in the `saved_models/final_models` directory. If you haven’t trained the speaker encoder or vocoder models, you can use pretrained models from `saved_models/default`. These models include `encoder.pt`, your latest synthesizer checkpoint like `synthesizer_000300.pt`, and a vocoder as follows.
49
+
50
+ ### Using WavRNN as Vocoder
51
+
52
+ ```
53
+ python3 inference.py --vocoder "WavRNN" --text "یک نمونه از خروجی" --ref_wav_path "/path/to/sample/reference.wav" --test_name "test1"
54
+ ```
55
+
56
+ ### Using HiFiGAN as Vocoder (Recommended)
57
+ WavRNN is an old vocoder and if you want to use HiFiGAN you must first download a pretrained model in English.
58
+ 1. **Install Parallel WaveGAN**
59
+ ```
60
+ pip install parallel_wavegan
61
+ ```
62
+ 2. **Download Pretrained HiFiGAN Model**
63
+ ```
64
+ from parallel_wavegan.utils import download_pretrained_model
65
+ download_pretrained_model("vctk_hifigan.v1", "saved_models/final_models/vocoder_HiFiGAN")
66
+ ```
67
+ 3. **Run Inference with HiFiGAN**
68
+ ```
69
+ python3 inference.py --vocoder "HiFiGAN" --text "یک نمونه از خروجی" --ref_wav_path "/path/to/sample/reference.wav" --test_name "test1"
70
+ ```
71
+
72
+ ## ManaTTS-Trained Model
73
+
74
+ This architecture has been used to train a Persian Text-to-Speech (TTS) model on the [**ManaTTS dataset**](https://huggingface.co/datasets/MahtaFetrat/Mana-TTS), the largest publicly available single-speaker Persian corpus. The trained model weights and detailed inference instructions can be found in the following repositories:
75
+
76
+ - [Hugging Face Repository](https://huggingface.co/MahtaFetrat/Persian-Tacotron2-on-ManaTTS)
77
+ - [GitHub Repository](https://github.com/MahtaFetrat/ManaTTS-Persian-Tacotron2-Model)
78
+
79
+ ## References:
80
+ - [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) Ye Jia, *et al*.,
81
+ - [Real-Time-Voice-Cloning repository](https://github.com/CorentinJ/Real-Time-Voice-Cloning/tree/master),
82
+ - [ParallelWaveGAN repository](https://github.com/kan-bayashi/ParallelWaveGAN)
83
+ - [Persian-MultiSpeaker-Tacotron2](https://github.com/Adibian/Persian-MultiSpeaker-Tacotron2)
84
+
85
+ ## License
86
+ This project is based on [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning),
87
+ which is licensed under the MIT License.
88
+ ```
89
+ Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
90
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
91
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
92
+ Original work Copyright (c) 2015 braindead (https://github.com/braindead)
93
+ Modified work Copyright (c) 2025 Majid Adibian (https://github.com/Adibian)
94
+ Modified work Copyright (c) 2025 Mahta Fetrat (https://github.com/MahtaFetrat)
95
+ ```
96
+
97
+
pmt2/encoder/__init__.py ADDED
File without changes
pmt2/encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ y=wav,
60
+ sr=sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
pmt2/encoder/config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ persian_datasets = {
2
+ "train": {
3
+ "data": ["persian_data/train_data"]
4
+ },
5
+ "test": {
6
+ "data": ["persian_data/test_data"]
7
+ }
8
+ }
9
+ librispeech_datasets = {
10
+ "train": {
11
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
12
+ "other": ["LibriSpeech/train-other-500"]
13
+ },
14
+ "test": {
15
+ "clean": ["LibriSpeech/test-clean"],
16
+ "other": ["LibriSpeech/test-other"]
17
+ },
18
+ "dev": {
19
+ "clean": ["LibriSpeech/dev-clean"],
20
+ "other": ["LibriSpeech/dev-other"]
21
+ },
22
+ }
23
+ libritts_datasets = {
24
+ "train": {
25
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
26
+ "other": ["LibriTTS/train-other-500"]
27
+ },
28
+ "test": {
29
+ "clean": ["LibriTTS/test-clean"],
30
+ "other": ["LibriTTS/test-other"]
31
+ },
32
+ "dev": {
33
+ "clean": ["LibriTTS/dev-clean"],
34
+ "other": ["LibriTTS/dev-other"]
35
+ },
36
+ }
37
+ voxceleb_datasets = {
38
+ "voxceleb1" : {
39
+ "train": ["VoxCeleb1/wav"],
40
+ "test": ["VoxCeleb1/test_wav"]
41
+ },
42
+ "voxceleb2" : {
43
+ "train": ["VoxCeleb2/dev/aac"],
44
+ "test": ["VoxCeleb2/test_wav"]
45
+ }
46
+ }
47
+
48
+ other_datasets = [
49
+ "LJSpeech-1.1",
50
+ "VCTK-Corpus/wav48",
51
+ ]
52
+
53
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
pmt2/encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
pmt2/encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
pmt2/encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
pmt2/encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+
6
+ class SpeakerBatch:
7
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
+ self.speakers = speakers
9
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
+
11
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
pmt2/encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
pmt2/encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
pmt2/encoder/inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+
10
+ _model = None # type: SpeakerEncoder
11
+ _device = None # type: torch.device
12
+
13
+
14
+ def load_model(weights_fpath: Path, device=None):
15
+ """
16
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
17
+ first call to embed_frames() with the default weights file.
18
+
19
+ :param weights_fpath: the path to saved model weights.
20
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
+ If None, will default to your GPU if it"s available, otherwise your CPU.
23
+ """
24
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
25
+ # was saved on. Worth investigating.
26
+ global _model, _device
27
+ if device is None:
28
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ elif isinstance(device, str):
30
+ _device = torch.device(device)
31
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
32
+ checkpoint = torch.load(weights_fpath, _device)
33
+ _model.load_state_dict(checkpoint["model_state"])
34
+ _model.eval()
35
+ # print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ import matplotlib.pyplot as plt
164
+ if ax is None:
165
+ ax = plt.gca()
166
+
167
+ if shape is None:
168
+ height = int(np.sqrt(len(embed)))
169
+ shape = (height, -1)
170
+ embed = embed.reshape(shape)
171
+
172
+ cmap = cm.get_cmap()
173
+ mappable = ax.imshow(embed, cmap=cmap)
174
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
175
+ sm = cm.ScalarMappable(cmap=cmap)
176
+ sm.set_clim(*color_range)
177
+
178
+ ax.set_xticks([]), ax.set_yticks([])
179
+ ax.set_title(title)
pmt2/encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
pmt2/encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
pmt2/encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
pmt2/encoder/preprocess.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from multiprocessing import Pool
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from encoder import audio
10
+ from encoder.config import librispeech_datasets, anglophone_nationalites, persian_datasets
11
+ from encoder.params_data import *
12
+
13
+
14
+ _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
+
16
+ class DatasetLog:
17
+ """
18
+ Registers metadata about the dataset in a text file.
19
+ """
20
+ def __init__(self, root, name):
21
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
+ self.sample_data = dict()
23
+
24
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
26
+ self.write_line("-----")
27
+ self._log_params()
28
+
29
+ def _log_params(self):
30
+ from encoder import params_data
31
+ self.write_line("Parameter values:")
32
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
+ value = getattr(params_data, param_name)
34
+ self.write_line("\t%s: %s" % (param_name, value))
35
+ self.write_line("-----")
36
+
37
+ def write_line(self, line):
38
+ self.text_file.write("%s\n" % line)
39
+
40
+ def add_sample(self, **kwargs):
41
+ for param_name, value in kwargs.items():
42
+ if not param_name in self.sample_data:
43
+ self.sample_data[param_name] = []
44
+ self.sample_data[param_name].append(value)
45
+
46
+ def finalize(self):
47
+ self.write_line("Statistics:")
48
+ for param_name, values in self.sample_data.items():
49
+ self.write_line("\t%s:" % param_name)
50
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
+ self.write_line("-----")
53
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
+ self.write_line("Finished on %s" % end_time)
55
+ self.text_file.close()
56
+
57
+
58
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
+ dataset_root = datasets_root.joinpath(dataset_name)
60
+ if not dataset_root.exists():
61
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
+ return None, None
63
+ return dataset_root, DatasetLog(out_dir, dataset_name)
64
+
65
+
66
+ def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ audio_durs = []
90
+ for extension in _AUDIO_EXTENSIONS:
91
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
+ # Check if the target output file already exists
93
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
95
+ if skip_existing and out_fname in existing_fnames:
96
+ continue
97
+
98
+ # Load and preprocess the waveform
99
+ wav = audio.preprocess_wav(in_fpath)
100
+ if len(wav) == 0:
101
+ continue
102
+
103
+ # Create the mel spectrogram, discard those that are too short
104
+ frames = audio.wav_to_mel_spectrogram(wav)
105
+ if len(frames) < partials_n_frames:
106
+ continue
107
+
108
+ out_fpath = speaker_out_dir.joinpath(out_fname)
109
+ np.save(out_fpath, frames)
110
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
+ audio_durs.append(len(wav) / sampling_rate)
112
+
113
+ sources_file.close()
114
+
115
+ return audio_durs
116
+
117
+
118
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
+
121
+ # Process the utterances for each speaker
122
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
+ with Pool(4) as pool:
124
+ tasks = pool.imap(work_fn, speaker_dirs)
125
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
+ for sample_dur in sample_durs:
127
+ logger.add_sample(duration=sample_dur)
128
+
129
+ logger.finalize()
130
+ print("Done preprocessing %s.\n" % dataset_name)
131
+
132
+
133
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
+ for dataset_name in librispeech_datasets["train"]["other"]:
135
+ # Initialize the preprocessing
136
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
+ if not dataset_root:
138
+ return
139
+
140
+ # Preprocess all speakers
141
+ speaker_dirs = list(dataset_root.glob("*"))
142
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
+
144
+
145
+
146
+ def preprocess_persian(datasets_root: Path, out_dir: Path, skip_existing=False):
147
+ for dataset_name in persian_datasets["train"]["data"]:
148
+ # Initialize the preprocessing
149
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
150
+ if not dataset_root:
151
+ return
152
+ # Preprocess all speakers
153
+ speaker_dirs = list(dataset_root.glob("*"))
154
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
155
+
156
+
157
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
158
+ # Initialize the preprocessing
159
+ dataset_name = "VoxCeleb1"
160
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
161
+ if not dataset_root:
162
+ return
163
+
164
+ # Get the contents of the meta file
165
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
166
+ metadata = [line.split("\t") for line in metafile][1:]
167
+
168
+ # Select the ID and the nationality, filter out non-anglophone speakers
169
+ nationalities = {line[0]: line[3] for line in metadata}
170
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
171
+ nationality.lower() in anglophone_nationalites]
172
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
173
+ (len(keep_speaker_ids), len(nationalities)))
174
+
175
+ # Get the speaker directories for anglophone speakers only
176
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
177
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
178
+ speaker_dir.name in keep_speaker_ids]
179
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
180
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
181
+
182
+ # Preprocess all speakers
183
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
184
+
185
+
186
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
187
+ # Initialize the preprocessing
188
+ dataset_name = "VoxCeleb2"
189
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
190
+ if not dataset_root:
191
+ return
192
+
193
+ # Get the speaker directories
194
+ # Preprocess all speakers
195
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
196
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
pmt2/encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
+ from encoder.model import SpeakerEncoder
7
+ from encoder.params_model import *
8
+ from encoder.visualizations import Visualizations
9
+ from utils.profiler import Profiler
10
+
11
+
12
+ def sync(device: torch.device):
13
+ # For correct profiling (cuda operations are async)
14
+ if device.type == "cuda":
15
+ torch.cuda.synchronize(device)
16
+
17
+
18
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
+ no_visdom: bool):
21
+ # Create a dataset and a dataloader
22
+ dataset = SpeakerVerificationDataset(clean_data_root)
23
+ loader = SpeakerVerificationDataLoader(
24
+ dataset,
25
+ speakers_per_batch,
26
+ utterances_per_speaker,
27
+ num_workers=4,
28
+ )
29
+
30
+ # Setup the device on which to run the forward pass and the loss. These can be different,
31
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
+ # hyperparameters) faster on the CPU.
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # FIXME: currently, the gradient is None if loss_device is cuda
35
+ loss_device = torch.device("cpu")
36
+
37
+ # Create the model and the optimizer
38
+ model = SpeakerEncoder(device, loss_device)
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
+ init_step = 1
41
+
42
+ # Configure file path for the model
43
+ model_dir = models_dir / run_id
44
+ model_dir.mkdir(exist_ok=True, parents=True)
45
+ state_fpath = model_dir / "encoder.pt"
46
+
47
+ # Load any existing model
48
+ if not force_restart:
49
+ if state_fpath.exists():
50
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
+ checkpoint = torch.load(state_fpath)
52
+ init_step = checkpoint["step"]
53
+ model.load_state_dict(checkpoint["model_state"])
54
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
55
+ optimizer.param_groups[0]["lr"] = learning_rate_init
56
+ else:
57
+ print("No model \"%s\" found, starting training from scratch." % run_id)
58
+ else:
59
+ print("Starting the training from scratch.")
60
+ model.train()
61
+
62
+ # Initialize the visualization environment
63
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
+ vis.log_dataset(dataset)
65
+ vis.log_params()
66
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
+ vis.log_implementation({"Device": device_name})
68
+
69
+ # Training loop
70
+ profiler = Profiler(summarize_every=10, disabled=False)
71
+ for step, speaker_batch in enumerate(loader, init_step):
72
+ profiler.tick("Blocking, waiting for batch (threaded)")
73
+
74
+ # Forward pass
75
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
76
+ sync(device)
77
+ profiler.tick("Data to %s" % device)
78
+ embeds = model(inputs)
79
+ sync(device)
80
+ profiler.tick("Forward pass")
81
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
+ loss, eer = model.loss(embeds_loss)
83
+ sync(loss_device)
84
+ profiler.tick("Loss")
85
+
86
+ # Backward pass
87
+ model.zero_grad()
88
+ loss.backward()
89
+ profiler.tick("Backward pass")
90
+ model.do_gradient_ops()
91
+ optimizer.step()
92
+ profiler.tick("Parameter update")
93
+
94
+ # Update visualizations
95
+ # learning_rate = optimizer.param_groups[0]["lr"]
96
+ vis.update(loss.item(), eer, step)
97
+
98
+ # Draw projections and save them to the backup folder
99
+ if umap_every != 0 and step % umap_every == 0:
100
+ print("Drawing and saving projections (step %d)" % step)
101
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
102
+ embeds = embeds.detach().cpu().numpy()
103
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
+ vis.save()
105
+
106
+ # Overwrite the latest version of the model
107
+ if save_every != 0 and step % save_every == 0:
108
+ print("Saving the model (step %d)" % step)
109
+ torch.save({
110
+ "step": step + 1,
111
+ "model_state": model.state_dict(),
112
+ "optimizer_state": optimizer.state_dict(),
113
+ }, state_fpath)
114
+
115
+ # Make a backup
116
+ if backup_every != 0 and step % backup_every == 0:
117
+ print("Making a backup (step %d)" % step)
118
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
+ torch.save({
120
+ "step": step + 1,
121
+ "model_state": model.state_dict(),
122
+ "optimizer_state": optimizer.state_dict(),
123
+ }, backup_fpath)
124
+
125
+ profiler.tick("Extras (visualizations, saving)")
pmt2/encoder/visualizations.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from time import perf_counter as timer
3
+
4
+ import numpy as np
5
+ import umap
6
+ import visdom
7
+
8
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
+
10
+
11
+ colormap = np.array([
12
+ [76, 255, 0],
13
+ [0, 127, 70],
14
+ [255, 0, 0],
15
+ [255, 217, 38],
16
+ [0, 135, 255],
17
+ [165, 0, 165],
18
+ [255, 167, 255],
19
+ [0, 255, 255],
20
+ [255, 96, 38],
21
+ [142, 76, 0],
22
+ [33, 0, 127],
23
+ [0, 0, 0],
24
+ [183, 183, 183],
25
+ ], dtype=np.float) / 255
26
+
27
+
28
+ class Visualizations:
29
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
+ # Tracking data
31
+ self.last_update_timestamp = timer()
32
+ self.update_every = update_every
33
+ self.step_times = []
34
+ self.losses = []
35
+ self.eers = []
36
+ print("Updating the visualizations every %d steps." % update_every)
37
+
38
+ # If visdom is disabled TODO: use a better paradigm for that
39
+ self.disabled = disabled
40
+ if self.disabled:
41
+ return
42
+
43
+ # Set the environment name
44
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
+ if env_name is None:
46
+ self.env_name = now
47
+ else:
48
+ self.env_name = "%s (%s)" % (env_name, now)
49
+
50
+ # Connect to visdom and open the corresponding window in the browser
51
+ try:
52
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
+ except ConnectionError:
54
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
+ "start it.")
56
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
+
58
+ # Create the windows
59
+ self.loss_win = None
60
+ self.eer_win = None
61
+ # self.lr_win = None
62
+ self.implementation_win = None
63
+ self.projection_win = None
64
+ self.implementation_string = ""
65
+
66
+ def log_params(self):
67
+ if self.disabled:
68
+ return
69
+ from encoder import params_data
70
+ from encoder import params_model
71
+ param_string = "<b>Model parameters</b>:<br>"
72
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
+ value = getattr(params_model, param_name)
74
+ param_string += "\t%s: %s<br>" % (param_name, value)
75
+ param_string += "<b>Data parameters</b>:<br>"
76
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
+ value = getattr(params_data, param_name)
78
+ param_string += "\t%s: %s<br>" % (param_name, value)
79
+ self.vis.text(param_string, opts={"title": "Parameters"})
80
+
81
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
82
+ if self.disabled:
83
+ return
84
+ dataset_string = ""
85
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
+ dataset_string += "\n" + dataset.get_logs()
87
+ dataset_string = dataset_string.replace("\n", "<br>")
88
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
89
+
90
+ def log_implementation(self, params):
91
+ if self.disabled:
92
+ return
93
+ implementation_string = ""
94
+ for param, value in params.items():
95
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
+ implementation_string = implementation_string.replace("\n", "<br>")
97
+ self.implementation_string = implementation_string
98
+ self.implementation_win = self.vis.text(
99
+ implementation_string,
100
+ opts={"title": "Training implementation"}
101
+ )
102
+
103
+ def update(self, loss, eer, step):
104
+ # Update the tracking data
105
+ now = timer()
106
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
107
+ self.last_update_timestamp = now
108
+ self.losses.append(loss)
109
+ self.eers.append(eer)
110
+ print(".", end="")
111
+
112
+ # Update the plots every <update_every> steps
113
+ if step % self.update_every != 0:
114
+ return
115
+ time_string = "Step time: mean: %5dms std: %5dms" % \
116
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
+ if not self.disabled:
120
+ self.loss_win = self.vis.line(
121
+ [np.mean(self.losses)],
122
+ [step],
123
+ win=self.loss_win,
124
+ update="append" if self.loss_win else None,
125
+ opts=dict(
126
+ legend=["Avg. loss"],
127
+ xlabel="Step",
128
+ ylabel="Loss",
129
+ title="Loss",
130
+ )
131
+ )
132
+ self.eer_win = self.vis.line(
133
+ [np.mean(self.eers)],
134
+ [step],
135
+ win=self.eer_win,
136
+ update="append" if self.eer_win else None,
137
+ opts=dict(
138
+ legend=["Avg. EER"],
139
+ xlabel="Step",
140
+ ylabel="EER",
141
+ title="Equal error rate"
142
+ )
143
+ )
144
+ if self.implementation_win is not None:
145
+ self.vis.text(
146
+ self.implementation_string + ("<b>%s</b>" % time_string),
147
+ win=self.implementation_win,
148
+ opts={"title": "Training implementation"},
149
+ )
150
+
151
+ # Reset the tracking
152
+ self.losses.clear()
153
+ self.eers.clear()
154
+ self.step_times.clear()
155
+
156
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
+ import matplotlib.pyplot as plt
158
+
159
+ max_speakers = min(max_speakers, len(colormap))
160
+ embeds = embeds[:max_speakers * utterances_per_speaker]
161
+
162
+ n_speakers = len(embeds) // utterances_per_speaker
163
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
+ colors = [colormap[i] for i in ground_truth]
165
+
166
+ reducer = umap.UMAP()
167
+ projected = reducer.fit_transform(embeds)
168
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
+ plt.gca().set_aspect("equal", "datalim")
170
+ plt.title("UMAP projection (step %d)" % step)
171
+ if not self.disabled:
172
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
+ if out_fpath is not None:
174
+ plt.savefig(out_fpath)
175
+ plt.clf()
176
+
177
+ def save(self):
178
+ if not self.disabled:
179
+ self.vis.save([self.env_name])
pmt2/encoder_preprocess.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.preprocess import preprocess_persian
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
9
+ pass
10
+
11
+ parser = argparse.ArgumentParser(
12
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
13
+ "writes them to the disk. This will allow you to train the encoder. The "
14
+ "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
15
+ "Ideally, you should have all three. You should extract them as they are "
16
+ "after having downloaded them and put them in a same directory, e.g.:\n"
17
+ "-[datasets_root]\n"
18
+ " -LibriSpeech\n"
19
+ " -train-other-500\n"
20
+ " -VoxCeleb1\n"
21
+ " -wav\n"
22
+ " -vox1_meta.csv\n"
23
+ " -VoxCeleb2\n"
24
+ " -dev",
25
+ formatter_class=MyFormatter
26
+ )
27
+ parser.add_argument("datasets_root", type=Path, help=\
28
+ "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
29
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
30
+ "Path to the output directory that will contain the mel spectrograms. If left out, "
31
+ "defaults to <datasets_root>/SV2TTS/encoder/")
32
+ parser.add_argument("-d", "--datasets", type=str,
33
+ default="librispeech_other,voxceleb1,voxceleb2", help=\
34
+ "Comma-separated list of the name of the datasets you want to preprocess. Only the train "
35
+ "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
36
+ "voxceleb2.")
37
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
38
+ "Whether to skip existing output files with the same name. Useful if this script was "
39
+ "interrupted.")
40
+ parser.add_argument("--no_trim", action="store_true", help=\
41
+ "Preprocess audio without trimming silences (not recommended).")
42
+ args = parser.parse_args()
43
+
44
+ # Verify webrtcvad is available
45
+ if not args.no_trim:
46
+ try:
47
+ import webrtcvad
48
+ except:
49
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
50
+ "noise removal and is recommended. Please install and try again. If installation fails, "
51
+ "use --no_trim to disable this error message.")
52
+ del args.no_trim
53
+
54
+ # Process the arguments
55
+ args.datasets = args.datasets.split(",")
56
+ if not hasattr(args, "out_dir"):
57
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
58
+ assert args.datasets_root.exists()
59
+ args.out_dir.mkdir(exist_ok=True, parents=True)
60
+
61
+ # Preprocess the datasets
62
+ print_args(args, parser)
63
+ preprocess_func = {
64
+ "persian_data": preprocess_persian
65
+ }
66
+ args = vars(args)
67
+ for dataset in args.pop("datasets"):
68
+ print("Preprocessing %s" % dataset)
69
+ preprocess_func[dataset](**args)
pmt2/encoder_train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.argutils import print_args
2
+ from encoder.train import train
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+
13
+ parser.add_argument("run_id", type=str, help= \
14
+ "Name for this model. By default, training outputs will be stored to saved_models/<run_id>/. If a model state "
15
+ "from the same run ID was previously saved, the training will restart from there. Pass -f to overwrite saved "
16
+ "states and restart from scratch.")
17
+ parser.add_argument("clean_data_root", type=Path, help= \
18
+ "Path to the output directory of encoder_preprocess.py. If you left the default "
19
+ "output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
20
+ parser.add_argument("-m", "--models_dir", type=Path, default="saved_models", help=\
21
+ "Path to the root directory that contains all models. A directory <run_name> will be created under this root."
22
+ "It will contain the saved model weights, as well as backups of those weights and plots generated during "
23
+ "training.")
24
+ parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
25
+ "Number of steps between updates of the loss and the plots.")
26
+ parser.add_argument("-u", "--umap_every", type=int, default=500, help= \
27
+ "Number of steps between updates of the umap projection. Set to 0 to never update the "
28
+ "projections.")
29
+ parser.add_argument("-s", "--save_every", type=int, default=500, help= \
30
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
31
+ "model.")
32
+ parser.add_argument("-b", "--backup_every", type=int, default=10000, help= \
33
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
34
+ "model.")
35
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
36
+ "Do not load any saved model.")
37
+ parser.add_argument("--visdom_server", type=str, default="http://localhost")
38
+ parser.add_argument("--no_visdom", action="store_true", help= \
39
+ "Disable visdom.")
40
+ args = parser.parse_args()
41
+
42
+ # Run the training
43
+ print_args(args, parser)
44
+ train(**vars(args))
45
+
pmt2/inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import sys
4
+
5
+ from encoder import inference as encoder
6
+ from synthesizer.inference import Synthesizer
7
+ from vocoder import inference as vocoder_wavrnn
8
+ from parallel_wavegan.utils import load_model as vocoder_hifigan
9
+
10
+ import soundfile as sf
11
+ import os
12
+ import argparse
13
+
14
+
15
+ main_path = os.getcwd()
16
+ models_path = os.path.join(main_path, 'saved_models/final_models/')
17
+
18
+ def wavRNN_infer(text, ref_wav_path, test_name):
19
+ encoder.load_model(os.path.join(models_path, 'encoder.pt'))
20
+ synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
21
+ vocoder_wavrnn.load_model(os.path.join(models_path, 'vocoder_WavRNN.pt'))
22
+
23
+ ref_wav_path = os.path.join(main_path, 'dataset/persian_data/train_data/book-1/', ref_wav_path) ## refrence wav
24
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
25
+
26
+ encoder_wav = encoder.preprocess_wav(wav)
27
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
28
+
29
+ texts = [text]
30
+ embeds = [embed] * len(texts)
31
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
32
+ breaks = [spec.shape[1] for spec in specs]
33
+ spec = np.concatenate(specs, axis=1)
34
+
35
+ wav = vocoder_wavrnn.infer_waveform(spec)
36
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
37
+ b_starts = np.concatenate(([0], b_ends[:-1]))
38
+ wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
39
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
40
+ wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
41
+ wav = wav / np.abs(wav).max() * 0.97
42
+
43
+ res_path = os.path.join(main_path, 'results/', test_name+".wav")
44
+ sf.write(res_path, wav, Synthesizer.sample_rate)
45
+ print('\nwav file is saved.')
46
+
47
+
48
+ def hifigan_infer(text, ref_wav_path, test_name):
49
+ encoder.load_model(os.path.join(models_path, 'encoder.pt'))
50
+ synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))
51
+ vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
52
+ vocoder.remove_weight_norm()
53
+ vocoder = vocoder.eval().to('cpu')
54
+
55
+ wav = Synthesizer.load_preprocess_wav(ref_wav_path)
56
+
57
+ encoder_wav = encoder.preprocess_wav(wav)
58
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
59
+
60
+ texts = [text]
61
+ embeds = [embed] * len(texts)
62
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
63
+ spec = np.concatenate(specs, axis=1)
64
+ x = torch.from_numpy(spec.T).to('cpu')
65
+
66
+ with torch.no_grad():
67
+ wav = vocoder.inference(x)
68
+ wav = wav / np.abs(wav).max() * 0.97
69
+
70
+ res_path = os.path.join(main_path, 'results/', test_name+".wav")
71
+ sf.write(res_path, wav, Synthesizer.sample_rate)
72
+ print('\nwav file is saved.')
73
+
74
+
75
+ def main(args):
76
+ if str(args.vocoder).lower() == "wavrnn":
77
+ wavRNN_infer(args.text, args.ref_wav_path, args.test_name)
78
+ elif str(args.vocoder).lower() == "hifigan":
79
+ hifigan_infer(args.text, args.ref_wav_path, args.test_name)
80
+ else:
81
+ print("--vocoder must be one of HiFiGAN or WavRNN")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--vocoder", type=str, help= "vocoder name: HiFiGAN or WavRNN")
87
+ parser.add_argument("--text", type=str, help="input text")
88
+ parser.add_argument("--ref_wav_path", type=str, help="path to refrence wav to create speaker from that")
89
+ parser.add_argument("--test_name", type=str, default="test1", help="name of current test to save the result wav")
90
+ args = parser.parse_args()
91
+
92
+ main(args)
93
+
94
+
pmt2/prepare_data.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import csv
4
+ import shutil
5
+ import os
6
+ import argparse
7
+
8
+ main_path = os.getcwd()
9
+
10
+ def get_duration(row):
11
+ phone_durs = row.split()
12
+ dur_sum = 0
13
+ for phone_dur in phone_durs:
14
+ if phone_dur == '|':
15
+ continue
16
+ else:
17
+ phone_dur = phone_dur.split('[')
18
+ dur = float(phone_dur[1][:-1])/1000
19
+ dur_sum += dur
20
+ return dur_sum
21
+
22
+ def prepare_data_for_model(path, duration_lim):
23
+ f = open(path, 'r')
24
+ data = csv.DictReader(f)
25
+ data_lines = []
26
+ for row in data:
27
+ dur = get_duration(row['phenome'])
28
+ if dur > duration_lim:
29
+ continue
30
+ phoneme = row['phenome']
31
+ utterance_name = row['seg_id']
32
+ speaker_id = row['speaker_id']
33
+ phoneme = re.sub("\[([0-9]+)\]", '', phoneme)
34
+ phoneme = re.sub("\s+\|\s+", ' ', phoneme)
35
+ data_lines.append([phoneme, utterance_name, speaker_id])
36
+ f.close()
37
+ return data_lines
38
+
39
+
40
+ def save_files(train_data, test_data, data_path):
41
+ for line in train_data:
42
+ try:
43
+ original = os.path.join(data_path, 'train_wav/{}.wav'.format(line[1]))
44
+ target = os.path.join(main_path, 'dataset/persian_data/train_data/speaker-{0}/book-1/utterance-{1}.wav'.format(line[2], line[1]))
45
+ os.makedirs(os.path.dirname(target), exist_ok=True)
46
+ shutil.copyfile(original, target)
47
+ except Exception as e:
48
+ print(e)
49
+ return False
50
+
51
+ path = os.path.join(main_path, 'dataset/persian_data/train_data/speaker-{0}/book-1/utterance-{1}.txt'.format(line[2], line[1]))
52
+ with open(path, 'w') as fp:
53
+ fp.write(line[0])
54
+
55
+ for line in test_data:
56
+ try:
57
+ original = os.path.join(data_path, 'test_wav/{}.wav'.format(line[1]))
58
+ target = os.path.join(main_path, 'dataset/persian_data/test_data/speaker-{0}/book-1/utterance-{1}.wav'.format(line[2], line[1]))
59
+ os.makedirs(os.path.dirname(target), exist_ok=True)
60
+ shutil.copyfile(original, target)
61
+ except Exception as e:
62
+ print(e)
63
+ return False
64
+
65
+ path = os.path.join(main_path, 'dataset/persian_data/test_data/speaker-{0}/book-1/utterance-{1}.txt'.format(line[2], line[1]))
66
+ with open(path, 'w') as fp:
67
+ fp.write(line[0])
68
+ return True
69
+
70
+ def main():
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument('--data_path', required=True)
73
+ args = parser.parse_args()
74
+ data_path = args.data_path
75
+
76
+ if os.path.isfile(os.path.join(data_path, 'train_info.csv')):
77
+ train_data_path = os.path.join(data_path, 'train_info.csv')
78
+ else:
79
+ print('data_path is not correct!')
80
+ return -1
81
+ if os.path.isfile(os.path.join(data_path, 'test_info.csv')):
82
+ test_data_path = os.path.join(data_path, 'test_info.csv')
83
+ else:
84
+ print('data_path is not correct!')
85
+ return -1
86
+ train_data = prepare_data_for_model(train_data_path, 12)
87
+ test_data = prepare_data_for_model(test_data_path, 15)
88
+ print('number of train data: ' + str(len(train_data)))
89
+ print('number of test data: ' + str(len(test_data)))
90
+
91
+ res = save_files(train_data, test_data, data_path)
92
+ if res:
93
+ print('Data is created.')
94
+
95
+ if __name__ == "__main__":
96
+ main()
pmt2/requirements.txt ADDED
Binary file (562 Bytes). View file
 
pmt2/resources/model.JPG ADDED
pmt2/synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
pmt2/synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
pmt2/synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex_)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(sr=hparams.sample_rate, n_fft=hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
pmt2/synthesizer/audio_v2(support_hifigan).py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # raccoonML audio tools.
2
+ # MIT License
3
+ # Copyright (c) 2021 raccoonML (https://patreon.com/raccoonML)
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software") to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR ANY OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ # THE SOFTWARE.
22
+
23
+ import librosa
24
+ import numpy as np
25
+ import soundfile as sf
26
+ import torch
27
+ from scipy import signal
28
+
29
+ _mel_basis = None
30
+
31
+
32
+ def load_wav(path, sr):
33
+ # Loads an audio file and returns the waveform data.
34
+ wav, _ = librosa.load(str(path), sr=sr)
35
+ return wav
36
+
37
+
38
+ def save_wav(wav, path, sr):
39
+ # Saves waveform data to audio file.
40
+ sf.write(path, wav, sr)
41
+
42
+
43
+ def melspectrogram(wav, hparams):
44
+ # Converts a waveform to a mel-scale spectrogram.
45
+ # Output shape = (num_mels, frames)
46
+
47
+ # Apply preemphasis
48
+ if hparams.preemphasize:
49
+ wav = preemphasis(wav, hparams.preemphasis)
50
+
51
+ # Short-time Fourier Transform (STFT)
52
+ D = librosa.stft(
53
+ y=wav,
54
+ n_fft=hparams.n_fft,
55
+ hop_length=hparams.hop_size,
56
+ win_length=hparams.win_size,
57
+ )
58
+
59
+ # Convert complex-valued output of STFT to absolute value (real)
60
+ S = np.abs(D)
61
+
62
+ # Build and cache mel basis
63
+ # This improves speed when calculating thousands of mel spectrograms.
64
+ global _mel_basis
65
+ if _mel_basis is None:
66
+ _mel_basis = _build_mel_basis(hparams)
67
+
68
+ # Transform to mel scale
69
+ S = np.dot(_mel_basis, S)
70
+
71
+ # Dynamic range compression
72
+ S = np.log(np.clip(S, a_min=1e-5, a_max=None))
73
+
74
+ return S.astype(np.float32)
75
+
76
+
77
+ def inv_mel_spectrogram(S, hparams):
78
+ # Converts a mel spectrogram to waveform using Griffin-Lim
79
+ # Input shape = (num_mels, frames)
80
+
81
+ # Denormalize
82
+ S = np.exp(S)
83
+
84
+ # Build and cache mel basis
85
+ # This improves speed when calculating thousands of mel spectrograms.
86
+ global _mel_basis
87
+ if _mel_basis is None:
88
+ _mel_basis = _build_mel_basis(hparams)
89
+
90
+ # Inverse mel basis
91
+ p = np.matmul(_mel_basis, _mel_basis.T)
92
+ d = [1.0 / x if np.abs(x) > 1.0e-8 else x for x in np.sum(p, axis=0)]
93
+ _inv_mel_basis = np.matmul(_mel_basis.T, np.diag(d))
94
+
95
+ # Invert mel basis to recover linear spectrogram
96
+ S = np.dot(_inv_mel_basis, S)
97
+
98
+ # Use Griffin-Lim to recover waveform
99
+ wav = _griffin_lim(S ** hparams.power, hparams)
100
+
101
+ # Invert preemphasis
102
+ if hparams.preemphasize:
103
+ wav = inv_preemphasis(wav, hparams.preemphasis)
104
+
105
+ return wav
106
+
107
+
108
+ def preemphasis(wav, k, preemphasize=True):
109
+ # Amplifies high frequency content in a waveform.
110
+ if preemphasize:
111
+ wav = signal.lfilter([1, -k], [1], wav)
112
+ return wav
113
+
114
+
115
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
116
+ # Inverts the preemphasis filter.
117
+ if inv_preemphasize:
118
+ wav = signal.lfilter([1], [1, -k], wav)
119
+ return wav
120
+
121
+
122
+ def _build_mel_basis(hparams):
123
+ return librosa.filters.mel(
124
+ sr=hparams.sample_rate,
125
+ n_fft=hparams.n_fft,
126
+ n_mels=hparams.num_mels,
127
+ fmin=hparams.fmin,
128
+ fmax=hparams.fmax,
129
+ )
130
+
131
+
132
+ def _griffin_lim(S, hparams):
133
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
134
+ S = np.abs(S).astype(np.complex)
135
+ wav = librosa.istft(
136
+ S * angles, hop_length=hparams.hop_size, win_length=hparams.win_size
137
+ )
138
+ for i in range(hparams.griffin_lim_iters):
139
+ angles = np.exp(
140
+ 1j
141
+ * np.angle(
142
+ librosa.stft(
143
+ wav,
144
+ n_fft=hparams.n_fft,
145
+ hop_length=hparams.hop_size,
146
+ win_length=hparams.win_size,
147
+ )
148
+ )
149
+ )
150
+ wav = librosa.istft(
151
+ S * angles, hop_length=hparams.hop_size, win_length=hparams.win_size
152
+ )
153
+
154
+ return wav
pmt2/synthesizer/english utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
pmt2/synthesizer/english utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
pmt2/synthesizer/english utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+ import re
13
+ from unidecode import unidecode
14
+ from synthesizer.utils.numbers import normalize_numbers
15
+
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ # text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
pmt2/synthesizer/english utils/numbers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+
5
+ _inflect = inflect.engine()
6
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11
+ _number_re = re.compile(r"[0-9]+")
12
+
13
+
14
+ def _remove_commas(m):
15
+ return m.group(1).replace(",", "")
16
+
17
+
18
+ def _expand_decimal_point(m):
19
+ return m.group(1).replace(".", " point ")
20
+
21
+
22
+ def _expand_dollars(m):
23
+ match = m.group(1)
24
+ parts = match.split(".")
25
+ if len(parts) > 2:
26
+ return match + " dollars" # Unexpected format
27
+ dollars = int(parts[0]) if parts[0] else 0
28
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29
+ if dollars and cents:
30
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
31
+ cent_unit = "cent" if cents == 1 else "cents"
32
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33
+ elif dollars:
34
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
35
+ return "%s %s" % (dollars, dollar_unit)
36
+ elif cents:
37
+ cent_unit = "cent" if cents == 1 else "cents"
38
+ return "%s %s" % (cents, cent_unit)
39
+ else:
40
+ return "zero dollars"
41
+
42
+
43
+ def _expand_ordinal(m):
44
+ return _inflect.number_to_words(m.group(0))
45
+
46
+
47
+ def _expand_number(m):
48
+ num = int(m.group(0))
49
+ if num > 1000 and num < 3000:
50
+ if num == 2000:
51
+ return "two thousand"
52
+ elif num > 2000 and num < 2010:
53
+ return "two thousand " + _inflect.number_to_words(num % 100)
54
+ elif num % 100 == 0:
55
+ return _inflect.number_to_words(num // 100) + " hundred"
56
+ else:
57
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58
+ else:
59
+ return _inflect.number_to_words(num, andword="")
60
+
61
+
62
+ def normalize_numbers(text):
63
+ text = re.sub(_comma_number_re, _remove_commas, text)
64
+ text = re.sub(_pounds_re, r"\1 pounds", text)
65
+ text = re.sub(_dollars_re, _expand_dollars, text)
66
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
68
+ text = re.sub(_number_re, _expand_number, text)
69
+ return text
pmt2/synthesizer/english utils/plot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def split_title_line(title_text, max_words=5):
5
+ """
6
+ A function that splits any string based on specific character
7
+ (returning it with the string), with maximum number of words on it
8
+ """
9
+ seq = title_text.split()
10
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
11
+
12
+
13
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ if max_len is not None:
19
+ alignment = alignment[:, :max_len]
20
+
21
+ fig = plt.figure(figsize=(8, 6))
22
+ ax = fig.add_subplot(111)
23
+
24
+ im = ax.imshow(
25
+ alignment,
26
+ aspect="auto",
27
+ origin="lower",
28
+ interpolation="none")
29
+ fig.colorbar(im, ax=ax)
30
+ xlabel = "Decoder timestep"
31
+
32
+ if split_title:
33
+ title = split_title_line(title)
34
+
35
+ plt.xlabel(xlabel)
36
+ plt.title(title)
37
+ plt.ylabel("Encoder timestep")
38
+ plt.tight_layout()
39
+ plt.savefig(path, format="png")
40
+ plt.close()
41
+
42
+
43
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+
48
+ if max_len is not None:
49
+ target_spectrogram = target_spectrogram[:max_len]
50
+ pred_spectrogram = pred_spectrogram[:max_len]
51
+
52
+ if split_title:
53
+ title = split_title_line(title)
54
+
55
+ fig = plt.figure(figsize=(10, 8))
56
+ # Set common labels
57
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
58
+
59
+ #target spectrogram subplot
60
+ if target_spectrogram is not None:
61
+ ax1 = fig.add_subplot(311)
62
+ ax2 = fig.add_subplot(312)
63
+
64
+ if auto_aspect:
65
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
66
+ else:
67
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
68
+ ax1.set_title("Target Mel-Spectrogram")
69
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
70
+ ax2.set_title("Predicted Mel-Spectrogram")
71
+ else:
72
+ ax2 = fig.add_subplot(211)
73
+
74
+ if auto_aspect:
75
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
76
+ else:
77
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
78
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
79
+
80
+ plt.tight_layout()
81
+ plt.savefig(path, format="png")
82
+ plt.close()
pmt2/synthesizer/english utils/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
12
+
13
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
15
+
16
+ # Export all symbols:
17
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
pmt2/synthesizer/english utils/text.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.utils.symbols import symbols
2
+ from synthesizer.utils import cleaners
3
+ import re
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+ # Regular expression matching text enclosed in curly braces:
11
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
12
+
13
+
14
+ def text_to_sequence(text, cleaner_names):
15
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16
+
17
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
18
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
19
+
20
+ Args:
21
+ text: string to convert to a sequence
22
+ cleaner_names: names of the cleaner functions to run the text through
23
+
24
+ Returns:
25
+ List of integers corresponding to the symbols in the text
26
+ """
27
+ sequence = []
28
+
29
+ # Check for curly braces and treat their contents as ARPAbet:
30
+ while len(text):
31
+ m = _curly_re.match(text)
32
+ if not m:
33
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
34
+ break
35
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
36
+ sequence += _arpabet_to_sequence(m.group(2))
37
+ text = m.group(3)
38
+
39
+ # Append EOS token
40
+ sequence.append(_symbol_to_id["~"])
41
+ return sequence
42
+
43
+
44
+ def sequence_to_text(sequence):
45
+ """Converts a sequence of IDs back to a string"""
46
+ result = ""
47
+ for symbol_id in sequence:
48
+ if symbol_id in _id_to_symbol:
49
+ s = _id_to_symbol[symbol_id]
50
+ # Enclose ARPAbet back in curly braces:
51
+ if len(s) > 1 and s[0] == "@":
52
+ s = "{%s}" % s[1:]
53
+ result += s
54
+ return result.replace("}{", " ")
55
+
56
+
57
+ def _clean_text(text, cleaner_names):
58
+ for name in cleaner_names:
59
+ cleaner = getattr(cleaners, name)
60
+ if not cleaner:
61
+ raise Exception("Unknown cleaner: %s" % name)
62
+ text = cleaner(text)
63
+ return text
64
+
65
+
66
+ def _symbols_to_sequence(symbols):
67
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68
+
69
+
70
+ def _arpabet_to_sequence(text):
71
+ return _symbols_to_sequence(["@" + s for s in text.split()])
72
+
73
+
74
+ def _should_keep_symbol(s):
75
+ return s in _symbol_to_id and s not in ("_", "~")
pmt2/synthesizer/hparams.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+
24
+ # sample_rate = 22050,
25
+ # n_fft = 1024,
26
+ # num_mels = 80,
27
+ # hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
28
+ # win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
29
+ # fmin = 0,
30
+ # fmax = 11025,
31
+
32
+ sample_rate = 24000,
33
+ n_fft = 2048,
34
+ num_mels = 80,
35
+ hop_size = 300, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
36
+ win_size = 1200, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
37
+ fmin = 80,
38
+
39
+ # sample_rate = 16000,
40
+ # n_fft = 800,
41
+ # num_mels = 80,
42
+ # hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
43
+ # win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
44
+ # fmin = 55,
45
+ min_level_db = -100,
46
+ ref_level_db = 20,
47
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
48
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
49
+ preemphasize = True,
50
+
51
+ ### Tacotron Text-to-Speech (TTS)
52
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
53
+ tts_encoder_dims = 256,
54
+ tts_decoder_dims = 128,
55
+ tts_postnet_dims = 512,
56
+ tts_encoder_K = 5,
57
+ tts_lstm_dims = 1024,
58
+ tts_postnet_K = 5,
59
+ tts_num_highways = 4,
60
+ tts_dropout = 0.5,
61
+ tts_cleaner_names = ["persian_cleaners"],
62
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
63
+ # For example, for a range of [-4, 4], this
64
+ # will terminate the sequence at the first
65
+ # frame that has all values < -3.4
66
+
67
+ ### Tacotron Training
68
+ tts_schedule = [(2, 1e-3, 10_000, 16), # Progressive training schedule
69
+ (2, 5e-4, 20_000, 16), # (r, lr, step, batch_size)
70
+ (2, 2e-4, 40_000, 16), #
71
+ (2, 1e-4, 80_000, 16), # r = reduction factor (# of mel frames
72
+ (2, 3e-5, 160_000, 16), # synthesized for each decoder iteration)
73
+ (2, 1e-5, 320_000, 16)], # lr = learning rate
74
+
75
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
76
+ tts_eval_interval = 5000, # Number of steps between model evaluation (sample generation)
77
+ # Set to -1 to generate after completing epoch, or 0 to disable
78
+
79
+ tts_eval_num_samples = 1, # Makes this number of samples
80
+
81
+ ### Data Preprocessing
82
+ max_mel_frames = 900,
83
+ rescale = True,
84
+ rescaling_max = 0.9,
85
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
86
+
87
+ ### Mel Visualization and Griffin-Lim
88
+ signal_normalization = True,
89
+ power = 1.5,
90
+ griffin_lim_iters = 60,
91
+
92
+ ### Audio processing options
93
+ fmax = 7600, # Should not exceed (sample_rate // 2)
94
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
95
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
96
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
97
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
98
+ # and [0, max_abs_value] if False
99
+ trim_silence = True, # Use with sample_rate of 16000 for best results
100
+
101
+ ### SV2TTS
102
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
103
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
104
+ utterance_min_duration = 0.8, # Duration in seconds below which utterances are discarded
105
+ )
106
+
107
+ def hparams_debug_string():
108
+ return str(hparams)
pmt2/synthesizer/hparams_new.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+
24
+ sample_rate = 22050,
25
+ n_fft = 1024,
26
+ num_mels = 80,
27
+ hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
28
+ win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
29
+ fmin = 0,
30
+ fmax = 11025,
31
+
32
+ # sample_rate = 24000,
33
+ # n_fft = 2048,
34
+ # num_mels = 80,
35
+ # hop_size = 300, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
36
+ # win_size = 1200, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
37
+ # fmin = 80,
38
+
39
+ # sample_rate = 16000,
40
+ # n_fft = 800,
41
+ # num_mels = 80,
42
+ # hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
43
+ # win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
44
+ # fmin = 55,
45
+ min_level_db = -100,
46
+ ref_level_db = 20,
47
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
48
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
49
+ preemphasize = True,
50
+
51
+ ### Tacotron Text-to-Speech (TTS)
52
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
53
+ tts_encoder_dims = 256,
54
+ tts_decoder_dims = 128,
55
+ tts_postnet_dims = 512,
56
+ tts_encoder_K = 5,
57
+ tts_lstm_dims = 1024,
58
+ tts_postnet_K = 5,
59
+ tts_num_highways = 4,
60
+ tts_dropout = 0.5,
61
+ tts_cleaner_names = ["persian_cleaners"],
62
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
63
+ # For example, for a range of [-4, 4], this
64
+ # will terminate the sequence at the first
65
+ # frame that has all values < -3.4
66
+
67
+ ### Tacotron Training
68
+ tts_schedule = [(2, 1e-3, 10_000, 16), # Progressive training schedule
69
+ (2, 5e-4, 20_000, 16), # (r, lr, step, batch_size)
70
+ (2, 2e-4, 40_000, 16), #
71
+ (2, 1e-4, 80_000, 16), # r = reduction factor (# of mel frames
72
+ (2, 3e-5, 160_000, 16), # synthesized for each decoder iteration)
73
+ (2, 1e-5, 320_000, 16)], # lr = learning rate
74
+
75
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
76
+ tts_eval_interval = 5000, # Number of steps between model evaluation (sample generation)
77
+ # Set to -1 to generate after completing epoch, or 0 to disable
78
+
79
+ tts_eval_num_samples = 1, # Makes this number of samples
80
+
81
+ ### Data Preprocessing
82
+ max_mel_frames = 900,
83
+ rescale = True,
84
+ rescaling_max = 0.9,
85
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
86
+
87
+ ### Mel Visualization and Griffin-Lim
88
+ signal_normalization = True,
89
+ power = 1.5,
90
+ griffin_lim_iters = 60,
91
+
92
+ ### Audio processing options
93
+ # fmax = 7600, # Should not exceed (sample_rate // 2)
94
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
95
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
96
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
97
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
98
+ # and [0, max_abs_value] if False
99
+ trim_silence = True, # Use with sample_rate of 16000 for best results
100
+
101
+ ### SV2TTS
102
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
103
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
104
+ utterance_min_duration = 0.8, # Duration in seconds below which utterances are discarded
105
+ )
106
+
107
+ def hparams_debug_string():
108
+ return str(hparams)
pmt2/synthesizer/inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.persian_utils.symbols import symbols
6
+ from synthesizer.persian_utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ # if torch.cuda.is_available():
30
+ # self.device = torch.device("cuda")
31
+ # else:
32
+ # self.device = torch.device("cpu")
33
+ self.device = torch.device("cpu")
34
+
35
+ if self.verbose:
36
+ print("Synthesizer using device:", self.device)
37
+
38
+ # Tacotron model will be instantiated later on first use.
39
+ self._model = None
40
+
41
+ def is_loaded(self):
42
+ """
43
+ Whether the model is loaded in memory.
44
+ """
45
+ return self._model is not None
46
+
47
+ def load(self):
48
+ """
49
+ Instantiates and loads the model given the weights file that was passed in the constructor.
50
+ """
51
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
52
+ num_chars=len(symbols),
53
+ encoder_dims=hparams.tts_encoder_dims,
54
+ decoder_dims=hparams.tts_decoder_dims,
55
+ n_mels=hparams.num_mels,
56
+ fft_bins=hparams.num_mels,
57
+ postnet_dims=hparams.tts_postnet_dims,
58
+ encoder_K=hparams.tts_encoder_K,
59
+ lstm_dims=hparams.tts_lstm_dims,
60
+ postnet_K=hparams.tts_postnet_K,
61
+ num_highways=hparams.tts_num_highways,
62
+ dropout=hparams.tts_dropout,
63
+ stop_threshold=hparams.tts_stop_threshold,
64
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
65
+
66
+ self._model.load(self.model_fpath)
67
+ self._model.eval()
68
+
69
+ if self.verbose:
70
+ # print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
71
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath, self._model.state_dict()["step"]))
72
+
73
+ def synthesize_spectrograms(self, texts: List[str],
74
+ embeddings: Union[np.ndarray, List[np.ndarray]],
75
+ return_alignments=False):
76
+ """
77
+ Synthesizes mel spectrograms from texts and speaker embeddings.
78
+
79
+ :param texts: a list of N text prompts to be synthesized
80
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
81
+ :param return_alignments: if True, a matrix representing the alignments between the
82
+ characters
83
+ and each decoder output step will be returned for each spectrogram
84
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
85
+ sequence length of spectrogram i, and possibly the alignments.
86
+ """
87
+ # Load the model on the first request.
88
+ if not self.is_loaded():
89
+ self.load()
90
+
91
+ # Preprocess text inputs
92
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
93
+ if not isinstance(embeddings, list):
94
+ embeddings = [embeddings]
95
+
96
+ # Batch inputs
97
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
98
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
99
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
100
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
101
+
102
+ specs = []
103
+ for i, batch in enumerate(batched_inputs, 1):
104
+ if self.verbose:
105
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
106
+
107
+ # Pad texts so they are all the same length
108
+ text_lens = [len(text) for text in batch]
109
+ max_text_len = max(text_lens)
110
+ chars = [pad1d(text, max_text_len) for text in batch]
111
+ chars = np.stack(chars)
112
+
113
+ # Stack speaker embeddings into 2D array for batch processing
114
+ speaker_embeds = np.stack(batched_embeds[i-1])
115
+
116
+ # Convert to tensor
117
+ chars = torch.tensor(chars).long().to(self.device)
118
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
119
+
120
+ # Inference
121
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
122
+ mels = mels.detach().cpu().numpy()
123
+ for m in mels:
124
+ # Trim silence from end of each spectrogram
125
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
126
+ m = m[:, :-1]
127
+ specs.append(m)
128
+
129
+ if self.verbose:
130
+ print("\n\nDone.\n")
131
+ return (specs, alignments) if return_alignments else specs
132
+
133
+ @staticmethod
134
+ def load_preprocess_wav(fpath):
135
+ """
136
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
137
+ train the synthesizer.
138
+ """
139
+ wav = librosa.load(str(fpath), sr=hparams.sample_rate)[0]
140
+ if hparams.rescale:
141
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
142
+ return wav
143
+
144
+ @staticmethod
145
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
146
+ """
147
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
148
+ were fed to the synthesizer when training.
149
+ """
150
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
151
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
152
+ else:
153
+ wav = fpath_or_wav
154
+
155
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
156
+ return mel_spectrogram
157
+
158
+ @staticmethod
159
+ def griffin_lim(mel):
160
+ """
161
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
162
+ with the same parameters present in hparams.py.
163
+ """
164
+ return audio.inv_mel_spectrogram(mel, hparams)
165
+
166
+
167
+ def pad1d(x, max_len, pad_value=0):
168
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
pmt2/synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
pmt2/synthesizer/persian_utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
pmt2/synthesizer/persian_utils/plot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def split_title_line(title_text, max_words=5):
5
+ """
6
+ A function that splits any string based on specific character
7
+ (returning it with the string), with maximum number of words on it
8
+ """
9
+ seq = title_text.split()
10
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
11
+
12
+
13
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ if max_len is not None:
19
+ alignment = alignment[:, :max_len]
20
+
21
+ fig = plt.figure(figsize=(8, 6))
22
+ ax = fig.add_subplot(111)
23
+
24
+ im = ax.imshow(
25
+ alignment,
26
+ aspect="auto",
27
+ origin="lower",
28
+ interpolation="none")
29
+ fig.colorbar(im, ax=ax)
30
+ xlabel = "Decoder timestep"
31
+
32
+ if split_title:
33
+ title = split_title_line(title)
34
+
35
+ plt.xlabel(xlabel)
36
+ plt.title(title)
37
+ plt.ylabel("Encoder timestep")
38
+ plt.tight_layout()
39
+ plt.savefig(path, format="png")
40
+ plt.close()
41
+
42
+
43
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+
48
+ if max_len is not None:
49
+ target_spectrogram = target_spectrogram[:max_len]
50
+ pred_spectrogram = pred_spectrogram[:max_len]
51
+
52
+ if split_title:
53
+ title = split_title_line(title)
54
+
55
+ fig = plt.figure(figsize=(10, 8))
56
+ # Set common labels
57
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
58
+
59
+ #target spectrogram subplot
60
+ if target_spectrogram is not None:
61
+ ax1 = fig.add_subplot(311)
62
+ ax2 = fig.add_subplot(312)
63
+
64
+ if auto_aspect:
65
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
66
+ else:
67
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
68
+ ax1.set_title("Target Mel-Spectrogram")
69
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
70
+ ax2.set_title("Predicted Mel-Spectrogram")
71
+ else:
72
+ ax2 = fig.add_subplot(211)
73
+
74
+ if auto_aspect:
75
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
76
+ else:
77
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
78
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
79
+
80
+ plt.tight_layout()
81
+ plt.savefig(path, format="png")
82
+ plt.close()
pmt2/synthesizer/persian_utils/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Defines the set of symbols used in text input to the model.
4
+
5
+ The default is a set of ASCII characters that works well for Persian.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = "ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآۀةأؤإئًَُّ!(),-.:;? ̠،…؛؟‌٪#üABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_–@+/\u200c"
12
+
13
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
15
+
16
+ # Export all symbols:
17
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
pmt2/synthesizer/persian_utils/text.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.persian_utils.symbols import symbols
2
+
3
+
4
+ # Mappings from symbol to numeric ID and vice versa:
5
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
6
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
7
+
8
+
9
+ def text_to_sequence(text, cleaner_names):
10
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
11
+
12
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
13
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
14
+
15
+ Args:
16
+ text: string to convert to a sequence
17
+ cleaner_names: names of the cleaner functions to run the text through
18
+
19
+ Returns:
20
+ List of integers corresponding to the symbols in the text
21
+ """
22
+ # print("######")
23
+ # print(cleaner_names)
24
+ if cleaner_names != ['persian_cleaners']:
25
+ return 'cleaner is not persian!'
26
+ sequence = []
27
+ for phoneme in text:
28
+ sequence.append(_symbol_to_id[phoneme])
29
+ # print(sequence)
30
+ # print("************")
31
+ return sequence
32
+
33
+
34
+ def sequence_to_text(sequence):
35
+ """Converts a sequence of IDs back to a string"""
36
+ result = []
37
+ for symbol_id in sequence:
38
+ result.append(_id_to_symbol[symbol_id])
39
+ ' '.join(result)
40
+ return result
pmt2/synthesizer/preprocess.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
14
+ no_alignments: bool, datasets_name: str, subfolders: str):
15
+ # Gather the input directories
16
+ dataset_root = datasets_root.joinpath(datasets_name)
17
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
18
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
19
+ assert all(input_dir.exists() for input_dir in input_dirs)
20
+
21
+ # Create the output directories for each output file type
22
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
23
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
24
+
25
+ # Create a metadata file
26
+ metadata_fpath = out_dir.joinpath("train.txt")
27
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
28
+
29
+ # Preprocess the dataset
30
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
31
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
32
+ hparams=hparams, no_alignments=no_alignments)
33
+ job = Pool(n_processes).imap(func, speaker_dirs)
34
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
35
+ for metadatum in speaker_metadata:
36
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
37
+ metadata_file.close()
38
+ # Verify the contents of the metadata file
39
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
40
+ metadata = [line.split("|") for line in metadata_file]
41
+ mel_frames = sum([int(m[4]) for m in metadata])
42
+ timesteps = sum([int(m[3]) for m in metadata])
43
+ sample_rate = hparams.sample_rate
44
+ hours = (timesteps / sample_rate) / 3600
45
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
46
+ (len(metadata), mel_frames, timesteps, hours))
47
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
48
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
49
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
50
+
51
+
52
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
53
+ metadata = []
54
+ for book_dir in speaker_dir.glob("*"):
55
+ if no_alignments:
56
+ # Gather the utterance audios and texts
57
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
58
+ extensions = ["*.wav", "*.flac", "*.mp3"]
59
+ for extension in extensions:
60
+ wav_fpaths = book_dir.glob(extension)
61
+
62
+ for wav_fpath in wav_fpaths:
63
+ # Load the audio waveform
64
+ wav, _ = librosa.load(str(wav_fpath), sr=hparams.sample_rate)
65
+ if hparams.rescale:
66
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
67
+
68
+ # Get the corresponding text
69
+ # Check for .txt (for compatibility with other datasets)
70
+ text_fpath = wav_fpath.with_suffix(".txt")
71
+ if not text_fpath.exists():
72
+ # Check for .normalized.txt (LibriTTS)
73
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
74
+ assert text_fpath.exists()
75
+ with text_fpath.open("r") as text_file:
76
+ text = "".join([line for line in text_file])
77
+ text = text.replace("\"", "")
78
+ text = text.strip()
79
+
80
+ # Process the utterance
81
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
82
+ skip_existing, hparams))
83
+ else:
84
+ # Process alignment file (LibriSpeech support)
85
+ # Gather the utterance audios and texts
86
+ try:
87
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
88
+ with alignments_fpath.open("r") as alignments_file:
89
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
90
+ except StopIteration:
91
+ # A few alignment files will be missing
92
+ continue
93
+
94
+ # Iterate over each entry in the alignments file
95
+ for wav_fname, words, end_times in alignments:
96
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
97
+ assert wav_fpath.exists()
98
+ words = words.replace("\"", "").split(",")
99
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
100
+
101
+ # Process each sub-utterance
102
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
103
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
104
+ sub_basename = "%s_%02d" % (wav_fname, i)
105
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
106
+ skip_existing, hparams))
107
+
108
+ return [m for m in metadata if m is not None]
109
+
110
+
111
+ def split_on_silences(wav_fpath, words, end_times, hparams):
112
+ # Load the audio waveform
113
+ wav, _ = librosa.load(str(wav_fpath), sr=hparams.sample_rate)
114
+ if hparams.rescale:
115
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
116
+
117
+ words = np.array(words)
118
+ start_times = np.array([0.0] + end_times[:-1])
119
+ end_times = np.array(end_times)
120
+ assert len(words) == len(end_times) == len(start_times)
121
+ assert words[0] == "" and words[-1] == ""
122
+
123
+ # Find pauses that are too long
124
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
125
+ mask[0] = mask[-1] = True
126
+ breaks = np.where(mask)[0]
127
+
128
+ # Profile the noise from the silences and perform noise reduction on the waveform
129
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
130
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
131
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
132
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
133
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
134
+ wav = logmmse.denoise(wav, profile, eta=0)
135
+
136
+ # Re-attach segments that are too short
137
+ segments = list(zip(breaks[:-1], breaks[1:]))
138
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
139
+ i = 0
140
+ while i < len(segments) and len(segments) > 1:
141
+ if segment_durations[i] < hparams.utterance_min_duration:
142
+ # See if the segment can be re-attached with the right or the left segment
143
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
144
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
145
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
146
+
147
+ # Do not re-attach if it causes the joined utterance to be too long
148
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
149
+ i += 1
150
+ continue
151
+
152
+ # Re-attach the segment with the neighbour of shortest duration
153
+ j = i - 1 if left_duration <= right_duration else i
154
+ segments[j] = (segments[j][0], segments[j + 1][1])
155
+ segment_durations[j] = joined_duration
156
+ del segments[j + 1], segment_durations[j + 1]
157
+ else:
158
+ i += 1
159
+
160
+ # Split the utterance
161
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
162
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
163
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
164
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
165
+
166
+ # # DEBUG: play the audio segments (run with -n=1)
167
+ # import sounddevice as sd
168
+ # if len(wavs) > 1:
169
+ # print("This sentence was split in %d segments:" % len(wavs))
170
+ # else:
171
+ # print("There are no silences long enough for this sentence to be split:")
172
+ # for wav, text in zip(wavs, texts):
173
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
174
+ # # when playing them. You shouldn't need to do that in your parsers.
175
+ # wav = np.concatenate((wav, [0] * 16000))
176
+ # print("\t%s" % text)
177
+ # sd.play(wav, 16000, blocking=True)
178
+ # print("")
179
+
180
+ return wavs, texts
181
+
182
+
183
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
184
+ skip_existing: bool, hparams):
185
+ ## FOR REFERENCE:
186
+ # For you not to lose your head if you ever wish to change things here or implement your own
187
+ # synthesizer.
188
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
189
+ # - There is no processing done to the audios that will be saved to disk beyond volume
190
+ # normalization (in split_on_silences)
191
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
192
+ # is why we re-apply it on the audio on the side of the vocoder.
193
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
194
+ # without extra padding. This means that you won't have an exact relation between the length
195
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
196
+
197
+
198
+ # Skip existing utterances if needed
199
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
200
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
201
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
202
+ return None
203
+
204
+ #print(text)
205
+
206
+ # Trim silence
207
+ if hparams.trim_silence:
208
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
209
+
210
+ # Skip utterances that are too short
211
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
212
+ return None
213
+
214
+ # Compute the mel spectrogram
215
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
216
+ mel_frames = mel_spectrogram.shape[1]
217
+
218
+ # Skip utterances that are too long
219
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
220
+ return None
221
+
222
+ # Write the spectrogram, embed and audio to disk
223
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
224
+ np.save(wav_fpath, wav, allow_pickle=False)
225
+
226
+ # Return a tuple describing this training example
227
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
228
+
229
+
230
+ def embed_utterance(fpaths, encoder_model_fpath):
231
+ if not encoder.is_loaded():
232
+ encoder.load_model(encoder_model_fpath)
233
+
234
+ # Compute the speaker embedding of the utterance
235
+ wav_fpath, embed_fpath = fpaths
236
+ wav = np.load(wav_fpath)
237
+ wav = encoder.preprocess_wav(wav)
238
+ embed = encoder.embed_utterance(wav)
239
+ np.save(embed_fpath, embed, allow_pickle=False)
240
+
241
+
242
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
243
+ wav_dir = synthesizer_root.joinpath("audio")
244
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
245
+ assert wav_dir.exists() and metadata_fpath.exists()
246
+ embed_dir = synthesizer_root.joinpath("embeds")
247
+ embed_dir.mkdir(exist_ok=True)
248
+
249
+ # Gather the input wave filepath and the target output embed filepath
250
+ with metadata_fpath.open("r") as metadata_file:
251
+ metadata = [line.split("|") for line in metadata_file]
252
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
253
+
254
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
255
+ # Embed the utterances in separate threads
256
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
257
+ job = Pool(n_processes).imap(func, fpaths)
258
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
259
+