mrfakename commited on
Commit
b7a138c
·
verified ·
1 Parent(s): 95ed3d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -34,12 +34,13 @@ DEFAULT_TTS_MODEL_CFG = [
34
 
35
  # Load vocoder and models on module load
36
  vocoder = load_vocoder()
37
- F5TTS_ema_model = load_model(
 
38
  DiT,
39
  json.loads(DEFAULT_TTS_MODEL_CFG[2]),
40
  str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
41
  )
42
- E2TTS_ema_model = load_model(
43
  UNetT,
44
  dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1),
45
  str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
@@ -72,11 +73,6 @@ with gr.Blocks() as app:
72
  audio_output = gr.Audio(label="Synthesized Audio")
73
  spectrogram_output = gr.Image(label="Spectrogram")
74
 
75
- model_cache = {
76
- DEFAULT_TTS_MODEL: F5TTS_ema_model,
77
- "E2-TTS": E2TTS_ema_model
78
- }
79
-
80
  @gpu_decorator
81
  def infer(
82
  ref_audio_orig,
@@ -119,7 +115,7 @@ with gr.Blocks() as app:
119
  pre_custom_path = model[1]
120
  ema_model = custom_ema_model
121
  else:
122
- ema_model = model_cache.get(model, F5TTS_ema_model)
123
 
124
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
125
  ref_audio,
 
34
 
35
  # Load vocoder and models on module load
36
  vocoder = load_vocoder()
37
+ model_cache = {}
38
+ model_cache[DEFAULT_TTS_MODEL] = load_model(
39
  DiT,
40
  json.loads(DEFAULT_TTS_MODEL_CFG[2]),
41
  str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
42
  )
43
+ model_cache["E2-TTS"] = load_model(
44
  UNetT,
45
  dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1),
46
  str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
 
73
  audio_output = gr.Audio(label="Synthesized Audio")
74
  spectrogram_output = gr.Image(label="Spectrogram")
75
 
 
 
 
 
 
76
  @gpu_decorator
77
  def infer(
78
  ref_audio_orig,
 
115
  pre_custom_path = model[1]
116
  ema_model = custom_ema_model
117
  else:
118
+ ema_model = model_cache.get(model, model_cache[DEFAULT_TTS_MODEL])
119
 
120
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
121
  ref_audio,