import torch
import time
import sys
sys.path.append('F5-TTS/src')
sys.path.append('SmoothCache/SmoothCache')
import os
from importlib.resources import files
from PIL import Image, ImageDraw
from functools import lru_cache
import gradio as gr
from smooth_cache_helper import SmoothCacheHelper
from f5_tts.infer.utils_infer import (
cross_fade_duration,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
speed
)
import numpy as np
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
# Constants
layer_names = ['ff', 'attn']
colors_rgb = [(0, 210, 106), (255, 103, 35)] # green, orange
cell_size = 20
spacing = 2
n_layers = 2
# Presets
presets = {
"32 NFE, α=0.15": {
'attn': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1],
'ff': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1],
},
"32 NFE, α=0.25": {
'attn': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
'ff': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
},
"16 NFE, α=0.3": {
'attn': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1],
'ff': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1],
},
"16 NFE, α=0.5": {
'attn': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1],
'ff': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1],
}
}
default_preset = "32 NFE, α=0.15"
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
config = tomli.load(open(os.path.join(files("f5_tts").joinpath(
"infer/examples/basic"), "basic.toml"), "rb"))
model = config.get("model", "F5TTS_v1_Base")
ckpt_file = config.get("ckpt_file", "")
vocab_file = config.get("vocab_file", "")
model_cfg = OmegaConf.load(
config.get("model_cfg", str(
files("f5_tts").joinpath(f"configs/{model}.yaml")))
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if not ckpt_file:
ckpt_file = str(cached_path(
f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
if not vocab_file:
vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt"))
ema_model = load_model(
model_cls, model_arc, ckpt_file, vocab_file=vocab_file
)
vocoder = load_vocoder()
@gpu_decorator
def render_grid(schedule: dict) -> np.ndarray:
n_steps = len(schedule['attn'])
img = Image.new("RGB", (n_steps * (cell_size + spacing),
n_layers * (cell_size + spacing)), "white")
draw = ImageDraw.Draw(img)
for row in range(n_layers):
layer = layer_names[row]
for col in range(n_steps):
x0 = col * (cell_size + spacing)
y0 = row * (cell_size + spacing)
x1 = x0 + cell_size
y1 = y0 + cell_size
color = colors_rgb[row] if schedule[layer][col] == 1 else "white"
draw.rectangle([x0, y0, x1, y1], fill=color, outline="black")
return np.array(img)
@gpu_decorator
def apply_preset(preset_name, cache_schedule):
if preset_name in presets:
schedule = presets[preset_name]
cache_schedule['attn'] = schedule['attn'][:]
cache_schedule['ff'] = schedule['ff'][:]
return render_grid(cache_schedule), len(cache_schedule['attn']), cache_schedule
@gpu_decorator
def toggle_cell(evt: gr.SelectData, cache_schedule):
col = evt.index[0] // (cell_size + spacing)
row = evt.index[1] // (cell_size + spacing)
layer = layer_names[row]
if col < len(cache_schedule[layer]):
cache_schedule[layer][col] ^= 1
return render_grid(cache_schedule), "Custom", cache_schedule
@gpu_decorator
def reset_schedule(n_steps):
cache_schedule = {
'attn': [1] * n_steps,
'ff': [1] * n_steps
}
return render_grid(cache_schedule), "Custom", cache_schedule
@gpu_decorator
def update_nfe(nfe_value):
return reset_schedule(nfe_value)
@gpu_decorator
def load_default():
cache_schedule = {
'attn': presets[default_preset]['attn'][:],
'ff': presets[default_preset]['ff'][:]
}
return render_grid(cache_schedule), default_preset
@gpu_decorator
def infer(
ref_audio_orig,
ref_text,
gen_text,
nfe_step,
cache_schedule,
recent_input
):
show_info = gr.Info
if not ref_audio_orig:
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
if not gen_text.strip():
gr.Warning("Please enter text to generate.")
return gr.update(), gr.update(), ref_text, gr.update(), gr.update()
ref_audio, ref_text = preprocess_ref_audio_text(
ref_audio_orig, ref_text, show_info=show_info)
skip_no_cache = False
if recent_input["ref_audio"] == ref_audio_orig and recent_input["ref_text"] == ref_text and recent_input["gen_text"] == gen_text and recent_input["nfe_step"] == nfe_step:
skip_no_cache = True
if not skip_no_cache:
start_time = time.time()
final_wave, final_sample_rate, _ = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
process_time = time.time() - start_time
cache_helper = SmoothCacheHelper(
model=ema_model.transformer,
block_classes=get_class("f5_tts.model.modules.DiTBlock"),
components_to_wrap=['attn', 'ff'],
schedule=cache_schedule
)
cache_helper.enable()
start_time = time.time()
final_wave_cache, final_sample_rate_cache, _ = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
process_time_cache = time.time() - start_time
cache_helper.disable()
recent_input["ref_audio"] = ref_audio_orig
recent_input["ref_text"] = ref_text
recent_input["gen_text"] = gen_text
recent_input["nfe_step"] = nfe_step
if skip_no_cache:
print("skip")
return gr.update(), (final_sample_rate_cache, final_wave_cache), ref_text, gr.update(), process_time_cache, recent_input
return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), ref_text, process_time, process_time_cache, recent_input
with gr.Blocks() as demo:
gr.Markdown("## F5-TTS + SmoothCache")
cache_schedule_state = gr.State({
'attn': presets[default_preset]['attn'][:],
'ff': presets[default_preset]['ff'][:]
})
recent_input_state = gr.State({
"ref_audio": None,
"ref_text": None,
"gen_text": None,
"nfe_step": None
})
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
ref_text_input = gr.Textbox(label="Reference Text (Optional)")
gen_text_input = gr.Textbox(label="Text to Generate")
with gr.Row():
with gr.Column(scale=0):
preset_dropdown = gr.Dropdown(choices=list(
presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset)
nfe_slider = gr.Slider(4, 64, value=32, step=1, label="Number of Steps (NFE)")
with gr.Group():
gr.Markdown(
"Click Grid to Customize Cache Schedule
🟧 = Compute Attn Layer
🟩 = Compute FFN Layer
⬜ = Cached Layer", container=True)
image = gr.Image(type="numpy", show_label=False, show_fullscreen_button=False, sources=[], interactive=True, scale=1)
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Row():
with gr.Group():
audio_output = gr.Audio(label="Synthesized Audio (No Cache)")
process_time = gr.Textbox(
label="⏱ Process Time", interactive=False)
with gr.Group():
audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)")
process_time_cache = gr.Textbox(
label="⏱ Process Time", interactive=False)
# Wire up logic
preset_dropdown.change(
fn=apply_preset, inputs=[preset_dropdown, cache_schedule_state] , outputs=[image, nfe_slider, cache_schedule_state])
image.select(fn=toggle_cell, inputs=[cache_schedule_state], outputs=[image, preset_dropdown, cache_schedule_state])
nfe_slider.release(fn=update_nfe, inputs=nfe_slider,
outputs=[image, preset_dropdown, cache_schedule_state])
generate_btn.click(
infer,
inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider, cache_schedule_state, recent_input_state],
outputs=[audio_output, audio_output_cache, ref_text_input,
process_time, process_time_cache, recent_input_state],
)
demo.load(fn=load_default, outputs=[image, preset_dropdown])
demo.launch()