mrfakename commited on
Commit
1cbd297
·
verified ·
1 Parent(s): b7a138c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -206
app.py CHANGED
@@ -1,209 +1,6 @@
1
- # ruff: noqa: E402
2
-
3
- import os
4
- import json
5
- import tempfile
6
- from functools import lru_cache
7
- from importlib.resources import files
8
-
9
  import gradio as gr
10
- import numpy as np
11
- import soundfile as sf
12
- import torch
13
- import torchaudio
14
- from cached_path import cached_path
15
- import spaces
16
-
17
- from f5_tts.infer.utils_infer import (
18
- infer_process,
19
- load_model,
20
- load_vocoder,
21
- preprocess_ref_audio_text,
22
- remove_silence_for_generated_wav,
23
- save_spectrogram,
24
- tempfile_kwargs,
25
- )
26
- from f5_tts.model import DiT, UNetT
27
-
28
- DEFAULT_TTS_MODEL = "F5-TTS_v1"
29
- DEFAULT_TTS_MODEL_CFG = [
30
- "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
31
- "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
32
- json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
33
- ]
34
-
35
- # Load vocoder and models on module load
36
- vocoder = load_vocoder()
37
- model_cache = {}
38
- model_cache[DEFAULT_TTS_MODEL] = load_model(
39
- DiT,
40
- json.loads(DEFAULT_TTS_MODEL_CFG[2]),
41
- str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
42
- )
43
- model_cache["E2-TTS"] = load_model(
44
- UNetT,
45
- dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1),
46
- str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
47
- )
48
- custom_ema_model, pre_custom_path = None, ""
49
- tts_model_choice = DEFAULT_TTS_MODEL
50
-
51
- def gpu_decorator(fn):
52
- return spaces.GPU(fn)
53
-
54
- with gr.Blocks() as app:
55
- gr.Markdown("# ZeroGPU TTS - F5/E2 Demo")
56
-
57
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
58
- gen_text_input = gr.Textbox(label="Text to Generate", lines=4)
59
- gen_text_file = gr.File(label="Upload Text File", file_types=[".txt"])
60
-
61
- ref_text_input = gr.Textbox(label="Reference Text (optional)", lines=2)
62
- ref_text_file = gr.File(label="Upload Reference Text", file_types=[".txt"])
63
-
64
- remove_silence = gr.Checkbox(label="Remove Silences", value=False)
65
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
66
- seed_input = gr.Number(value=0, precision=0, label="Seed")
67
-
68
- cross_fade_duration_slider = gr.Slider(label="Cross-Fade Duration", minimum=0.0, maximum=1.0, value=0.15)
69
- nfe_slider = gr.Slider(label="NFE Steps", minimum=4, maximum=64, value=32, step=2)
70
- speed_slider = gr.Slider(label="Speed", minimum=0.3, maximum=2.0, value=1.0, step=0.1)
71
-
72
- generate_btn = gr.Button("Generate")
73
- audio_output = gr.Audio(label="Synthesized Audio")
74
- spectrogram_output = gr.Image(label="Spectrogram")
75
-
76
- @gpu_decorator
77
- def infer(
78
- ref_audio_orig,
79
- ref_text,
80
- gen_text,
81
- model,
82
- remove_silence,
83
- seed,
84
- cross_fade_duration=0.15,
85
- nfe_step=32,
86
- speed=1,
87
- show_info=gr.Info,
88
- ):
89
- if not ref_audio_orig:
90
- gr.Warning("Please provide reference audio.")
91
- return gr.update(), gr.update(), ref_text
92
-
93
- if seed < 0 or seed > 2**31 - 1:
94
- gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
95
- seed = np.random.randint(0, 2**31 - 1)
96
- torch.manual_seed(seed)
97
- used_seed = seed
98
-
99
- if not gen_text.strip():
100
- gr.Warning("Please enter text to generate or upload a text file.")
101
- return gr.update(), gr.update(), ref_text
102
-
103
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
104
-
105
- if isinstance(model, tuple) and model[0] == "Custom":
106
- global custom_ema_model, pre_custom_path
107
- if pre_custom_path != model[1]:
108
- show_info("Loading Custom TTS model...")
109
- custom_ema_model = load_model(
110
- DiT,
111
- json.loads(model[3]),
112
- str(cached_path(model[1])),
113
- vocab_file=str(cached_path(model[2]))
114
- )
115
- pre_custom_path = model[1]
116
- ema_model = custom_ema_model
117
- else:
118
- ema_model = model_cache.get(model, model_cache[DEFAULT_TTS_MODEL])
119
-
120
- final_wave, final_sample_rate, combined_spectrogram = infer_process(
121
- ref_audio,
122
- ref_text,
123
- gen_text,
124
- ema_model,
125
- vocoder,
126
- cross_fade_duration=cross_fade_duration,
127
- nfe_step=nfe_step,
128
- speed=speed,
129
- show_info=show_info,
130
- progress=gr.Progress(),
131
- )
132
-
133
- if remove_silence:
134
- with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
135
- temp_path = f.name
136
- try:
137
- sf.write(temp_path, final_wave, final_sample_rate)
138
- remove_silence_for_generated_wav(f.name)
139
- final_wave, _ = torchaudio.load(f.name)
140
- finally:
141
- os.unlink(temp_path)
142
- final_wave = final_wave.squeeze().cpu().numpy()
143
-
144
- with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
145
- spectrogram_path = tmp_spectrogram.name
146
- save_spectrogram(combined_spectrogram, spectrogram_path)
147
-
148
- return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
149
-
150
- @gpu_decorator
151
- def load_text_from_file(file):
152
- if file:
153
- with open(file, "r", encoding="utf-8") as f:
154
- text = f.read().strip()
155
- else:
156
- text = ""
157
- return gr.update(value=text)
158
-
159
- @gpu_decorator
160
- def basic_tts(
161
- ref_audio_input,
162
- ref_text_input,
163
- gen_text_input,
164
- remove_silence,
165
- randomize_seed,
166
- seed_input,
167
- cross_fade_duration_slider,
168
- nfe_slider,
169
- speed_slider,
170
- ):
171
- if randomize_seed:
172
- seed_input = np.random.randint(0, 2**31 - 1)
173
-
174
- audio_out, spectrogram_path, ref_text_out, used_seed = infer(
175
- ref_audio_input,
176
- ref_text_input,
177
- gen_text_input,
178
- tts_model_choice,
179
- remove_silence,
180
- seed=seed_input,
181
- cross_fade_duration=cross_fade_duration_slider,
182
- nfe_step=nfe_slider,
183
- speed=speed_slider,
184
- )
185
- return audio_out, spectrogram_path, ref_text_out, used_seed
186
-
187
- gen_text_file.upload(load_text_from_file, inputs=[gen_text_file], outputs=[gen_text_input])
188
- ref_text_file.upload(load_text_from_file, inputs=[ref_text_file], outputs=[ref_text_input])
189
-
190
- ref_audio_input.clear(lambda: [None, None], None, [ref_text_input, ref_text_file])
191
 
192
- generate_btn.click(
193
- basic_tts,
194
- inputs=[
195
- ref_audio_input,
196
- ref_text_input,
197
- gen_text_input,
198
- remove_silence,
199
- randomize_seed,
200
- seed_input,
201
- cross_fade_duration_slider,
202
- nfe_slider,
203
- speed_slider,
204
- ],
205
- outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
206
- )
207
 
208
- if __name__ == "__main__":
209
- app.queue().launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ with gr.Blocks() as demo:
4
+ gr.Markdown("Hi everyone, due to breaking changes with ZeroGPU/Xet-storage spaces, this space is temporarily down. I hope to find a solution to this soon, so please stay tuned. Sorry for the inconvenience. In the mean time, please check out: https://huggingface.co/spaces/mrfakename/MegaTTS3-Voice-Cloning https://huggingface.co/spaces/styletts2/styletts2 if you need TTS spaces.")
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ demo.launch()