Spaces:
Running
Running
import torch | |
import time | |
import sys | |
sys.path.append('F5-TTS/src') | |
sys.path.append('SmoothCache/SmoothCache') | |
import os | |
from importlib.resources import files | |
from PIL import Image, ImageDraw | |
from functools import lru_cache | |
import gradio as gr | |
from smooth_cache_helper import SmoothCacheHelper | |
from f5_tts.infer.utils_infer import ( | |
cross_fade_duration, | |
infer_process, | |
load_model, | |
load_vocoder, | |
preprocess_ref_audio_text, | |
speed | |
) | |
import numpy as np | |
import tomli | |
from cached_path import cached_path | |
from hydra.utils import get_class | |
from omegaconf import OmegaConf | |
try: | |
import spaces | |
USING_SPACES = True | |
except ImportError: | |
USING_SPACES = False | |
def gpu_decorator(func): | |
if USING_SPACES: | |
return spaces.GPU(func) | |
else: | |
return func | |
# Constants | |
layer_names = ['ff', 'attn'] | |
colors_rgb = [(0, 210, 106), (255, 103, 35)] # green, orange | |
cell_size = 20 | |
spacing = 2 | |
n_layers = 2 | |
# Presets | |
presets = { | |
"32 NFE, α=0.15": { | |
'attn': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], | |
'ff': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], | |
}, | |
"32 NFE, α=0.25": { | |
'attn': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], | |
'ff': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], | |
}, | |
"16 NFE, α=0.3": { | |
'attn': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], | |
'ff': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], | |
}, | |
"16 NFE, α=0.5": { | |
'attn': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], | |
'ff': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], | |
} | |
} | |
default_preset = "32 NFE, α=0.15" | |
seed = np.random.randint(0, 2**31 - 1) | |
torch.manual_seed(seed) | |
config = tomli.load(open(os.path.join(files("f5_tts").joinpath( | |
"infer/examples/basic"), "basic.toml"), "rb")) | |
model = config.get("model", "F5TTS_v1_Base") | |
ckpt_file = config.get("ckpt_file", "") | |
vocab_file = config.get("vocab_file", "") | |
model_cfg = OmegaConf.load( | |
config.get("model_cfg", str( | |
files("f5_tts").joinpath(f"configs/{model}.yaml"))) | |
) | |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") | |
model_arc = model_cfg.model.arch | |
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" | |
if not ckpt_file: | |
ckpt_file = str(cached_path( | |
f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) | |
if not vocab_file: | |
vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt")) | |
ema_model = load_model( | |
model_cls, model_arc, ckpt_file, vocab_file=vocab_file | |
) | |
vocoder = load_vocoder() | |
def render_grid(schedule: dict) -> np.ndarray: | |
n_steps = len(schedule['attn']) | |
img = Image.new("RGB", (n_steps * (cell_size + spacing), | |
n_layers * (cell_size + spacing)), "white") | |
draw = ImageDraw.Draw(img) | |
for row in range(n_layers): | |
layer = layer_names[row] | |
for col in range(n_steps): | |
x0 = col * (cell_size + spacing) | |
y0 = row * (cell_size + spacing) | |
x1 = x0 + cell_size | |
y1 = y0 + cell_size | |
color = colors_rgb[row] if schedule[layer][col] == 1 else "white" | |
draw.rectangle([x0, y0, x1, y1], fill=color, outline="black") | |
return np.array(img) | |
def apply_preset(preset_name, cache_schedule): | |
if preset_name in presets: | |
schedule = presets[preset_name] | |
cache_schedule['attn'] = schedule['attn'][:] | |
cache_schedule['ff'] = schedule['ff'][:] | |
return render_grid(cache_schedule), len(cache_schedule['attn']), cache_schedule | |
def toggle_cell(evt: gr.SelectData, cache_schedule): | |
col = evt.index[0] // (cell_size + spacing) | |
row = evt.index[1] // (cell_size + spacing) | |
layer = layer_names[row] | |
if col < len(cache_schedule[layer]): | |
cache_schedule[layer][col] ^= 1 | |
return render_grid(cache_schedule), "Custom", cache_schedule | |
def reset_schedule(n_steps): | |
cache_schedule = { | |
'attn': [1] * n_steps, | |
'ff': [1] * n_steps | |
} | |
return render_grid(cache_schedule), "Custom", cache_schedule | |
def update_nfe(nfe_value): | |
return reset_schedule(nfe_value) | |
def load_default(): | |
cache_schedule = { | |
'attn': presets[default_preset]['attn'][:], | |
'ff': presets[default_preset]['ff'][:] | |
} | |
return render_grid(cache_schedule), default_preset | |
def infer( | |
ref_audio_orig, | |
ref_text, | |
gen_text, | |
nfe_step, | |
cache_schedule, | |
recent_input | |
): | |
show_info = gr.Info | |
if not ref_audio_orig: | |
gr.Warning("Please provide reference audio.") | |
return gr.update(), gr.update(), ref_text, gr.update(), gr.update() | |
if not gen_text.strip(): | |
gr.Warning("Please enter text to generate.") | |
return gr.update(), gr.update(), ref_text, gr.update(), gr.update() | |
ref_audio, ref_text = preprocess_ref_audio_text( | |
ref_audio_orig, ref_text, show_info=show_info) | |
skip_no_cache = False | |
if recent_input["ref_audio"] == ref_audio_orig and recent_input["ref_text"] == ref_text and recent_input["gen_text"] == gen_text and recent_input["nfe_step"] == nfe_step: | |
skip_no_cache = True | |
if not skip_no_cache: | |
start_time = time.time() | |
final_wave, final_sample_rate, _ = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
ema_model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
speed=speed, | |
show_info=show_info, | |
progress=gr.Progress(), | |
) | |
process_time = time.time() - start_time | |
cache_helper = SmoothCacheHelper( | |
model=ema_model.transformer, | |
block_classes=get_class("f5_tts.model.modules.DiTBlock"), | |
components_to_wrap=['attn', 'ff'], | |
schedule=cache_schedule | |
) | |
cache_helper.enable() | |
start_time = time.time() | |
final_wave_cache, final_sample_rate_cache, _ = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
ema_model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
speed=speed, | |
show_info=show_info, | |
progress=gr.Progress(), | |
) | |
process_time_cache = time.time() - start_time | |
cache_helper.disable() | |
recent_input["ref_audio"] = ref_audio_orig | |
recent_input["ref_text"] = ref_text | |
recent_input["gen_text"] = gen_text | |
recent_input["nfe_step"] = nfe_step | |
if skip_no_cache: | |
print("skip") | |
return gr.update(), (final_sample_rate_cache, final_wave_cache), ref_text, gr.update(), process_time_cache, recent_input | |
return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), ref_text, process_time, process_time_cache, recent_input | |
with gr.Blocks() as demo: | |
gr.Markdown("## F5-TTS + SmoothCache") | |
cache_schedule_state = gr.State({ | |
'attn': presets[default_preset]['attn'][:], | |
'ff': presets[default_preset]['ff'][:] | |
}) | |
recent_input_state = gr.State({ | |
"ref_audio": None, | |
"ref_text": None, | |
"gen_text": None, | |
"nfe_step": None | |
}) | |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
ref_text_input = gr.Textbox(label="Reference Text (Optional)") | |
gen_text_input = gr.Textbox(label="Text to Generate") | |
with gr.Row(): | |
with gr.Column(scale=0): | |
preset_dropdown = gr.Dropdown(choices=list( | |
presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset) | |
nfe_slider = gr.Slider(4, 64, value=32, step=1, label="Number of Steps (NFE)") | |
with gr.Group(): | |
gr.Markdown( | |
"Click Grid to Customize Cache Schedule<br>🟧 = Compute Attn Layer <br> 🟩 = Compute FFN Layer <br> ⬜ = Cached Layer", container=True) | |
image = gr.Image(type="numpy", show_label=False, show_fullscreen_button=False, sources=[], interactive=True, scale=1) | |
generate_btn = gr.Button("Synthesize", variant="primary") | |
with gr.Row(): | |
with gr.Group(): | |
audio_output = gr.Audio(label="Synthesized Audio (No Cache)") | |
process_time = gr.Textbox( | |
label="⏱ Process Time", interactive=False) | |
with gr.Group(): | |
audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)") | |
process_time_cache = gr.Textbox( | |
label="⏱ Process Time", interactive=False) | |
# Wire up logic | |
preset_dropdown.change( | |
fn=apply_preset, inputs=[preset_dropdown, cache_schedule_state] , outputs=[image, nfe_slider, cache_schedule_state]) | |
image.select(fn=toggle_cell, inputs=[cache_schedule_state], outputs=[image, preset_dropdown, cache_schedule_state]) | |
nfe_slider.release(fn=update_nfe, inputs=nfe_slider, | |
outputs=[image, preset_dropdown, cache_schedule_state]) | |
generate_btn.click( | |
infer, | |
inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider, cache_schedule_state, recent_input_state], | |
outputs=[audio_output, audio_output_cache, ref_text_input, | |
process_time, process_time_cache, recent_input_state], | |
) | |
demo.load(fn=load_default, outputs=[image, preset_dropdown]) | |
demo.launch() | |