Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,12 +34,13 @@ DEFAULT_TTS_MODEL_CFG = [
|
|
| 34 |
|
| 35 |
# Load vocoder and models on module load
|
| 36 |
vocoder = load_vocoder()
|
| 37 |
-
|
|
|
|
| 38 |
DiT,
|
| 39 |
json.loads(DEFAULT_TTS_MODEL_CFG[2]),
|
| 40 |
str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
| 41 |
)
|
| 42 |
-
|
| 43 |
UNetT,
|
| 44 |
dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1),
|
| 45 |
str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
|
@@ -72,11 +73,6 @@ with gr.Blocks() as app:
|
|
| 72 |
audio_output = gr.Audio(label="Synthesized Audio")
|
| 73 |
spectrogram_output = gr.Image(label="Spectrogram")
|
| 74 |
|
| 75 |
-
model_cache = {
|
| 76 |
-
DEFAULT_TTS_MODEL: F5TTS_ema_model,
|
| 77 |
-
"E2-TTS": E2TTS_ema_model
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
@gpu_decorator
|
| 81 |
def infer(
|
| 82 |
ref_audio_orig,
|
|
@@ -119,7 +115,7 @@ with gr.Blocks() as app:
|
|
| 119 |
pre_custom_path = model[1]
|
| 120 |
ema_model = custom_ema_model
|
| 121 |
else:
|
| 122 |
-
ema_model = model_cache.get(model,
|
| 123 |
|
| 124 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
| 125 |
ref_audio,
|
|
|
|
| 34 |
|
| 35 |
# Load vocoder and models on module load
|
| 36 |
vocoder = load_vocoder()
|
| 37 |
+
model_cache = {}
|
| 38 |
+
model_cache[DEFAULT_TTS_MODEL] = load_model(
|
| 39 |
DiT,
|
| 40 |
json.loads(DEFAULT_TTS_MODEL_CFG[2]),
|
| 41 |
str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
| 42 |
)
|
| 43 |
+
model_cache["E2-TTS"] = load_model(
|
| 44 |
UNetT,
|
| 45 |
dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1),
|
| 46 |
str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
|
|
|
| 73 |
audio_output = gr.Audio(label="Synthesized Audio")
|
| 74 |
spectrogram_output = gr.Image(label="Spectrogram")
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
@gpu_decorator
|
| 77 |
def infer(
|
| 78 |
ref_audio_orig,
|
|
|
|
| 115 |
pre_custom_path = model[1]
|
| 116 |
ema_model = custom_ema_model
|
| 117 |
else:
|
| 118 |
+
ema_model = model_cache.get(model, model_cache[DEFAULT_TTS_MODEL])
|
| 119 |
|
| 120 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
| 121 |
ref_audio,
|