mrfakename commited on
Commit
d37849f
·
verified ·
1 Parent(s): a7a80be

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

.github/workflows/pre-commit.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pre-commit
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+ branches: [main]
7
+
8
+ jobs:
9
+ pre-commit:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ - uses: actions/setup-python@v3
14
+ - uses: pre-commit/[email protected]
.pre-commit-config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.7.0
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ args: [--fix]
9
+ # Run the formatter.
10
+ - id: ruff-format
11
+ - repo: https://github.com/pre-commit/pre-commit-hooks
12
+ rev: v2.3.0
13
+ hooks:
14
+ - id: check-yaml
README_REPO.md CHANGED
@@ -43,6 +43,26 @@ pip install -r requirements.txt
43
  docker build -t f5tts:v1 .
44
  ```
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ## Prepare Dataset
47
 
48
  Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
 
43
  docker build -t f5tts:v1 .
44
  ```
45
 
46
+ ### Development
47
+
48
+ When making a pull request, please use pre-commit to ensure code quality:
49
+
50
+ ```bash
51
+ pip install pre-commit
52
+ pre-commit install
53
+ ```
54
+
55
+ This will run linters and formatters automatically before each commit.
56
+
57
+ Manually run using:
58
+
59
+ ```bash
60
+ pre-commit run --all-files
61
+ ```
62
+
63
+ Note: Some model components have linting exceptions for E722 to accommodate tensor notation
64
+
65
+
66
  ## Prepare Dataset
67
 
68
  Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import re
2
  import tempfile
3
 
@@ -11,16 +14,19 @@ from pydub import AudioSegment
11
 
12
  try:
13
  import spaces
 
14
  USING_SPACES = True
15
  except ImportError:
16
  USING_SPACES = False
17
 
 
18
  def gpu_decorator(func):
19
  if USING_SPACES:
20
  return spaces.GPU(func)
21
  else:
22
  return func
23
 
 
24
  from model import DiT, UNetT
25
  from model.utils import (
26
  save_spectrogram,
@@ -38,15 +44,18 @@ vocos = load_vocoder()
38
 
39
  # load models
40
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
41
- F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")))
 
 
42
 
43
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
44
- E2TTS_ema_model = load_model(UNetT, E2TTS_model_cfg, str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")))
 
 
45
 
46
 
47
  @gpu_decorator
48
  def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
49
-
50
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
51
 
52
  if model == "F5-TTS":
@@ -54,7 +63,16 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
54
  elif model == "E2-TTS":
55
  ema_model = E2TTS_ema_model
56
 
57
- final_wave, final_sample_rate, combined_spectrogram = infer_process(ref_audio, ref_text, gen_text, ema_model, cross_fade_duration=cross_fade_duration, speed=speed, show_info=gr.Info, progress=gr.Progress())
 
 
 
 
 
 
 
 
 
58
 
59
  # Remove silence
60
  if remove_silence:
@@ -73,17 +91,19 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
73
 
74
 
75
  @gpu_decorator
76
- def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence):
 
 
77
  # Split the script into speaker blocks
78
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
79
  speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
80
-
81
  generated_audio_segments = []
82
-
83
  for i in range(0, len(speaker_blocks), 2):
84
  speaker = speaker_blocks[i]
85
- text = speaker_blocks[i+1].strip()
86
-
87
  # Determine which speaker is talking
88
  if speaker == speaker1_name:
89
  ref_audio = ref_audio1
@@ -93,51 +113,52 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name
93
  ref_text = ref_text2
94
  else:
95
  continue # Skip if the speaker is neither speaker1 nor speaker2
96
-
97
  # Generate audio for this block
98
  audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
99
-
100
  # Convert the generated audio to a numpy array
101
  sr, audio_data = audio
102
-
103
  # Save the audio data as a WAV file
104
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
105
  sf.write(temp_file.name, audio_data, sr)
106
  audio_segment = AudioSegment.from_wav(temp_file.name)
107
-
108
  generated_audio_segments.append(audio_segment)
109
-
110
  # Add a short pause between speakers
111
  pause = AudioSegment.silent(duration=500) # 500ms pause
112
  generated_audio_segments.append(pause)
113
-
114
  # Concatenate all audio segments
115
  final_podcast = sum(generated_audio_segments)
116
-
117
  # Export the final podcast
118
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
119
  podcast_path = temp_file.name
120
  final_podcast.export(podcast_path, format="wav")
121
-
122
  return podcast_path
123
 
 
124
  def parse_speechtypes_text(gen_text):
125
  # Pattern to find (Emotion)
126
- pattern = r'\((.*?)\)'
127
 
128
  # Split the text by the pattern
129
  tokens = re.split(pattern, gen_text)
130
 
131
  segments = []
132
 
133
- current_emotion = 'Regular'
134
 
135
  for i in range(len(tokens)):
136
  if i % 2 == 0:
137
  # This is text
138
  text = tokens[i].strip()
139
  if text:
140
- segments.append({'emotion': current_emotion, 'text': text})
141
  else:
142
  # This is emotion
143
  emotion = tokens[i].strip()
@@ -158,9 +179,7 @@ with gr.Blocks() as app_tts:
158
  gr.Markdown("# Batched TTS")
159
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
160
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
161
- model_choice = gr.Radio(
162
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
163
- )
164
  generate_btn = gr.Button("Synthesize", variant="primary")
165
  with gr.Accordion("Advanced Settings", open=False):
166
  ref_text_input = gr.Textbox(
@@ -206,23 +225,24 @@ with gr.Blocks() as app_tts:
206
  ],
207
  outputs=[audio_output, spectrogram_output],
208
  )
209
-
210
  with gr.Blocks() as app_podcast:
211
  gr.Markdown("# Podcast Generation")
212
  speaker1_name = gr.Textbox(label="Speaker 1 Name")
213
  ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
214
  ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
215
-
216
  speaker2_name = gr.Textbox(label="Speaker 2 Name")
217
  ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
218
  ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
219
-
220
- script_input = gr.Textbox(label="Podcast Script", lines=10,
221
- placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
222
-
223
- podcast_model_choice = gr.Radio(
224
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
225
  )
 
 
226
  podcast_remove_silence = gr.Checkbox(
227
  label="Remove Silences",
228
  value=True,
@@ -230,8 +250,12 @@ with gr.Blocks() as app_podcast:
230
  generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
231
  podcast_output = gr.Audio(label="Generated Podcast")
232
 
233
- def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
234
- return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
 
 
 
 
235
 
236
  generate_podcast_btn.click(
237
  podcast_generation,
@@ -249,23 +273,24 @@ with gr.Blocks() as app_podcast:
249
  outputs=podcast_output,
250
  )
251
 
 
252
  def parse_emotional_text(gen_text):
253
  # Pattern to find (Emotion)
254
- pattern = r'\((.*?)\)'
255
 
256
  # Split the text by the pattern
257
  tokens = re.split(pattern, gen_text)
258
 
259
  segments = []
260
 
261
- current_emotion = 'Regular'
262
 
263
  for i in range(len(tokens)):
264
  if i % 2 == 0:
265
  # This is text
266
  text = tokens[i].strip()
267
  if text:
268
- segments.append({'emotion': current_emotion, 'text': text})
269
  else:
270
  # This is emotion
271
  emotion = tokens[i].strip()
@@ -273,6 +298,7 @@ def parse_emotional_text(gen_text):
273
 
274
  return segments
275
 
 
276
  with gr.Blocks() as app_emotional:
277
  # New section for emotional generation
278
  gr.Markdown(
@@ -287,13 +313,15 @@ with gr.Blocks() as app_emotional:
287
  """
288
  )
289
 
290
- gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
 
 
291
 
292
  # Regular speech type (mandatory)
293
  with gr.Row():
294
- regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
295
- regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
296
- regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
297
 
298
  # Additional speech types (up to 99 more)
299
  max_speech_types = 100
@@ -304,9 +332,9 @@ with gr.Blocks() as app_emotional:
304
 
305
  for i in range(max_speech_types - 1):
306
  with gr.Row():
307
- name_input = gr.Textbox(label='Speech Type Name', visible=False)
308
- audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
309
- ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
310
  delete_btn = gr.Button("Delete", variant="secondary", visible=False)
311
  speech_type_names.append(name_input)
312
  speech_type_audios.append(audio_input)
@@ -351,7 +379,11 @@ with gr.Blocks() as app_emotional:
351
  add_speech_type_btn.click(
352
  add_speech_type_fn,
353
  inputs=speech_type_count,
354
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
 
 
 
 
355
  )
356
 
357
  # Function to delete a speech type
@@ -365,9 +397,9 @@ with gr.Blocks() as app_emotional:
365
 
366
  for i in range(max_speech_types - 1):
367
  if i == index:
368
- name_updates.append(gr.update(visible=False, value=''))
369
  audio_updates.append(gr.update(visible=False, value=None))
370
- ref_text_updates.append(gr.update(visible=False, value=''))
371
  delete_btn_updates.append(gr.update(visible=False))
372
  else:
373
  name_updates.append(gr.update())
@@ -386,16 +418,18 @@ with gr.Blocks() as app_emotional:
386
  delete_btn.click(
387
  delete_fn,
388
  inputs=speech_type_count,
389
- outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
 
 
 
 
390
  )
391
 
392
  # Text input for the prompt
393
  gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
394
 
395
  # Model choice
396
- model_choice_emotional = gr.Radio(
397
- choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
398
- )
399
 
400
  with gr.Accordion("Advanced Settings", open=False):
401
  remove_silence_emotional = gr.Checkbox(
@@ -408,6 +442,7 @@ with gr.Blocks() as app_emotional:
408
 
409
  # Output audio
410
  audio_output_emotional = gr.Audio(label="Synthesized Audio")
 
411
  @gpu_decorator
412
  def generate_emotional_speech(
413
  regular_audio,
@@ -417,37 +452,39 @@ with gr.Blocks() as app_emotional:
417
  ):
418
  num_additional_speech_types = max_speech_types - 1
419
  speech_type_names_list = args[:num_additional_speech_types]
420
- speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
421
- speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
422
  model_choice = args[3 * num_additional_speech_types]
423
  remove_silence = args[3 * num_additional_speech_types + 1]
424
 
425
  # Collect the speech types and their audios into a dict
426
- speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
427
 
428
- for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
 
 
429
  if name_input and audio_input:
430
- speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
431
 
432
  # Parse the gen_text into segments
433
  segments = parse_speechtypes_text(gen_text)
434
 
435
  # For each segment, generate speech
436
  generated_audio_segments = []
437
- current_emotion = 'Regular'
438
 
439
  for segment in segments:
440
- emotion = segment['emotion']
441
- text = segment['text']
442
 
443
  if emotion in speech_types:
444
  current_emotion = emotion
445
  else:
446
  # If emotion not available, default to Regular
447
- current_emotion = 'Regular'
448
 
449
- ref_audio = speech_types[current_emotion]['audio']
450
- ref_text = speech_types[current_emotion].get('ref_text', '')
451
 
452
  # Generate speech for this segment
453
  audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
@@ -469,7 +506,11 @@ with gr.Blocks() as app_emotional:
469
  regular_audio,
470
  regular_ref_text,
471
  gen_text_input_emotional,
472
- ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
 
 
 
 
473
  model_choice_emotional,
474
  remove_silence_emotional,
475
  ],
@@ -477,11 +518,7 @@ with gr.Blocks() as app_emotional:
477
  )
478
 
479
  # Validation function to disable Generate button if speech types are missing
480
- def validate_speech_types(
481
- gen_text,
482
- regular_name,
483
- *args
484
- ):
485
  num_additional_speech_types = max_speech_types - 1
486
  speech_type_names_list = args[:num_additional_speech_types]
487
 
@@ -495,7 +532,7 @@ with gr.Blocks() as app_emotional:
495
 
496
  # Parse the gen_text to get the speech types used
497
  segments = parse_emotional_text(gen_text)
498
- speech_types_in_text = set(segment['emotion'] for segment in segments)
499
 
500
  # Check if all speech types in text are available
501
  missing_speech_types = speech_types_in_text - speech_types_available
@@ -510,7 +547,7 @@ with gr.Blocks() as app_emotional:
510
  gen_text_input_emotional.change(
511
  validate_speech_types,
512
  inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
513
- outputs=generate_emotional_btn
514
  )
515
  with gr.Blocks() as app:
516
  gr.Markdown(
@@ -531,6 +568,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
531
  )
532
  gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
533
 
 
534
  @click.command()
535
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
536
  @click.option("--host", "-H", default=None, help="Host to run the app on")
@@ -544,10 +582,8 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
544
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
545
  def main(port, host, share, api):
546
  global app
547
- print(f"Starting app...")
548
- app.queue(api_open=api).launch(
549
- server_name=host, server_port=port, share=share, show_api=api
550
- )
551
 
552
 
553
  if __name__ == "__main__":
 
1
+ # ruff: noqa: E402
2
+ # Above allows ruff to ignore E402: module level import not at top of file
3
+
4
  import re
5
  import tempfile
6
 
 
14
 
15
  try:
16
  import spaces
17
+
18
  USING_SPACES = True
19
  except ImportError:
20
  USING_SPACES = False
21
 
22
+
23
  def gpu_decorator(func):
24
  if USING_SPACES:
25
  return spaces.GPU(func)
26
  else:
27
  return func
28
 
29
+
30
  from model import DiT, UNetT
31
  from model.utils import (
32
  save_spectrogram,
 
44
 
45
  # load models
46
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
47
+ F5TTS_ema_model = load_model(
48
+ DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
49
+ )
50
 
51
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ E2TTS_ema_model = load_model(
53
+ UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
54
+ )
55
 
56
 
57
  @gpu_decorator
58
  def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
 
59
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
60
 
61
  if model == "F5-TTS":
 
63
  elif model == "E2-TTS":
64
  ema_model = E2TTS_ema_model
65
 
66
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(
67
+ ref_audio,
68
+ ref_text,
69
+ gen_text,
70
+ ema_model,
71
+ cross_fade_duration=cross_fade_duration,
72
+ speed=speed,
73
+ show_info=gr.Info,
74
+ progress=gr.Progress(),
75
+ )
76
 
77
  # Remove silence
78
  if remove_silence:
 
91
 
92
 
93
  @gpu_decorator
94
+ def generate_podcast(
95
+ script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence
96
+ ):
97
  # Split the script into speaker blocks
98
  speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
99
  speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
100
+
101
  generated_audio_segments = []
102
+
103
  for i in range(0, len(speaker_blocks), 2):
104
  speaker = speaker_blocks[i]
105
+ text = speaker_blocks[i + 1].strip()
106
+
107
  # Determine which speaker is talking
108
  if speaker == speaker1_name:
109
  ref_audio = ref_audio1
 
113
  ref_text = ref_text2
114
  else:
115
  continue # Skip if the speaker is neither speaker1 nor speaker2
116
+
117
  # Generate audio for this block
118
  audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
119
+
120
  # Convert the generated audio to a numpy array
121
  sr, audio_data = audio
122
+
123
  # Save the audio data as a WAV file
124
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
125
  sf.write(temp_file.name, audio_data, sr)
126
  audio_segment = AudioSegment.from_wav(temp_file.name)
127
+
128
  generated_audio_segments.append(audio_segment)
129
+
130
  # Add a short pause between speakers
131
  pause = AudioSegment.silent(duration=500) # 500ms pause
132
  generated_audio_segments.append(pause)
133
+
134
  # Concatenate all audio segments
135
  final_podcast = sum(generated_audio_segments)
136
+
137
  # Export the final podcast
138
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
139
  podcast_path = temp_file.name
140
  final_podcast.export(podcast_path, format="wav")
141
+
142
  return podcast_path
143
 
144
+
145
  def parse_speechtypes_text(gen_text):
146
  # Pattern to find (Emotion)
147
+ pattern = r"\((.*?)\)"
148
 
149
  # Split the text by the pattern
150
  tokens = re.split(pattern, gen_text)
151
 
152
  segments = []
153
 
154
+ current_emotion = "Regular"
155
 
156
  for i in range(len(tokens)):
157
  if i % 2 == 0:
158
  # This is text
159
  text = tokens[i].strip()
160
  if text:
161
+ segments.append({"emotion": current_emotion, "text": text})
162
  else:
163
  # This is emotion
164
  emotion = tokens[i].strip()
 
179
  gr.Markdown("# Batched TTS")
180
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
181
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
182
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
 
 
183
  generate_btn = gr.Button("Synthesize", variant="primary")
184
  with gr.Accordion("Advanced Settings", open=False):
185
  ref_text_input = gr.Textbox(
 
225
  ],
226
  outputs=[audio_output, spectrogram_output],
227
  )
228
+
229
  with gr.Blocks() as app_podcast:
230
  gr.Markdown("# Podcast Generation")
231
  speaker1_name = gr.Textbox(label="Speaker 1 Name")
232
  ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
233
  ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
234
+
235
  speaker2_name = gr.Textbox(label="Speaker 2 Name")
236
  ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
237
  ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
238
+
239
+ script_input = gr.Textbox(
240
+ label="Podcast Script",
241
+ lines=10,
242
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...",
 
243
  )
244
+
245
+ podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
246
  podcast_remove_silence = gr.Checkbox(
247
  label="Remove Silences",
248
  value=True,
 
250
  generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
251
  podcast_output = gr.Audio(label="Generated Podcast")
252
 
253
+ def podcast_generation(
254
+ script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
255
+ ):
256
+ return generate_podcast(
257
+ script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
258
+ )
259
 
260
  generate_podcast_btn.click(
261
  podcast_generation,
 
273
  outputs=podcast_output,
274
  )
275
 
276
+
277
  def parse_emotional_text(gen_text):
278
  # Pattern to find (Emotion)
279
+ pattern = r"\((.*?)\)"
280
 
281
  # Split the text by the pattern
282
  tokens = re.split(pattern, gen_text)
283
 
284
  segments = []
285
 
286
+ current_emotion = "Regular"
287
 
288
  for i in range(len(tokens)):
289
  if i % 2 == 0:
290
  # This is text
291
  text = tokens[i].strip()
292
  if text:
293
+ segments.append({"emotion": current_emotion, "text": text})
294
  else:
295
  # This is emotion
296
  emotion = tokens[i].strip()
 
298
 
299
  return segments
300
 
301
+
302
  with gr.Blocks() as app_emotional:
303
  # New section for emotional generation
304
  gr.Markdown(
 
313
  """
314
  )
315
 
316
+ gr.Markdown(
317
+ "Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
318
+ )
319
 
320
  # Regular speech type (mandatory)
321
  with gr.Row():
322
+ regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False)
323
+ regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
324
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
325
 
326
  # Additional speech types (up to 99 more)
327
  max_speech_types = 100
 
332
 
333
  for i in range(max_speech_types - 1):
334
  with gr.Row():
335
+ name_input = gr.Textbox(label="Speech Type Name", visible=False)
336
+ audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
337
+ ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False)
338
  delete_btn = gr.Button("Delete", variant="secondary", visible=False)
339
  speech_type_names.append(name_input)
340
  speech_type_audios.append(audio_input)
 
379
  add_speech_type_btn.click(
380
  add_speech_type_fn,
381
  inputs=speech_type_count,
382
+ outputs=[speech_type_count]
383
+ + speech_type_names
384
+ + speech_type_audios
385
+ + speech_type_ref_texts
386
+ + speech_type_delete_btns,
387
  )
388
 
389
  # Function to delete a speech type
 
397
 
398
  for i in range(max_speech_types - 1):
399
  if i == index:
400
+ name_updates.append(gr.update(visible=False, value=""))
401
  audio_updates.append(gr.update(visible=False, value=None))
402
+ ref_text_updates.append(gr.update(visible=False, value=""))
403
  delete_btn_updates.append(gr.update(visible=False))
404
  else:
405
  name_updates.append(gr.update())
 
418
  delete_btn.click(
419
  delete_fn,
420
  inputs=speech_type_count,
421
+ outputs=[speech_type_count]
422
+ + speech_type_names
423
+ + speech_type_audios
424
+ + speech_type_ref_texts
425
+ + speech_type_delete_btns,
426
  )
427
 
428
  # Text input for the prompt
429
  gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
430
 
431
  # Model choice
432
+ model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
 
 
433
 
434
  with gr.Accordion("Advanced Settings", open=False):
435
  remove_silence_emotional = gr.Checkbox(
 
442
 
443
  # Output audio
444
  audio_output_emotional = gr.Audio(label="Synthesized Audio")
445
+
446
  @gpu_decorator
447
  def generate_emotional_speech(
448
  regular_audio,
 
452
  ):
453
  num_additional_speech_types = max_speech_types - 1
454
  speech_type_names_list = args[:num_additional_speech_types]
455
+ speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
456
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
457
  model_choice = args[3 * num_additional_speech_types]
458
  remove_silence = args[3 * num_additional_speech_types + 1]
459
 
460
  # Collect the speech types and their audios into a dict
461
+ speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
462
 
463
+ for name_input, audio_input, ref_text_input in zip(
464
+ speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
465
+ ):
466
  if name_input and audio_input:
467
+ speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
468
 
469
  # Parse the gen_text into segments
470
  segments = parse_speechtypes_text(gen_text)
471
 
472
  # For each segment, generate speech
473
  generated_audio_segments = []
474
+ current_emotion = "Regular"
475
 
476
  for segment in segments:
477
+ emotion = segment["emotion"]
478
+ text = segment["text"]
479
 
480
  if emotion in speech_types:
481
  current_emotion = emotion
482
  else:
483
  # If emotion not available, default to Regular
484
+ current_emotion = "Regular"
485
 
486
+ ref_audio = speech_types[current_emotion]["audio"]
487
+ ref_text = speech_types[current_emotion].get("ref_text", "")
488
 
489
  # Generate speech for this segment
490
  audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
 
506
  regular_audio,
507
  regular_ref_text,
508
  gen_text_input_emotional,
509
+ ]
510
+ + speech_type_names
511
+ + speech_type_audios
512
+ + speech_type_ref_texts
513
+ + [
514
  model_choice_emotional,
515
  remove_silence_emotional,
516
  ],
 
518
  )
519
 
520
  # Validation function to disable Generate button if speech types are missing
521
+ def validate_speech_types(gen_text, regular_name, *args):
 
 
 
 
522
  num_additional_speech_types = max_speech_types - 1
523
  speech_type_names_list = args[:num_additional_speech_types]
524
 
 
532
 
533
  # Parse the gen_text to get the speech types used
534
  segments = parse_emotional_text(gen_text)
535
+ speech_types_in_text = set(segment["emotion"] for segment in segments)
536
 
537
  # Check if all speech types in text are available
538
  missing_speech_types = speech_types_in_text - speech_types_available
 
547
  gen_text_input_emotional.change(
548
  validate_speech_types,
549
  inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
550
+ outputs=generate_emotional_btn,
551
  )
552
  with gr.Blocks() as app:
553
  gr.Markdown(
 
568
  )
569
  gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
570
 
571
+
572
  @click.command()
573
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
574
  @click.option("--host", "-H", default=None, help="Host to run the app on")
 
582
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
583
  def main(port, host, share, api):
584
  global app
585
+ print("Starting app...")
586
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
 
 
587
 
588
 
589
  if __name__ == "__main__":
finetune-cli.py CHANGED
@@ -1,42 +1,57 @@
1
  import argparse
2
- from model import CFM, UNetT, DiT, MMDiT, Trainer
3
  from model.utils import get_tokenizer
4
  from model.dataset import load_dataset
5
  from cached_path import cached_path
6
- import shutil,os
 
 
7
  # -------------------------- Dataset Settings --------------------------- #
8
  target_sample_rate = 24000
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
- tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
- tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
 
15
  # -------------------------- Argument Parsing --------------------------- #
16
  def parse_args():
17
- parser = argparse.ArgumentParser(description='Train CFM Model')
18
-
19
- parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
20
- parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
21
- parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
22
- parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
23
- parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
24
- parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
25
- parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
26
- parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
27
- parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
28
- parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
- parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
- parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
- parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return parser.parse_args()
34
 
 
35
  # -------------------------- Training Settings -------------------------- #
36
 
 
37
  def main():
38
  args = parse_args()
39
-
40
 
41
  # Model parameters based on experiment name
42
  if args.exp_name == "F5TTS_Base":
@@ -44,24 +59,31 @@ def main():
44
  model_cls = DiT
45
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
  if args.finetune:
47
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
  elif args.exp_name == "E2TTS_Base":
49
  wandb_resume_id = None
50
  model_cls = UNetT
51
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
  if args.finetune:
53
- ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
-
55
  if args.finetune:
56
- path_ckpt = os.path.join("ckpts",args.dataset_name)
57
- if os.path.isdir(path_ckpt)==False:
58
- os.makedirs(path_ckpt,exist_ok=True)
59
- shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
-
61
- checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
-
63
- # Use the dataset_name provided in the command line
64
- tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
 
 
 
 
 
 
 
65
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
66
 
67
  mel_spec_kwargs = dict(
@@ -71,11 +93,7 @@ def main():
71
  )
72
 
73
  e2tts = CFM(
74
- transformer=model_cls(
75
- **model_cfg,
76
- text_num_embeds=vocab_size,
77
- mel_dim=n_mel_channels
78
- ),
79
  mel_spec_kwargs=mel_spec_kwargs,
80
  vocab_char_map=vocab_char_map,
81
  )
@@ -99,10 +117,11 @@ def main():
99
  )
100
 
101
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
102
- trainer.train(train_dataset,
103
- resumable_with_seed=666 # seed for shuffling dataset
104
- )
 
105
 
106
 
107
- if __name__ == '__main__':
108
  main()
 
1
  import argparse
2
+ from model import CFM, UNetT, DiT, Trainer
3
  from model.utils import get_tokenizer
4
  from model.dataset import load_dataset
5
  from cached_path import cached_path
6
+ import shutil
7
+ import os
8
+
9
  # -------------------------- Dataset Settings --------------------------- #
10
  target_sample_rate = 24000
11
  n_mel_channels = 100
12
  hop_length = 256
13
 
 
 
14
 
15
  # -------------------------- Argument Parsing --------------------------- #
16
  def parse_args():
17
+ parser = argparse.ArgumentParser(description="Train CFM Model")
18
+
19
+ parser.add_argument(
20
+ "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
21
+ )
22
+ parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
23
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
24
+ parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
25
+ parser.add_argument(
26
+ "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
27
+ )
28
+ parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
29
+ parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
30
+ parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
31
+ parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
32
+ parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
33
+ parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
34
+ parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
35
+ parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
36
+
37
+ parser.add_argument(
38
+ "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
39
+ )
40
+ parser.add_argument(
41
+ "--tokenizer_path",
42
+ type=str,
43
+ default=None,
44
+ help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
45
+ )
46
+
47
  return parser.parse_args()
48
 
49
+
50
  # -------------------------- Training Settings -------------------------- #
51
 
52
+
53
  def main():
54
  args = parse_args()
 
55
 
56
  # Model parameters based on experiment name
57
  if args.exp_name == "F5TTS_Base":
 
59
  model_cls = DiT
60
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
61
  if args.finetune:
62
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
63
  elif args.exp_name == "E2TTS_Base":
64
  wandb_resume_id = None
65
  model_cls = UNetT
66
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
67
  if args.finetune:
68
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
69
+
70
  if args.finetune:
71
+ path_ckpt = os.path.join("ckpts", args.dataset_name)
72
+ if not os.path.isdir(path_ckpt):
73
+ os.makedirs(path_ckpt, exist_ok=True)
74
+ shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
75
+
76
+ checkpoint_path = os.path.join("ckpts", args.dataset_name)
77
+
78
+ # Use the tokenizer and tokenizer_path provided in the command line arguments
79
+ tokenizer = args.tokenizer
80
+ if tokenizer == "custom":
81
+ if not args.tokenizer_path:
82
+ raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
83
+ tokenizer_path = args.tokenizer_path
84
+ else:
85
+ tokenizer_path = args.dataset_name
86
+
87
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
88
 
89
  mel_spec_kwargs = dict(
 
93
  )
94
 
95
  e2tts = CFM(
96
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
 
 
97
  mel_spec_kwargs=mel_spec_kwargs,
98
  vocab_char_map=vocab_char_map,
99
  )
 
117
  )
118
 
119
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
120
+ trainer.train(
121
+ train_dataset,
122
+ resumable_with_seed=666, # seed for shuffling dataset
123
+ )
124
 
125
 
126
+ if __name__ == "__main__":
127
  main()
finetune_gradio.py CHANGED
@@ -1,4 +1,5 @@
1
- import os,sys
 
2
 
3
  from transformers import pipeline
4
  import gradio as gr
@@ -20,34 +21,37 @@ import platform
20
  import subprocess
21
  from datasets.arrow_writer import ArrowWriter
22
 
23
- import json
24
 
25
- training_process = None
26
  system = platform.system()
27
  python_executable = sys.executable or "python"
28
 
29
- path_data="data"
30
 
31
- device = (
32
- "cuda"
33
- if torch.cuda.is_available()
34
- else "mps" if torch.backends.mps.is_available() else "cpu"
35
- )
36
 
37
  pipe = None
38
 
 
39
  # Load metadata
40
  def get_audio_duration(audio_path):
41
  """Calculate the duration of an audio file."""
42
  audio, sample_rate = torchaudio.load(audio_path)
43
- num_channels = audio.shape[0]
44
  return audio.shape[1] / (sample_rate * num_channels)
45
 
 
46
  def clear_text(text):
47
  """Clean and prepare text by lowering the case and stripping whitespace."""
48
  return text.lower().strip()
49
 
50
- def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
 
 
 
 
 
 
51
  padding = (int(frame_length // 2), int(frame_length // 2))
52
  y = np.pad(y, padding, mode=pad_mode)
53
 
@@ -74,7 +78,8 @@ def get_rms(y,frame_length=2048,hop_length=512,pad_mode="constant",): # https://
74
 
75
  return np.sqrt(power)
76
 
77
- class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
 
78
  def __init__(
79
  self,
80
  sr: int,
@@ -85,13 +90,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
85
  max_sil_kept: int = 2000,
86
  ):
87
  if not min_length >= min_interval >= hop_size:
88
- raise ValueError(
89
- "The following condition must be satisfied: min_length >= min_interval >= hop_size"
90
- )
91
  if not max_sil_kept >= hop_size:
92
- raise ValueError(
93
- "The following condition must be satisfied: max_sil_kept >= hop_size"
94
- )
95
  min_interval = sr * min_interval / 1000
96
  self.threshold = 10 ** (threshold / 20.0)
97
  self.hop_size = round(sr * hop_size / 1000)
@@ -102,13 +103,9 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
102
 
103
  def _apply_slice(self, waveform, begin, end):
104
  if len(waveform.shape) > 1:
105
- return waveform[
106
- :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
107
- ]
108
  else:
109
- return waveform[
110
- begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
111
- ]
112
 
113
  # @timeit
114
  def slice(self, waveform):
@@ -118,9 +115,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
118
  samples = waveform
119
  if samples.shape[0] <= self.min_length:
120
  return [waveform]
121
- rms_list = get_rms(
122
- y=samples, frame_length=self.win_size, hop_length=self.hop_size
123
- ).squeeze(0)
124
  sil_tags = []
125
  silence_start = None
126
  clip_start = 0
@@ -136,10 +131,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
136
  continue
137
  # Clear recorded silence start if interval is not enough or clip is too short
138
  is_leading_silence = silence_start == 0 and i > self.max_sil_kept
139
- need_slice_middle = (
140
- i - silence_start >= self.min_interval
141
- and i - clip_start >= self.min_length
142
- )
143
  if not is_leading_silence and not need_slice_middle:
144
  silence_start = None
145
  continue
@@ -152,21 +144,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
152
  sil_tags.append((pos, pos))
153
  clip_start = pos
154
  elif i - silence_start <= self.max_sil_kept * 2:
155
- pos = rms_list[
156
- i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
157
- ].argmin()
158
  pos += i - self.max_sil_kept
159
- pos_l = (
160
- rms_list[
161
- silence_start : silence_start + self.max_sil_kept + 1
162
- ].argmin()
163
- + silence_start
164
- )
165
- pos_r = (
166
- rms_list[i - self.max_sil_kept : i + 1].argmin()
167
- + i
168
- - self.max_sil_kept
169
- )
170
  if silence_start == 0:
171
  sil_tags.append((0, pos_r))
172
  clip_start = pos_r
@@ -174,17 +155,8 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
174
  sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
175
  clip_start = max(pos_r, pos)
176
  else:
177
- pos_l = (
178
- rms_list[
179
- silence_start : silence_start + self.max_sil_kept + 1
180
- ].argmin()
181
- + silence_start
182
- )
183
- pos_r = (
184
- rms_list[i - self.max_sil_kept : i + 1].argmin()
185
- + i
186
- - self.max_sil_kept
187
- )
188
  if silence_start == 0:
189
  sil_tags.append((0, pos_r))
190
  else:
@@ -193,33 +165,39 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
193
  silence_start = None
194
  # Deal with trailing silence.
195
  total_frames = rms_list.shape[0]
196
- if (
197
- silence_start is not None
198
- and total_frames - silence_start >= self.min_interval
199
- ):
200
  silence_end = min(total_frames, silence_start + self.max_sil_kept)
201
  pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
202
  sil_tags.append((pos, total_frames + 1))
203
  # Apply and return slices.
204
  ####音频+起始时间+终止时间
205
  if len(sil_tags) == 0:
206
- return [[waveform,0,int(total_frames*self.hop_size)]]
207
  else:
208
  chunks = []
209
  if sil_tags[0][0] > 0:
210
- chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]),0,int(sil_tags[0][0]*self.hop_size)])
211
  for i in range(len(sil_tags) - 1):
212
  chunks.append(
213
- [self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),int(sil_tags[i][1]*self.hop_size),int(sil_tags[i + 1][0]*self.hop_size)]
 
 
 
 
214
  )
215
  if sil_tags[-1][1] < total_frames:
216
  chunks.append(
217
- [self._apply_slice(waveform, sil_tags[-1][1], total_frames),int(sil_tags[-1][1]*self.hop_size),int(total_frames*self.hop_size)]
 
 
 
 
218
  )
219
  return chunks
220
 
221
- #terminal
222
- def terminate_process_tree(pid, including_parent=True):
 
223
  try:
224
  parent = psutil.Process(pid)
225
  except psutil.NoSuchProcess:
@@ -238,6 +216,7 @@ def terminate_process_tree(pid, including_parent=True):
238
  except OSError:
239
  pass
240
 
 
241
  def terminate_process(pid):
242
  if system == "Windows":
243
  cmd = f"taskkill /t /f /pid {pid}"
@@ -245,132 +224,154 @@ def terminate_process(pid):
245
  else:
246
  terminate_process_tree(pid)
247
 
248
- def start_training(dataset_name="",
249
- exp_name="F5TTS_Base",
250
- learning_rate=1e-4,
251
- batch_size_per_gpu=400,
252
- batch_size_type="frame",
253
- max_samples=64,
254
- grad_accumulation_steps=1,
255
- max_grad_norm=1.0,
256
- epochs=11,
257
- num_warmup_updates=200,
258
- save_per_updates=400,
259
- last_per_steps=800,
260
- finetune=True,
261
- ):
262
-
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  global training_process
265
 
266
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
267
 
268
- if os.path.isdir(path_project)==False:
269
- yield f"There is not project with name {dataset_name}",gr.update(interactive=True),gr.update(interactive=False)
 
 
 
 
270
  return
271
 
272
- file_raw = os.path.join(path_project,"raw.arrow")
273
- if os.path.isfile(file_raw)==False:
274
- yield f"There is no file {file_raw}",gr.update(interactive=True),gr.update(interactive=False)
275
- return
276
 
277
  # Check if a training process is already running
278
  if training_process is not None:
279
- return "Train run already!",gr.update(interactive=False),gr.update(interactive=True)
280
 
281
- yield "start train",gr.update(interactive=False),gr.update(interactive=False)
282
 
283
  # Command to run the training script with the specified arguments
284
- cmd = f"accelerate launch finetune-cli.py --exp_name {exp_name} " \
285
- f"--learning_rate {learning_rate} " \
286
- f"--batch_size_per_gpu {batch_size_per_gpu} " \
287
- f"--batch_size_type {batch_size_type} " \
288
- f"--max_samples {max_samples} " \
289
- f"--grad_accumulation_steps {grad_accumulation_steps} " \
290
- f"--max_grad_norm {max_grad_norm} " \
291
- f"--epochs {epochs} " \
292
- f"--num_warmup_updates {num_warmup_updates} " \
293
- f"--save_per_updates {save_per_updates} " \
294
- f"--last_per_steps {last_per_steps} " \
295
- f"--dataset_name {dataset_name}"
296
- if finetune:cmd += f" --finetune {finetune}"
 
 
 
297
 
298
  print(cmd)
299
-
300
  try:
301
- # Start the training process
302
- training_process = subprocess.Popen(cmd, shell=True)
303
 
304
- time.sleep(5)
305
- yield "check terminal for wandb",gr.update(interactive=False),gr.update(interactive=True)
306
-
307
- # Wait for the training process to finish
308
- training_process.wait()
309
- time.sleep(1)
310
-
311
- if training_process is None:
312
- text_info = 'train stop'
313
- else:
314
- text_info = "train complete !"
315
 
316
  except Exception as e: # Catch all exceptions
317
  # Ensure that we reset the training process variable in case of an error
318
- text_info=f"An error occurred: {str(e)}"
319
-
320
- training_process=None
 
 
321
 
322
- yield text_info,gr.update(interactive=True),gr.update(interactive=False)
323
 
324
  def stop_training():
325
  global training_process
326
- if training_process is None:return f"Train not run !",gr.update(interactive=True),gr.update(interactive=False)
 
327
  terminate_process_tree(training_process.pid)
328
  training_process = None
329
- return 'train stop',gr.update(interactive=True),gr.update(interactive=False)
 
330
 
331
  def create_data_project(name):
332
- name+="_pinyin"
333
- os.makedirs(os.path.join(path_data,name),exist_ok=True)
334
- os.makedirs(os.path.join(path_data,name,"dataset"),exist_ok=True)
335
-
336
- def transcribe(file_audio,language="english"):
 
337
  global pipe
338
 
339
  if pipe is None:
340
- pipe = pipeline("automatic-speech-recognition",model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16,device=device)
 
 
 
 
 
341
 
342
  text_transcribe = pipe(
343
  file_audio,
344
  chunk_length_s=30,
345
  batch_size=128,
346
- generate_kwargs={"task": "transcribe","language": language},
347
  return_timestamps=False,
348
  )["text"].strip()
349
  return text_transcribe
350
 
351
- def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Progress()):
352
- name_project+="_pinyin"
353
- path_project= os.path.join(path_data,name_project)
354
- path_dataset = os.path.join(path_project,"dataset")
355
- path_project_wavs = os.path.join(path_project,"wavs")
356
- file_metadata = os.path.join(path_project,"metadata.csv")
357
 
358
- if audio_files is None:return "You need to load an audio file."
 
 
 
 
 
 
 
 
359
 
360
  if os.path.isdir(path_project_wavs):
361
- shutil.rmtree(path_project_wavs)
362
 
363
  if os.path.isfile(file_metadata):
364
- os.remove(file_metadata)
 
 
365
 
366
- os.makedirs(path_project_wavs,exist_ok=True)
367
-
368
  if user:
369
- file_audios = [file for format in ('*.wav', '*.ogg', '*.opus', '*.mp3', '*.flac') for file in glob(os.path.join(path_dataset, format))]
370
- if file_audios==[]:return "No audio file was found in the dataset."
 
 
 
 
 
371
  else:
372
- file_audios = audio_files
373
-
374
 
375
  alpha = 0.5
376
  _max = 1.0
@@ -378,181 +379,202 @@ def transcribe_all(name_project,audio_files,language,user=False,progress=gr.Prog
378
 
379
  num = 0
380
  error_num = 0
381
- data=""
382
- for file_audio in progress.tqdm(file_audios, desc="transcribe files",total=len((file_audios))):
383
-
384
- audio, _ = librosa.load(file_audio, sr=24000, mono=True)
385
-
386
- list_slicer=slicer.slice(audio)
387
- for chunk, start, end in progress.tqdm(list_slicer,total=len(list_slicer), desc="slicer files"):
388
-
389
  name_segment = os.path.join(f"segment_{num}")
390
- file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
391
-
392
  tmp_max = np.abs(chunk).max()
393
- if(tmp_max>1):chunk/=tmp_max
 
394
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
395
- wavfile.write(file_segment,24000, (chunk * 32767).astype(np.int16))
396
-
397
  try:
398
- text=transcribe(file_segment,language)
399
- text = text.lower().strip().replace('"',"")
400
 
401
- data+= f"{name_segment}|{text}\n"
402
 
403
- num+=1
404
- except:
405
- error_num +=1
406
 
407
- with open(file_metadata,"w",encoding="utf-8") as f:
408
  f.write(data)
409
-
410
- if error_num!=[]:
411
- error_text=f"\nerror files : {error_num}"
412
  else:
413
- error_text=""
414
-
415
  return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
416
 
 
417
  def format_seconds_to_hms(seconds):
418
  hours = int(seconds / 3600)
419
  minutes = int((seconds % 3600) / 60)
420
  seconds = seconds % 60
421
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
422
 
423
- def create_metadata(name_project,progress=gr.Progress()):
424
- name_project+="_pinyin"
425
- path_project= os.path.join(path_data,name_project)
426
- path_project_wavs = os.path.join(path_project,"wavs")
427
- file_metadata = os.path.join(path_project,"metadata.csv")
428
- file_raw = os.path.join(path_project,"raw.arrow")
429
- file_duration = os.path.join(path_project,"duration.json")
430
- file_vocab = os.path.join(path_project,"vocab.txt")
431
-
432
- if os.path.isfile(file_metadata)==False: return "The file was not found in " + file_metadata
433
-
434
- with open(file_metadata,"r",encoding="utf-8") as f:
435
- data=f.read()
436
-
437
- audio_path_list=[]
438
- text_list=[]
439
- duration_list=[]
440
-
441
- count=data.split("\n")
442
- lenght=0
443
- result=[]
444
- error_files=[]
445
- for line in progress.tqdm(data.split("\n"),total=count):
446
- sp_line=line.split("|")
447
- if len(sp_line)!=2:continue
448
- name_audio,text = sp_line[:2]
 
 
 
449
 
450
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
451
 
452
- if os.path.isfile(file_audio)==False:
453
  error_files.append(file_audio)
454
  continue
455
 
456
  duraction = get_audio_duration(file_audio)
457
- if duraction<2 and duraction>15:continue
458
- if len(text)<4:continue
 
 
459
 
460
  text = clear_text(text)
461
- text = convert_char_to_pinyin([text], polyphone = True)[0]
462
 
463
  audio_path_list.append(file_audio)
464
  duration_list.append(duraction)
465
  text_list.append(text)
466
-
467
  result.append({"audio_path": file_audio, "text": text, "duration": duraction})
468
 
469
- lenght+=duraction
470
 
471
- if duration_list==[]:
472
- error_files_text="\n".join(error_files)
473
  return f"Error: No audio files found in the specified path : \n{error_files_text}"
474
-
475
- min_second = round(min(duration_list),2)
476
- max_second = round(max(duration_list),2)
477
 
478
  with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
479
- for line in progress.tqdm(result,total=len(result), desc=f"prepare data"):
480
  writer.write(line)
481
 
482
- with open(file_duration, 'w', encoding='utf-8') as f:
483
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
484
-
485
- file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
486
- if os.path.isfile(file_vocab_finetune==False):return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
 
487
  shutil.copy2(file_vocab_finetune, file_vocab)
488
-
489
- if error_files!=[]:
490
- error_text="error files\n" + "\n".join(error_files)
491
  else:
492
- error_text=""
493
-
494
  return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
495
 
 
496
  def check_user(value):
497
- return gr.update(visible=not value),gr.update(visible=value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
- def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune):
500
- name_project+="_pinyin"
501
- path_project= os.path.join(path_data,name_project)
502
- file_duraction = os.path.join(path_project,"duration.json")
503
 
504
- with open(file_duraction, 'r') as file:
505
- data = json.load(file)
506
-
507
- duration_list = data['duration']
508
 
509
  samples = len(duration_list)
510
 
511
  if torch.cuda.is_available():
512
  gpu_properties = torch.cuda.get_device_properties(0)
513
- total_memory = gpu_properties.total_memory / (1024 ** 3)
514
  elif torch.backends.mps.is_available():
515
- total_memory = psutil.virtual_memory().available / (1024 ** 3)
516
-
517
- if batch_size_type=="frame":
518
- batch = int(total_memory * 0.5)
519
- batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
520
- batch_size_per_gpu = int(38400 / batch )
521
  else:
522
- batch_size_per_gpu = int(total_memory / 8)
523
- batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
524
- batch = batch_size_per_gpu
525
 
526
- if batch_size_per_gpu<=0:batch_size_per_gpu=1
 
527
 
528
- if samples<64:
529
- max_samples = int(samples * 0.25)
530
  else:
531
- max_samples = 64
532
-
533
- num_warmup_updates = int(samples * 0.10)
534
- save_per_updates = int(samples * 0.25)
535
- last_per_steps =int(save_per_updates * 5)
536
-
537
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
538
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
539
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
540
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
541
 
542
- if finetune:learning_rate=1e-4
543
- else:learning_rate=7.5e-5
 
 
 
 
544
 
545
- return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
546
 
547
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
548
  try:
549
  checkpoint = torch.load(checkpoint_path)
550
  print("Original Checkpoint Keys:", checkpoint.keys())
551
-
552
- ema_model_state_dict = checkpoint.get('ema_model_state_dict', None)
553
 
554
  if ema_model_state_dict is not None:
555
- new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
556
  torch.save(new_checkpoint, new_checkpoint_path)
557
  return f"New checkpoint saved at: {new_checkpoint_path}"
558
  else:
@@ -561,65 +583,61 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -
561
  except Exception as e:
562
  return f"An error occurred: {e}"
563
 
 
564
  def vocab_check(project_name):
565
  name_project = project_name + "_pinyin"
566
  path_project = os.path.join(path_data, name_project)
567
 
568
  file_metadata = os.path.join(path_project, "metadata.csv")
569
-
570
- file_vocab="data/Emilia_ZH_EN_pinyin/vocab.txt"
571
- if os.path.isfile(file_vocab)==False:
572
  return f"the file {file_vocab} not found !"
573
-
574
- with open(file_vocab,"r",encoding="utf-8") as f:
575
- data=f.read()
576
 
577
  vocab = data.split("\n")
578
 
579
- if os.path.isfile(file_metadata)==False:
580
  return f"the file {file_metadata} not found !"
581
 
582
- with open(file_metadata,"r",encoding="utf-8") as f:
583
- data=f.read()
584
 
585
- miss_symbols=[]
586
- miss_symbols_keep={}
587
  for item in data.split("\n"):
588
- sp=item.split("|")
589
- if len(sp)!=2:continue
590
- text=sp[1].lower().strip()
591
-
592
- for t in text:
593
- if (t in vocab)==False and (t in miss_symbols_keep)==False:
594
- miss_symbols.append(t)
595
- miss_symbols_keep[t]=t
596
-
597
-
598
- if miss_symbols==[]:info ="You can train using your language !"
599
- else:info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
 
600
 
601
  return info
602
 
603
 
604
-
605
  with gr.Blocks() as app:
606
-
607
  with gr.Row():
608
- project_name=gr.Textbox(label="project name",value="my_speak")
609
- bt_create=gr.Button("create new project")
610
-
611
- bt_create.click(fn=create_data_project,inputs=[project_name])
612
-
613
- with gr.Tabs():
614
-
615
 
616
- with gr.TabItem("transcribe Data"):
617
 
 
 
 
618
 
619
- ch_manual = gr.Checkbox(label="user",value=False)
620
-
621
- mark_info_transcribe=gr.Markdown(
622
- """```plaintext
623
  Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
624
 
625
  my_speak/
@@ -628,18 +646,24 @@ with gr.Blocks() as app:
628
  ├── audio1.wav
629
  └── audio2.wav
630
  ...
631
- ```""",visible=False)
632
-
633
- audio_speaker = gr.File(label="voice",type="filepath",file_count="multiple")
634
- txt_lang = gr.Text(label="Language",value="english")
635
- bt_transcribe=bt_create=gr.Button("transcribe")
636
- txt_info_transcribe=gr.Text(label="info",value="")
637
- bt_transcribe.click(fn=transcribe_all,inputs=[project_name,audio_speaker,txt_lang,ch_manual],outputs=[txt_info_transcribe])
638
- ch_manual.change(fn=check_user,inputs=[ch_manual],outputs=[audio_speaker,mark_info_transcribe])
639
-
640
- with gr.TabItem("prepare Data"):
641
- gr.Markdown(
642
- """```plaintext
 
 
 
 
 
 
643
  place all your wavs folder and your metadata.csv file in {your name project}
644
  my_speak/
645
 
@@ -656,61 +680,104 @@ with gr.Blocks() as app:
656
  audio2|text1
657
  ...
658
 
659
- ```""")
660
-
661
- bt_prepare=bt_create=gr.Button("prepare")
662
- txt_info_prepare=gr.Text(label="info",value="")
663
- bt_prepare.click(fn=create_metadata,inputs=[project_name],outputs=[txt_info_prepare])
664
-
665
- with gr.TabItem("train Data"):
666
-
667
- with gr.Row():
668
- bt_calculate=bt_create=gr.Button("Auto Settings")
669
- ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True)
670
- lb_samples = gr.Label(label="samples")
671
- batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
672
-
673
- with gr.Row():
674
- exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
675
- learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
676
-
677
- with gr.Row():
678
- batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
679
- max_samples = gr.Number(label="Max Samples", value=16)
680
-
681
- with gr.Row():
682
- grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
683
- max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
684
-
685
- with gr.Row():
686
- epochs = gr.Number(label="Epochs", value=10)
687
- num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
688
-
689
- with gr.Row():
690
- save_per_updates = gr.Number(label="Save per Updates", value=10)
691
- last_per_steps = gr.Number(label="Last per Steps", value=50)
692
-
693
- with gr.Row():
694
- start_button = gr.Button("Start Training")
695
- stop_button = gr.Button("Stop Training",interactive=False)
696
-
697
- txt_info_train=gr.Text(label="info",value="")
698
- start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button])
699
- stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
700
- bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate])
701
-
702
- with gr.TabItem("reduse checkpoint"):
703
- txt_path_checkpoint = gr.Text(label="path checkpoint :")
704
- txt_path_checkpoint_small = gr.Text(label="path output :")
705
- txt_info_reduse = gr.Text(label="info",value="")
706
- reduse_button = gr.Button("reduse")
707
- reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small],outputs=[txt_info_reduse])
708
-
709
- with gr.TabItem("vocab check experiment"):
710
- check_button = gr.Button("check vocab")
711
- txt_info_check=gr.Text(label="info",value="")
712
- check_button.click(fn=vocab_check,inputs=[project_name],outputs=[txt_info_check])
713
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
 
715
  @click.command()
716
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@@ -725,10 +792,9 @@ with gr.Blocks() as app:
725
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
726
  def main(port, host, share, api):
727
  global app
728
- print(f"Starting app...")
729
- app.queue(api_open=api).launch(
730
- server_name=host, server_port=port, share=share, show_api=api
731
- )
732
 
733
  if __name__ == "__main__":
734
  main()
 
1
+ import os
2
+ import sys
3
 
4
  from transformers import pipeline
5
  import gradio as gr
 
21
  import subprocess
22
  from datasets.arrow_writer import ArrowWriter
23
 
 
24
 
25
+ training_process = None
26
  system = platform.system()
27
  python_executable = sys.executable or "python"
28
 
29
+ path_data = "data"
30
 
31
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
32
 
33
  pipe = None
34
 
35
+
36
  # Load metadata
37
  def get_audio_duration(audio_path):
38
  """Calculate the duration of an audio file."""
39
  audio, sample_rate = torchaudio.load(audio_path)
40
+ num_channels = audio.shape[0]
41
  return audio.shape[1] / (sample_rate * num_channels)
42
 
43
+
44
  def clear_text(text):
45
  """Clean and prepare text by lowering the case and stripping whitespace."""
46
  return text.lower().strip()
47
 
48
+
49
+ def get_rms(
50
+ y,
51
+ frame_length=2048,
52
+ hop_length=512,
53
+ pad_mode="constant",
54
+ ): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
55
  padding = (int(frame_length // 2), int(frame_length // 2))
56
  y = np.pad(y, padding, mode=pad_mode)
57
 
 
78
 
79
  return np.sqrt(power)
80
 
81
+
82
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
83
  def __init__(
84
  self,
85
  sr: int,
 
90
  max_sil_kept: int = 2000,
91
  ):
92
  if not min_length >= min_interval >= hop_size:
93
+ raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
 
 
94
  if not max_sil_kept >= hop_size:
95
+ raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
 
 
96
  min_interval = sr * min_interval / 1000
97
  self.threshold = 10 ** (threshold / 20.0)
98
  self.hop_size = round(sr * hop_size / 1000)
 
103
 
104
  def _apply_slice(self, waveform, begin, end):
105
  if len(waveform.shape) > 1:
106
+ return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
 
 
107
  else:
108
+ return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
 
 
109
 
110
  # @timeit
111
  def slice(self, waveform):
 
115
  samples = waveform
116
  if samples.shape[0] <= self.min_length:
117
  return [waveform]
118
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
 
 
119
  sil_tags = []
120
  silence_start = None
121
  clip_start = 0
 
131
  continue
132
  # Clear recorded silence start if interval is not enough or clip is too short
133
  is_leading_silence = silence_start == 0 and i > self.max_sil_kept
134
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
 
 
 
135
  if not is_leading_silence and not need_slice_middle:
136
  silence_start = None
137
  continue
 
144
  sil_tags.append((pos, pos))
145
  clip_start = pos
146
  elif i - silence_start <= self.max_sil_kept * 2:
147
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
 
 
148
  pos += i - self.max_sil_kept
149
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
150
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
 
 
 
 
 
 
 
 
 
151
  if silence_start == 0:
152
  sil_tags.append((0, pos_r))
153
  clip_start = pos_r
 
155
  sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
156
  clip_start = max(pos_r, pos)
157
  else:
158
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
159
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
 
 
 
 
 
 
 
 
 
160
  if silence_start == 0:
161
  sil_tags.append((0, pos_r))
162
  else:
 
165
  silence_start = None
166
  # Deal with trailing silence.
167
  total_frames = rms_list.shape[0]
168
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
 
 
 
169
  silence_end = min(total_frames, silence_start + self.max_sil_kept)
170
  pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
171
  sil_tags.append((pos, total_frames + 1))
172
  # Apply and return slices.
173
  ####音频+起始时间+终止时间
174
  if len(sil_tags) == 0:
175
+ return [[waveform, 0, int(total_frames * self.hop_size)]]
176
  else:
177
  chunks = []
178
  if sil_tags[0][0] > 0:
179
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
180
  for i in range(len(sil_tags) - 1):
181
  chunks.append(
182
+ [
183
+ self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
184
+ int(sil_tags[i][1] * self.hop_size),
185
+ int(sil_tags[i + 1][0] * self.hop_size),
186
+ ]
187
  )
188
  if sil_tags[-1][1] < total_frames:
189
  chunks.append(
190
+ [
191
+ self._apply_slice(waveform, sil_tags[-1][1], total_frames),
192
+ int(sil_tags[-1][1] * self.hop_size),
193
+ int(total_frames * self.hop_size),
194
+ ]
195
  )
196
  return chunks
197
 
198
+
199
+ # terminal
200
+ def terminate_process_tree(pid, including_parent=True):
201
  try:
202
  parent = psutil.Process(pid)
203
  except psutil.NoSuchProcess:
 
216
  except OSError:
217
  pass
218
 
219
+
220
  def terminate_process(pid):
221
  if system == "Windows":
222
  cmd = f"taskkill /t /f /pid {pid}"
 
224
  else:
225
  terminate_process_tree(pid)
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ def start_training(
229
+ dataset_name="",
230
+ exp_name="F5TTS_Base",
231
+ learning_rate=1e-4,
232
+ batch_size_per_gpu=400,
233
+ batch_size_type="frame",
234
+ max_samples=64,
235
+ grad_accumulation_steps=1,
236
+ max_grad_norm=1.0,
237
+ epochs=11,
238
+ num_warmup_updates=200,
239
+ save_per_updates=400,
240
+ last_per_steps=800,
241
+ finetune=True,
242
+ ):
243
  global training_process
244
 
245
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
246
 
247
+ if not os.path.isdir(path_project):
248
+ yield (
249
+ f"There is not project with name {dataset_name}",
250
+ gr.update(interactive=True),
251
+ gr.update(interactive=False),
252
+ )
253
  return
254
 
255
+ file_raw = os.path.join(path_project, "raw.arrow")
256
+ if not os.path.isfile(file_raw):
257
+ yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
258
+ return
259
 
260
  # Check if a training process is already running
261
  if training_process is not None:
262
+ return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
263
 
264
+ yield "start train", gr.update(interactive=False), gr.update(interactive=False)
265
 
266
  # Command to run the training script with the specified arguments
267
+ cmd = (
268
+ f"accelerate launch finetune-cli.py --exp_name {exp_name} "
269
+ f"--learning_rate {learning_rate} "
270
+ f"--batch_size_per_gpu {batch_size_per_gpu} "
271
+ f"--batch_size_type {batch_size_type} "
272
+ f"--max_samples {max_samples} "
273
+ f"--grad_accumulation_steps {grad_accumulation_steps} "
274
+ f"--max_grad_norm {max_grad_norm} "
275
+ f"--epochs {epochs} "
276
+ f"--num_warmup_updates {num_warmup_updates} "
277
+ f"--save_per_updates {save_per_updates} "
278
+ f"--last_per_steps {last_per_steps} "
279
+ f"--dataset_name {dataset_name}"
280
+ )
281
+ if finetune:
282
+ cmd += f" --finetune {finetune}"
283
 
284
  print(cmd)
285
+
286
  try:
287
+ # Start the training process
288
+ training_process = subprocess.Popen(cmd, shell=True)
289
 
290
+ time.sleep(5)
291
+ yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True)
292
+
293
+ # Wait for the training process to finish
294
+ training_process.wait()
295
+ time.sleep(1)
296
+
297
+ if training_process is None:
298
+ text_info = "train stop"
299
+ else:
300
+ text_info = "train complete !"
301
 
302
  except Exception as e: # Catch all exceptions
303
  # Ensure that we reset the training process variable in case of an error
304
+ text_info = f"An error occurred: {str(e)}"
305
+
306
+ training_process = None
307
+
308
+ yield text_info, gr.update(interactive=True), gr.update(interactive=False)
309
 
 
310
 
311
  def stop_training():
312
  global training_process
313
+ if training_process is None:
314
+ return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
315
  terminate_process_tree(training_process.pid)
316
  training_process = None
317
+ return "train stop", gr.update(interactive=True), gr.update(interactive=False)
318
+
319
 
320
  def create_data_project(name):
321
+ name += "_pinyin"
322
+ os.makedirs(os.path.join(path_data, name), exist_ok=True)
323
+ os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
324
+
325
+
326
+ def transcribe(file_audio, language="english"):
327
  global pipe
328
 
329
  if pipe is None:
330
+ pipe = pipeline(
331
+ "automatic-speech-recognition",
332
+ model="openai/whisper-large-v3-turbo",
333
+ torch_dtype=torch.float16,
334
+ device=device,
335
+ )
336
 
337
  text_transcribe = pipe(
338
  file_audio,
339
  chunk_length_s=30,
340
  batch_size=128,
341
+ generate_kwargs={"task": "transcribe", "language": language},
342
  return_timestamps=False,
343
  )["text"].strip()
344
  return text_transcribe
345
 
 
 
 
 
 
 
346
 
347
+ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
348
+ name_project += "_pinyin"
349
+ path_project = os.path.join(path_data, name_project)
350
+ path_dataset = os.path.join(path_project, "dataset")
351
+ path_project_wavs = os.path.join(path_project, "wavs")
352
+ file_metadata = os.path.join(path_project, "metadata.csv")
353
+
354
+ if audio_files is None:
355
+ return "You need to load an audio file."
356
 
357
  if os.path.isdir(path_project_wavs):
358
+ shutil.rmtree(path_project_wavs)
359
 
360
  if os.path.isfile(file_metadata):
361
+ os.remove(file_metadata)
362
+
363
+ os.makedirs(path_project_wavs, exist_ok=True)
364
 
 
 
365
  if user:
366
+ file_audios = [
367
+ file
368
+ for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
369
+ for file in glob(os.path.join(path_dataset, format))
370
+ ]
371
+ if file_audios == []:
372
+ return "No audio file was found in the dataset."
373
  else:
374
+ file_audios = audio_files
 
375
 
376
  alpha = 0.5
377
  _max = 1.0
 
379
 
380
  num = 0
381
  error_num = 0
382
+ data = ""
383
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
384
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
385
+
386
+ list_slicer = slicer.slice(audio)
387
+ for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
 
 
388
  name_segment = os.path.join(f"segment_{num}")
389
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
390
+
391
  tmp_max = np.abs(chunk).max()
392
+ if tmp_max > 1:
393
+ chunk /= tmp_max
394
  chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
395
+ wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
396
+
397
  try:
398
+ text = transcribe(file_segment, language)
399
+ text = text.lower().strip().replace('"', "")
400
 
401
+ data += f"{name_segment}|{text}\n"
402
 
403
+ num += 1
404
+ except: # noqa: E722
405
+ error_num += 1
406
 
407
+ with open(file_metadata, "w", encoding="utf-8") as f:
408
  f.write(data)
409
+
410
+ if error_num != []:
411
+ error_text = f"\nerror files : {error_num}"
412
  else:
413
+ error_text = ""
414
+
415
  return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
416
 
417
+
418
  def format_seconds_to_hms(seconds):
419
  hours = int(seconds / 3600)
420
  minutes = int((seconds % 3600) / 60)
421
  seconds = seconds % 60
422
  return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
423
 
424
+
425
+ def create_metadata(name_project, progress=gr.Progress()):
426
+ name_project += "_pinyin"
427
+ path_project = os.path.join(path_data, name_project)
428
+ path_project_wavs = os.path.join(path_project, "wavs")
429
+ file_metadata = os.path.join(path_project, "metadata.csv")
430
+ file_raw = os.path.join(path_project, "raw.arrow")
431
+ file_duration = os.path.join(path_project, "duration.json")
432
+ file_vocab = os.path.join(path_project, "vocab.txt")
433
+
434
+ if not os.path.isfile(file_metadata):
435
+ return "The file was not found in " + file_metadata
436
+
437
+ with open(file_metadata, "r", encoding="utf-8") as f:
438
+ data = f.read()
439
+
440
+ audio_path_list = []
441
+ text_list = []
442
+ duration_list = []
443
+
444
+ count = data.split("\n")
445
+ lenght = 0
446
+ result = []
447
+ error_files = []
448
+ for line in progress.tqdm(data.split("\n"), total=count):
449
+ sp_line = line.split("|")
450
+ if len(sp_line) != 2:
451
+ continue
452
+ name_audio, text = sp_line[:2]
453
 
454
  file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
455
 
456
+ if not os.path.isfile(file_audio):
457
  error_files.append(file_audio)
458
  continue
459
 
460
  duraction = get_audio_duration(file_audio)
461
+ if duraction < 2 and duraction > 15:
462
+ continue
463
+ if len(text) < 4:
464
+ continue
465
 
466
  text = clear_text(text)
467
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
468
 
469
  audio_path_list.append(file_audio)
470
  duration_list.append(duraction)
471
  text_list.append(text)
472
+
473
  result.append({"audio_path": file_audio, "text": text, "duration": duraction})
474
 
475
+ lenght += duraction
476
 
477
+ if duration_list == []:
478
+ error_files_text = "\n".join(error_files)
479
  return f"Error: No audio files found in the specified path : \n{error_files_text}"
480
+
481
+ min_second = round(min(duration_list), 2)
482
+ max_second = round(max(duration_list), 2)
483
 
484
  with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
485
+ for line in progress.tqdm(result, total=len(result), desc="prepare data"):
486
  writer.write(line)
487
 
488
+ with open(file_duration, "w", encoding="utf-8") as f:
489
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
490
+
491
+ file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
492
+ if not os.path.isfile(file_vocab_finetune):
493
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
494
  shutil.copy2(file_vocab_finetune, file_vocab)
495
+
496
+ if error_files != []:
497
+ error_text = "error files\n" + "\n".join(error_files)
498
  else:
499
+ error_text = ""
500
+
501
  return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
502
 
503
+
504
  def check_user(value):
505
+ return gr.update(visible=not value), gr.update(visible=value)
506
+
507
+
508
+ def calculate_train(
509
+ name_project,
510
+ batch_size_type,
511
+ max_samples,
512
+ learning_rate,
513
+ num_warmup_updates,
514
+ save_per_updates,
515
+ last_per_steps,
516
+ finetune,
517
+ ):
518
+ name_project += "_pinyin"
519
+ path_project = os.path.join(path_data, name_project)
520
+ file_duraction = os.path.join(path_project, "duration.json")
521
 
522
+ with open(file_duraction, "r") as file:
523
+ data = json.load(file)
 
 
524
 
525
+ duration_list = data["duration"]
 
 
 
526
 
527
  samples = len(duration_list)
528
 
529
  if torch.cuda.is_available():
530
  gpu_properties = torch.cuda.get_device_properties(0)
531
+ total_memory = gpu_properties.total_memory / (1024**3)
532
  elif torch.backends.mps.is_available():
533
+ total_memory = psutil.virtual_memory().available / (1024**3)
534
+
535
+ if batch_size_type == "frame":
536
+ batch = int(total_memory * 0.5)
537
+ batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
538
+ batch_size_per_gpu = int(38400 / batch)
539
  else:
540
+ batch_size_per_gpu = int(total_memory / 8)
541
+ batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
542
+ batch = batch_size_per_gpu
543
 
544
+ if batch_size_per_gpu <= 0:
545
+ batch_size_per_gpu = 1
546
 
547
+ if samples < 64:
548
+ max_samples = int(samples * 0.25)
549
  else:
550
+ max_samples = 64
551
+
552
+ num_warmup_updates = int(samples * 0.10)
553
+ save_per_updates = int(samples * 0.25)
554
+ last_per_steps = int(save_per_updates * 5)
555
+
556
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
557
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
558
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
559
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
560
 
561
+ if finetune:
562
+ learning_rate = 1e-4
563
+ else:
564
+ learning_rate = 7.5e-5
565
+
566
+ return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
567
 
 
568
 
569
  def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
570
  try:
571
  checkpoint = torch.load(checkpoint_path)
572
  print("Original Checkpoint Keys:", checkpoint.keys())
573
+
574
+ ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
575
 
576
  if ema_model_state_dict is not None:
577
+ new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
578
  torch.save(new_checkpoint, new_checkpoint_path)
579
  return f"New checkpoint saved at: {new_checkpoint_path}"
580
  else:
 
583
  except Exception as e:
584
  return f"An error occurred: {e}"
585
 
586
+
587
  def vocab_check(project_name):
588
  name_project = project_name + "_pinyin"
589
  path_project = os.path.join(path_data, name_project)
590
 
591
  file_metadata = os.path.join(path_project, "metadata.csv")
592
+
593
+ file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
594
+ if not os.path.isfile(file_vocab):
595
  return f"the file {file_vocab} not found !"
596
+
597
+ with open(file_vocab, "r", encoding="utf-8") as f:
598
+ data = f.read()
599
 
600
  vocab = data.split("\n")
601
 
602
+ if not os.path.isfile(file_metadata):
603
  return f"the file {file_metadata} not found !"
604
 
605
+ with open(file_metadata, "r", encoding="utf-8") as f:
606
+ data = f.read()
607
 
608
+ miss_symbols = []
609
+ miss_symbols_keep = {}
610
  for item in data.split("\n"):
611
+ sp = item.split("|")
612
+ if len(sp) != 2:
613
+ continue
614
+ text = sp[1].lower().strip()
615
+
616
+ for t in text:
617
+ if t not in vocab and t not in miss_symbols_keep:
618
+ miss_symbols.append(t)
619
+ miss_symbols_keep[t] = t
620
+ if miss_symbols == []:
621
+ info = "You can train using your language !"
622
+ else:
623
+ info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
624
 
625
  return info
626
 
627
 
 
628
  with gr.Blocks() as app:
 
629
  with gr.Row():
630
+ project_name = gr.Textbox(label="project name", value="my_speak")
631
+ bt_create = gr.Button("create new project")
 
 
 
 
 
632
 
633
+ bt_create.click(fn=create_data_project, inputs=[project_name])
634
 
635
+ with gr.Tabs():
636
+ with gr.TabItem("transcribe Data"):
637
+ ch_manual = gr.Checkbox(label="user", value=False)
638
 
639
+ mark_info_transcribe = gr.Markdown(
640
+ """```plaintext
 
 
641
  Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory.
642
 
643
  my_speak/
 
646
  ├── audio1.wav
647
  └── audio2.wav
648
  ...
649
+ ```""",
650
+ visible=False,
651
+ )
652
+
653
+ audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
654
+ txt_lang = gr.Text(label="Language", value="english")
655
+ bt_transcribe = bt_create = gr.Button("transcribe")
656
+ txt_info_transcribe = gr.Text(label="info", value="")
657
+ bt_transcribe.click(
658
+ fn=transcribe_all,
659
+ inputs=[project_name, audio_speaker, txt_lang, ch_manual],
660
+ outputs=[txt_info_transcribe],
661
+ )
662
+ ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
663
+
664
+ with gr.TabItem("prepare Data"):
665
+ gr.Markdown(
666
+ """```plaintext
667
  place all your wavs folder and your metadata.csv file in {your name project}
668
  my_speak/
669
 
 
680
  audio2|text1
681
  ...
682
 
683
+ ```"""
684
+ )
685
+
686
+ bt_prepare = bt_create = gr.Button("prepare")
687
+ txt_info_prepare = gr.Text(label="info", value="")
688
+ bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
689
+
690
+ with gr.TabItem("train Data"):
691
+ with gr.Row():
692
+ bt_calculate = bt_create = gr.Button("Auto Settings")
693
+ ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
694
+ lb_samples = gr.Label(label="samples")
695
+ batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
696
+
697
+ with gr.Row():
698
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
699
+ learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
700
+
701
+ with gr.Row():
702
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
703
+ max_samples = gr.Number(label="Max Samples", value=16)
704
+
705
+ with gr.Row():
706
+ grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
707
+ max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
708
+
709
+ with gr.Row():
710
+ epochs = gr.Number(label="Epochs", value=10)
711
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
712
+
713
+ with gr.Row():
714
+ save_per_updates = gr.Number(label="Save per Updates", value=10)
715
+ last_per_steps = gr.Number(label="Last per Steps", value=50)
716
+
717
+ with gr.Row():
718
+ start_button = gr.Button("Start Training")
719
+ stop_button = gr.Button("Stop Training", interactive=False)
720
+
721
+ txt_info_train = gr.Text(label="info", value="")
722
+ start_button.click(
723
+ fn=start_training,
724
+ inputs=[
725
+ project_name,
726
+ exp_name,
727
+ learning_rate,
728
+ batch_size_per_gpu,
729
+ batch_size_type,
730
+ max_samples,
731
+ grad_accumulation_steps,
732
+ max_grad_norm,
733
+ epochs,
734
+ num_warmup_updates,
735
+ save_per_updates,
736
+ last_per_steps,
737
+ ch_finetune,
738
+ ],
739
+ outputs=[txt_info_train, start_button, stop_button],
740
+ )
741
+ stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
742
+ bt_calculate.click(
743
+ fn=calculate_train,
744
+ inputs=[
745
+ project_name,
746
+ batch_size_type,
747
+ max_samples,
748
+ learning_rate,
749
+ num_warmup_updates,
750
+ save_per_updates,
751
+ last_per_steps,
752
+ ch_finetune,
753
+ ],
754
+ outputs=[
755
+ batch_size_per_gpu,
756
+ max_samples,
757
+ num_warmup_updates,
758
+ save_per_updates,
759
+ last_per_steps,
760
+ lb_samples,
761
+ learning_rate,
762
+ ],
763
+ )
764
+
765
+ with gr.TabItem("reduse checkpoint"):
766
+ txt_path_checkpoint = gr.Text(label="path checkpoint :")
767
+ txt_path_checkpoint_small = gr.Text(label="path output :")
768
+ txt_info_reduse = gr.Text(label="info", value="")
769
+ reduse_button = gr.Button("reduse")
770
+ reduse_button.click(
771
+ fn=extract_and_save_ema_model,
772
+ inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
773
+ outputs=[txt_info_reduse],
774
+ )
775
+
776
+ with gr.TabItem("vocab check experiment"):
777
+ check_button = gr.Button("check vocab")
778
+ txt_info_check = gr.Text(label="info", value="")
779
+ check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
780
+
781
 
782
  @click.command()
783
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
792
  @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
793
  def main(port, host, share, api):
794
  global app
795
+ print("Starting app...")
796
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
797
+
 
798
 
799
  if __name__ == "__main__":
800
  main()
inference-cli.py CHANGED
@@ -44,19 +44,8 @@ parser.add_argument(
44
  "--vocab_file",
45
  help="The vocab .txt",
46
  )
47
- parser.add_argument(
48
- "-r",
49
- "--ref_audio",
50
- type=str,
51
- help="Reference audio file < 15 seconds."
52
- )
53
- parser.add_argument(
54
- "-s",
55
- "--ref_text",
56
- type=str,
57
- default="666",
58
- help="Subtitle for the reference audio."
59
- )
60
  parser.add_argument(
61
  "-t",
62
  "--gen_text",
@@ -99,8 +88,8 @@ model = args.model if args.model else config["model"]
99
  ckpt_file = args.ckpt_file if args.ckpt_file else ""
100
  vocab_file = args.vocab_file if args.vocab_file else ""
101
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
102
- wave_path = Path(output_dir)/"out.wav"
103
- spectrogram_path = Path(output_dir)/"out.png"
104
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
105
 
106
  vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
@@ -110,44 +99,46 @@ vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_loc
110
  if model == "F5-TTS":
111
  model_cls = DiT
112
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
113
- if ckpt_file == "":
114
- repo_name= "F5-TTS"
115
  exp_name = "F5TTS_Base"
116
- ckpt_step= 1200000
117
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
118
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
119
 
120
  elif model == "E2-TTS":
121
  model_cls = UNetT
122
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
123
- if ckpt_file == "":
124
- repo_name= "E2-TTS"
125
  exp_name = "E2TTS_Base"
126
- ckpt_step= 1200000
127
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
128
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
129
 
130
  print(f"Using {model}...")
131
  ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
132
-
133
 
134
  def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
135
- main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
136
  if "voices" not in config:
137
  voices = {"main": main_voice}
138
  else:
139
  voices = config["voices"]
140
  voices["main"] = main_voice
141
  for voice in voices:
142
- voices[voice]['ref_audio'], voices[voice]['ref_text'] = preprocess_ref_audio_text(voices[voice]['ref_audio'], voices[voice]['ref_text'])
 
 
143
  print("Voice:", voice)
144
- print("Ref_audio:", voices[voice]['ref_audio'])
145
- print("Ref_text:", voices[voice]['ref_text'])
146
 
147
  generated_audio_segments = []
148
- reg1 = r'(?=\[\w+\])'
149
  chunks = re.split(reg1, text_gen)
150
- reg2 = r'\[(\w+)\]'
151
  for text in chunks:
152
  match = re.match(reg2, text)
153
  if match:
@@ -160,8 +151,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
160
  voice = "main"
161
  text = re.sub(reg2, "", text)
162
  gen_text = text.strip()
163
- ref_audio = voices[voice]['ref_audio']
164
- ref_text = voices[voice]['ref_text']
165
  print(f"Voice: {voice}")
166
  audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
167
  generated_audio_segments.append(audio)
 
44
  "--vocab_file",
45
  help="The vocab .txt",
46
  )
47
+ parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
48
+ parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
 
 
 
 
 
 
 
 
 
 
 
49
  parser.add_argument(
50
  "-t",
51
  "--gen_text",
 
88
  ckpt_file = args.ckpt_file if args.ckpt_file else ""
89
  vocab_file = args.vocab_file if args.vocab_file else ""
90
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
91
+ wave_path = Path(output_dir) / "out.wav"
92
+ spectrogram_path = Path(output_dir) / "out.png"
93
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
94
 
95
  vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
 
99
  if model == "F5-TTS":
100
  model_cls = DiT
101
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
102
+ if ckpt_file == "":
103
+ repo_name = "F5-TTS"
104
  exp_name = "F5TTS_Base"
105
+ ckpt_step = 1200000
106
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
 
109
  elif model == "E2-TTS":
110
  model_cls = UNetT
111
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
112
+ if ckpt_file == "":
113
+ repo_name = "E2-TTS"
114
  exp_name = "E2TTS_Base"
115
+ ckpt_step = 1200000
116
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
  # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
 
119
  print(f"Using {model}...")
120
  ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
121
+
122
 
123
  def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
124
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
125
  if "voices" not in config:
126
  voices = {"main": main_voice}
127
  else:
128
  voices = config["voices"]
129
  voices["main"] = main_voice
130
  for voice in voices:
131
+ voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
132
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
133
+ )
134
  print("Voice:", voice)
135
+ print("Ref_audio:", voices[voice]["ref_audio"])
136
+ print("Ref_text:", voices[voice]["ref_text"])
137
 
138
  generated_audio_segments = []
139
+ reg1 = r"(?=\[\w+\])"
140
  chunks = re.split(reg1, text_gen)
141
+ reg2 = r"\[(\w+)\]"
142
  for text in chunks:
143
  match = re.match(reg2, text)
144
  if match:
 
151
  voice = "main"
152
  text = re.sub(reg2, "", text)
153
  gen_text = text.strip()
154
+ ref_audio = voices[voice]["ref_audio"]
155
+ ref_text = voices[voice]["ref_text"]
156
  print(f"Voice: {voice}")
157
  audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
158
  generated_audio_segments.append(audio)
model/__init__.py CHANGED
@@ -5,3 +5,6 @@ from model.backbones.dit import DiT
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
 
 
 
 
5
  from model.backbones.mmdit import MMDiT
6
 
7
  from model.trainer import Trainer
8
+
9
+
10
+ __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
model/backbones/dit.py CHANGED
@@ -21,14 +21,16 @@ from model.modules import (
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
  AdaLayerNormZero_Final,
24
- precompute_freqs_cis, get_pos_embed_indices,
 
25
  )
26
 
27
 
28
  # Text embedding
29
 
 
30
  class TextEmbedding(nn.Module):
31
- def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
32
  super().__init__()
33
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
34
 
@@ -36,20 +38,22 @@ class TextEmbedding(nn.Module):
36
  self.extra_modeling = True
37
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
38
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
39
- self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
40
  else:
41
  self.extra_modeling = False
42
 
43
- def forward(self, text: int['b nt'], seq_len, drop_text = False):
44
  batch, text_len = text.shape[0], text.shape[1]
45
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
46
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
47
- text = F.pad(text, (0, seq_len - text_len), value = 0)
48
 
49
  if drop_text: # cfg for text
50
  text = torch.zeros_like(text)
51
 
52
- text = self.text_embed(text) # b n -> b n d
53
 
54
  # possible extra modeling
55
  if self.extra_modeling:
@@ -67,88 +71,91 @@ class TextEmbedding(nn.Module):
67
 
68
  # noised input audio and context mixing embedding
69
 
 
70
  class InputEmbedding(nn.Module):
71
  def __init__(self, mel_dim, text_dim, out_dim):
72
  super().__init__()
73
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
74
- self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
75
 
76
- def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
77
  if drop_audio_cond: # cfg for cond audio
78
  cond = torch.zeros_like(cond)
79
 
80
- x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
81
  x = self.conv_pos_embed(x) + x
82
  return x
83
-
84
 
85
  # Transformer backbone using DiT blocks
86
 
 
87
  class DiT(nn.Module):
88
- def __init__(self, *,
89
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
90
- mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
91
- long_skip_connection = False,
 
 
 
 
 
 
 
 
 
 
92
  ):
93
  super().__init__()
94
 
95
  self.time_embed = TimestepEmbedding(dim)
96
  if text_dim is None:
97
  text_dim = mel_dim
98
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
99
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
100
 
101
  self.rotary_embed = RotaryEmbedding(dim_head)
102
 
103
  self.dim = dim
104
  self.depth = depth
105
-
106
  self.transformer_blocks = nn.ModuleList(
107
- [
108
- DiTBlock(
109
- dim = dim,
110
- heads = heads,
111
- dim_head = dim_head,
112
- ff_mult = ff_mult,
113
- dropout = dropout
114
- )
115
- for _ in range(depth)
116
- ]
117
  )
118
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
119
-
120
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
121
  self.proj_out = nn.Linear(dim, mel_dim)
122
 
123
  def forward(
124
  self,
125
- x: float['b n d'], # nosied input audio
126
- cond: float['b n d'], # masked cond audio
127
- text: int['b nt'], # text
128
- time: float['b'] | float[''], # time step
129
  drop_audio_cond, # cfg for cond audio
130
- drop_text, # cfg for text
131
- mask: bool['b n'] | None = None,
132
  ):
133
  batch, seq_len = x.shape[0], x.shape[1]
134
  if time.ndim == 0:
135
  time = time.repeat(batch)
136
-
137
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
138
  t = self.time_embed(time)
139
- text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
140
- x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
141
-
142
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
143
 
144
  if self.long_skip_connection is not None:
145
  residual = x
146
 
147
  for block in self.transformer_blocks:
148
- x = block(x, t, mask = mask, rope = rope)
149
 
150
  if self.long_skip_connection is not None:
151
- x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
152
 
153
  x = self.norm_out(x, t)
154
  output = self.proj_out(x)
 
21
  ConvPositionEmbedding,
22
  DiTBlock,
23
  AdaLayerNormZero_Final,
24
+ precompute_freqs_cis,
25
+ get_pos_embed_indices,
26
  )
27
 
28
 
29
  # Text embedding
30
 
31
+
32
  class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
  super().__init__()
35
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
 
 
38
  self.extra_modeling = True
39
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(
42
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
+ )
44
  else:
45
  self.extra_modeling = False
46
 
47
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
  batch, text_len = text.shape[0], text.shape[1]
49
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
50
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
51
+ text = F.pad(text, (0, seq_len - text_len), value=0)
52
 
53
  if drop_text: # cfg for text
54
  text = torch.zeros_like(text)
55
 
56
+ text = self.text_embed(text) # b n -> b n d
57
 
58
  # possible extra modeling
59
  if self.extra_modeling:
 
71
 
72
  # noised input audio and context mixing embedding
73
 
74
+
75
  class InputEmbedding(nn.Module):
76
  def __init__(self, mel_dim, text_dim, out_dim):
77
  super().__init__()
78
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
 
81
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
  if drop_audio_cond: # cfg for cond audio
83
  cond = torch.zeros_like(cond)
84
 
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
  x = self.conv_pos_embed(x) + x
87
  return x
88
+
89
 
90
  # Transformer backbone using DiT blocks
91
 
92
+
93
  class DiT(nn.Module):
94
+ def __init__(
95
+ self,
96
+ *,
97
+ dim,
98
+ depth=8,
99
+ heads=8,
100
+ dim_head=64,
101
+ dropout=0.1,
102
+ ff_mult=4,
103
+ mel_dim=100,
104
+ text_num_embeds=256,
105
+ text_dim=None,
106
+ conv_layers=0,
107
+ long_skip_connection=False,
108
  ):
109
  super().__init__()
110
 
111
  self.time_embed = TimestepEmbedding(dim)
112
  if text_dim is None:
113
  text_dim = mel_dim
114
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
 
117
  self.rotary_embed = RotaryEmbedding(dim_head)
118
 
119
  self.dim = dim
120
  self.depth = depth
121
+
122
  self.transformer_blocks = nn.ModuleList(
123
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
 
 
 
 
 
 
 
 
 
124
  )
125
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
+
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
130
  def forward(
131
  self,
132
+ x: float["b n d"], # nosied input audio # noqa: F722
133
+ cond: float["b n d"], # masked cond audio # noqa: F722
134
+ text: int["b nt"], # text # noqa: F722
135
+ time: float["b"] | float[""], # time step # noqa: F821 F722
136
  drop_audio_cond, # cfg for cond audio
137
+ drop_text, # cfg for text
138
+ mask: bool["b n"] | None = None, # noqa: F722
139
  ):
140
  batch, seq_len = x.shape[0], x.shape[1]
141
  if time.ndim == 0:
142
  time = time.repeat(batch)
143
+
144
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
  t = self.time_embed(time)
146
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
+
149
  rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
 
151
  if self.long_skip_connection is not None:
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
+ x = block(x, t, mask=mask, rope=rope)
156
 
157
  if self.long_skip_connection is not None:
158
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
 
160
  x = self.norm_out(x, t)
161
  output = self.proj_out(x)
model/backbones/mmdit.py CHANGED
@@ -19,12 +19,14 @@ from model.modules import (
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
  AdaLayerNormZero_Final,
22
- precompute_freqs_cis, get_pos_embed_indices,
 
23
  )
24
 
25
 
26
  # text embedding
27
 
 
28
  class TextEmbedding(nn.Module):
29
  def __init__(self, out_dim, text_num_embeds):
30
  super().__init__()
@@ -33,7 +35,7 @@ class TextEmbedding(nn.Module):
33
  self.precompute_max_pos = 1024
34
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
35
 
36
- def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
37
  text = text + 1
38
  if drop_text:
39
  text = torch.zeros_like(text)
@@ -52,27 +54,37 @@ class TextEmbedding(nn.Module):
52
 
53
  # noised input & masked cond audio embedding
54
 
 
55
  class AudioEmbedding(nn.Module):
56
  def __init__(self, in_dim, out_dim):
57
  super().__init__()
58
  self.linear = nn.Linear(2 * in_dim, out_dim)
59
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
60
 
61
- def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
62
  if drop_audio_cond:
63
  cond = torch.zeros_like(cond)
64
- x = torch.cat((x, cond), dim = -1)
65
  x = self.linear(x)
66
  x = self.conv_pos_embed(x) + x
67
  return x
68
-
69
 
70
  # Transformer backbone using MM-DiT blocks
71
 
 
72
  class MMDiT(nn.Module):
73
- def __init__(self, *,
74
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
75
- text_num_embeds = 256, mel_dim = 100,
 
 
 
 
 
 
 
 
76
  ):
77
  super().__init__()
78
 
@@ -84,16 +96,16 @@ class MMDiT(nn.Module):
84
 
85
  self.dim = dim
86
  self.depth = depth
87
-
88
  self.transformer_blocks = nn.ModuleList(
89
  [
90
  MMDiTBlock(
91
- dim = dim,
92
- heads = heads,
93
- dim_head = dim_head,
94
- dropout = dropout,
95
- ff_mult = ff_mult,
96
- context_pre_only = i == depth - 1,
97
  )
98
  for i in range(depth)
99
  ]
@@ -103,13 +115,13 @@ class MMDiT(nn.Module):
103
 
104
  def forward(
105
  self,
106
- x: float['b n d'], # nosied input audio
107
- cond: float['b n d'], # masked cond audio
108
- text: int['b nt'], # text
109
- time: float['b'] | float[''], # time step
110
  drop_audio_cond, # cfg for cond audio
111
- drop_text, # cfg for text
112
- mask: bool['b n'] | None = None,
113
  ):
114
  batch = x.shape[0]
115
  if time.ndim == 0:
@@ -117,16 +129,16 @@ class MMDiT(nn.Module):
117
 
118
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
119
  t = self.time_embed(time)
120
- c = self.text_embed(text, drop_text = drop_text)
121
- x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
122
 
123
  seq_len = x.shape[1]
124
  text_len = text.shape[1]
125
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
126
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
127
-
128
  for block in self.transformer_blocks:
129
- c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
130
 
131
  x = self.norm_out(x, t)
132
  output = self.proj_out(x)
 
19
  ConvPositionEmbedding,
20
  MMDiTBlock,
21
  AdaLayerNormZero_Final,
22
+ precompute_freqs_cis,
23
+ get_pos_embed_indices,
24
  )
25
 
26
 
27
  # text embedding
28
 
29
+
30
  class TextEmbedding(nn.Module):
31
  def __init__(self, out_dim, text_num_embeds):
32
  super().__init__()
 
35
  self.precompute_max_pos = 1024
36
  self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
 
38
+ def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
  text = text + 1
40
  if drop_text:
41
  text = torch.zeros_like(text)
 
54
 
55
  # noised input & masked cond audio embedding
56
 
57
+
58
  class AudioEmbedding(nn.Module):
59
  def __init__(self, in_dim, out_dim):
60
  super().__init__()
61
  self.linear = nn.Linear(2 * in_dim, out_dim)
62
  self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
 
64
+ def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
  if drop_audio_cond:
66
  cond = torch.zeros_like(cond)
67
+ x = torch.cat((x, cond), dim=-1)
68
  x = self.linear(x)
69
  x = self.conv_pos_embed(x) + x
70
  return x
71
+
72
 
73
  # Transformer backbone using MM-DiT blocks
74
 
75
+
76
  class MMDiT(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth=8,
82
+ heads=8,
83
+ dim_head=64,
84
+ dropout=0.1,
85
+ ff_mult=4,
86
+ text_num_embeds=256,
87
+ mel_dim=100,
88
  ):
89
  super().__init__()
90
 
 
96
 
97
  self.dim = dim
98
  self.depth = depth
99
+
100
  self.transformer_blocks = nn.ModuleList(
101
  [
102
  MMDiTBlock(
103
+ dim=dim,
104
+ heads=heads,
105
+ dim_head=dim_head,
106
+ dropout=dropout,
107
+ ff_mult=ff_mult,
108
+ context_pre_only=i == depth - 1,
109
  )
110
  for i in range(depth)
111
  ]
 
115
 
116
  def forward(
117
  self,
118
+ x: float["b n d"], # nosied input audio # noqa: F722
119
+ cond: float["b n d"], # masked cond audio # noqa: F722
120
+ text: int["b nt"], # text # noqa: F722
121
+ time: float["b"] | float[""], # time step # noqa: F821 F722
122
  drop_audio_cond, # cfg for cond audio
123
+ drop_text, # cfg for text
124
+ mask: bool["b n"] | None = None, # noqa: F722
125
  ):
126
  batch = x.shape[0]
127
  if time.ndim == 0:
 
129
 
130
  # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
  t = self.time_embed(time)
132
+ c = self.text_embed(text, drop_text=drop_text)
133
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
 
135
  seq_len = x.shape[1]
136
  text_len = text.shape[1]
137
  rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
  rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
+
140
  for block in self.transformer_blocks:
141
+ c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
 
143
  x = self.norm_out(x, t)
144
  output = self.proj_out(x)
model/backbones/unett.py CHANGED
@@ -24,14 +24,16 @@ from model.modules import (
24
  Attention,
25
  AttnProcessor,
26
  FeedForward,
27
- precompute_freqs_cis, get_pos_embed_indices,
 
28
  )
29
 
30
 
31
  # Text embedding
32
 
 
33
  class TextEmbedding(nn.Module):
34
- def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
35
  super().__init__()
36
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
37
 
@@ -39,20 +41,22 @@ class TextEmbedding(nn.Module):
39
  self.extra_modeling = True
40
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
41
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
42
- self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
 
 
43
  else:
44
  self.extra_modeling = False
45
 
46
- def forward(self, text: int['b nt'], seq_len, drop_text = False):
47
  batch, text_len = text.shape[0], text.shape[1]
48
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
- text = F.pad(text, (0, seq_len - text_len), value = 0)
51
 
52
  if drop_text: # cfg for text
53
  text = torch.zeros_like(text)
54
 
55
- text = self.text_embed(text) # b n -> b n d
56
 
57
  # possible extra modeling
58
  if self.extra_modeling:
@@ -70,28 +74,40 @@ class TextEmbedding(nn.Module):
70
 
71
  # noised input audio and context mixing embedding
72
 
 
73
  class InputEmbedding(nn.Module):
74
  def __init__(self, mel_dim, text_dim, out_dim):
75
  super().__init__()
76
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
77
- self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
78
 
79
- def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
80
  if drop_audio_cond: # cfg for cond audio
81
  cond = torch.zeros_like(cond)
82
 
83
- x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
84
  x = self.conv_pos_embed(x) + x
85
  return x
86
 
87
 
88
  # Flat UNet Transformer backbone
89
 
 
90
  class UNetT(nn.Module):
91
- def __init__(self, *,
92
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
93
- mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
94
- skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
 
 
 
 
 
 
 
 
 
 
95
  ):
96
  super().__init__()
97
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
@@ -99,7 +115,7 @@ class UNetT(nn.Module):
99
  self.time_embed = TimestepEmbedding(dim)
100
  if text_dim is None:
101
  text_dim = mel_dim
102
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
103
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
104
 
105
  self.rotary_embed = RotaryEmbedding(dim_head)
@@ -108,7 +124,7 @@ class UNetT(nn.Module):
108
 
109
  self.dim = dim
110
  self.skip_connect_type = skip_connect_type
111
- needs_skip_proj = skip_connect_type == 'concat'
112
 
113
  self.depth = depth
114
  self.layers = nn.ModuleList([])
@@ -118,53 +134,57 @@ class UNetT(nn.Module):
118
 
119
  attn_norm = RMSNorm(dim)
120
  attn = Attention(
121
- processor = AttnProcessor(),
122
- dim = dim,
123
- heads = heads,
124
- dim_head = dim_head,
125
- dropout = dropout,
126
- )
127
 
128
  ff_norm = RMSNorm(dim)
129
- ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
130
-
131
- skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
132
-
133
- self.layers.append(nn.ModuleList([
134
- skip_proj,
135
- attn_norm,
136
- attn,
137
- ff_norm,
138
- ff,
139
- ]))
 
 
 
 
140
 
141
  self.norm_out = RMSNorm(dim)
142
  self.proj_out = nn.Linear(dim, mel_dim)
143
 
144
  def forward(
145
  self,
146
- x: float['b n d'], # nosied input audio
147
- cond: float['b n d'], # masked cond audio
148
- text: int['b nt'], # text
149
- time: float['b'] | float[''], # time step
150
  drop_audio_cond, # cfg for cond audio
151
- drop_text, # cfg for text
152
- mask: bool['b n'] | None = None,
153
  ):
154
  batch, seq_len = x.shape[0], x.shape[1]
155
  if time.ndim == 0:
156
  time = time.repeat(batch)
157
-
158
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
159
  t = self.time_embed(time)
160
- text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
161
- x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
162
 
163
  # postfix time t to input x, [b n d] -> [b n+1 d]
164
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
165
  if mask is not None:
166
  mask = F.pad(mask, (1, 0), value=1)
167
-
168
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
169
 
170
  # flat unet transformer
@@ -182,14 +202,14 @@ class UNetT(nn.Module):
182
 
183
  if is_later_half:
184
  skip = skips.pop()
185
- if skip_connect_type == 'concat':
186
- x = torch.cat((x, skip), dim = -1)
187
  x = maybe_skip_proj(x)
188
- elif skip_connect_type == 'add':
189
  x = x + skip
190
 
191
  # attention and feedforward blocks
192
- x = attn(attn_norm(x), rope = rope, mask = mask) + x
193
  x = ff(ff_norm(x)) + x
194
 
195
  assert len(skips) == 0
 
24
  Attention,
25
  AttnProcessor,
26
  FeedForward,
27
+ precompute_freqs_cis,
28
+ get_pos_embed_indices,
29
  )
30
 
31
 
32
  # Text embedding
33
 
34
+
35
  class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
  super().__init__()
38
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
 
 
41
  self.extra_modeling = True
42
  self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(
45
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
+ )
47
  else:
48
  self.extra_modeling = False
49
 
50
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
  batch, text_len = text.shape[0], text.shape[1]
52
  text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
53
  text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
54
+ text = F.pad(text, (0, seq_len - text_len), value=0)
55
 
56
  if drop_text: # cfg for text
57
  text = torch.zeros_like(text)
58
 
59
+ text = self.text_embed(text) # b n -> b n d
60
 
61
  # possible extra modeling
62
  if self.extra_modeling:
 
74
 
75
  # noised input audio and context mixing embedding
76
 
77
+
78
  class InputEmbedding(nn.Module):
79
  def __init__(self, mel_dim, text_dim, out_dim):
80
  super().__init__()
81
  self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
 
84
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
  if drop_audio_cond: # cfg for cond audio
86
  cond = torch.zeros_like(cond)
87
 
88
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
  x = self.conv_pos_embed(x) + x
90
  return x
91
 
92
 
93
  # Flat UNet Transformer backbone
94
 
95
+
96
  class UNetT(nn.Module):
97
+ def __init__(
98
+ self,
99
+ *,
100
+ dim,
101
+ depth=8,
102
+ heads=8,
103
+ dim_head=64,
104
+ dropout=0.1,
105
+ ff_mult=4,
106
+ mel_dim=100,
107
+ text_num_embeds=256,
108
+ text_dim=None,
109
+ conv_layers=0,
110
+ skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
  ):
112
  super().__init__()
113
  assert depth % 2 == 0, "UNet-Transformer's depth should be even."
 
115
  self.time_embed = TimestepEmbedding(dim)
116
  if text_dim is None:
117
  text_dim = mel_dim
118
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
 
121
  self.rotary_embed = RotaryEmbedding(dim_head)
 
124
 
125
  self.dim = dim
126
  self.skip_connect_type = skip_connect_type
127
+ needs_skip_proj = skip_connect_type == "concat"
128
 
129
  self.depth = depth
130
  self.layers = nn.ModuleList([])
 
134
 
135
  attn_norm = RMSNorm(dim)
136
  attn = Attention(
137
+ processor=AttnProcessor(),
138
+ dim=dim,
139
+ heads=heads,
140
+ dim_head=dim_head,
141
+ dropout=dropout,
142
+ )
143
 
144
  ff_norm = RMSNorm(dim)
145
+ ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
+
147
+ skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
+
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ skip_proj,
153
+ attn_norm,
154
+ attn,
155
+ ff_norm,
156
+ ff,
157
+ ]
158
+ )
159
+ )
160
 
161
  self.norm_out = RMSNorm(dim)
162
  self.proj_out = nn.Linear(dim, mel_dim)
163
 
164
  def forward(
165
  self,
166
+ x: float["b n d"], # nosied input audio # noqa: F722
167
+ cond: float["b n d"], # masked cond audio # noqa: F722
168
+ text: int["b nt"], # text # noqa: F722
169
+ time: float["b"] | float[""], # time step # noqa: F821 F722
170
  drop_audio_cond, # cfg for cond audio
171
+ drop_text, # cfg for text
172
+ mask: bool["b n"] | None = None, # noqa: F722
173
  ):
174
  batch, seq_len = x.shape[0], x.shape[1]
175
  if time.ndim == 0:
176
  time = time.repeat(batch)
177
+
178
  # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
  t = self.time_embed(time)
180
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
 
183
  # postfix time t to input x, [b n d] -> [b n+1 d]
184
  x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
  if mask is not None:
186
  mask = F.pad(mask, (1, 0), value=1)
187
+
188
  rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
 
190
  # flat unet transformer
 
202
 
203
  if is_later_half:
204
  skip = skips.pop()
205
+ if skip_connect_type == "concat":
206
+ x = torch.cat((x, skip), dim=-1)
207
  x = maybe_skip_proj(x)
208
+ elif skip_connect_type == "add":
209
  x = x + skip
210
 
211
  # attention and feedforward blocks
212
+ x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
  x = ff(ff_norm(x)) + x
214
 
215
  assert len(skips) == 0
model/cfm.py CHANGED
@@ -20,29 +20,32 @@ from torchdiffeq import odeint
20
 
21
  from model.modules import MelSpec
22
  from model.utils import (
23
- default, exists,
24
- list_str_to_idx, list_str_to_tensor,
25
- lens_to_mask, mask_from_frac_lengths,
26
- )
 
 
 
27
 
28
 
29
  class CFM(nn.Module):
30
  def __init__(
31
  self,
32
  transformer: nn.Module,
33
- sigma = 0.,
34
  odeint_kwargs: dict = dict(
35
  # atol = 1e-5,
36
  # rtol = 1e-5,
37
- method = 'euler' # 'midpoint'
38
  ),
39
- audio_drop_prob = 0.3,
40
- cond_drop_prob = 0.2,
41
- num_channels = None,
42
  mel_spec_module: nn.Module | None = None,
43
  mel_spec_kwargs: dict = dict(),
44
- frac_lengths_mask: tuple[float, float] = (0.7, 1.),
45
- vocab_char_map: dict[str: int] | None = None
46
  ):
47
  super().__init__()
48
 
@@ -78,21 +81,21 @@ class CFM(nn.Module):
78
  @torch.no_grad()
79
  def sample(
80
  self,
81
- cond: float['b n d'] | float['b nw'],
82
- text: int['b nt'] | list[str],
83
- duration: int | int['b'],
84
  *,
85
- lens: int['b'] | None = None,
86
- steps = 32,
87
- cfg_strength = 1.,
88
- sway_sampling_coef = None,
89
  seed: int | None = None,
90
- max_duration = 4096,
91
- vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
92
- no_ref_audio = False,
93
- duplicate_test = False,
94
- t_inter = 0.1,
95
- edit_mask = None,
96
  ):
97
  self.eval()
98
 
@@ -108,7 +111,7 @@ class CFM(nn.Module):
108
 
109
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
  if not exists(lens):
111
- lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
 
113
  # text
114
 
@@ -120,8 +123,8 @@ class CFM(nn.Module):
120
  assert text.shape[0] == batch
121
 
122
  if exists(text):
123
- text_lens = (text != -1).sum(dim = -1)
124
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
 
126
  # duration
127
 
@@ -130,20 +133,22 @@ class CFM(nn.Module):
130
  cond_mask = cond_mask & edit_mask
131
 
132
  if isinstance(duration, int):
133
- duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
 
135
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
- duration = duration.clamp(max = max_duration)
137
  max_duration = duration.amax()
138
-
139
  # duplicate test corner for inner time step oberservation
140
  if duplicate_test:
141
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
-
143
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
  cond_mask = cond_mask.unsqueeze(-1)
146
- step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
 
 
147
 
148
  if batch > 1:
149
  mask = lens_to_mask(duration)
@@ -161,11 +166,15 @@ class CFM(nn.Module):
161
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
 
163
  # predict flow
164
- pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
 
 
165
  if cfg_strength < 1e-5:
166
  return pred
167
-
168
- null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
 
 
169
  return pred + (pred - null_pred) * cfg_strength
170
 
171
  # noise input
@@ -175,8 +184,8 @@ class CFM(nn.Module):
175
  for dur in duration:
176
  if exists(seed):
177
  torch.manual_seed(seed)
178
- y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
179
- y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
 
181
  t_start = 0
182
 
@@ -186,12 +195,12 @@ class CFM(nn.Module):
186
  y0 = (1 - t_start) * y0 + t_start * test_cond
187
  steps = int(steps * (1 - t_start))
188
 
189
- t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
190
  if sway_sampling_coef is not None:
191
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
 
193
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
-
195
  sampled = trajectory[-1]
196
  out = sampled
197
  out = torch.where(cond_mask, cond, out)
@@ -204,10 +213,10 @@ class CFM(nn.Module):
204
 
205
  def forward(
206
  self,
207
- inp: float['b n d'] | float['b nw'], # mel or raw wave
208
- text: int['b nt'] | list[str],
209
  *,
210
- lens: int['b'] | None = None,
211
  noise_scheduler: str | None = None,
212
  ):
213
  # handle raw wave
@@ -216,7 +225,7 @@ class CFM(nn.Module):
216
  inp = inp.permute(0, 2, 1)
217
  assert inp.shape[-1] == self.num_channels
218
 
219
- batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
 
221
  # handle text as string
222
  if isinstance(text, list):
@@ -228,12 +237,12 @@ class CFM(nn.Module):
228
 
229
  # lens and mask
230
  if not exists(lens):
231
- lens = torch.full((batch,), seq_len, device = device)
232
-
233
- mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
 
235
  # get a random span to mask out for training conditionally
236
- frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
 
239
  if exists(mask):
@@ -246,7 +255,7 @@ class CFM(nn.Module):
246
  x0 = torch.randn_like(x1)
247
 
248
  # time step
249
- time = torch.rand((batch,), dtype = dtype, device = self.device)
250
  # TODO. noise_scheduler
251
 
252
  # sample xt (φ_t(x) in the paper)
@@ -255,10 +264,7 @@ class CFM(nn.Module):
255
  flow = x1 - x0
256
 
257
  # only predict what is within the random mask span for infilling
258
- cond = torch.where(
259
- rand_span_mask[..., None],
260
- torch.zeros_like(x1), x1
261
- )
262
 
263
  # transformer and cfg training with a drop rate
264
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
@@ -267,13 +273,15 @@ class CFM(nn.Module):
267
  drop_text = True
268
  else:
269
  drop_text = False
270
-
271
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
- pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
 
 
274
 
275
  # flow matching loss
276
- loss = F.mse_loss(pred, flow, reduction = 'none')
277
  loss = loss[rand_span_mask]
278
 
279
  return loss.mean(), cond, pred
 
20
 
21
  from model.modules import MelSpec
22
  from model.utils import (
23
+ default,
24
+ exists,
25
+ list_str_to_idx,
26
+ list_str_to_tensor,
27
+ lens_to_mask,
28
+ mask_from_frac_lengths,
29
+ )
30
 
31
 
32
  class CFM(nn.Module):
33
  def __init__(
34
  self,
35
  transformer: nn.Module,
36
+ sigma=0.0,
37
  odeint_kwargs: dict = dict(
38
  # atol = 1e-5,
39
  # rtol = 1e-5,
40
+ method="euler" # 'midpoint'
41
  ),
42
+ audio_drop_prob=0.3,
43
+ cond_drop_prob=0.2,
44
+ num_channels=None,
45
  mel_spec_module: nn.Module | None = None,
46
  mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
48
+ vocab_char_map: dict[str:int] | None = None,
49
  ):
50
  super().__init__()
51
 
 
81
  @torch.no_grad()
82
  def sample(
83
  self,
84
+ cond: float["b n d"] | float["b nw"], # noqa: F722
85
+ text: int["b nt"] | list[str], # noqa: F722
86
+ duration: int | int["b"], # noqa: F821
87
  *,
88
+ lens: int["b"] | None = None, # noqa: F821
89
+ steps=32,
90
+ cfg_strength=1.0,
91
+ sway_sampling_coef=None,
92
  seed: int | None = None,
93
+ max_duration=4096,
94
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
95
+ no_ref_audio=False,
96
+ duplicate_test=False,
97
+ t_inter=0.1,
98
+ edit_mask=None,
99
  ):
100
  self.eval()
101
 
 
111
 
112
  batch, cond_seq_len, device = *cond.shape[:2], cond.device
113
  if not exists(lens):
114
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
115
 
116
  # text
117
 
 
123
  assert text.shape[0] == batch
124
 
125
  if exists(text):
126
+ text_lens = (text != -1).sum(dim=-1)
127
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
128
 
129
  # duration
130
 
 
133
  cond_mask = cond_mask & edit_mask
134
 
135
  if isinstance(duration, int):
136
+ duration = torch.full((batch,), duration, device=device, dtype=torch.long)
137
 
138
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
139
+ duration = duration.clamp(max=max_duration)
140
  max_duration = duration.amax()
141
+
142
  # duplicate test corner for inner time step oberservation
143
  if duplicate_test:
144
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
145
+
146
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
147
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
148
  cond_mask = cond_mask.unsqueeze(-1)
149
+ step_cond = torch.where(
150
+ cond_mask, cond, torch.zeros_like(cond)
151
+ ) # allow direct control (cut cond audio) with lens passed in
152
 
153
  if batch > 1:
154
  mask = lens_to_mask(duration)
 
166
  # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
167
 
168
  # predict flow
169
+ pred = self.transformer(
170
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
171
+ )
172
  if cfg_strength < 1e-5:
173
  return pred
174
+
175
+ null_pred = self.transformer(
176
+ x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
177
+ )
178
  return pred + (pred - null_pred) * cfg_strength
179
 
180
  # noise input
 
184
  for dur in duration:
185
  if exists(seed):
186
  torch.manual_seed(seed)
187
+ y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
188
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
189
 
190
  t_start = 0
191
 
 
195
  y0 = (1 - t_start) * y0 + t_start * test_cond
196
  steps = int(steps * (1 - t_start))
197
 
198
+ t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
199
  if sway_sampling_coef is not None:
200
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
201
 
202
  trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
203
+
204
  sampled = trajectory[-1]
205
  out = sampled
206
  out = torch.where(cond_mask, cond, out)
 
213
 
214
  def forward(
215
  self,
216
+ inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
217
+ text: int["b nt"] | list[str], # noqa: F722
218
  *,
219
+ lens: int["b"] | None = None, # noqa: F821
220
  noise_scheduler: str | None = None,
221
  ):
222
  # handle raw wave
 
225
  inp = inp.permute(0, 2, 1)
226
  assert inp.shape[-1] == self.num_channels
227
 
228
+ batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
229
 
230
  # handle text as string
231
  if isinstance(text, list):
 
237
 
238
  # lens and mask
239
  if not exists(lens):
240
+ lens = torch.full((batch,), seq_len, device=device)
241
+
242
+ mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
243
 
244
  # get a random span to mask out for training conditionally
245
+ frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
246
  rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
247
 
248
  if exists(mask):
 
255
  x0 = torch.randn_like(x1)
256
 
257
  # time step
258
+ time = torch.rand((batch,), dtype=dtype, device=self.device)
259
  # TODO. noise_scheduler
260
 
261
  # sample xt (φ_t(x) in the paper)
 
264
  flow = x1 - x0
265
 
266
  # only predict what is within the random mask span for infilling
267
+ cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
 
 
 
268
 
269
  # transformer and cfg training with a drop rate
270
  drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
 
273
  drop_text = True
274
  else:
275
  drop_text = False
276
+
277
  # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
278
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
279
+ pred = self.transformer(
280
+ x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
281
+ )
282
 
283
  # flow matching loss
284
+ loss = F.mse_loss(pred, flow, reduction="none")
285
  loss = loss[rand_span_mask]
286
 
287
  return loss.mean(), cond, pred
model/dataset.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
- from datasets import load_dataset, load_from_disk
10
  from datasets import Dataset as Dataset_
11
 
12
  from model.modules import MelSpec
@@ -16,53 +16,55 @@ class HFDataset(Dataset):
16
  def __init__(
17
  self,
18
  hf_dataset: Dataset,
19
- target_sample_rate = 24_000,
20
- n_mel_channels = 100,
21
- hop_length = 256,
22
  ):
23
  self.data = hf_dataset
24
  self.target_sample_rate = target_sample_rate
25
  self.hop_length = hop_length
26
- self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
27
-
 
 
28
  def get_frame_len(self, index):
29
  row = self.data[index]
30
- audio = row['audio']['array']
31
- sample_rate = row['audio']['sampling_rate']
32
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
33
 
34
  def __len__(self):
35
  return len(self.data)
36
-
37
  def __getitem__(self, index):
38
  row = self.data[index]
39
- audio = row['audio']['array']
40
 
41
  # logger.info(f"Audio shape: {audio.shape}")
42
 
43
- sample_rate = row['audio']['sampling_rate']
44
  duration = audio.shape[-1] / sample_rate
45
 
46
  if duration > 30 or duration < 0.3:
47
  return self.__getitem__((index + 1) % len(self.data))
48
-
49
  audio_tensor = torch.from_numpy(audio).float()
50
-
51
  if sample_rate != self.target_sample_rate:
52
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
53
  audio_tensor = resampler(audio_tensor)
54
-
55
  audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
56
-
57
  mel_spec = self.mel_spectrogram(audio_tensor)
58
-
59
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
60
-
61
- text = row['text']
62
-
63
  return dict(
64
- mel_spec = mel_spec,
65
- text = text,
66
  )
67
 
68
 
@@ -70,11 +72,11 @@ class CustomDataset(Dataset):
70
  def __init__(
71
  self,
72
  custom_dataset: Dataset,
73
- durations = None,
74
- target_sample_rate = 24_000,
75
- hop_length = 256,
76
- n_mel_channels = 100,
77
- preprocessed_mel = False,
78
  ):
79
  self.data = custom_dataset
80
  self.durations = durations
@@ -82,16 +84,20 @@ class CustomDataset(Dataset):
82
  self.hop_length = hop_length
83
  self.preprocessed_mel = preprocessed_mel
84
  if not preprocessed_mel:
85
- self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
 
 
86
 
87
  def get_frame_len(self, index):
88
- if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
 
 
89
  return self.durations[index] * self.target_sample_rate / self.hop_length
90
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
91
-
92
  def __len__(self):
93
  return len(self.data)
94
-
95
  def __getitem__(self, index):
96
  row = self.data[index]
97
  audio_path = row["audio_path"]
@@ -108,45 +114,52 @@ class CustomDataset(Dataset):
108
 
109
  if duration > 30 or duration < 0.3:
110
  return self.__getitem__((index + 1) % len(self.data))
111
-
112
  if source_sample_rate != self.target_sample_rate:
113
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
  audio = resampler(audio)
115
-
116
  mel_spec = self.mel_spectrogram(audio)
117
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
118
-
119
  return dict(
120
- mel_spec = mel_spec,
121
- text = text,
122
  )
123
-
124
 
125
  # Dynamic Batch Sampler
126
 
 
127
  class DynamicBatchSampler(Sampler[list[int]]):
128
- """ Extension of Sampler that will do the following:
129
- 1. Change the batch size (essentially number of sequences)
130
- in a batch to ensure that the total number of frames are less
131
- than a certain threshold.
132
- 2. Make sure the padding efficiency in the batch is high.
133
  """
134
 
135
- def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
 
 
136
  self.sampler = sampler
137
  self.frames_threshold = frames_threshold
138
  self.max_samples = max_samples
139
 
140
  indices, batches = [], []
141
  data_source = self.sampler.data_source
142
-
143
- for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
 
 
144
  indices.append((idx, data_source.get_frame_len(idx)))
145
- indices.sort(key=lambda elem : elem[1])
146
 
147
  batch = []
148
  batch_frames = 0
149
- for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
 
 
150
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
  batch.append(idx)
152
  batch_frames += frame_len
@@ -182,76 +195,86 @@ class DynamicBatchSampler(Sampler[list[int]]):
182
 
183
  # Load dataset
184
 
 
185
  def load_dataset(
186
- dataset_name: str,
187
- tokenizer: str = "pinyin",
188
- dataset_type: str = "CustomDataset",
189
- audio_type: str = "raw",
190
- mel_spec_kwargs: dict = dict()
191
- ) -> CustomDataset | HFDataset:
192
- '''
193
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
  - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
- '''
196
-
197
  print("Loading dataset ...")
198
 
199
  if dataset_type == "CustomDataset":
200
  if audio_type == "raw":
201
  try:
202
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
203
- except:
204
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
205
  preprocessed_mel = False
206
  elif audio_type == "mel":
207
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
208
  preprocessed_mel = True
209
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
210
  data_dict = json.load(f)
211
  durations = data_dict["duration"]
212
- train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
-
 
 
214
  elif dataset_type == "CustomDatasetPath":
215
  try:
216
  train_dataset = load_from_disk(f"{dataset_name}/raw")
217
- except:
218
  train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
-
220
- with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
  data_dict = json.load(f)
222
  durations = data_dict["duration"]
223
- train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
-
 
 
225
  elif dataset_type == "HFDataset":
226
- print("Should manually modify the path of huggingface dataset to your need.\n" +
227
- "May also the corresponding script cuz different dataset may have different format.")
 
 
228
  pre, post = dataset_name.split("_")
229
- train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
 
 
230
 
231
  return train_dataset
232
 
233
 
234
  # collation
235
 
 
236
  def collate_fn(batch):
237
- mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
238
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
239
  max_mel_length = mel_lengths.amax()
240
 
241
  padded_mel_specs = []
242
  for spec in mel_specs: # TODO. maybe records mask for attention here
243
  padding = (0, max_mel_length - spec.size(-1))
244
- padded_spec = F.pad(spec, padding, value = 0)
245
  padded_mel_specs.append(padded_spec)
246
-
247
  mel_specs = torch.stack(padded_mel_specs)
248
 
249
- text = [item['text'] for item in batch]
250
  text_lengths = torch.LongTensor([len(item) for item in text])
251
 
252
  return dict(
253
- mel = mel_specs,
254
- mel_lengths = mel_lengths,
255
- text = text,
256
- text_lengths = text_lengths,
257
  )
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import Dataset, Sampler
8
  import torchaudio
9
+ from datasets import load_from_disk
10
  from datasets import Dataset as Dataset_
11
 
12
  from model.modules import MelSpec
 
16
  def __init__(
17
  self,
18
  hf_dataset: Dataset,
19
+ target_sample_rate=24_000,
20
+ n_mel_channels=100,
21
+ hop_length=256,
22
  ):
23
  self.data = hf_dataset
24
  self.target_sample_rate = target_sample_rate
25
  self.hop_length = hop_length
26
+ self.mel_spectrogram = MelSpec(
27
+ target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
28
+ )
29
+
30
  def get_frame_len(self, index):
31
  row = self.data[index]
32
+ audio = row["audio"]["array"]
33
+ sample_rate = row["audio"]["sampling_rate"]
34
  return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
 
36
  def __len__(self):
37
  return len(self.data)
38
+
39
  def __getitem__(self, index):
40
  row = self.data[index]
41
+ audio = row["audio"]["array"]
42
 
43
  # logger.info(f"Audio shape: {audio.shape}")
44
 
45
+ sample_rate = row["audio"]["sampling_rate"]
46
  duration = audio.shape[-1] / sample_rate
47
 
48
  if duration > 30 or duration < 0.3:
49
  return self.__getitem__((index + 1) % len(self.data))
50
+
51
  audio_tensor = torch.from_numpy(audio).float()
52
+
53
  if sample_rate != self.target_sample_rate:
54
  resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
  audio_tensor = resampler(audio_tensor)
56
+
57
  audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
58
+
59
  mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
62
+
63
+ text = row["text"]
64
+
65
  return dict(
66
+ mel_spec=mel_spec,
67
+ text=text,
68
  )
69
 
70
 
 
72
  def __init__(
73
  self,
74
  custom_dataset: Dataset,
75
+ durations=None,
76
+ target_sample_rate=24_000,
77
+ hop_length=256,
78
+ n_mel_channels=100,
79
+ preprocessed_mel=False,
80
  ):
81
  self.data = custom_dataset
82
  self.durations = durations
 
84
  self.hop_length = hop_length
85
  self.preprocessed_mel = preprocessed_mel
86
  if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(
88
+ target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
89
+ )
90
 
91
  def get_frame_len(self, index):
92
+ if (
93
+ self.durations is not None
94
+ ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
95
  return self.durations[index] * self.target_sample_rate / self.hop_length
96
  return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
97
+
98
  def __len__(self):
99
  return len(self.data)
100
+
101
  def __getitem__(self, index):
102
  row = self.data[index]
103
  audio_path = row["audio_path"]
 
114
 
115
  if duration > 30 or duration < 0.3:
116
  return self.__getitem__((index + 1) % len(self.data))
117
+
118
  if source_sample_rate != self.target_sample_rate:
119
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
120
  audio = resampler(audio)
121
+
122
  mel_spec = self.mel_spectrogram(audio)
123
  mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
124
+
125
  return dict(
126
+ mel_spec=mel_spec,
127
+ text=text,
128
  )
129
+
130
 
131
  # Dynamic Batch Sampler
132
 
133
+
134
  class DynamicBatchSampler(Sampler[list[int]]):
135
+ """Extension of Sampler that will do the following:
136
+ 1. Change the batch size (essentially number of sequences)
137
+ in a batch to ensure that the total number of frames are less
138
+ than a certain threshold.
139
+ 2. Make sure the padding efficiency in the batch is high.
140
  """
141
 
142
+ def __init__(
143
+ self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
144
+ ):
145
  self.sampler = sampler
146
  self.frames_threshold = frames_threshold
147
  self.max_samples = max_samples
148
 
149
  indices, batches = [], []
150
  data_source = self.sampler.data_source
151
+
152
+ for idx in tqdm(
153
+ self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
154
+ ):
155
  indices.append((idx, data_source.get_frame_len(idx)))
156
+ indices.sort(key=lambda elem: elem[1])
157
 
158
  batch = []
159
  batch_frames = 0
160
+ for idx, frame_len in tqdm(
161
+ indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
162
+ ):
163
  if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
164
  batch.append(idx)
165
  batch_frames += frame_len
 
195
 
196
  # Load dataset
197
 
198
+
199
  def load_dataset(
200
+ dataset_name: str,
201
+ tokenizer: str = "pinyin",
202
+ dataset_type: str = "CustomDataset",
203
+ audio_type: str = "raw",
204
+ mel_spec_kwargs: dict = dict(),
205
+ ) -> CustomDataset | HFDataset:
206
+ """
207
  dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
208
  - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
209
+ """
210
+
211
  print("Loading dataset ...")
212
 
213
  if dataset_type == "CustomDataset":
214
  if audio_type == "raw":
215
  try:
216
  train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
217
+ except: # noqa: E722
218
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
219
  preprocessed_mel = False
220
  elif audio_type == "mel":
221
  train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
222
  preprocessed_mel = True
223
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
224
  data_dict = json.load(f)
225
  durations = data_dict["duration"]
226
+ train_dataset = CustomDataset(
227
+ train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
228
+ )
229
+
230
  elif dataset_type == "CustomDatasetPath":
231
  try:
232
  train_dataset = load_from_disk(f"{dataset_name}/raw")
233
+ except: # noqa: E722
234
  train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
235
+
236
+ with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
237
  data_dict = json.load(f)
238
  durations = data_dict["duration"]
239
+ train_dataset = CustomDataset(
240
+ train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
241
+ )
242
+
243
  elif dataset_type == "HFDataset":
244
+ print(
245
+ "Should manually modify the path of huggingface dataset to your need.\n"
246
+ + "May also the corresponding script cuz different dataset may have different format."
247
+ )
248
  pre, post = dataset_name.split("_")
249
+ train_dataset = HFDataset(
250
+ load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
251
+ )
252
 
253
  return train_dataset
254
 
255
 
256
  # collation
257
 
258
+
259
  def collate_fn(batch):
260
+ mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
261
  mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
262
  max_mel_length = mel_lengths.amax()
263
 
264
  padded_mel_specs = []
265
  for spec in mel_specs: # TODO. maybe records mask for attention here
266
  padding = (0, max_mel_length - spec.size(-1))
267
+ padded_spec = F.pad(spec, padding, value=0)
268
  padded_mel_specs.append(padded_spec)
269
+
270
  mel_specs = torch.stack(padded_mel_specs)
271
 
272
+ text = [item["text"] for item in batch]
273
  text_lengths = torch.LongTensor([len(item) for item in text])
274
 
275
  return dict(
276
+ mel=mel_specs,
277
+ mel_lengths=mel_lengths,
278
+ text=text,
279
+ text_lengths=text_lengths,
280
  )
model/ecapa_tdnn.py CHANGED
@@ -9,13 +9,14 @@ import torch.nn as nn
9
  import torch.nn.functional as F
10
 
11
 
12
- ''' Res2Conv1d + BatchNorm1d + ReLU
13
- '''
 
14
 
15
  class Res2Conv1dReluBn(nn.Module):
16
- '''
17
  in_channels == out_channels == channels
18
- '''
19
 
20
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
  super().__init__()
@@ -51,8 +52,9 @@ class Res2Conv1dReluBn(nn.Module):
51
  return out
52
 
53
 
54
- ''' Conv1d + BatchNorm1d + ReLU
55
- '''
 
56
 
57
  class Conv1dReluBn(nn.Module):
58
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
@@ -64,8 +66,9 @@ class Conv1dReluBn(nn.Module):
64
  return self.bn(F.relu(self.conv(x)))
65
 
66
 
67
- ''' The SE connection of 1D case.
68
- '''
 
69
 
70
  class SE_Connect(nn.Module):
71
  def __init__(self, channels, se_bottleneck_dim=128):
@@ -82,8 +85,8 @@ class SE_Connect(nn.Module):
82
  return out
83
 
84
 
85
- ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
- '''
87
 
88
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
  # return nn.Sequential(
@@ -93,6 +96,7 @@ class SE_Connect(nn.Module):
93
  # SE_Connect(channels)
94
  # )
95
 
 
96
  class SE_Res2Block(nn.Module):
97
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
  super().__init__()
@@ -122,8 +126,9 @@ class SE_Res2Block(nn.Module):
122
  return x + residual
123
 
124
 
125
- ''' Attentive weighted mean and standard deviation pooling.
126
- '''
 
127
 
128
  class AttentiveStatsPool(nn.Module):
129
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -138,7 +143,6 @@ class AttentiveStatsPool(nn.Module):
138
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
 
140
  def forward(self, x):
141
-
142
  if self.global_context_att:
143
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
@@ -151,38 +155,52 @@ class AttentiveStatsPool(nn.Module):
151
  # alpha = F.relu(self.linear1(x_in))
152
  alpha = torch.softmax(self.linear2(alpha), dim=2)
153
  mean = torch.sum(alpha * x, dim=2)
154
- residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
  std = torch.sqrt(residuals.clamp(min=1e-9))
156
  return torch.cat([mean, std], dim=1)
157
 
158
 
159
  class ECAPA_TDNN(nn.Module):
160
- def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
- feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
 
 
 
 
 
 
 
 
 
 
162
  super().__init__()
163
 
164
  self.feat_type = feat_type
165
  self.feature_selection = feature_selection
166
  self.update_extract = update_extract
167
  self.sr = sr
168
-
169
- torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
  try:
171
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
- except:
174
- self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
 
176
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
 
 
177
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
 
 
179
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
 
181
  self.feat_num = self.get_feat_num()
182
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
 
184
- if feat_type != 'fbank' and feat_type != 'mfcc':
185
- freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
  for name, param in self.feature_extract.named_parameters():
187
  for freeze_val in freeze_list:
188
  if freeze_val in name:
@@ -198,18 +216,46 @@ class ECAPA_TDNN(nn.Module):
198
  self.channels = [channels] * 4 + [1536]
199
 
200
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
- self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
- self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
- self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
  cat_channels = channels * 3
207
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
- self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
 
 
209
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
 
212
-
213
  def get_feat_num(self):
214
  self.feature_extract.eval()
215
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
@@ -226,12 +272,12 @@ class ECAPA_TDNN(nn.Module):
226
  x = self.feature_extract([sample for sample in x])
227
  else:
228
  with torch.no_grad():
229
- if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
  else:
232
  x = self.feature_extract([sample for sample in x])
233
 
234
- if self.feat_type == 'fbank':
235
  x = x.log()
236
 
237
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
@@ -263,6 +309,22 @@ class ECAPA_TDNN(nn.Module):
263
  return out
264
 
265
 
266
- def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
- return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
- feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import torch.nn.functional as F
10
 
11
 
12
+ """ Res2Conv1d + BatchNorm1d + ReLU
13
+ """
14
+
15
 
16
  class Res2Conv1dReluBn(nn.Module):
17
+ """
18
  in_channels == out_channels == channels
19
+ """
20
 
21
  def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
  super().__init__()
 
52
  return out
53
 
54
 
55
+ """ Conv1d + BatchNorm1d + ReLU
56
+ """
57
+
58
 
59
  class Conv1dReluBn(nn.Module):
60
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
 
66
  return self.bn(F.relu(self.conv(x)))
67
 
68
 
69
+ """ The SE connection of 1D case.
70
+ """
71
+
72
 
73
  class SE_Connect(nn.Module):
74
  def __init__(self, channels, se_bottleneck_dim=128):
 
85
  return out
86
 
87
 
88
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
89
+ """
90
 
91
  # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
  # return nn.Sequential(
 
96
  # SE_Connect(channels)
97
  # )
98
 
99
+
100
  class SE_Res2Block(nn.Module):
101
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
  super().__init__()
 
126
  return x + residual
127
 
128
 
129
+ """ Attentive weighted mean and standard deviation pooling.
130
+ """
131
+
132
 
133
  class AttentiveStatsPool(nn.Module):
134
  def __init__(self, in_dim, attention_channels=128, global_context_att=False):
 
143
  self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
 
145
  def forward(self, x):
 
146
  if self.global_context_att:
147
  context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
  context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
 
155
  # alpha = F.relu(self.linear1(x_in))
156
  alpha = torch.softmax(self.linear2(alpha), dim=2)
157
  mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
  std = torch.sqrt(residuals.clamp(min=1e-9))
160
  return torch.cat([mean, std], dim=1)
161
 
162
 
163
  class ECAPA_TDNN(nn.Module):
164
+ def __init__(
165
+ self,
166
+ feat_dim=80,
167
+ channels=512,
168
+ emb_dim=192,
169
+ global_context_att=False,
170
+ feat_type="wavlm_large",
171
+ sr=16000,
172
+ feature_selection="hidden_states",
173
+ update_extract=False,
174
+ config_path=None,
175
+ ):
176
  super().__init__()
177
 
178
  self.feat_type = feat_type
179
  self.feature_selection = feature_selection
180
  self.update_extract = update_extract
181
  self.sr = sr
182
+
183
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
  try:
185
  local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
+ except: # noqa: E722
188
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
 
190
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
+ self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
+ ):
193
  self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
+ self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
+ ):
197
  self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
 
199
  self.feat_num = self.get_feat_num()
200
  self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
 
202
+ if feat_type != "fbank" and feat_type != "mfcc":
203
+ freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
  for name, param in self.feature_extract.named_parameters():
205
  for freeze_val in freeze_list:
206
  if freeze_val in name:
 
216
  self.channels = [channels] * 4 + [1536]
217
 
218
  self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
+ self.layer2 = SE_Res2Block(
220
+ self.channels[0],
221
+ self.channels[1],
222
+ kernel_size=3,
223
+ stride=1,
224
+ padding=2,
225
+ dilation=2,
226
+ scale=8,
227
+ se_bottleneck_dim=128,
228
+ )
229
+ self.layer3 = SE_Res2Block(
230
+ self.channels[1],
231
+ self.channels[2],
232
+ kernel_size=3,
233
+ stride=1,
234
+ padding=3,
235
+ dilation=3,
236
+ scale=8,
237
+ se_bottleneck_dim=128,
238
+ )
239
+ self.layer4 = SE_Res2Block(
240
+ self.channels[2],
241
+ self.channels[3],
242
+ kernel_size=3,
243
+ stride=1,
244
+ padding=4,
245
+ dilation=4,
246
+ scale=8,
247
+ se_bottleneck_dim=128,
248
+ )
249
 
250
  # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
  cat_channels = channels * 3
252
  self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
+ self.pooling = AttentiveStatsPool(
254
+ self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
+ )
256
  self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
  self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
 
 
259
  def get_feat_num(self):
260
  self.feature_extract.eval()
261
  wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
 
272
  x = self.feature_extract([sample for sample in x])
273
  else:
274
  with torch.no_grad():
275
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
  x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
  else:
278
  x = self.feature_extract([sample for sample in x])
279
 
280
+ if self.feat_type == "fbank":
281
  x = x.log()
282
 
283
  if self.feat_type != "fbank" and self.feat_type != "mfcc":
 
309
  return out
310
 
311
 
312
+ def ECAPA_TDNN_SMALL(
313
+ feat_dim,
314
+ emb_dim=256,
315
+ feat_type="wavlm_large",
316
+ sr=16000,
317
+ feature_selection="hidden_states",
318
+ update_extract=False,
319
+ config_path=None,
320
+ ):
321
+ return ECAPA_TDNN(
322
+ feat_dim=feat_dim,
323
+ channels=512,
324
+ emb_dim=emb_dim,
325
+ feat_type=feat_type,
326
+ sr=sr,
327
+ feature_selection=feature_selection,
328
+ update_extract=update_extract,
329
+ config_path=config_path,
330
+ )
model/modules.py CHANGED
@@ -21,39 +21,40 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
21
 
22
  # raw wav to mel spec
23
 
 
24
  class MelSpec(nn.Module):
25
  def __init__(
26
  self,
27
- filter_length = 1024,
28
- hop_length = 256,
29
- win_length = 1024,
30
- n_mel_channels = 100,
31
- target_sample_rate = 24_000,
32
- normalize = False,
33
- power = 1,
34
- norm = None,
35
- center = True,
36
  ):
37
  super().__init__()
38
  self.n_mel_channels = n_mel_channels
39
 
40
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
41
- sample_rate = target_sample_rate,
42
- n_fft = filter_length,
43
- win_length = win_length,
44
- hop_length = hop_length,
45
- n_mels = n_mel_channels,
46
- power = power,
47
- center = center,
48
- normalized = normalize,
49
- norm = norm,
50
  )
51
 
52
- self.register_buffer('dummy', torch.tensor(0), persistent = False)
53
 
54
  def forward(self, inp):
55
  if len(inp.shape) == 3:
56
- inp = inp.squeeze(1) # 'b 1 nw -> b nw'
57
 
58
  assert len(inp.shape) == 2
59
 
@@ -61,12 +62,13 @@ class MelSpec(nn.Module):
61
  self.to(inp.device)
62
 
63
  mel = self.mel_stft(inp)
64
- mel = mel.clamp(min = 1e-5).log()
65
  return mel
66
-
67
 
68
  # sinusoidal position embedding
69
 
 
70
  class SinusPositionEmbedding(nn.Module):
71
  def __init__(self, dim):
72
  super().__init__()
@@ -84,35 +86,37 @@ class SinusPositionEmbedding(nn.Module):
84
 
85
  # convolutional position embedding
86
 
 
87
  class ConvPositionEmbedding(nn.Module):
88
- def __init__(self, dim, kernel_size = 31, groups = 16):
89
  super().__init__()
90
  assert kernel_size % 2 != 0
91
  self.conv1d = nn.Sequential(
92
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
93
  nn.Mish(),
94
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
95
  nn.Mish(),
96
  )
97
 
98
- def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
99
  if mask is not None:
100
  mask = mask[..., None]
101
- x = x.masked_fill(~mask, 0.)
102
 
103
  x = x.permute(0, 2, 1)
104
  x = self.conv1d(x)
105
  out = x.permute(0, 2, 1)
106
 
107
  if mask is not None:
108
- out = out.masked_fill(~mask, 0.)
109
 
110
  return out
111
 
112
 
113
  # rotary positional embedding related
114
 
115
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
 
116
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
117
  # has some connection to NTK literature
118
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -125,12 +129,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
125
  freqs_sin = torch.sin(freqs) # imaginary part
126
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
127
 
128
- def get_pos_embed_indices(start, length, max_pos, scale=1.):
 
129
  # length = length if isinstance(length, int) else length.max()
130
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
131
- pos = start.unsqueeze(1) + (
132
- torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
133
- scale.unsqueeze(1)).long()
 
134
  # avoid extra long error.
135
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
136
  return pos
@@ -138,6 +144,7 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.):
138
 
139
  # Global Response Normalization layer (Instance Normalization ?)
140
 
 
141
  class GRN(nn.Module):
142
  def __init__(self, dim):
143
  super().__init__()
@@ -153,6 +160,7 @@ class GRN(nn.Module):
153
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
154
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
155
 
 
156
  class ConvNeXtV2Block(nn.Module):
157
  def __init__(
158
  self,
@@ -162,7 +170,9 @@ class ConvNeXtV2Block(nn.Module):
162
  ):
163
  super().__init__()
164
  padding = (dilation * (7 - 1)) // 2
165
- self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
 
 
166
  self.norm = nn.LayerNorm(dim, eps=1e-6)
167
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
168
  self.act = nn.GELU()
@@ -185,6 +195,7 @@ class ConvNeXtV2Block(nn.Module):
185
  # AdaLayerNormZero
186
  # return with modulated x for attn input, and params for later mlp modulation
187
 
 
188
  class AdaLayerNormZero(nn.Module):
189
  def __init__(self, dim):
190
  super().__init__()
@@ -194,7 +205,7 @@ class AdaLayerNormZero(nn.Module):
194
 
195
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
196
 
197
- def forward(self, x, emb = None):
198
  emb = self.linear(self.silu(emb))
199
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
200
 
@@ -205,6 +216,7 @@ class AdaLayerNormZero(nn.Module):
205
  # AdaLayerNormZero for final layer
206
  # return only with modulated x for attn input, cuz no more mlp modulation
207
 
 
208
  class AdaLayerNormZero_Final(nn.Module):
209
  def __init__(self, dim):
210
  super().__init__()
@@ -224,22 +236,16 @@ class AdaLayerNormZero_Final(nn.Module):
224
 
225
  # FeedForward
226
 
 
227
  class FeedForward(nn.Module):
228
- def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
229
  super().__init__()
230
  inner_dim = int(dim * mult)
231
  dim_out = dim_out if dim_out is not None else dim
232
 
233
  activation = nn.GELU(approximate=approximate)
234
- project_in = nn.Sequential(
235
- nn.Linear(dim, inner_dim),
236
- activation
237
- )
238
- self.ff = nn.Sequential(
239
- project_in,
240
- nn.Dropout(dropout),
241
- nn.Linear(inner_dim, dim_out)
242
- )
243
 
244
  def forward(self, x):
245
  return self.ff(x)
@@ -248,6 +254,7 @@ class FeedForward(nn.Module):
248
  # Attention with possible joint part
249
  # modified from diffusers/src/diffusers/models/attention_processor.py
250
 
 
251
  class Attention(nn.Module):
252
  def __init__(
253
  self,
@@ -256,8 +263,8 @@ class Attention(nn.Module):
256
  heads: int = 8,
257
  dim_head: int = 64,
258
  dropout: float = 0.0,
259
- context_dim: Optional[int] = None, # if not None -> joint attention
260
- context_pre_only = None,
261
  ):
262
  super().__init__()
263
 
@@ -293,20 +300,21 @@ class Attention(nn.Module):
293
 
294
  def forward(
295
  self,
296
- x: float['b n d'], # noised input x
297
- c: float['b n d'] = None, # context c
298
- mask: bool['b n'] | None = None,
299
- rope = None, # rotary position embedding for x
300
- c_rope = None, # rotary position embedding for c
301
  ) -> torch.Tensor:
302
  if c is not None:
303
- return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
304
  else:
305
- return self.processor(self, x, mask = mask, rope = rope)
306
 
307
 
308
  # Attention processor
309
 
 
310
  class AttnProcessor:
311
  def __init__(self):
312
  pass
@@ -314,11 +322,10 @@ class AttnProcessor:
314
  def __call__(
315
  self,
316
  attn: Attention,
317
- x: float['b n d'], # noised input x
318
- mask: bool['b n'] | None = None,
319
- rope = None, # rotary position embedding
320
  ) -> torch.FloatTensor:
321
-
322
  batch_size = x.shape[0]
323
 
324
  # `sample` projections.
@@ -329,7 +336,7 @@ class AttnProcessor:
329
  # apply rotary position embedding
330
  if rope is not None:
331
  freqs, xpos_scale = rope
332
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
333
 
334
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
335
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -360,14 +367,15 @@ class AttnProcessor:
360
 
361
  if mask is not None:
362
  mask = mask.unsqueeze(-1)
363
- x = x.masked_fill(~mask, 0.)
364
 
365
  return x
366
-
367
 
368
  # Joint Attention processor for MM-DiT
369
  # modified from diffusers/src/diffusers/models/attention_processor.py
370
 
 
371
  class JointAttnProcessor:
372
  def __init__(self):
373
  pass
@@ -375,11 +383,11 @@ class JointAttnProcessor:
375
  def __call__(
376
  self,
377
  attn: Attention,
378
- x: float['b n d'], # noised input x
379
- c: float['b nt d'] = None, # context c, here text
380
- mask: bool['b n'] | None = None,
381
- rope = None, # rotary position embedding for x
382
- c_rope = None, # rotary position embedding for c
383
  ) -> torch.FloatTensor:
384
  residual = x
385
 
@@ -398,12 +406,12 @@ class JointAttnProcessor:
398
  # apply rope for context and noised input independently
399
  if rope is not None:
400
  freqs, xpos_scale = rope
401
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
402
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
403
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
404
  if c_rope is not None:
405
  freqs, xpos_scale = c_rope
406
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
407
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
408
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
409
 
@@ -420,7 +428,7 @@ class JointAttnProcessor:
420
 
421
  # mask. e.g. inference got a batch with different target durations, mask out the padding
422
  if mask is not None:
423
- attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
424
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
425
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
426
  else:
@@ -432,8 +440,8 @@ class JointAttnProcessor:
432
 
433
  # Split the attention outputs.
434
  x, c = (
435
- x[:, :residual.shape[1]],
436
- x[:, residual.shape[1]:],
437
  )
438
 
439
  # linear proj
@@ -445,7 +453,7 @@ class JointAttnProcessor:
445
 
446
  if mask is not None:
447
  mask = mask.unsqueeze(-1)
448
- x = x.masked_fill(~mask, 0.)
449
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
450
 
451
  return x, c
@@ -453,24 +461,24 @@ class JointAttnProcessor:
453
 
454
  # DiT Block
455
 
456
- class DiTBlock(nn.Module):
457
 
458
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
 
459
  super().__init__()
460
-
461
  self.attn_norm = AdaLayerNormZero(dim)
462
  self.attn = Attention(
463
- processor = AttnProcessor(),
464
- dim = dim,
465
- heads = heads,
466
- dim_head = dim_head,
467
- dropout = dropout,
468
- )
469
-
470
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
471
- self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
472
 
473
- def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
474
  # pre-norm & modulation for attention input
475
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
476
 
@@ -479,7 +487,7 @@ class DiTBlock(nn.Module):
479
 
480
  # process attention output for input x
481
  x = x + gate_msa.unsqueeze(1) * attn_output
482
-
483
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
484
  ff_output = self.ff(norm)
485
  x = x + gate_mlp.unsqueeze(1) * ff_output
@@ -489,8 +497,9 @@ class DiTBlock(nn.Module):
489
 
490
  # MMDiT Block https://arxiv.org/abs/2403.03206
491
 
 
492
  class MMDiTBlock(nn.Module):
493
- r"""
494
  modified from diffusers/src/diffusers/models/attention.py
495
 
496
  notes.
@@ -499,33 +508,33 @@ class MMDiTBlock(nn.Module):
499
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
500
  """
501
 
502
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
503
  super().__init__()
504
 
505
  self.context_pre_only = context_pre_only
506
-
507
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
508
  self.attn_norm_x = AdaLayerNormZero(dim)
509
  self.attn = Attention(
510
- processor = JointAttnProcessor(),
511
- dim = dim,
512
- heads = heads,
513
- dim_head = dim_head,
514
- dropout = dropout,
515
- context_dim = dim,
516
- context_pre_only = context_pre_only,
517
- )
518
 
519
  if not context_pre_only:
520
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
521
- self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
522
  else:
523
  self.ff_norm_c = None
524
  self.ff_c = None
525
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
526
- self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
527
 
528
- def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
529
  # pre-norm & modulation for attention input
530
  if self.context_pre_only:
531
  norm_c = self.attn_norm_c(c, t)
@@ -539,7 +548,7 @@ class MMDiTBlock(nn.Module):
539
  # process attention output for context c
540
  if self.context_pre_only:
541
  c = None
542
- else: # if not last layer
543
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
544
 
545
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
@@ -548,7 +557,7 @@ class MMDiTBlock(nn.Module):
548
 
549
  # process attention output for input x
550
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
551
-
552
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
553
  x_ff_output = self.ff_x(norm_x)
554
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
@@ -558,17 +567,14 @@ class MMDiTBlock(nn.Module):
558
 
559
  # time step conditioning embedding
560
 
 
561
  class TimestepEmbedding(nn.Module):
562
  def __init__(self, dim, freq_embed_dim=256):
563
  super().__init__()
564
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
565
- self.time_mlp = nn.Sequential(
566
- nn.Linear(freq_embed_dim, dim),
567
- nn.SiLU(),
568
- nn.Linear(dim, dim)
569
- )
570
 
571
- def forward(self, timestep: float['b']):
572
  time_hidden = self.time_embed(timestep)
573
  time_hidden = time_hidden.to(timestep.dtype)
574
  time = self.time_mlp(time_hidden) # b d
 
21
 
22
  # raw wav to mel spec
23
 
24
+
25
  class MelSpec(nn.Module):
26
  def __init__(
27
  self,
28
+ filter_length=1024,
29
+ hop_length=256,
30
+ win_length=1024,
31
+ n_mel_channels=100,
32
+ target_sample_rate=24_000,
33
+ normalize=False,
34
+ power=1,
35
+ norm=None,
36
+ center=True,
37
  ):
38
  super().__init__()
39
  self.n_mel_channels = n_mel_channels
40
 
41
  self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate=target_sample_rate,
43
+ n_fft=filter_length,
44
+ win_length=win_length,
45
+ hop_length=hop_length,
46
+ n_mels=n_mel_channels,
47
+ power=power,
48
+ center=center,
49
+ normalized=normalize,
50
+ norm=norm,
51
  )
52
 
53
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
54
 
55
  def forward(self, inp):
56
  if len(inp.shape) == 3:
57
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
58
 
59
  assert len(inp.shape) == 2
60
 
 
62
  self.to(inp.device)
63
 
64
  mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min=1e-5).log()
66
  return mel
67
+
68
 
69
  # sinusoidal position embedding
70
 
71
+
72
  class SinusPositionEmbedding(nn.Module):
73
  def __init__(self, dim):
74
  super().__init__()
 
86
 
87
  # convolutional position embedding
88
 
89
+
90
  class ConvPositionEmbedding(nn.Module):
91
+ def __init__(self, dim, kernel_size=31, groups=16):
92
  super().__init__()
93
  assert kernel_size % 2 != 0
94
  self.conv1d = nn.Sequential(
95
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
96
  nn.Mish(),
97
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
98
  nn.Mish(),
99
  )
100
 
101
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
102
  if mask is not None:
103
  mask = mask[..., None]
104
+ x = x.masked_fill(~mask, 0.0)
105
 
106
  x = x.permute(0, 2, 1)
107
  x = self.conv1d(x)
108
  out = x.permute(0, 2, 1)
109
 
110
  if mask is not None:
111
+ out = out.masked_fill(~mask, 0.0)
112
 
113
  return out
114
 
115
 
116
  # rotary positional embedding related
117
 
118
+
119
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
120
  # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
121
  # has some connection to NTK literature
122
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
 
129
  freqs_sin = torch.sin(freqs) # imaginary part
130
  return torch.cat([freqs_cos, freqs_sin], dim=-1)
131
 
132
+
133
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
134
  # length = length if isinstance(length, int) else length.max()
135
  scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
136
+ pos = (
137
+ start.unsqueeze(1)
138
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
139
+ )
140
  # avoid extra long error.
141
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
142
  return pos
 
144
 
145
  # Global Response Normalization layer (Instance Normalization ?)
146
 
147
+
148
  class GRN(nn.Module):
149
  def __init__(self, dim):
150
  super().__init__()
 
160
  # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
161
  # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
162
 
163
+
164
  class ConvNeXtV2Block(nn.Module):
165
  def __init__(
166
  self,
 
170
  ):
171
  super().__init__()
172
  padding = (dilation * (7 - 1)) // 2
173
+ self.dwconv = nn.Conv1d(
174
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
175
+ ) # depthwise conv
176
  self.norm = nn.LayerNorm(dim, eps=1e-6)
177
  self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
178
  self.act = nn.GELU()
 
195
  # AdaLayerNormZero
196
  # return with modulated x for attn input, and params for later mlp modulation
197
 
198
+
199
  class AdaLayerNormZero(nn.Module):
200
  def __init__(self, dim):
201
  super().__init__()
 
205
 
206
  self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
207
 
208
+ def forward(self, x, emb=None):
209
  emb = self.linear(self.silu(emb))
210
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
211
 
 
216
  # AdaLayerNormZero for final layer
217
  # return only with modulated x for attn input, cuz no more mlp modulation
218
 
219
+
220
  class AdaLayerNormZero_Final(nn.Module):
221
  def __init__(self, dim):
222
  super().__init__()
 
236
 
237
  # FeedForward
238
 
239
+
240
  class FeedForward(nn.Module):
241
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
242
  super().__init__()
243
  inner_dim = int(dim * mult)
244
  dim_out = dim_out if dim_out is not None else dim
245
 
246
  activation = nn.GELU(approximate=approximate)
247
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
248
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
 
 
 
 
 
249
 
250
  def forward(self, x):
251
  return self.ff(x)
 
254
  # Attention with possible joint part
255
  # modified from diffusers/src/diffusers/models/attention_processor.py
256
 
257
+
258
  class Attention(nn.Module):
259
  def __init__(
260
  self,
 
263
  heads: int = 8,
264
  dim_head: int = 64,
265
  dropout: float = 0.0,
266
+ context_dim: Optional[int] = None, # if not None -> joint attention
267
+ context_pre_only=None,
268
  ):
269
  super().__init__()
270
 
 
300
 
301
  def forward(
302
  self,
303
+ x: float["b n d"], # noised input x # noqa: F722
304
+ c: float["b n d"] = None, # context c # noqa: F722
305
+ mask: bool["b n"] | None = None, # noqa: F722
306
+ rope=None, # rotary position embedding for x
307
+ c_rope=None, # rotary position embedding for c
308
  ) -> torch.Tensor:
309
  if c is not None:
310
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
311
  else:
312
+ return self.processor(self, x, mask=mask, rope=rope)
313
 
314
 
315
  # Attention processor
316
 
317
+
318
  class AttnProcessor:
319
  def __init__(self):
320
  pass
 
322
  def __call__(
323
  self,
324
  attn: Attention,
325
+ x: float["b n d"], # noised input x # noqa: F722
326
+ mask: bool["b n"] | None = None, # noqa: F722
327
+ rope=None, # rotary position embedding
328
  ) -> torch.FloatTensor:
 
329
  batch_size = x.shape[0]
330
 
331
  # `sample` projections.
 
336
  # apply rotary position embedding
337
  if rope is not None:
338
  freqs, xpos_scale = rope
339
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
340
 
341
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
342
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
 
367
 
368
  if mask is not None:
369
  mask = mask.unsqueeze(-1)
370
+ x = x.masked_fill(~mask, 0.0)
371
 
372
  return x
373
+
374
 
375
  # Joint Attention processor for MM-DiT
376
  # modified from diffusers/src/diffusers/models/attention_processor.py
377
 
378
+
379
  class JointAttnProcessor:
380
  def __init__(self):
381
  pass
 
383
  def __call__(
384
  self,
385
  attn: Attention,
386
+ x: float["b n d"], # noised input x # noqa: F722
387
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
388
+ mask: bool["b n"] | None = None, # noqa: F722
389
+ rope=None, # rotary position embedding for x
390
+ c_rope=None, # rotary position embedding for c
391
  ) -> torch.FloatTensor:
392
  residual = x
393
 
 
406
  # apply rope for context and noised input independently
407
  if rope is not None:
408
  freqs, xpos_scale = rope
409
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
410
  query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
411
  key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
412
  if c_rope is not None:
413
  freqs, xpos_scale = c_rope
414
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
415
  c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
416
  c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
417
 
 
428
 
429
  # mask. e.g. inference got a batch with different target durations, mask out the padding
430
  if mask is not None:
431
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
432
  attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
433
  attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
434
  else:
 
440
 
441
  # Split the attention outputs.
442
  x, c = (
443
+ x[:, : residual.shape[1]],
444
+ x[:, residual.shape[1] :],
445
  )
446
 
447
  # linear proj
 
453
 
454
  if mask is not None:
455
  mask = mask.unsqueeze(-1)
456
+ x = x.masked_fill(~mask, 0.0)
457
  # c = c.masked_fill(~mask, 0.) # no mask for c (text)
458
 
459
  return x, c
 
461
 
462
  # DiT Block
463
 
 
464
 
465
+ class DiTBlock(nn.Module):
466
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
467
  super().__init__()
468
+
469
  self.attn_norm = AdaLayerNormZero(dim)
470
  self.attn = Attention(
471
+ processor=AttnProcessor(),
472
+ dim=dim,
473
+ heads=heads,
474
+ dim_head=dim_head,
475
+ dropout=dropout,
476
+ )
477
+
478
  self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
479
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
480
 
481
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
482
  # pre-norm & modulation for attention input
483
  norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
484
 
 
487
 
488
  # process attention output for input x
489
  x = x + gate_msa.unsqueeze(1) * attn_output
490
+
491
  norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
492
  ff_output = self.ff(norm)
493
  x = x + gate_mlp.unsqueeze(1) * ff_output
 
497
 
498
  # MMDiT Block https://arxiv.org/abs/2403.03206
499
 
500
+
501
  class MMDiTBlock(nn.Module):
502
+ r"""
503
  modified from diffusers/src/diffusers/models/attention.py
504
 
505
  notes.
 
508
  context_pre_only: last layer only do prenorm + modulation cuz no more ffn
509
  """
510
 
511
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
512
  super().__init__()
513
 
514
  self.context_pre_only = context_pre_only
515
+
516
  self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
517
  self.attn_norm_x = AdaLayerNormZero(dim)
518
  self.attn = Attention(
519
+ processor=JointAttnProcessor(),
520
+ dim=dim,
521
+ heads=heads,
522
+ dim_head=dim_head,
523
+ dropout=dropout,
524
+ context_dim=dim,
525
+ context_pre_only=context_pre_only,
526
+ )
527
 
528
  if not context_pre_only:
529
  self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
530
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
531
  else:
532
  self.ff_norm_c = None
533
  self.ff_c = None
534
  self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
535
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
536
 
537
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
538
  # pre-norm & modulation for attention input
539
  if self.context_pre_only:
540
  norm_c = self.attn_norm_c(c, t)
 
548
  # process attention output for context c
549
  if self.context_pre_only:
550
  c = None
551
+ else: # if not last layer
552
  c = c + c_gate_msa.unsqueeze(1) * c_attn_output
553
 
554
  norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
 
557
 
558
  # process attention output for input x
559
  x = x + x_gate_msa.unsqueeze(1) * x_attn_output
560
+
561
  norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
562
  x_ff_output = self.ff_x(norm_x)
563
  x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
 
567
 
568
  # time step conditioning embedding
569
 
570
+
571
  class TimestepEmbedding(nn.Module):
572
  def __init__(self, dim, freq_embed_dim=256):
573
  super().__init__()
574
  self.time_embed = SinusPositionEmbedding(freq_embed_dim)
575
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
 
 
 
 
576
 
577
+ def forward(self, timestep: float["b"]): # noqa: F821
578
  time_hidden = self.time_embed(timestep)
579
  time_hidden = time_hidden.to(timestep.dtype)
580
  time = self.time_mlp(time_hidden) # b d
model/trainer.py CHANGED
@@ -22,71 +22,69 @@ from model.dataset import DynamicBatchSampler, collate_fn
22
 
23
  # trainer
24
 
 
25
  class Trainer:
26
  def __init__(
27
  self,
28
  model: CFM,
29
  epochs,
30
  learning_rate,
31
- num_warmup_updates = 20000,
32
- save_per_updates = 1000,
33
- checkpoint_path = None,
34
- batch_size = 32,
35
  batch_size_type: str = "sample",
36
- max_samples = 32,
37
- grad_accumulation_steps = 1,
38
- max_grad_norm = 1.0,
39
  noise_scheduler: str | None = None,
40
  duration_predictor: torch.nn.Module | None = None,
41
- wandb_project = "test_e2-tts",
42
- wandb_run_name = "test_run",
43
  wandb_resume_id: str = None,
44
- last_per_steps = None,
45
  accelerate_kwargs: dict = dict(),
46
  ema_kwargs: dict = dict(),
47
  bnb_optimizer: bool = False,
48
  ):
49
-
50
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
51
 
52
  logger = "wandb" if wandb.api.api_key else None
53
  print(f"Using logger: {logger}")
54
 
55
  self.accelerator = Accelerator(
56
- log_with = logger,
57
- kwargs_handlers = [ddp_kwargs],
58
- gradient_accumulation_steps = grad_accumulation_steps,
59
- **accelerate_kwargs
60
  )
61
 
62
  if logger == "wandb":
63
  if exists(wandb_resume_id):
64
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
65
  else:
66
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
67
  self.accelerator.init_trackers(
68
- project_name = wandb_project,
69
  init_kwargs=init_kwargs,
70
- config={"epochs": epochs,
71
- "learning_rate": learning_rate,
72
- "num_warmup_updates": num_warmup_updates,
73
- "batch_size": batch_size,
74
- "batch_size_type": batch_size_type,
75
- "max_samples": max_samples,
76
- "grad_accumulation_steps": grad_accumulation_steps,
77
- "max_grad_norm": max_grad_norm,
78
- "gpus": self.accelerator.num_processes,
79
- "noise_scheduler": noise_scheduler}
80
- )
 
 
81
 
82
  self.model = model
83
 
84
  if self.is_main:
85
- self.ema_model = EMA(
86
- model,
87
- include_online_model = False,
88
- **ema_kwargs
89
- )
90
 
91
  self.ema_model.to(self.accelerator.device)
92
 
@@ -94,7 +92,7 @@ class Trainer:
94
  self.num_warmup_updates = num_warmup_updates
95
  self.save_per_updates = save_per_updates
96
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
97
- self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
98
 
99
  self.batch_size = batch_size
100
  self.batch_size_type = batch_size_type
@@ -108,12 +106,11 @@ class Trainer:
108
 
109
  if bnb_optimizer:
110
  import bitsandbytes as bnb
 
111
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
112
  else:
113
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
114
- self.model, self.optimizer = self.accelerator.prepare(
115
- self.model, self.optimizer
116
- )
117
 
118
  @property
119
  def is_main(self):
@@ -123,81 +120,112 @@ class Trainer:
123
  self.accelerator.wait_for_everyone()
124
  if self.is_main:
125
  checkpoint = dict(
126
- model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
127
- optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
128
- ema_model_state_dict = self.ema_model.state_dict(),
129
- scheduler_state_dict = self.scheduler.state_dict(),
130
- step = step
131
  )
132
  if not os.path.exists(self.checkpoint_path):
133
  os.makedirs(self.checkpoint_path)
134
- if last == True:
135
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
136
  print(f"Saved last checkpoint at step {step}")
137
  else:
138
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
139
 
140
  def load_checkpoint(self):
141
- if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
 
 
 
 
142
  return 0
143
-
144
  self.accelerator.wait_for_everyone()
145
  if "model_last.pt" in os.listdir(self.checkpoint_path):
146
  latest_checkpoint = "model_last.pt"
147
  else:
148
- latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
 
 
 
149
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
150
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
151
 
152
  if self.is_main:
153
- self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
154
 
155
- if 'step' in checkpoint:
156
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
158
  if self.scheduler:
159
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
160
- step = checkpoint['step']
161
  else:
162
- checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
163
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
164
  step = 0
165
 
166
- del checkpoint; gc.collect()
 
167
  return step
168
 
169
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
170
-
171
  if exists(resumable_with_seed):
172
  generator = torch.Generator()
173
  generator.manual_seed(resumable_with_seed)
174
- else:
175
  generator = None
176
 
177
  if self.batch_size_type == "sample":
178
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
179
- batch_size=self.batch_size, shuffle=True, generator=generator)
 
 
 
 
 
 
 
 
180
  elif self.batch_size_type == "frame":
181
  self.accelerator.even_batches = False
182
  sampler = SequentialSampler(train_dataset)
183
- batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
184
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
185
- batch_sampler=batch_sampler)
 
 
 
 
 
 
 
 
186
  else:
187
  raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
188
-
189
  # accelerator.prepare() dispatches batches to devices;
190
  # which means the length of dataloader calculated before, should consider the number of devices
191
- warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
192
- # otherwise by default with split_batches=False, warmup steps change with num_processes
 
 
193
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
194
  decay_steps = total_steps - warmup_steps
195
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
196
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
197
- self.scheduler = SequentialLR(self.optimizer,
198
- schedulers=[warmup_scheduler, decay_scheduler],
199
- milestones=[warmup_steps])
200
- train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
 
 
201
  start_step = self.load_checkpoint()
202
  global_step = start_step
203
 
@@ -212,23 +240,36 @@ class Trainer:
212
  for epoch in range(skipped_epoch, self.epochs):
213
  self.model.train()
214
  if exists(resumable_with_seed) and epoch == skipped_epoch:
215
- progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
216
- initial=skipped_batch, total=orig_epoch_step)
 
 
 
 
 
 
217
  else:
218
- progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
 
 
 
 
 
219
 
220
  for batch in progress_bar:
221
  with self.accelerator.accumulate(self.model):
222
- text_inputs = batch['text']
223
- mel_spec = batch['mel'].permute(0, 2, 1)
224
  mel_lengths = batch["mel_lengths"]
225
 
226
  # TODO. add duration predictor training
227
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
228
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
229
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
230
 
231
- loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
 
 
232
  self.accelerator.backward(loss)
233
 
234
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -245,13 +286,13 @@ class Trainer:
245
 
246
  if self.accelerator.is_local_main_process:
247
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
248
-
249
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
250
-
251
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
252
  self.save_checkpoint(global_step)
253
-
254
  if global_step % self.last_per_steps == 0:
255
  self.save_checkpoint(global_step, last=True)
256
-
257
  self.accelerator.end_training()
 
22
 
23
  # trainer
24
 
25
+
26
  class Trainer:
27
  def __init__(
28
  self,
29
  model: CFM,
30
  epochs,
31
  learning_rate,
32
+ num_warmup_updates=20000,
33
+ save_per_updates=1000,
34
+ checkpoint_path=None,
35
+ batch_size=32,
36
  batch_size_type: str = "sample",
37
+ max_samples=32,
38
+ grad_accumulation_steps=1,
39
+ max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
42
+ wandb_project="test_e2-tts",
43
+ wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
45
+ last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  ):
50
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
 
51
 
52
  logger = "wandb" if wandb.api.api_key else None
53
  print(f"Using logger: {logger}")
54
 
55
  self.accelerator = Accelerator(
56
+ log_with=logger,
57
+ kwargs_handlers=[ddp_kwargs],
58
+ gradient_accumulation_steps=grad_accumulation_steps,
59
+ **accelerate_kwargs,
60
  )
61
 
62
  if logger == "wandb":
63
  if exists(wandb_resume_id):
64
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
  else:
66
+ init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
67
  self.accelerator.init_trackers(
68
+ project_name=wandb_project,
69
  init_kwargs=init_kwargs,
70
+ config={
71
+ "epochs": epochs,
72
+ "learning_rate": learning_rate,
73
+ "num_warmup_updates": num_warmup_updates,
74
+ "batch_size": batch_size,
75
+ "batch_size_type": batch_size_type,
76
+ "max_samples": max_samples,
77
+ "grad_accumulation_steps": grad_accumulation_steps,
78
+ "max_grad_norm": max_grad_norm,
79
+ "gpus": self.accelerator.num_processes,
80
+ "noise_scheduler": noise_scheduler,
81
+ },
82
+ )
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
+ self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
 
 
 
88
 
89
  self.ema_model.to(self.accelerator.device)
90
 
 
92
  self.num_warmup_updates = num_warmup_updates
93
  self.save_per_updates = save_per_updates
94
  self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
95
+ self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
96
 
97
  self.batch_size = batch_size
98
  self.batch_size_type = batch_size_type
 
106
 
107
  if bnb_optimizer:
108
  import bitsandbytes as bnb
109
+
110
  self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
111
  else:
112
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
113
+ self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
 
 
114
 
115
  @property
116
  def is_main(self):
 
120
  self.accelerator.wait_for_everyone()
121
  if self.is_main:
122
  checkpoint = dict(
123
+ model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
124
+ optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
125
+ ema_model_state_dict=self.ema_model.state_dict(),
126
+ scheduler_state_dict=self.scheduler.state_dict(),
127
+ step=step,
128
  )
129
  if not os.path.exists(self.checkpoint_path):
130
  os.makedirs(self.checkpoint_path)
131
+ if last:
132
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
133
  print(f"Saved last checkpoint at step {step}")
134
  else:
135
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
136
 
137
  def load_checkpoint(self):
138
+ if (
139
+ not exists(self.checkpoint_path)
140
+ or not os.path.exists(self.checkpoint_path)
141
+ or not os.listdir(self.checkpoint_path)
142
+ ):
143
  return 0
144
+
145
  self.accelerator.wait_for_everyone()
146
  if "model_last.pt" in os.listdir(self.checkpoint_path):
147
  latest_checkpoint = "model_last.pt"
148
  else:
149
+ latest_checkpoint = sorted(
150
+ [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
151
+ key=lambda x: int("".join(filter(str.isdigit, x))),
152
+ )[-1]
153
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
154
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
155
 
156
  if self.is_main:
157
+ self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
158
 
159
+ if "step" in checkpoint:
160
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
161
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
162
  if self.scheduler:
163
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
164
+ step = checkpoint["step"]
165
  else:
166
+ checkpoint["model_state_dict"] = {
167
+ k.replace("ema_model.", ""): v
168
+ for k, v in checkpoint["ema_model_state_dict"].items()
169
+ if k not in ["initted", "step"]
170
+ }
171
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
172
  step = 0
173
 
174
+ del checkpoint
175
+ gc.collect()
176
  return step
177
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
181
  generator.manual_seed(resumable_with_seed)
182
+ else:
183
  generator = None
184
 
185
  if self.batch_size_type == "sample":
186
+ train_dataloader = DataLoader(
187
+ train_dataset,
188
+ collate_fn=collate_fn,
189
+ num_workers=num_workers,
190
+ pin_memory=True,
191
+ persistent_workers=True,
192
+ batch_size=self.batch_size,
193
+ shuffle=True,
194
+ generator=generator,
195
+ )
196
  elif self.batch_size_type == "frame":
197
  self.accelerator.even_batches = False
198
  sampler = SequentialSampler(train_dataset)
199
+ batch_sampler = DynamicBatchSampler(
200
+ sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
201
+ )
202
+ train_dataloader = DataLoader(
203
+ train_dataset,
204
+ collate_fn=collate_fn,
205
+ num_workers=num_workers,
206
+ pin_memory=True,
207
+ persistent_workers=True,
208
+ batch_sampler=batch_sampler,
209
+ )
210
  else:
211
  raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
212
+
213
  # accelerator.prepare() dispatches batches to devices;
214
  # which means the length of dataloader calculated before, should consider the number of devices
215
+ warmup_steps = (
216
+ self.num_warmup_updates * self.accelerator.num_processes
217
+ ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
218
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
219
  total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
220
  decay_steps = total_steps - warmup_steps
221
  warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
222
  decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
223
+ self.scheduler = SequentialLR(
224
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
225
+ )
226
+ train_dataloader, self.scheduler = self.accelerator.prepare(
227
+ train_dataloader, self.scheduler
228
+ ) # actual steps = 1 gpu steps / gpus
229
  start_step = self.load_checkpoint()
230
  global_step = start_step
231
 
 
240
  for epoch in range(skipped_epoch, self.epochs):
241
  self.model.train()
242
  if exists(resumable_with_seed) and epoch == skipped_epoch:
243
+ progress_bar = tqdm(
244
+ skipped_dataloader,
245
+ desc=f"Epoch {epoch+1}/{self.epochs}",
246
+ unit="step",
247
+ disable=not self.accelerator.is_local_main_process,
248
+ initial=skipped_batch,
249
+ total=orig_epoch_step,
250
+ )
251
  else:
252
+ progress_bar = tqdm(
253
+ train_dataloader,
254
+ desc=f"Epoch {epoch+1}/{self.epochs}",
255
+ unit="step",
256
+ disable=not self.accelerator.is_local_main_process,
257
+ )
258
 
259
  for batch in progress_bar:
260
  with self.accelerator.accumulate(self.model):
261
+ text_inputs = batch["text"]
262
+ mel_spec = batch["mel"].permute(0, 2, 1)
263
  mel_lengths = batch["mel_lengths"]
264
 
265
  # TODO. add duration predictor training
266
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
267
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
268
  self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
269
 
270
+ loss, cond, pred = self.model(
271
+ mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
+ )
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
286
 
287
  if self.accelerator.is_local_main_process:
288
  self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
+
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
+
292
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
293
  self.save_checkpoint(global_step)
294
+
295
  if global_step % self.last_per_steps == 0:
296
  self.save_checkpoint(global_step, last=True)
297
+
298
  self.accelerator.end_training()
model/utils.py CHANGED
@@ -8,6 +8,7 @@ from tqdm import tqdm
8
  from collections import defaultdict
9
 
10
  import matplotlib
 
11
  matplotlib.use("Agg")
12
  import matplotlib.pylab as plt
13
 
@@ -25,109 +26,102 @@ from model.modules import MelSpec
25
 
26
  # seed everything
27
 
28
- def seed_everything(seed = 0):
 
29
  random.seed(seed)
30
- os.environ['PYTHONHASHSEED'] = str(seed)
31
  torch.manual_seed(seed)
32
  torch.cuda.manual_seed(seed)
33
  torch.cuda.manual_seed_all(seed)
34
  torch.backends.cudnn.deterministic = True
35
  torch.backends.cudnn.benchmark = False
36
 
 
37
  # helpers
38
 
 
39
  def exists(v):
40
  return v is not None
41
 
 
42
  def default(v, d):
43
  return v if exists(v) else d
44
 
 
45
  # tensor helpers
46
 
47
- def lens_to_mask(
48
- t: int['b'],
49
- length: int | None = None
50
- ) -> bool['b n']:
51
 
 
52
  if not exists(length):
53
  length = t.amax()
54
 
55
- seq = torch.arange(length, device = t.device)
56
  return seq[None, :] < t[:, None]
57
 
58
- def mask_from_start_end_indices(
59
- seq_len: int['b'],
60
- start: int['b'],
61
- end: int['b']
62
- ):
63
- max_seq_len = seq_len.max().item()
64
- seq = torch.arange(max_seq_len, device = start.device).long()
65
  start_mask = seq[None, :] >= start[:, None]
66
  end_mask = seq[None, :] < end[:, None]
67
  return start_mask & end_mask
68
 
69
- def mask_from_frac_lengths(
70
- seq_len: int['b'],
71
- frac_lengths: float['b']
72
- ):
73
  lengths = (frac_lengths * seq_len).long()
74
  max_start = seq_len - lengths
75
 
76
  rand = torch.rand_like(frac_lengths)
77
- start = (max_start * rand).long().clamp(min = 0)
78
  end = start + lengths
79
 
80
  return mask_from_start_end_indices(seq_len, start, end)
81
 
82
- def maybe_masked_mean(
83
- t: float['b n d'],
84
- mask: bool['b n'] = None
85
- ) -> float['b d']:
86
 
 
87
  if not exists(mask):
88
- return t.mean(dim = 1)
89
 
90
- t = torch.where(mask[:, :, None], t, torch.tensor(0., device=t.device))
91
  num = t.sum(dim=1)
92
  den = mask.float().sum(dim=1)
93
 
94
- return num / den.clamp(min=1.)
95
 
96
 
97
  # simple utf-8 tokenizer, since paper went character based
98
- def list_str_to_tensor(
99
- text: list[str],
100
- padding_value = -1
101
- ) -> int['b nt']:
102
- list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
103
- text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
104
  return text
105
 
 
106
  # char tokenizer, based on custom dataset's extracted .txt file
107
  def list_str_to_idx(
108
  text: list[str] | list[list[str]],
109
  vocab_char_map: dict[str, int], # {char: idx}
110
- padding_value = -1
111
- ) -> int['b nt']:
112
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
113
- text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
114
  return text
115
 
116
 
117
  # Get tokenizer
118
 
 
119
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
120
- '''
121
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
122
  - "char" for char-wise tokenizer, need .txt vocab_file
123
  - "byte" for utf-8 tokenizer
124
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
125
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
126
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
127
- - if use "byte", set to 256 (unicode byte range)
128
- '''
129
  if tokenizer in ["pinyin", "char"]:
130
- with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
131
  vocab_char_map = {}
132
  for i, char in enumerate(f):
133
  vocab_char_map[char[:-1]] = i
@@ -138,7 +132,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
138
  vocab_char_map = None
139
  vocab_size = 256
140
  elif tokenizer == "custom":
141
- with open (dataset_name, "r", encoding="utf-8") as f:
142
  vocab_char_map = {}
143
  for i, char in enumerate(f):
144
  vocab_char_map[char[:-1]] = i
@@ -149,16 +143,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
149
 
150
  # convert char to pinyin
151
 
152
- def convert_char_to_pinyin(text_list, polyphone = True):
 
153
  final_text_list = []
154
- god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
155
- custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
 
 
156
  for text in text_list:
157
  char_list = []
158
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
159
  text = text.translate(custom_trans)
160
  for seg in jieba.cut(text):
161
- seg_byte_len = len(bytes(seg, 'UTF-8'))
162
  if seg_byte_len == len(seg): # if pure alphabets and symbols
163
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
164
  char_list.append(" ")
@@ -187,7 +184,7 @@ def convert_char_to_pinyin(text_list, polyphone = True):
187
  # save spectrogram
188
  def save_spectrogram(spectrogram, path):
189
  plt.figure(figsize=(12, 4))
190
- plt.imshow(spectrogram, origin='lower', aspect='auto')
191
  plt.colorbar()
192
  plt.savefig(path)
193
  plt.close()
@@ -195,13 +192,15 @@ def save_spectrogram(spectrogram, path):
195
 
196
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
197
  def get_seedtts_testset_metainfo(metalst):
198
- f = open(metalst); lines = f.readlines(); f.close()
 
 
199
  metainfo = []
200
  for line in lines:
201
- if len(line.strip().split('|')) == 5:
202
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
203
- elif len(line.strip().split('|')) == 4:
204
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
205
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
206
  if not os.path.isabs(prompt_wav):
207
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -211,18 +210,20 @@ def get_seedtts_testset_metainfo(metalst):
211
 
212
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
213
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
214
- f = open(metalst); lines = f.readlines(); f.close()
 
 
215
  metainfo = []
216
  for line in lines:
217
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
218
 
219
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
220
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
221
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
222
 
223
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
224
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
225
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
226
 
227
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
228
 
@@ -234,7 +235,7 @@ def padded_mel_batch(ref_mels):
234
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
235
  padded_ref_mels = []
236
  for mel in ref_mels:
237
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
238
  padded_ref_mels.append(padded_ref_mel)
239
  padded_ref_mels = torch.stack(padded_ref_mels)
240
  padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
@@ -243,12 +244,21 @@ def padded_mel_batch(ref_mels):
243
 
244
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
245
 
 
246
  def get_inference_prompt(
247
- metainfo,
248
- speed = 1., tokenizer = "pinyin", polyphone = True,
249
- target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
250
- use_truth_duration = False,
251
- infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
 
 
 
 
 
 
 
 
252
  ):
253
  prompts_all = []
254
 
@@ -256,13 +266,15 @@ def get_inference_prompt(
256
  max_tokens = max_secs * target_sample_rate // hop_length
257
 
258
  batch_accum = [0] * num_buckets
259
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
260
- ([[] for _ in range(num_buckets)] for _ in range(6))
 
261
 
262
- mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
 
 
263
 
264
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
265
-
266
  # Audio
267
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
268
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
@@ -274,11 +286,11 @@ def get_inference_prompt(
274
  ref_audio = resampler(ref_audio)
275
 
276
  # Text
277
- if len(prompt_text[-1].encode('utf-8')) == 1:
278
  prompt_text = prompt_text + " "
279
  text = [prompt_text + gt_text]
280
  if tokenizer == "pinyin":
281
- text_list = convert_char_to_pinyin(text, polyphone = polyphone)
282
  else:
283
  text_list = text
284
 
@@ -294,8 +306,8 @@ def get_inference_prompt(
294
  # # test vocoder resynthesis
295
  # ref_audio = gt_audio
296
  else:
297
- ref_text_len = len(prompt_text.encode('utf-8'))
298
- gen_text_len = len(gt_text.encode('utf-8'))
299
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
300
 
301
  # to mel spectrogram
@@ -304,8 +316,9 @@ def get_inference_prompt(
304
 
305
  # deal with batch
306
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
307
- assert min_tokens <= total_mel_len <= max_tokens, \
308
- f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
 
309
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
310
 
311
  utts[bucket_i].append(utt)
@@ -319,28 +332,39 @@ def get_inference_prompt(
319
 
320
  if batch_accum[bucket_i] >= infer_batch_size:
321
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
322
- prompts_all.append((
323
- utts[bucket_i],
324
- ref_rms_list[bucket_i],
325
- padded_mel_batch(ref_mels[bucket_i]),
326
- ref_mel_lens[bucket_i],
327
- total_mel_lens[bucket_i],
328
- final_text_list[bucket_i]
329
- ))
 
 
330
  batch_accum[bucket_i] = 0
331
- utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
 
 
 
 
 
 
 
332
 
333
  # add residual
334
  for bucket_i, bucket_frames in enumerate(batch_accum):
335
  if bucket_frames > 0:
336
- prompts_all.append((
337
- utts[bucket_i],
338
- ref_rms_list[bucket_i],
339
- padded_mel_batch(ref_mels[bucket_i]),
340
- ref_mel_lens[bucket_i],
341
- total_mel_lens[bucket_i],
342
- final_text_list[bucket_i]
343
- ))
 
 
344
  # not only leave easy work for last workers
345
  random.seed(666)
346
  random.shuffle(prompts_all)
@@ -351,6 +375,7 @@ def get_inference_prompt(
351
  # get wav_res_ref_text of seed-tts test metalst
352
  # https://github.com/BytedanceSpeech/seed-tts-eval
353
 
 
354
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
355
  f = open(metalst)
356
  lines = f.readlines()
@@ -358,14 +383,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
 
359
  test_set_ = []
360
  for line in tqdm(lines):
361
- if len(line.strip().split('|')) == 5:
362
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
363
- elif len(line.strip().split('|')) == 4:
364
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
365
 
366
- if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
367
  continue
368
- gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
369
  if not os.path.isabs(prompt_wav):
370
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
371
 
@@ -374,65 +399,69 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
374
  num_jobs = len(gpus)
375
  if num_jobs == 1:
376
  return [(gpus[0], test_set_)]
377
-
378
  wav_per_job = len(test_set_) // num_jobs + 1
379
  test_set = []
380
  for i in range(num_jobs):
381
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
382
 
383
  return test_set
384
 
385
 
386
  # get librispeech test-clean cross sentence test
387
 
388
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
 
389
  f = open(metalst)
390
  lines = f.readlines()
391
  f.close()
392
 
393
  test_set_ = []
394
  for line in tqdm(lines):
395
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
396
 
397
  if eval_ground_truth:
398
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
399
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
400
  else:
401
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
402
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
403
- gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
404
 
405
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
406
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
407
 
408
  test_set_.append((gen_wav, ref_wav, gen_txt))
409
 
410
  num_jobs = len(gpus)
411
  if num_jobs == 1:
412
  return [(gpus[0], test_set_)]
413
-
414
  wav_per_job = len(test_set_) // num_jobs + 1
415
  test_set = []
416
  for i in range(num_jobs):
417
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
418
 
419
  return test_set
420
 
421
 
422
  # load asr model
423
 
424
- def load_asr_model(lang, ckpt_dir = ""):
 
425
  if lang == "zh":
426
  from funasr import AutoModel
 
427
  model = AutoModel(
428
- model = os.path.join(ckpt_dir, "paraformer-zh"),
429
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
430
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
431
- # spk_model = os.path.join(ckpt_dir, "cam++"),
432
  disable_update=True,
433
- ) # following seed-tts setting
434
  elif lang == "en":
435
  from faster_whisper import WhisperModel
 
436
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
437
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
438
  return model
@@ -440,44 +469,50 @@ def load_asr_model(lang, ckpt_dir = ""):
440
 
441
  # WER Evaluation, the way Seed-TTS does
442
 
 
443
  def run_asr_wer(args):
444
  rank, lang, test_set, ckpt_dir = args
445
 
446
  if lang == "zh":
447
  import zhconv
 
448
  torch.cuda.set_device(rank)
449
  elif lang == "en":
450
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
451
  else:
452
- raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
 
 
 
 
453
 
454
- asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
455
-
456
  from zhon.hanzi import punctuation
 
457
  punctuation_all = punctuation + string.punctuation
458
  wers = []
459
 
460
  from jiwer import compute_measures
 
461
  for gen_wav, prompt_wav, truth in tqdm(test_set):
462
  if lang == "zh":
463
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
464
  hypo = res[0]["text"]
465
- hypo = zhconv.convert(hypo, 'zh-cn')
466
  elif lang == "en":
467
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
468
- hypo = ''
469
  for segment in segments:
470
- hypo = hypo + ' ' + segment.text
471
 
472
  # raw_truth = truth
473
  # raw_hypo = hypo
474
 
475
  for x in punctuation_all:
476
- truth = truth.replace(x, '')
477
- hypo = hypo.replace(x, '')
478
 
479
- truth = truth.replace(' ', ' ')
480
- hypo = hypo.replace(' ', ' ')
481
 
482
  if lang == "zh":
483
  truth = " ".join([x for x in truth])
@@ -501,22 +536,22 @@ def run_asr_wer(args):
501
 
502
  # SIM Evaluation
503
 
 
504
  def run_sim(args):
505
  rank, test_set, ckpt_dir = args
506
  device = f"cuda:{rank}"
507
 
508
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
509
  state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
510
- model.load_state_dict(state_dict['model'], strict=False)
511
 
512
- use_gpu=True if torch.cuda.is_available() else False
513
  if use_gpu:
514
  model = model.cuda(device)
515
  model.eval()
516
 
517
  sim_list = []
518
  for wav1, wav2, truth in tqdm(test_set):
519
-
520
  wav1, sr1 = torchaudio.load(wav1)
521
  wav2, sr2 = torchaudio.load(wav2)
522
 
@@ -531,20 +566,21 @@ def run_sim(args):
531
  with torch.no_grad():
532
  emb1 = model(wav1)
533
  emb2 = model(wav2)
534
-
535
  sim = F.cosine_similarity(emb1, emb2)[0].item()
536
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
537
  sim_list.append(sim)
538
-
539
  return sim_list
540
 
541
 
542
  # filter func for dirty data with many repetitions
543
 
544
- def repetition_found(text, length = 2, tolerance = 10):
 
545
  pattern_count = defaultdict(int)
546
  for i in range(len(text) - length + 1):
547
- pattern = text[i:i + length]
548
  pattern_count[pattern] += 1
549
  for pattern, count in pattern_count.items():
550
  if count > tolerance:
@@ -554,25 +590,31 @@ def repetition_found(text, length = 2, tolerance = 10):
554
 
555
  # load model checkpoint for inference
556
 
557
- def load_checkpoint(model, ckpt_path, device, use_ema = True):
 
558
  if device == "cuda":
559
  model = model.half()
560
 
561
  ckpt_type = ckpt_path.split(".")[-1]
562
  if ckpt_type == "safetensors":
563
  from safetensors.torch import load_file
 
564
  checkpoint = load_file(ckpt_path)
565
  else:
566
  checkpoint = torch.load(ckpt_path, weights_only=True)
567
 
568
  if use_ema:
569
  if ckpt_type == "safetensors":
570
- checkpoint = {'ema_model_state_dict': checkpoint}
571
- checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
572
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
573
  else:
574
  if ckpt_type == "safetensors":
575
- checkpoint = {'model_state_dict': checkpoint}
576
- model.load_state_dict(checkpoint['model_state_dict'])
577
 
578
  return model.to(device)
 
8
  from collections import defaultdict
9
 
10
  import matplotlib
11
+
12
  matplotlib.use("Agg")
13
  import matplotlib.pylab as plt
14
 
 
26
 
27
  # seed everything
28
 
29
+
30
+ def seed_everything(seed=0):
31
  random.seed(seed)
32
+ os.environ["PYTHONHASHSEED"] = str(seed)
33
  torch.manual_seed(seed)
34
  torch.cuda.manual_seed(seed)
35
  torch.cuda.manual_seed_all(seed)
36
  torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
+
40
  # helpers
41
 
42
+
43
  def exists(v):
44
  return v is not None
45
 
46
+
47
  def default(v, d):
48
  return v if exists(v) else d
49
 
50
+
51
  # tensor helpers
52
 
 
 
 
 
53
 
54
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
55
  if not exists(length):
56
  length = t.amax()
57
 
58
+ seq = torch.arange(length, device=t.device)
59
  return seq[None, :] < t[:, None]
60
 
61
+
62
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
63
+ max_seq_len = seq_len.max().item()
64
+ seq = torch.arange(max_seq_len, device=start.device).long()
 
 
 
65
  start_mask = seq[None, :] >= start[:, None]
66
  end_mask = seq[None, :] < end[:, None]
67
  return start_mask & end_mask
68
 
69
+
70
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
71
  lengths = (frac_lengths * seq_len).long()
72
  max_start = seq_len - lengths
73
 
74
  rand = torch.rand_like(frac_lengths)
75
+ start = (max_start * rand).long().clamp(min=0)
76
  end = start + lengths
77
 
78
  return mask_from_start_end_indices(seq_len, start, end)
79
 
 
 
 
 
80
 
81
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
82
  if not exists(mask):
83
+ return t.mean(dim=1)
84
 
85
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
86
  num = t.sum(dim=1)
87
  den = mask.float().sum(dim=1)
88
 
89
+ return num / den.clamp(min=1.0)
90
 
91
 
92
  # simple utf-8 tokenizer, since paper went character based
93
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
94
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
95
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
 
 
 
96
  return text
97
 
98
+
99
  # char tokenizer, based on custom dataset's extracted .txt file
100
  def list_str_to_idx(
101
  text: list[str] | list[list[str]],
102
  vocab_char_map: dict[str, int], # {char: idx}
103
+ padding_value=-1,
104
+ ) -> int["b nt"]: # noqa: F722
105
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
106
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
107
  return text
108
 
109
 
110
  # Get tokenizer
111
 
112
+
113
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
114
+ """
115
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
116
  - "char" for char-wise tokenizer, need .txt vocab_file
117
  - "byte" for utf-8 tokenizer
118
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
119
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
120
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
121
+ - if use "byte", set to 256 (unicode byte range)
122
+ """
123
  if tokenizer in ["pinyin", "char"]:
124
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
125
  vocab_char_map = {}
126
  for i, char in enumerate(f):
127
  vocab_char_map[char[:-1]] = i
 
132
  vocab_char_map = None
133
  vocab_size = 256
134
  elif tokenizer == "custom":
135
+ with open(dataset_name, "r", encoding="utf-8") as f:
136
  vocab_char_map = {}
137
  for i, char in enumerate(f):
138
  vocab_char_map[char[:-1]] = i
 
143
 
144
  # convert char to pinyin
145
 
146
+
147
+ def convert_char_to_pinyin(text_list, polyphone=True):
148
  final_text_list = []
149
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
150
+ {"“": '"', "”": '"', "‘": "'", "’": "'"}
151
+ ) # in case librispeech (orig no-pc) test-clean
152
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
153
  for text in text_list:
154
  char_list = []
155
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
156
  text = text.translate(custom_trans)
157
  for seg in jieba.cut(text):
158
+ seg_byte_len = len(bytes(seg, "UTF-8"))
159
  if seg_byte_len == len(seg): # if pure alphabets and symbols
160
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
161
  char_list.append(" ")
 
184
  # save spectrogram
185
  def save_spectrogram(spectrogram, path):
186
  plt.figure(figsize=(12, 4))
187
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
188
  plt.colorbar()
189
  plt.savefig(path)
190
  plt.close()
 
192
 
193
  # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
194
  def get_seedtts_testset_metainfo(metalst):
195
+ f = open(metalst)
196
+ lines = f.readlines()
197
+ f.close()
198
  metainfo = []
199
  for line in lines:
200
+ if len(line.strip().split("|")) == 5:
201
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
202
+ elif len(line.strip().split("|")) == 4:
203
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
204
  gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
205
  if not os.path.isabs(prompt_wav):
206
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
 
210
 
211
  # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
212
  def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
213
+ f = open(metalst)
214
+ lines = f.readlines()
215
+ f.close()
216
  metainfo = []
217
  for line in lines:
218
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
219
 
220
  # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
222
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
223
 
224
  # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
226
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
227
 
228
  metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
 
 
235
  max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
  padded_ref_mels = []
237
  for mel in ref_mels:
238
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
239
  padded_ref_mels.append(padded_ref_mel)
240
  padded_ref_mels = torch.stack(padded_ref_mels)
241
  padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
 
244
 
245
  # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
 
247
+
248
  def get_inference_prompt(
249
+ metainfo,
250
+ speed=1.0,
251
+ tokenizer="pinyin",
252
+ polyphone=True,
253
+ target_sample_rate=24000,
254
+ n_mel_channels=100,
255
+ hop_length=256,
256
+ target_rms=0.1,
257
+ use_truth_duration=False,
258
+ infer_batch_size=1,
259
+ num_buckets=200,
260
+ min_secs=3,
261
+ max_secs=40,
262
  ):
263
  prompts_all = []
264
 
 
266
  max_tokens = max_secs * target_sample_rate // hop_length
267
 
268
  batch_accum = [0] * num_buckets
269
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
270
+ [[] for _ in range(num_buckets)] for _ in range(6)
271
+ )
272
 
273
+ mel_spectrogram = MelSpec(
274
+ target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
275
+ )
276
 
277
  for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
 
278
  # Audio
279
  ref_audio, ref_sr = torchaudio.load(prompt_wav)
280
  ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
 
286
  ref_audio = resampler(ref_audio)
287
 
288
  # Text
289
+ if len(prompt_text[-1].encode("utf-8")) == 1:
290
  prompt_text = prompt_text + " "
291
  text = [prompt_text + gt_text]
292
  if tokenizer == "pinyin":
293
+ text_list = convert_char_to_pinyin(text, polyphone=polyphone)
294
  else:
295
  text_list = text
296
 
 
306
  # # test vocoder resynthesis
307
  # ref_audio = gt_audio
308
  else:
309
+ ref_text_len = len(prompt_text.encode("utf-8"))
310
+ gen_text_len = len(gt_text.encode("utf-8"))
311
  total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
312
 
313
  # to mel spectrogram
 
316
 
317
  # deal with batch
318
  assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
319
+ assert (
320
+ min_tokens <= total_mel_len <= max_tokens
321
+ ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
322
  bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
323
 
324
  utts[bucket_i].append(utt)
 
332
 
333
  if batch_accum[bucket_i] >= infer_batch_size:
334
  # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
335
+ prompts_all.append(
336
+ (
337
+ utts[bucket_i],
338
+ ref_rms_list[bucket_i],
339
+ padded_mel_batch(ref_mels[bucket_i]),
340
+ ref_mel_lens[bucket_i],
341
+ total_mel_lens[bucket_i],
342
+ final_text_list[bucket_i],
343
+ )
344
+ )
345
  batch_accum[bucket_i] = 0
346
+ (
347
+ utts[bucket_i],
348
+ ref_rms_list[bucket_i],
349
+ ref_mels[bucket_i],
350
+ ref_mel_lens[bucket_i],
351
+ total_mel_lens[bucket_i],
352
+ final_text_list[bucket_i],
353
+ ) = [], [], [], [], [], []
354
 
355
  # add residual
356
  for bucket_i, bucket_frames in enumerate(batch_accum):
357
  if bucket_frames > 0:
358
+ prompts_all.append(
359
+ (
360
+ utts[bucket_i],
361
+ ref_rms_list[bucket_i],
362
+ padded_mel_batch(ref_mels[bucket_i]),
363
+ ref_mel_lens[bucket_i],
364
+ total_mel_lens[bucket_i],
365
+ final_text_list[bucket_i],
366
+ )
367
+ )
368
  # not only leave easy work for last workers
369
  random.seed(666)
370
  random.shuffle(prompts_all)
 
375
  # get wav_res_ref_text of seed-tts test metalst
376
  # https://github.com/BytedanceSpeech/seed-tts-eval
377
 
378
+
379
  def get_seed_tts_test(metalst, gen_wav_dir, gpus):
380
  f = open(metalst)
381
  lines = f.readlines()
 
383
 
384
  test_set_ = []
385
  for line in tqdm(lines):
386
+ if len(line.strip().split("|")) == 5:
387
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
388
+ elif len(line.strip().split("|")) == 4:
389
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
390
 
391
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
392
  continue
393
+ gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
394
  if not os.path.isabs(prompt_wav):
395
  prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
396
 
 
399
  num_jobs = len(gpus)
400
  if num_jobs == 1:
401
  return [(gpus[0], test_set_)]
402
+
403
  wav_per_job = len(test_set_) // num_jobs + 1
404
  test_set = []
405
  for i in range(num_jobs):
406
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
407
 
408
  return test_set
409
 
410
 
411
  # get librispeech test-clean cross sentence test
412
 
413
+
414
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
415
  f = open(metalst)
416
  lines = f.readlines()
417
  f.close()
418
 
419
  test_set_ = []
420
  for line in tqdm(lines):
421
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
422
 
423
  if eval_ground_truth:
424
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
425
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
426
  else:
427
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
428
  raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
429
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
430
 
431
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
432
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
433
 
434
  test_set_.append((gen_wav, ref_wav, gen_txt))
435
 
436
  num_jobs = len(gpus)
437
  if num_jobs == 1:
438
  return [(gpus[0], test_set_)]
439
+
440
  wav_per_job = len(test_set_) // num_jobs + 1
441
  test_set = []
442
  for i in range(num_jobs):
443
+ test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
444
 
445
  return test_set
446
 
447
 
448
  # load asr model
449
 
450
+
451
+ def load_asr_model(lang, ckpt_dir=""):
452
  if lang == "zh":
453
  from funasr import AutoModel
454
+
455
  model = AutoModel(
456
+ model=os.path.join(ckpt_dir, "paraformer-zh"),
457
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
458
  # punc_model = os.path.join(ckpt_dir, "ct-punc"),
459
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
460
  disable_update=True,
461
+ ) # following seed-tts setting
462
  elif lang == "en":
463
  from faster_whisper import WhisperModel
464
+
465
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
466
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
467
  return model
 
469
 
470
  # WER Evaluation, the way Seed-TTS does
471
 
472
+
473
  def run_asr_wer(args):
474
  rank, lang, test_set, ckpt_dir = args
475
 
476
  if lang == "zh":
477
  import zhconv
478
+
479
  torch.cuda.set_device(rank)
480
  elif lang == "en":
481
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
482
  else:
483
+ raise NotImplementedError(
484
+ "lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
485
+ )
486
+
487
+ asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
488
 
 
 
489
  from zhon.hanzi import punctuation
490
+
491
  punctuation_all = punctuation + string.punctuation
492
  wers = []
493
 
494
  from jiwer import compute_measures
495
+
496
  for gen_wav, prompt_wav, truth in tqdm(test_set):
497
  if lang == "zh":
498
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
499
  hypo = res[0]["text"]
500
+ hypo = zhconv.convert(hypo, "zh-cn")
501
  elif lang == "en":
502
  segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
503
+ hypo = ""
504
  for segment in segments:
505
+ hypo = hypo + " " + segment.text
506
 
507
  # raw_truth = truth
508
  # raw_hypo = hypo
509
 
510
  for x in punctuation_all:
511
+ truth = truth.replace(x, "")
512
+ hypo = hypo.replace(x, "")
513
 
514
+ truth = truth.replace(" ", " ")
515
+ hypo = hypo.replace(" ", " ")
516
 
517
  if lang == "zh":
518
  truth = " ".join([x for x in truth])
 
536
 
537
  # SIM Evaluation
538
 
539
+
540
  def run_sim(args):
541
  rank, test_set, ckpt_dir = args
542
  device = f"cuda:{rank}"
543
 
544
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
545
  state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
546
+ model.load_state_dict(state_dict["model"], strict=False)
547
 
548
+ use_gpu = True if torch.cuda.is_available() else False
549
  if use_gpu:
550
  model = model.cuda(device)
551
  model.eval()
552
 
553
  sim_list = []
554
  for wav1, wav2, truth in tqdm(test_set):
 
555
  wav1, sr1 = torchaudio.load(wav1)
556
  wav2, sr2 = torchaudio.load(wav2)
557
 
 
566
  with torch.no_grad():
567
  emb1 = model(wav1)
568
  emb2 = model(wav2)
569
+
570
  sim = F.cosine_similarity(emb1, emb2)[0].item()
571
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
572
  sim_list.append(sim)
573
+
574
  return sim_list
575
 
576
 
577
  # filter func for dirty data with many repetitions
578
 
579
+
580
+ def repetition_found(text, length=2, tolerance=10):
581
  pattern_count = defaultdict(int)
582
  for i in range(len(text) - length + 1):
583
+ pattern = text[i : i + length]
584
  pattern_count[pattern] += 1
585
  for pattern, count in pattern_count.items():
586
  if count > tolerance:
 
590
 
591
  # load model checkpoint for inference
592
 
593
+
594
+ def load_checkpoint(model, ckpt_path, device, use_ema=True):
595
  if device == "cuda":
596
  model = model.half()
597
 
598
  ckpt_type = ckpt_path.split(".")[-1]
599
  if ckpt_type == "safetensors":
600
  from safetensors.torch import load_file
601
+
602
  checkpoint = load_file(ckpt_path)
603
  else:
604
  checkpoint = torch.load(ckpt_path, weights_only=True)
605
 
606
  if use_ema:
607
  if ckpt_type == "safetensors":
608
+ checkpoint = {"ema_model_state_dict": checkpoint}
609
+ checkpoint["model_state_dict"] = {
610
+ k.replace("ema_model.", ""): v
611
+ for k, v in checkpoint["ema_model_state_dict"].items()
612
+ if k not in ["initted", "step"]
613
+ }
614
+ model.load_state_dict(checkpoint["model_state_dict"])
615
  else:
616
  if ckpt_type == "safetensors":
617
+ checkpoint = {"model_state_dict": checkpoint}
618
+ model.load_state_dict(checkpoint["model_state_dict"])
619
 
620
  return model.to(device)
model/utils_infer.py CHANGED
@@ -19,11 +19,7 @@ from model.utils import (
19
  convert_char_to_pinyin,
20
  )
21
 
22
- device = (
23
- "cuda"
24
- if torch.cuda.is_available()
25
- else "mps" if torch.backends.mps.is_available() else "cpu"
26
- )
27
  print(f"Using {device} device")
28
 
29
  asr_pipe = pipeline(
@@ -54,6 +50,7 @@ fix_duration = None
54
 
55
  # chunk text into smaller pieces
56
 
 
57
  def chunk_text(text, max_chars=135):
58
  """
59
  Splits the input text into chunks, each with a maximum number of characters.
@@ -68,15 +65,15 @@ def chunk_text(text, max_chars=135):
68
  chunks = []
69
  current_chunk = ""
70
  # Split the text into sentences based on punctuation followed by whitespace
71
- sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
72
 
73
  for sentence in sentences:
74
- if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
75
- current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
76
  else:
77
  if current_chunk:
78
  chunks.append(current_chunk.strip())
79
- current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
80
 
81
  if current_chunk:
82
  chunks.append(current_chunk.strip())
@@ -86,6 +83,7 @@ def chunk_text(text, max_chars=135):
86
 
87
  # load vocoder
88
 
 
89
  def load_vocoder(is_local=False, local_path=""):
90
  if is_local:
91
  print(f"Load vocos from local path {local_path}")
@@ -101,23 +99,21 @@ def load_vocoder(is_local=False, local_path=""):
101
 
102
  # load model for inference
103
 
 
104
  def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
105
-
106
  if vocab_file == "":
107
  vocab_file = "Emilia_ZH_EN"
108
  tokenizer = "pinyin"
109
  else:
110
  tokenizer = "custom"
111
 
112
- print("\nvocab : ", vocab_file, tokenizer)
113
- print("tokenizer : ", tokenizer)
114
- print("model : ", ckpt_path,"\n")
115
 
116
  vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
117
  model = CFM(
118
- transformer=model_cls(
119
- **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
120
- ),
121
  mel_spec_kwargs=dict(
122
  target_sample_rate=target_sample_rate,
123
  n_mel_channels=n_mel_channels,
@@ -129,21 +125,20 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
129
  vocab_char_map=vocab_char_map,
130
  ).to(device)
131
 
132
- model = load_checkpoint(model, ckpt_path, device, use_ema = True)
133
 
134
  return model
135
 
136
 
137
  # preprocess reference audio and text
138
 
 
139
  def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
140
  show_info("Converting audio...")
141
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
142
  aseg = AudioSegment.from_file(ref_audio_orig)
143
 
144
- non_silent_segs = silence.split_on_silence(
145
- aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
146
- )
147
  non_silent_wave = AudioSegment.silent(duration=0)
148
  for non_silent_seg in non_silent_segs:
149
  non_silent_wave += non_silent_seg
@@ -181,22 +176,27 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
181
 
182
  # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
183
 
184
- def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm):
185
 
 
 
 
186
  # Split the input text into batches
187
  audio, sr = torchaudio.load(ref_audio)
188
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
189
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
190
  for i, gen_text in enumerate(gen_text_batches):
191
- print(f'gen_text {i}', gen_text)
192
-
193
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
194
  return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
195
 
196
 
197
  # infer batches
198
 
199
- def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm):
 
 
 
200
  audio, sr = ref_audio
201
  if audio.shape[0] > 1:
202
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -212,7 +212,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
212
  generated_waves = []
213
  spectrograms = []
214
 
215
- if len(ref_text[-1].encode('utf-8')) == 1:
216
  ref_text = ref_text + " "
217
  for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
218
  # Prepare the text
@@ -221,8 +221,8 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
221
 
222
  # Calculate duration
223
  ref_audio_len = audio.shape[-1] // hop_length
224
- ref_text_len = len(ref_text.encode('utf-8'))
225
- gen_text_len = len(gen_text.encode('utf-8'))
226
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
227
 
228
  # inference
@@ -245,7 +245,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
245
 
246
  # wav -> numpy
247
  generated_wave = generated_wave.squeeze().cpu().numpy()
248
-
249
  generated_waves.append(generated_wave)
250
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
251
 
@@ -280,11 +280,9 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
280
  cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
281
 
282
  # Combine
283
- new_wave = np.concatenate([
284
- prev_wave[:-cross_fade_samples],
285
- cross_faded_overlap,
286
- next_wave[cross_fade_samples:]
287
- ])
288
 
289
  final_wave = new_wave
290
 
@@ -296,6 +294,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
296
 
297
  # remove silence from generated wav
298
 
 
299
  def remove_silence_for_generated_wav(filename):
300
  aseg = AudioSegment.from_file(filename)
301
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
19
  convert_char_to_pinyin,
20
  )
21
 
22
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
23
  print(f"Using {device} device")
24
 
25
  asr_pipe = pipeline(
 
50
 
51
  # chunk text into smaller pieces
52
 
53
+
54
  def chunk_text(text, max_chars=135):
55
  """
56
  Splits the input text into chunks, each with a maximum number of characters.
 
65
  chunks = []
66
  current_chunk = ""
67
  # Split the text into sentences based on punctuation followed by whitespace
68
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
69
 
70
  for sentence in sentences:
71
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
72
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
73
  else:
74
  if current_chunk:
75
  chunks.append(current_chunk.strip())
76
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
77
 
78
  if current_chunk:
79
  chunks.append(current_chunk.strip())
 
83
 
84
  # load vocoder
85
 
86
+
87
  def load_vocoder(is_local=False, local_path=""):
88
  if is_local:
89
  print(f"Load vocos from local path {local_path}")
 
99
 
100
  # load model for inference
101
 
102
+
103
  def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
 
104
  if vocab_file == "":
105
  vocab_file = "Emilia_ZH_EN"
106
  tokenizer = "pinyin"
107
  else:
108
  tokenizer = "custom"
109
 
110
+ print("\nvocab : ", vocab_file, tokenizer)
111
+ print("tokenizer : ", tokenizer)
112
+ print("model : ", ckpt_path, "\n")
113
 
114
  vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
115
  model = CFM(
116
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
 
 
117
  mel_spec_kwargs=dict(
118
  target_sample_rate=target_sample_rate,
119
  n_mel_channels=n_mel_channels,
 
125
  vocab_char_map=vocab_char_map,
126
  ).to(device)
127
 
128
+ model = load_checkpoint(model, ckpt_path, device, use_ema=True)
129
 
130
  return model
131
 
132
 
133
  # preprocess reference audio and text
134
 
135
+
136
  def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
137
  show_info("Converting audio...")
138
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
139
  aseg = AudioSegment.from_file(ref_audio_orig)
140
 
141
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
 
 
142
  non_silent_wave = AudioSegment.silent(duration=0)
143
  for non_silent_seg in non_silent_segs:
144
  non_silent_wave += non_silent_seg
 
176
 
177
  # infer process: chunk text -> infer batches [i.e. infer_batch_process()]
178
 
 
179
 
180
+ def infer_process(
181
+ ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
182
+ ):
183
  # Split the input text into batches
184
  audio, sr = torchaudio.load(ref_audio)
185
+ max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
186
  gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
187
  for i, gen_text in enumerate(gen_text_batches):
188
+ print(f"gen_text {i}", gen_text)
189
+
190
  show_info(f"Generating audio in {len(gen_text_batches)} batches...")
191
  return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
192
 
193
 
194
  # infer batches
195
 
196
+
197
+ def infer_batch_process(
198
+ ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
199
+ ):
200
  audio, sr = ref_audio
201
  if audio.shape[0] > 1:
202
  audio = torch.mean(audio, dim=0, keepdim=True)
 
212
  generated_waves = []
213
  spectrograms = []
214
 
215
+ if len(ref_text[-1].encode("utf-8")) == 1:
216
  ref_text = ref_text + " "
217
  for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
218
  # Prepare the text
 
221
 
222
  # Calculate duration
223
  ref_audio_len = audio.shape[-1] // hop_length
224
+ ref_text_len = len(ref_text.encode("utf-8"))
225
+ gen_text_len = len(gen_text.encode("utf-8"))
226
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
227
 
228
  # inference
 
245
 
246
  # wav -> numpy
247
  generated_wave = generated_wave.squeeze().cpu().numpy()
248
+
249
  generated_waves.append(generated_wave)
250
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
251
 
 
280
  cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
281
 
282
  # Combine
283
+ new_wave = np.concatenate(
284
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
285
+ )
 
 
286
 
287
  final_wave = new_wave
288
 
 
294
 
295
  # remove silence from generated wav
296
 
297
+
298
  def remove_silence_for_generated_wav(filename):
299
  aseg = AudioSegment.from_file(filename)
300
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
ruff.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ line-length = 120
2
+ target-version = "py310"
3
+
4
+ [lint]
5
+ # Only ignore variables with names starting with "_".
6
+ dummy-variable-rgx = "^_.*$"
7
+
8
+ [lint.isort]
9
+ force-single-line = true
10
+ lines-after-imports = 2
scripts/count_max_epoch.py CHANGED
@@ -1,6 +1,7 @@
1
- '''ADAPTIVE BATCH SIZE'''
2
- print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
- print(' -> least padding, gather wavs with accumulated frames in a batch\n')
 
4
 
5
  # data
6
  total_hours = 95282
 
1
+ """ADAPTIVE BATCH SIZE"""
2
+
3
+ print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
4
+ print(" -> least padding, gather wavs with accumulated frames in a batch\n")
5
 
6
  # data
7
  total_hours = 95282
scripts/count_params_gflops.py CHANGED
@@ -1,13 +1,15 @@
1
- import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
- from model import M2_TTS, UNetT, DiT, MMDiT
5
 
6
  import torch
7
  import thop
8
 
9
 
10
- ''' ~155M '''
11
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -15,11 +17,11 @@ import thop
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
 
18
- ''' ~335M '''
19
  # FLOPs: 622.1 G, Params: 333.2 M
20
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
  # FLOPs: 363.4 G, Params: 335.8 M
22
- transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
 
24
 
25
  model = M2_TTS(transformer=transformer)
@@ -30,6 +32,8 @@ duration = 20
30
  frame_length = int(duration * target_sample_rate / hop_length)
31
  text_length = 150
32
 
33
- flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
 
 
34
  print(f"FLOPs: {flops / 1e9} G")
35
  print(f"Params: {params / 1e6} M")
 
1
+ import sys
2
+ import os
3
+
4
  sys.path.append(os.getcwd())
5
 
6
+ from model import M2_TTS, DiT
7
 
8
  import torch
9
  import thop
10
 
11
 
12
+ """ ~155M """
13
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
  # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
 
17
  # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
  # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
 
20
+ """ ~335M """
21
  # FLOPs: 622.1 G, Params: 333.2 M
22
  # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
  # FLOPs: 363.4 G, Params: 335.8 M
24
+ transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
 
26
 
27
  model = M2_TTS(transformer=transformer)
 
32
  frame_length = int(duration * target_sample_rate / hop_length)
33
  text_length = 150
34
 
35
+ flops, params = thop.profile(
36
+ model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
+ )
38
  print(f"FLOPs: {flops / 1e9} G")
39
  print(f"Params: {params / 1e6} M")
scripts/eval_infer_batch.py CHANGED
@@ -1,4 +1,6 @@
1
- import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
  import time
@@ -14,9 +16,9 @@ from vocos import Vocos
14
  from model import CFM, UNetT, DiT
15
  from model.utils import (
16
  load_checkpoint,
17
- get_tokenizer,
18
- get_seedtts_testset_metainfo,
19
- get_librispeech_test_clean_metainfo,
20
  get_inference_prompt,
21
  )
22
 
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
38
 
39
  parser = argparse.ArgumentParser(description="batch inference")
40
 
41
- parser.add_argument('-s', '--seed', default=None, type=int)
42
- parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
43
- parser.add_argument('-n', '--expname', required=True)
44
- parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
45
 
46
- parser.add_argument('-nfe', '--nfestep', default=32, type=int)
47
- parser.add_argument('-o', '--odemethod', default="euler")
48
- parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
49
 
50
- parser.add_argument('-t', '--testset', required=True)
51
 
52
  args = parser.parse_args()
53
 
@@ -66,26 +68,26 @@ testset = args.testset
66
 
67
 
68
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
69
- cfg_strength = 2.
70
- speed = 1.
71
  use_truth_duration = False
72
  no_ref_audio = False
73
 
74
 
75
  if exp_name == "F5TTS_Base":
76
  model_cls = DiT
77
- model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
78
 
79
  elif exp_name == "E2TTS_Base":
80
  model_cls = UNetT
81
- model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
82
 
83
 
84
  if testset == "ls_pc_test_clean":
85
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
86
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
87
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
88
-
89
  elif testset == "seedtts_test_zh":
90
  metalst = "data/seedtts_testset/zh/meta.lst"
91
  metainfo = get_seedtts_testset_metainfo(metalst)
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
96
 
97
 
98
  # path to save genereted wavs
99
- if seed is None: seed = random.randint(-10000, 10000)
100
- output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
101
- f"seed{seed}_{ode_method}_nfe{nfe_step}" \
102
- f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
103
- f"_cfg{cfg_strength}_speed{speed}" \
104
- f"{'_gt-dur' if use_truth_duration else ''}" \
 
 
105
  f"{'_no-ref-audio' if no_ref_audio else ''}"
 
106
 
107
 
108
  # -------------------------------------------------#
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
110
  use_ema = True
111
 
112
  prompts_all = get_inference_prompt(
113
- metainfo,
114
- speed = speed,
115
- tokenizer = tokenizer,
116
- target_sample_rate = target_sample_rate,
117
- n_mel_channels = n_mel_channels,
118
- hop_length = hop_length,
119
- target_rms = target_rms,
120
- use_truth_duration = use_truth_duration,
121
- infer_batch_size = infer_batch_size,
122
  )
123
 
124
  # Vocoder model
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
137
 
138
  # Model
139
  model = CFM(
140
- transformer = model_cls(
141
- **model_cfg,
142
- text_num_embeds = vocab_size,
143
- mel_dim = n_mel_channels
 
144
  ),
145
- mel_spec_kwargs = dict(
146
- target_sample_rate = target_sample_rate,
147
- n_mel_channels = n_mel_channels,
148
- hop_length = hop_length,
149
  ),
150
- odeint_kwargs = dict(
151
- method = ode_method,
152
- ),
153
- vocab_char_map = vocab_char_map,
154
  ).to(device)
155
 
156
- model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
157
 
158
  if not os.path.exists(output_dir) and accelerator.is_main_process:
159
  os.makedirs(output_dir)
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
163
  start = time.time()
164
 
165
  with accelerator.split_between_processes(prompts_all) as prompts:
166
-
167
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
168
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
169
  ref_mels = ref_mels.to(device)
170
- ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
171
- total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
172
-
173
  # Inference
174
  with torch.inference_mode():
175
  generated, _ = model.sample(
176
- cond = ref_mels,
177
- text = final_text_list,
178
- duration = total_mel_lens,
179
- lens = ref_mel_lens,
180
- steps = nfe_step,
181
- cfg_strength = cfg_strength,
182
- sway_sampling_coef = sway_sampling_coef,
183
- no_ref_audio = no_ref_audio,
184
- seed = seed,
185
  )
186
  # Final result
187
  for i, gen in enumerate(generated):
188
- gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
189
  gen_mel_spec = gen.permute(0, 2, 1)
190
  generated_wave = vocos.decode(gen_mel_spec.cpu())
191
  if ref_rms_list[i] < target_rms:
 
1
+ import sys
2
+ import os
3
+
4
  sys.path.append(os.getcwd())
5
 
6
  import time
 
16
  from model import CFM, UNetT, DiT
17
  from model.utils import (
18
  load_checkpoint,
19
+ get_tokenizer,
20
+ get_seedtts_testset_metainfo,
21
+ get_librispeech_test_clean_metainfo,
22
  get_inference_prompt,
23
  )
24
 
 
40
 
41
  parser = argparse.ArgumentParser(description="batch inference")
42
 
43
+ parser.add_argument("-s", "--seed", default=None, type=int)
44
+ parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
45
+ parser.add_argument("-n", "--expname", required=True)
46
+ parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
47
 
48
+ parser.add_argument("-nfe", "--nfestep", default=32, type=int)
49
+ parser.add_argument("-o", "--odemethod", default="euler")
50
+ parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
51
 
52
+ parser.add_argument("-t", "--testset", required=True)
53
 
54
  args = parser.parse_args()
55
 
 
68
 
69
 
70
  infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
71
+ cfg_strength = 2.0
72
+ speed = 1.0
73
  use_truth_duration = False
74
  no_ref_audio = False
75
 
76
 
77
  if exp_name == "F5TTS_Base":
78
  model_cls = DiT
79
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
 
81
  elif exp_name == "E2TTS_Base":
82
  model_cls = UNetT
83
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
84
 
85
 
86
  if testset == "ls_pc_test_clean":
87
  metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
88
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
89
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
90
+
91
  elif testset == "seedtts_test_zh":
92
  metalst = "data/seedtts_testset/zh/meta.lst"
93
  metainfo = get_seedtts_testset_metainfo(metalst)
 
98
 
99
 
100
  # path to save genereted wavs
101
+ if seed is None:
102
+ seed = random.randint(-10000, 10000)
103
+ output_dir = (
104
+ f"results/{exp_name}_{ckpt_step}/{testset}/"
105
+ f"seed{seed}_{ode_method}_nfe{nfe_step}"
106
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
107
+ f"_cfg{cfg_strength}_speed{speed}"
108
+ f"{'_gt-dur' if use_truth_duration else ''}"
109
  f"{'_no-ref-audio' if no_ref_audio else ''}"
110
+ )
111
 
112
 
113
  # -------------------------------------------------#
 
115
  use_ema = True
116
 
117
  prompts_all = get_inference_prompt(
118
+ metainfo,
119
+ speed=speed,
120
+ tokenizer=tokenizer,
121
+ target_sample_rate=target_sample_rate,
122
+ n_mel_channels=n_mel_channels,
123
+ hop_length=hop_length,
124
+ target_rms=target_rms,
125
+ use_truth_duration=use_truth_duration,
126
+ infer_batch_size=infer_batch_size,
127
  )
128
 
129
  # Vocoder model
 
142
 
143
  # Model
144
  model = CFM(
145
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
146
+ mel_spec_kwargs=dict(
147
+ target_sample_rate=target_sample_rate,
148
+ n_mel_channels=n_mel_channels,
149
+ hop_length=hop_length,
150
  ),
151
+ odeint_kwargs=dict(
152
+ method=ode_method,
 
 
153
  ),
154
+ vocab_char_map=vocab_char_map,
 
 
 
155
  ).to(device)
156
 
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
158
 
159
  if not os.path.exists(output_dir) and accelerator.is_main_process:
160
  os.makedirs(output_dir)
 
164
  start = time.time()
165
 
166
  with accelerator.split_between_processes(prompts_all) as prompts:
 
167
  for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
168
  utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
169
  ref_mels = ref_mels.to(device)
170
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
171
+ total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
172
+
173
  # Inference
174
  with torch.inference_mode():
175
  generated, _ = model.sample(
176
+ cond=ref_mels,
177
+ text=final_text_list,
178
+ duration=total_mel_lens,
179
+ lens=ref_mel_lens,
180
+ steps=nfe_step,
181
+ cfg_strength=cfg_strength,
182
+ sway_sampling_coef=sway_sampling_coef,
183
+ no_ref_audio=no_ref_audio,
184
+ seed=seed,
185
  )
186
  # Final result
187
  for i, gen in enumerate(generated):
188
+ gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
189
  gen_mel_spec = gen.permute(0, 2, 1)
190
  generated_wave = vocos.decode(gen_mel_spec.cpu())
191
  if ref_rms_list[i] < target_rms:
scripts/eval_librispeech_test_clean.py CHANGED
@@ -1,6 +1,8 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
 
22
- gpus = [0,1,2,3,4,5,6,7]
23
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
 
25
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
@@ -46,7 +48,7 @@ if eval_task == "wer":
46
  for wers_ in results:
47
  wers.extend(wers_)
48
 
49
- wer = round(np.mean(wers)*100, 3)
50
  print(f"\nTotal {len(wers)} samples")
51
  print(f"WER : {wer}%")
52
 
@@ -62,6 +64,6 @@ if eval_task == "sim":
62
  for sim_ in results:
63
  sim_list.extend(sim_)
64
 
65
- sim = round(sum(sim_list)/len(sim_list), 3)
66
  print(f"\nTotal {len(sim_list)} samples")
67
  print(f"SIM : {sim}")
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
+ import sys
4
+ import os
5
+
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
 
21
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
22
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
+ gpus = [0, 1, 2, 3, 4, 5, 6, 7]
25
  test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
26
 
27
  ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
 
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
+ wer = round(np.mean(wers) * 100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
 
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
+ sim = round(sum(sim_list) / len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py CHANGED
@@ -1,6 +1,8 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys, os
 
 
4
  sys.path.append(os.getcwd())
5
 
6
  import multiprocessing as mp
@@ -14,21 +16,21 @@ from model.utils import (
14
 
15
 
16
  eval_task = "wer" # sim | wer
17
- lang = "zh" # zh | en
18
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
- gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
 
22
 
23
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
- # zh 1.254 seems a result of 4 workers wer_seed_tts
25
- gpus = [0,1,2,3,4,5,6,7]
26
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
 
28
  local = False
29
  if local: # use local custom checkpoint dir
30
  if lang == "zh":
31
- asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
  elif lang == "en":
33
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
  else:
@@ -48,7 +50,7 @@ if eval_task == "wer":
48
  for wers_ in results:
49
  wers.extend(wers_)
50
 
51
- wer = round(np.mean(wers)*100, 3)
52
  print(f"\nTotal {len(wers)} samples")
53
  print(f"WER : {wer}%")
54
 
@@ -64,6 +66,6 @@ if eval_task == "sim":
64
  for sim_ in results:
65
  sim_list.extend(sim_)
66
 
67
- sim = round(sum(sim_list)/len(sim_list), 3)
68
  print(f"\nTotal {len(sim_list)} samples")
69
  print(f"SIM : {sim}")
 
1
  # Evaluate with Seed-TTS testset
2
 
3
+ import sys
4
+ import os
5
+
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
 
16
 
17
 
18
  eval_task = "wer" # sim | wer
19
+ lang = "zh" # zh | en
20
  metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
21
  # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
22
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
 
25
  # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
26
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
27
+ gpus = [0, 1, 2, 3, 4, 5, 6, 7]
28
  test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
29
 
30
  local = False
31
  if local: # use local custom checkpoint dir
32
  if lang == "zh":
33
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
34
  elif lang == "en":
35
  asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
36
  else:
 
50
  for wers_ in results:
51
  wers.extend(wers_)
52
 
53
+ wer = round(np.mean(wers) * 100, 3)
54
  print(f"\nTotal {len(wers)} samples")
55
  print(f"WER : {wer}%")
56
 
 
66
  for sim_ in results:
67
  sim_list.extend(sim_)
68
 
69
+ sim = round(sum(sim_list) / len(sim_list), 3)
70
  print(f"\nTotal {len(sim_list)} samples")
71
  print(f"SIM : {sim}")
scripts/prepare_csv_wavs.py CHANGED
@@ -1,4 +1,6 @@
1
- import sys, os
 
 
2
  sys.path.append(os.getcwd())
3
 
4
  from pathlib import Path
@@ -17,10 +19,11 @@ from model.utils import (
17
 
18
  PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
19
 
 
20
  def is_csv_wavs_format(input_dataset_dir):
21
  fpath = Path(input_dataset_dir)
22
  metadata = fpath / "metadata.csv"
23
- wavs = fpath / 'wavs'
24
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
25
 
26
 
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
46
 
47
  return sub_result, durations, vocab_set
48
 
 
49
  def get_audio_duration(audio_path):
50
  audio, sample_rate = torchaudio.load(audio_path)
51
  num_channels = audio.shape[0]
52
  return audio.shape[1] / (sample_rate * num_channels)
53
 
 
54
  def read_audio_text_pairs(csv_file_path):
55
  audio_text_pairs = []
56
 
57
  parent = Path(csv_file_path).parent
58
- with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
59
- reader = csv.reader(csvfile, delimiter='|')
60
  next(reader) # Skip the header row
61
  for row in reader:
62
  if len(row) >= 2:
63
  audio_file = row[0].strip() # First column: audio file path
64
- text = row[1].strip() # Second column: text
65
  audio_file_path = parent / audio_file
66
  audio_text_pairs.append((audio_file_path.as_posix(), text))
67
 
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
78
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
79
  raw_arrow_path = out_dir / "raw.arrow"
80
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
81
- for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
82
  writer.write(line)
83
 
84
  # dup a json separately saving duration in case for DynamicBatchSampler ease
85
  dur_json_path = out_dir / "duration.json"
86
- with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
87
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
88
 
89
  # vocab map, i.e. tokenizer
@@ -120,13 +125,14 @@ def cli():
120
  # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
121
  # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
122
  parser = argparse.ArgumentParser(description="Prepare and save dataset.")
123
- parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
124
- parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
125
- parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
126
 
127
  args = parser.parse_args()
128
 
129
  prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
130
 
 
131
  if __name__ == "__main__":
132
  cli()
 
1
+ import sys
2
+ import os
3
+
4
  sys.path.append(os.getcwd())
5
 
6
  from pathlib import Path
 
19
 
20
  PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
21
 
22
+
23
  def is_csv_wavs_format(input_dataset_dir):
24
  fpath = Path(input_dataset_dir)
25
  metadata = fpath / "metadata.csv"
26
+ wavs = fpath / "wavs"
27
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
28
 
29
 
 
49
 
50
  return sub_result, durations, vocab_set
51
 
52
+
53
  def get_audio_duration(audio_path):
54
  audio, sample_rate = torchaudio.load(audio_path)
55
  num_channels = audio.shape[0]
56
  return audio.shape[1] / (sample_rate * num_channels)
57
 
58
+
59
  def read_audio_text_pairs(csv_file_path):
60
  audio_text_pairs = []
61
 
62
  parent = Path(csv_file_path).parent
63
+ with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
+ reader = csv.reader(csvfile, delimiter="|")
65
  next(reader) # Skip the header row
66
  for row in reader:
67
  if len(row) >= 2:
68
  audio_file = row[0].strip() # First column: audio file path
69
+ text = row[1].strip() # Second column: text
70
  audio_file_path = parent / audio_file
71
  audio_text_pairs.append((audio_file_path.as_posix(), text))
72
 
 
83
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
84
  raw_arrow_path = out_dir / "raw.arrow"
85
  with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
86
+ for line in tqdm(result, desc="Writing to raw.arrow ..."):
87
  writer.write(line)
88
 
89
  # dup a json separately saving duration in case for DynamicBatchSampler ease
90
  dur_json_path = out_dir / "duration.json"
91
+ with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
92
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
93
 
94
  # vocab map, i.e. tokenizer
 
125
  # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
126
  # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
127
  parser = argparse.ArgumentParser(description="Prepare and save dataset.")
128
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
129
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
130
+ parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
131
 
132
  args = parser.parse_args()
133
 
134
  prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
135
 
136
+
137
  if __name__ == "__main__":
138
  cli()
scripts/prepare_emilia.py CHANGED
@@ -4,7 +4,9 @@
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
- import sys, os
 
 
8
  sys.path.append(os.getcwd())
9
 
10
  from pathlib import Path
@@ -12,7 +14,6 @@ import json
12
  from tqdm import tqdm
13
  from concurrent.futures import ProcessPoolExecutor
14
 
15
- from datasets import Dataset
16
  from datasets.arrow_writer import ArrowWriter
17
 
18
  from model.utils import (
@@ -21,13 +22,89 @@ from model.utils import (
21
  )
22
 
23
 
24
- out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
 
 
 
 
 
 
 
25
  zh_filters = ["い", "て"]
26
  # seems synthesized audios, or heavily code-switched
27
  out_en = {
28
- "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
-
30
- "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  en_filters = ["ا", "い", "て"]
33
 
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
43
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
  obj = json.loads(line)
45
  text = obj["text"]
46
- if obj['language'] == "zh":
47
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
  bad_case_zh += 1
49
  continue
50
  else:
51
- text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
- if obj['language'] == "en":
53
- if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
 
 
 
 
 
 
54
  bad_case_en += 1
55
  continue
56
  if tokenizer == "pinyin":
57
- text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
  duration = obj["duration"]
59
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
  durations.append(duration)
@@ -96,11 +179,11 @@ def main():
96
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
- for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
  writer.write(line)
101
 
102
  # dup a json separately saving duration in case for DynamicBatchSampler ease
103
- with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
 
106
  # vocab map, i.e. tokenizer
@@ -114,12 +197,13 @@ def main():
114
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
- if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
- if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
 
 
119
 
120
 
121
  if __name__ == "__main__":
122
-
123
  max_workers = 32
124
 
125
  tokenizer = "pinyin" # "pinyin" | "char"
 
4
  # generate audio text map for Emilia ZH & EN
5
  # evaluate for vocab size
6
 
7
+ import sys
8
+ import os
9
+
10
  sys.path.append(os.getcwd())
11
 
12
  from pathlib import Path
 
14
  from tqdm import tqdm
15
  from concurrent.futures import ProcessPoolExecutor
16
 
 
17
  from datasets.arrow_writer import ArrowWriter
18
 
19
  from model.utils import (
 
22
  )
23
 
24
 
25
+ out_zh = {
26
+ "ZH_B00041_S06226",
27
+ "ZH_B00042_S09204",
28
+ "ZH_B00065_S09430",
29
+ "ZH_B00065_S09431",
30
+ "ZH_B00066_S09327",
31
+ "ZH_B00066_S09328",
32
+ }
33
  zh_filters = ["い", "て"]
34
  # seems synthesized audios, or heavily code-switched
35
  out_en = {
36
+ "EN_B00013_S00913",
37
+ "EN_B00042_S00120",
38
+ "EN_B00055_S04111",
39
+ "EN_B00061_S00693",
40
+ "EN_B00061_S01494",
41
+ "EN_B00061_S03375",
42
+ "EN_B00059_S00092",
43
+ "EN_B00111_S04300",
44
+ "EN_B00100_S03759",
45
+ "EN_B00087_S03811",
46
+ "EN_B00059_S00950",
47
+ "EN_B00089_S00946",
48
+ "EN_B00078_S05127",
49
+ "EN_B00070_S04089",
50
+ "EN_B00074_S09659",
51
+ "EN_B00061_S06983",
52
+ "EN_B00061_S07060",
53
+ "EN_B00059_S08397",
54
+ "EN_B00082_S06192",
55
+ "EN_B00091_S01238",
56
+ "EN_B00089_S07349",
57
+ "EN_B00070_S04343",
58
+ "EN_B00061_S02400",
59
+ "EN_B00076_S01262",
60
+ "EN_B00068_S06467",
61
+ "EN_B00076_S02943",
62
+ "EN_B00064_S05954",
63
+ "EN_B00061_S05386",
64
+ "EN_B00066_S06544",
65
+ "EN_B00076_S06944",
66
+ "EN_B00072_S08620",
67
+ "EN_B00076_S07135",
68
+ "EN_B00076_S09127",
69
+ "EN_B00065_S00497",
70
+ "EN_B00059_S06227",
71
+ "EN_B00063_S02859",
72
+ "EN_B00075_S01547",
73
+ "EN_B00061_S08286",
74
+ "EN_B00079_S02901",
75
+ "EN_B00092_S03643",
76
+ "EN_B00096_S08653",
77
+ "EN_B00063_S04297",
78
+ "EN_B00063_S04614",
79
+ "EN_B00079_S04698",
80
+ "EN_B00104_S01666",
81
+ "EN_B00061_S09504",
82
+ "EN_B00061_S09694",
83
+ "EN_B00065_S05444",
84
+ "EN_B00063_S06860",
85
+ "EN_B00065_S05725",
86
+ "EN_B00069_S07628",
87
+ "EN_B00083_S03875",
88
+ "EN_B00071_S07665",
89
+ "EN_B00071_S07665",
90
+ "EN_B00062_S04187",
91
+ "EN_B00065_S09873",
92
+ "EN_B00065_S09922",
93
+ "EN_B00084_S02463",
94
+ "EN_B00067_S05066",
95
+ "EN_B00106_S08060",
96
+ "EN_B00073_S06399",
97
+ "EN_B00073_S09236",
98
+ "EN_B00087_S00432",
99
+ "EN_B00085_S05618",
100
+ "EN_B00064_S01262",
101
+ "EN_B00072_S01739",
102
+ "EN_B00059_S03913",
103
+ "EN_B00069_S04036",
104
+ "EN_B00067_S05623",
105
+ "EN_B00060_S05389",
106
+ "EN_B00060_S07290",
107
+ "EN_B00062_S08995",
108
  }
109
  en_filters = ["ا", "い", "て"]
110
 
 
120
  for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
121
  obj = json.loads(line)
122
  text = obj["text"]
123
+ if obj["language"] == "zh":
124
  if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
125
  bad_case_zh += 1
126
  continue
127
  else:
128
+ text = text.translate(
129
+ str.maketrans({",": ",", "!": "", "?": "?"})
130
+ ) # not "" cuz much code-switched
131
+ if obj["language"] == "en":
132
+ if (
133
+ obj["wav"].split("/")[1] in out_en
134
+ or any(f in text for f in en_filters)
135
+ or repetition_found(text, length=4)
136
+ ):
137
  bad_case_en += 1
138
  continue
139
  if tokenizer == "pinyin":
140
+ text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
141
  duration = obj["duration"]
142
  sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
143
  durations.append(duration)
 
179
  # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
180
  # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
181
  with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
182
+ for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
184
 
185
  # dup a json separately saving duration in case for DynamicBatchSampler ease
186
+ with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
187
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
188
 
189
  # vocab map, i.e. tokenizer
 
197
  print(f"\nFor {dataset_name}, sample count: {len(result)}")
198
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
199
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
200
+ if "ZH" in langs:
201
+ print(f"Bad zh transcription case: {total_bad_case_zh}")
202
+ if "EN" in langs:
203
+ print(f"Bad en transcription case: {total_bad_case_en}\n")
204
 
205
 
206
  if __name__ == "__main__":
 
207
  max_workers = 32
208
 
209
  tokenizer = "pinyin" # "pinyin" | "char"
scripts/prepare_wenetspeech4tts.py CHANGED
@@ -1,7 +1,9 @@
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
- import sys, os
 
 
5
  sys.path.append(os.getcwd())
6
 
7
  import json
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
23
 
24
  audio_paths, texts, durations = [], [], []
25
  for text_file in tqdm(text_files):
26
- with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
  first_line = file.readline().split("\t")
28
  audio_nm = first_line[0]
29
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
32
  audio_paths.append(audio_path)
33
 
34
  if tokenizer == "pinyin":
35
- texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
  elif tokenizer == "char":
37
  texts.append(text)
38
 
@@ -46,7 +48,7 @@ def main():
46
  assert tokenizer in ["pinyin", "char"]
47
 
48
  audio_path_list, text_list, duration_list = [], [], []
49
-
50
  executor = ProcessPoolExecutor(max_workers=max_workers)
51
  futures = []
52
  for dataset_path in dataset_paths:
@@ -68,8 +70,10 @@ def main():
68
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
 
71
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
- json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
 
 
73
 
74
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
  text_vocab_set = set()
@@ -85,22 +89,21 @@ def main():
85
  f.write(vocab + "\n")
86
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
-
89
 
90
- if __name__ == "__main__":
91
 
 
92
  max_workers = 32
93
 
94
  tokenizer = "pinyin" # "pinyin" | "char"
95
  polyphone = True
96
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
 
98
- dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
  dataset_paths = [
100
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
- ][-dataset_choice:]
104
  print(f"\nChoose Dataset: {dataset_name}\n")
105
 
106
  main()
@@ -109,8 +112,8 @@ if __name__ == "__main__":
109
  # WenetSpeech4TTS Basic Standard Premium
110
  # samples count 3932473 1941220 407494
111
  # pinyin vocab size 1349 1348 1344 (no polyphone)
112
- # - - 1459 (polyphone)
113
  # char vocab size 5264 5219 5042
114
-
115
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
  # please be careful if using pretrained model, make sure the vocab.txt is same
 
1
  # generate audio text map for WenetSpeech4TTS
2
  # evaluate for vocab size
3
 
4
+ import sys
5
+ import os
6
+
7
  sys.path.append(os.getcwd())
8
 
9
  import json
 
25
 
26
  audio_paths, texts, durations = [], [], []
27
  for text_file in tqdm(text_files):
28
+ with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
29
  first_line = file.readline().split("\t")
30
  audio_nm = first_line[0]
31
  audio_path = os.path.join(audio_dir, audio_nm + ".wav")
 
34
  audio_paths.append(audio_path)
35
 
36
  if tokenizer == "pinyin":
37
+ texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
38
  elif tokenizer == "char":
39
  texts.append(text)
40
 
 
48
  assert tokenizer in ["pinyin", "char"]
49
 
50
  audio_path_list, text_list, duration_list = [], [], []
51
+
52
  executor = ProcessPoolExecutor(max_workers=max_workers)
53
  futures = []
54
  for dataset_path in dataset_paths:
 
70
  dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
71
  dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
72
 
73
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
74
+ json.dump(
75
+ {"duration": duration_list}, f, ensure_ascii=False
76
+ ) # dup a json separately saving duration in case for DynamicBatchSampler ease
77
 
78
  print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
79
  text_vocab_set = set()
 
89
  f.write(vocab + "\n")
90
  print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
91
  print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
 
92
 
 
93
 
94
+ if __name__ == "__main__":
95
  max_workers = 32
96
 
97
  tokenizer = "pinyin" # "pinyin" | "char"
98
  polyphone = True
99
  dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
100
 
101
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
102
  dataset_paths = [
103
  "<SOME_PATH>/WenetSpeech4TTS/Basic",
104
  "<SOME_PATH>/WenetSpeech4TTS/Standard",
105
  "<SOME_PATH>/WenetSpeech4TTS/Premium",
106
+ ][-dataset_choice:]
107
  print(f"\nChoose Dataset: {dataset_name}\n")
108
 
109
  main()
 
112
  # WenetSpeech4TTS Basic Standard Premium
113
  # samples count 3932473 1941220 407494
114
  # pinyin vocab size 1349 1348 1344 (no polyphone)
115
+ # - - 1459 (polyphone)
116
  # char vocab size 5264 5219 5042
117
+
118
  # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
119
  # please be careful if using pretrained model, make sure the vocab.txt is same
speech_edit.py CHANGED
@@ -5,11 +5,11 @@ import torch.nn.functional as F
5
  import torchaudio
6
  from vocos import Vocos
7
 
8
- from model import CFM, UNetT, DiT, MMDiT
9
  from model.utils import (
10
  load_checkpoint,
11
- get_tokenizer,
12
- convert_char_to_pinyin,
13
  save_spectrogram,
14
  )
15
 
@@ -35,18 +35,18 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
  ckpt_step = 1200000
36
 
37
  nfe_step = 32 # 16, 32
38
- cfg_strength = 2.
39
- ode_method = 'euler' # euler | midpoint
40
- sway_sampling_coef = -1.
41
- speed = 1.
42
 
43
  if exp_name == "F5TTS_Base":
44
  model_cls = DiT
45
- model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
46
 
47
  elif exp_name == "E2TTS_Base":
48
  model_cls = UNetT
49
- model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
50
 
51
  ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
  output_dir = "tests"
@@ -62,8 +62,14 @@ output_dir = "tests"
62
  audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
63
  origin_text = "Some call me nature, others call me mother nature."
64
  target_text = "Some call me optimist, others call me realist."
65
- parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
66
- fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
 
 
 
 
 
 
67
 
68
  # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
69
  # origin_text = "对,这就是我,万人敬仰的太乙真人。"
@@ -86,7 +92,7 @@ if local:
86
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
87
  state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
88
  vocos.load_state_dict(state_dict)
89
-
90
  vocos.eval()
91
  else:
92
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@@ -96,23 +102,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
 
97
  # Model
98
  model = CFM(
99
- transformer = model_cls(
100
- **model_cfg,
101
- text_num_embeds = vocab_size,
102
- mel_dim = n_mel_channels
103
- ),
104
- mel_spec_kwargs = dict(
105
- target_sample_rate = target_sample_rate,
106
- n_mel_channels = n_mel_channels,
107
- hop_length = hop_length,
108
  ),
109
- odeint_kwargs = dict(
110
- method = ode_method,
111
  ),
112
- vocab_char_map = vocab_char_map,
113
  ).to(device)
114
 
115
- model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
116
 
117
  # Audio
118
  audio, sr = torchaudio.load(audio_to_edit)
@@ -132,14 +134,18 @@ for part in parts_to_edit:
132
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
  part_dur = part_dur * target_sample_rate
134
  start = start * target_sample_rate
135
- audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
136
- edit_mask = torch.cat((edit_mask,
137
- torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
138
- torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
139
- ), dim = -1)
 
 
 
 
140
  offset = end * target_sample_rate
141
  # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
142
- edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
143
  audio = audio.to(device)
144
  edit_mask = edit_mask.to(device)
145
 
@@ -159,14 +165,14 @@ duration = audio.shape[-1] // hop_length
159
  # Inference
160
  with torch.inference_mode():
161
  generated, trajectory = model.sample(
162
- cond = audio,
163
- text = final_text_list,
164
- duration = duration,
165
- steps = nfe_step,
166
- cfg_strength = cfg_strength,
167
- sway_sampling_coef = sway_sampling_coef,
168
- seed = seed,
169
- edit_mask = edit_mask,
170
  )
171
  print(f"Generated mel: {generated.shape}")
172
 
 
5
  import torchaudio
6
  from vocos import Vocos
7
 
8
+ from model import CFM, UNetT, DiT
9
  from model.utils import (
10
  load_checkpoint,
11
+ get_tokenizer,
12
+ convert_char_to_pinyin,
13
  save_spectrogram,
14
  )
15
 
 
35
  ckpt_step = 1200000
36
 
37
  nfe_step = 32 # 16, 32
38
+ cfg_strength = 2.0
39
+ ode_method = "euler" # euler | midpoint
40
+ sway_sampling_coef = -1.0
41
+ speed = 1.0
42
 
43
  if exp_name == "F5TTS_Base":
44
  model_cls = DiT
45
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
 
47
  elif exp_name == "E2TTS_Base":
48
  model_cls = UNetT
49
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
 
51
  ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
  output_dir = "tests"
 
62
  audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
63
  origin_text = "Some call me nature, others call me mother nature."
64
  target_text = "Some call me optimist, others call me realist."
65
+ parts_to_edit = [
66
+ [1.42, 2.44],
67
+ [4.04, 4.9],
68
+ ] # stard_ends of "nature" & "mother nature", in seconds
69
+ fix_duration = [
70
+ 1.2,
71
+ 1,
72
+ ] # fix duration for "optimist" & "realist", in seconds
73
 
74
  # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
75
  # origin_text = "对,这就是我,万人敬仰的太乙真人。"
 
92
  vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
93
  state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
94
  vocos.load_state_dict(state_dict)
95
+
96
  vocos.eval()
97
  else:
98
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
 
102
 
103
  # Model
104
  model = CFM(
105
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
106
+ mel_spec_kwargs=dict(
107
+ target_sample_rate=target_sample_rate,
108
+ n_mel_channels=n_mel_channels,
109
+ hop_length=hop_length,
 
 
 
 
110
  ),
111
+ odeint_kwargs=dict(
112
+ method=ode_method,
113
  ),
114
+ vocab_char_map=vocab_char_map,
115
  ).to(device)
116
 
117
+ model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
118
 
119
  # Audio
120
  audio, sr = torchaudio.load(audio_to_edit)
 
134
  part_dur = end - start if fix_duration is None else fix_duration.pop(0)
135
  part_dur = part_dur * target_sample_rate
136
  start = start * target_sample_rate
137
+ audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
138
+ edit_mask = torch.cat(
139
+ (
140
+ edit_mask,
141
+ torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
142
+ torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
143
+ ),
144
+ dim=-1,
145
+ )
146
  offset = end * target_sample_rate
147
  # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
148
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
149
  audio = audio.to(device)
150
  edit_mask = edit_mask.to(device)
151
 
 
165
  # Inference
166
  with torch.inference_mode():
167
  generated, trajectory = model.sample(
168
+ cond=audio,
169
+ text=final_text_list,
170
+ duration=duration,
171
+ steps=nfe_step,
172
+ cfg_strength=cfg_strength,
173
+ sway_sampling_coef=sway_sampling_coef,
174
+ seed=seed,
175
+ edit_mask=edit_mask,
176
  )
177
  print(f"Generated mel: {generated.shape}")
178
 
train.py CHANGED
@@ -1,4 +1,4 @@
1
- from model import CFM, UNetT, DiT, MMDiT, Trainer
2
  from model.utils import get_tokenizer
3
  from model.dataset import load_dataset
4
 
@@ -9,8 +9,8 @@ target_sample_rate = 24000
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
- tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
- tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
  dataset_name = "Emilia_ZH_EN"
15
 
16
  # -------------------------- Training Settings -------------------------- #
@@ -23,7 +23,7 @@ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
  batch_size_type = "frame" # "frame" or "sample"
24
  max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
  grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
- max_grad_norm = 1.
27
 
28
  epochs = 11 # use linear decay, thus epochs control the slope
29
  num_warmup_updates = 20000 # warmup steps
@@ -34,15 +34,16 @@ last_per_steps = 5000 # save last checkpoint per steps
34
  if exp_name == "F5TTS_Base":
35
  wandb_resume_id = None
36
  model_cls = DiT
37
- model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
  elif exp_name == "E2TTS_Base":
39
  wandb_resume_id = None
40
  model_cls = UNetT
41
- model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
 
43
 
44
  # ----------------------------------------------------------------------- #
45
 
 
46
  def main():
47
  if tokenizer == "custom":
48
  tokenizer_path = tokenizer_path
@@ -51,44 +52,41 @@ def main():
51
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
52
 
53
  mel_spec_kwargs = dict(
54
- target_sample_rate = target_sample_rate,
55
- n_mel_channels = n_mel_channels,
56
- hop_length = hop_length,
57
- )
58
-
59
  model = CFM(
60
- transformer = model_cls(
61
- **model_cfg,
62
- text_num_embeds = vocab_size,
63
- mel_dim = n_mel_channels
64
- ),
65
- mel_spec_kwargs = mel_spec_kwargs,
66
- vocab_char_map = vocab_char_map,
67
  )
68
 
69
  trainer = Trainer(
70
  model,
71
- epochs,
72
  learning_rate,
73
- num_warmup_updates = num_warmup_updates,
74
- save_per_updates = save_per_updates,
75
- checkpoint_path = f'ckpts/{exp_name}',
76
- batch_size = batch_size_per_gpu,
77
- batch_size_type = batch_size_type,
78
- max_samples = max_samples,
79
- grad_accumulation_steps = grad_accumulation_steps,
80
- max_grad_norm = max_grad_norm,
81
- wandb_project = "CFM-TTS",
82
- wandb_run_name = exp_name,
83
- wandb_resume_id = wandb_resume_id,
84
- last_per_steps = last_per_steps,
85
  )
86
 
87
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
88
- trainer.train(train_dataset,
89
- resumable_with_seed = 666 # seed for shuffling dataset
90
- )
 
91
 
92
 
93
- if __name__ == '__main__':
94
  main()
 
1
+ from model import CFM, UNetT, DiT, Trainer
2
  from model.utils import get_tokenizer
3
  from model.dataset import load_dataset
4
 
 
9
  n_mel_channels = 100
10
  hop_length = 256
11
 
12
+ tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
13
+ tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
14
  dataset_name = "Emilia_ZH_EN"
15
 
16
  # -------------------------- Training Settings -------------------------- #
 
23
  batch_size_type = "frame" # "frame" or "sample"
24
  max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
  grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.0
27
 
28
  epochs = 11 # use linear decay, thus epochs control the slope
29
  num_warmup_updates = 20000 # warmup steps
 
34
  if exp_name == "F5TTS_Base":
35
  wandb_resume_id = None
36
  model_cls = DiT
37
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
38
  elif exp_name == "E2TTS_Base":
39
  wandb_resume_id = None
40
  model_cls = UNetT
41
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
42
 
43
 
44
  # ----------------------------------------------------------------------- #
45
 
46
+
47
  def main():
48
  if tokenizer == "custom":
49
  tokenizer_path = tokenizer_path
 
52
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
53
 
54
  mel_spec_kwargs = dict(
55
+ target_sample_rate=target_sample_rate,
56
+ n_mel_channels=n_mel_channels,
57
+ hop_length=hop_length,
58
+ )
59
+
60
  model = CFM(
61
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
62
+ mel_spec_kwargs=mel_spec_kwargs,
63
+ vocab_char_map=vocab_char_map,
 
 
 
 
64
  )
65
 
66
  trainer = Trainer(
67
  model,
68
+ epochs,
69
  learning_rate,
70
+ num_warmup_updates=num_warmup_updates,
71
+ save_per_updates=save_per_updates,
72
+ checkpoint_path=f"ckpts/{exp_name}",
73
+ batch_size=batch_size_per_gpu,
74
+ batch_size_type=batch_size_type,
75
+ max_samples=max_samples,
76
+ grad_accumulation_steps=grad_accumulation_steps,
77
+ max_grad_norm=max_grad_norm,
78
+ wandb_project="CFM-TTS",
79
+ wandb_run_name=exp_name,
80
+ wandb_resume_id=wandb_resume_id,
81
+ last_per_steps=last_per_steps,
82
  )
83
 
84
  train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(
86
+ train_dataset,
87
+ resumable_with_seed=666, # seed for shuffling dataset
88
+ )
89
 
90
 
91
+ if __name__ == "__main__":
92
  main()