sob111 commited on
Commit
bfd1232
·
verified ·
1 Parent(s): d2aa88e

Create train_gpt_xtts.py

Browse files
Files changed (1) hide show
  1. train_gpt_xtts.py +177 -0
train_gpt_xtts.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from trainer import Trainer, TrainerArgs
4
+
5
+ from TTS.config.shared_configs import BaseDatasetConfig
6
+ from TTS.tts.datasets import load_tts_samples
7
+ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
8
+ from TTS.utils.manage import ModelManager
9
+
10
+ # Logging parameters
11
+ RUN_NAME = "GPT_XTTS_v2.0_LJSpeech_FT"
12
+ PROJECT_NAME = "XTTS_trainer"
13
+ DASHBOARD_LOGGER = "tensorboard"
14
+ LOGGER_URI = None
15
+
16
+ # Set here the path that the checkpoints will be saved. Default: ./run/training/
17
+ OUT_PATH = "/tmp/output_model/run/training"
18
+
19
+
20
+ # Training Parameters
21
+ OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
22
+ START_WITH_EVAL = True # if True it will star with evaluation
23
+ BATCH_SIZE = 3 # set here the batch size
24
+ GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
25
+ # Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
26
+
27
+ # Define here the dataset that you want to use for the fine-tuning on.
28
+ config_dataset = BaseDatasetConfig(
29
+ formatter="ljspeech",
30
+ dataset_name="ljspeech",
31
+ path="/raid/datasets/LJSpeech-1.1_24khz/",
32
+ meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
33
+ language="en",
34
+ )
35
+
36
+ # Add here the configs of the datasets
37
+ DATASETS_CONFIG_LIST = [config_dataset]
38
+
39
+ # Define the path where XTTS v2.0.1 files will be downloaded
40
+ CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
41
+ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
42
+
43
+
44
+ # DVAE files
45
+ DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
46
+ MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
47
+
48
+ # Set the path to the downloaded files
49
+ DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
50
+ MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
51
+
52
+ # download DVAE files if needed
53
+ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
54
+ print(" > Downloading DVAE files!")
55
+ ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
56
+
57
+
58
+ # Download XTTS v2.0 checkpoint if needed
59
+ TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
60
+ XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
61
+
62
+ # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
63
+ TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
64
+ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
65
+
66
+ # download XTTS v2.0 files if needed
67
+ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
68
+ print(" > Downloading XTTS v2.0 files!")
69
+ ModelManager._download_model_files(
70
+ [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
71
+ )
72
+
73
+
74
+ # Training sentences generations
75
+ SPEAKER_REFERENCE = [
76
+ "./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
77
+ ]
78
+ LANGUAGE = config_dataset.language
79
+
80
+
81
+ def main():
82
+ # init args and config
83
+ model_args = GPTArgs(
84
+ max_conditioning_length=132300, # 6 secs
85
+ min_conditioning_length=66150, # 3 secs
86
+ debug_loading_failures=False,
87
+ max_wav_length=255995, # ~11.6 seconds
88
+ max_text_length=200,
89
+ mel_norm_file=MEL_NORM_FILE,
90
+ dvae_checkpoint=DVAE_CHECKPOINT,
91
+ xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
92
+ tokenizer_file=TOKENIZER_FILE,
93
+ gpt_num_audio_tokens=1026,
94
+ gpt_start_audio_token=1024,
95
+ gpt_stop_audio_token=1025,
96
+ gpt_use_masking_gt_prompt_approach=True,
97
+ gpt_use_perceiver_resampler=True,
98
+ )
99
+ # define audio config
100
+ audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
101
+ # training parameters config
102
+ config = GPTTrainerConfig(
103
+ output_path=OUT_PATH,
104
+ model_args=model_args,
105
+ run_name=RUN_NAME,
106
+ project_name=PROJECT_NAME,
107
+ run_description="""
108
+ GPT XTTS training
109
+ """,
110
+ dashboard_logger=DASHBOARD_LOGGER,
111
+ logger_uri=LOGGER_URI,
112
+ audio=audio_config,
113
+ batch_size=BATCH_SIZE,
114
+ batch_group_size=48,
115
+ eval_batch_size=BATCH_SIZE,
116
+ num_loader_workers=8,
117
+ eval_split_max_size=256,
118
+ print_step=50,
119
+ plot_step=100,
120
+ log_model_step=1000,
121
+ save_step=10000,
122
+ save_n_checkpoints=1,
123
+ save_checkpoints=True,
124
+ # target_loss="loss",
125
+ print_eval=False,
126
+ # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
127
+ optimizer="AdamW",
128
+ optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
129
+ optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
130
+ lr=5e-06, # learning rate
131
+ lr_scheduler="MultiStepLR",
132
+ # it was adjusted accordly for the new step scheme
133
+ lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
134
+ test_sentences=[
135
+ {
136
+ "text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
137
+ "speaker_wav": SPEAKER_REFERENCE,
138
+ "language": LANGUAGE,
139
+ },
140
+ {
141
+ "text": "This cake is great. It's so delicious and moist.",
142
+ "speaker_wav": SPEAKER_REFERENCE,
143
+ "language": LANGUAGE,
144
+ },
145
+ ],
146
+ )
147
+
148
+ # init the model from config
149
+ model = GPTTrainer.init_from_config(config)
150
+
151
+ # load training samples
152
+ train_samples, eval_samples = load_tts_samples(
153
+ DATASETS_CONFIG_LIST,
154
+ eval_split=True,
155
+ eval_split_max_size=config.eval_split_max_size,
156
+ eval_split_size=config.eval_split_size,
157
+ )
158
+
159
+ # init the trainer and 🚀
160
+ trainer = Trainer(
161
+ TrainerArgs(
162
+ restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
163
+ skip_train_epoch=False,
164
+ start_with_eval=START_WITH_EVAL,
165
+ grad_accum_steps=GRAD_ACUMM_STEPS,
166
+ ),
167
+ config,
168
+ output_path=OUT_PATH,
169
+ model=model,
170
+ train_samples=train_samples,
171
+ eval_samples=eval_samples,
172
+ )
173
+ trainer.fit()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()