import torch import time import sys sys.path.append('F5-TTS/src') sys.path.append('SmoothCache/SmoothCache') import os from importlib.resources import files from PIL import Image, ImageDraw from functools import lru_cache import gradio as gr from smooth_cache_helper import SmoothCacheHelper from f5_tts.infer.utils_infer import ( cross_fade_duration, infer_process, load_model, load_vocoder, preprocess_ref_audio_text, speed ) import numpy as np import tomli from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func # Constants layer_names = ['ff', 'attn'] colors_rgb = [(0, 210, 106), (255, 103, 35)] # green, orange cell_size = 20 spacing = 2 n_layers = 2 # Presets presets = { "32 NFE, α=0.15": { 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1], }, "32 NFE, α=0.25": { 'attn': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], 'ff': [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], }, "16 NFE, α=0.3": { 'attn': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], 'ff': [1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1], }, "16 NFE, α=0.5": { 'attn': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], 'ff': [1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1], } } default_preset = "32 NFE, α=0.15" seed = np.random.randint(0, 2**31 - 1) torch.manual_seed(seed) config = tomli.load(open(os.path.join(files("f5_tts").joinpath( "infer/examples/basic"), "basic.toml"), "rb")) model = config.get("model", "F5TTS_v1_Base") ckpt_file = config.get("ckpt_file", "") vocab_file = config.get("vocab_file", "") model_cfg = OmegaConf.load( config.get("model_cfg", str( files("f5_tts").joinpath(f"configs/{model}.yaml"))) ) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" if not ckpt_file: ckpt_file = str(cached_path( f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) if not vocab_file: vocab_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/vocab.txt")) ema_model = load_model( model_cls, model_arc, ckpt_file, vocab_file=vocab_file ) vocoder = load_vocoder() @gpu_decorator def render_grid(schedule: dict) -> np.ndarray: n_steps = len(schedule['attn']) img = Image.new("RGB", (n_steps * (cell_size + spacing), n_layers * (cell_size + spacing)), "white") draw = ImageDraw.Draw(img) for row in range(n_layers): layer = layer_names[row] for col in range(n_steps): x0 = col * (cell_size + spacing) y0 = row * (cell_size + spacing) x1 = x0 + cell_size y1 = y0 + cell_size color = colors_rgb[row] if schedule[layer][col] == 1 else "white" draw.rectangle([x0, y0, x1, y1], fill=color, outline="black") return np.array(img) @gpu_decorator def apply_preset(preset_name, cache_schedule): if preset_name in presets: schedule = presets[preset_name] cache_schedule['attn'] = schedule['attn'][:] cache_schedule['ff'] = schedule['ff'][:] return render_grid(cache_schedule), len(cache_schedule['attn']), cache_schedule @gpu_decorator def toggle_cell(evt: gr.SelectData, cache_schedule): col = evt.index[0] // (cell_size + spacing) row = evt.index[1] // (cell_size + spacing) layer = layer_names[row] if col < len(cache_schedule[layer]): cache_schedule[layer][col] ^= 1 return render_grid(cache_schedule), "Custom", cache_schedule @gpu_decorator def reset_schedule(n_steps): cache_schedule = { 'attn': [1] * n_steps, 'ff': [1] * n_steps } return render_grid(cache_schedule), "Custom", cache_schedule @gpu_decorator def update_nfe(nfe_value): return reset_schedule(nfe_value) @gpu_decorator def load_default(): cache_schedule = { 'attn': presets[default_preset]['attn'][:], 'ff': presets[default_preset]['ff'][:] } return render_grid(cache_schedule), default_preset @gpu_decorator def infer( ref_audio_orig, ref_text, gen_text, nfe_step, cache_schedule, recent_input ): show_info = gr.Info if not ref_audio_orig: gr.Warning("Please provide reference audio.") return gr.update(), gr.update(), ref_text, gr.update(), gr.update() if not gen_text.strip(): gr.Warning("Please enter text to generate.") return gr.update(), gr.update(), ref_text, gr.update(), gr.update() ref_audio, ref_text = preprocess_ref_audio_text( ref_audio_orig, ref_text, show_info=show_info) skip_no_cache = False if recent_input["ref_audio"] == ref_audio_orig and recent_input["ref_text"] == ref_text and recent_input["gen_text"] == gen_text and recent_input["nfe_step"] == nfe_step: skip_no_cache = True if not skip_no_cache: start_time = time.time() final_wave, final_sample_rate, _ = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) process_time = time.time() - start_time cache_helper = SmoothCacheHelper( model=ema_model.transformer, block_classes=get_class("f5_tts.model.modules.DiTBlock"), components_to_wrap=['attn', 'ff'], schedule=cache_schedule ) cache_helper.enable() start_time = time.time() final_wave_cache, final_sample_rate_cache, _ = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, speed=speed, show_info=show_info, progress=gr.Progress(), ) process_time_cache = time.time() - start_time cache_helper.disable() recent_input["ref_audio"] = ref_audio_orig recent_input["ref_text"] = ref_text recent_input["gen_text"] = gen_text recent_input["nfe_step"] = nfe_step if skip_no_cache: print("skip") return gr.update(), (final_sample_rate_cache, final_wave_cache), ref_text, gr.update(), process_time_cache, recent_input return (final_sample_rate, final_wave), (final_sample_rate_cache, final_wave_cache), ref_text, process_time, process_time_cache, recent_input with gr.Blocks() as demo: gr.Markdown("## F5-TTS + SmoothCache") cache_schedule_state = gr.State({ 'attn': presets[default_preset]['attn'][:], 'ff': presets[default_preset]['ff'][:] }) recent_input_state = gr.State({ "ref_audio": None, "ref_text": None, "gen_text": None, "nfe_step": None }) ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") ref_text_input = gr.Textbox(label="Reference Text (Optional)") gen_text_input = gr.Textbox(label="Text to Generate") with gr.Row(): with gr.Column(scale=0): preset_dropdown = gr.Dropdown(choices=list( presets.keys()) + ["Custom"], label="Choose Preset", value=default_preset) nfe_slider = gr.Slider(4, 64, value=32, step=1, label="Number of Steps (NFE)") with gr.Group(): gr.Markdown( "Click Grid to Customize Cache Schedule
🟧 = Compute Attn Layer
🟩 = Compute FFN Layer
⬜ = Cached Layer", container=True) image = gr.Image(type="numpy", show_label=False, show_fullscreen_button=False, sources=[], interactive=True, scale=1) generate_btn = gr.Button("Synthesize", variant="primary") with gr.Row(): with gr.Group(): audio_output = gr.Audio(label="Synthesized Audio (No Cache)") process_time = gr.Textbox( label="⏱ Process Time", interactive=False) with gr.Group(): audio_output_cache = gr.Audio(label="Synthesized Audio (Cache)") process_time_cache = gr.Textbox( label="⏱ Process Time", interactive=False) # Wire up logic preset_dropdown.change( fn=apply_preset, inputs=[preset_dropdown, cache_schedule_state] , outputs=[image, nfe_slider, cache_schedule_state]) image.select(fn=toggle_cell, inputs=[cache_schedule_state], outputs=[image, preset_dropdown, cache_schedule_state]) nfe_slider.release(fn=update_nfe, inputs=nfe_slider, outputs=[image, preset_dropdown, cache_schedule_state]) generate_btn.click( infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, nfe_slider, cache_schedule_state, recent_input_state], outputs=[audio_output, audio_output_cache, ref_text_input, process_time, process_time_cache, recent_input_state], ) demo.load(fn=load_default, outputs=[image, preset_dropdown]) demo.launch()