# Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import soundfile as sf import logging import argparse import gradio as gr from datetime import datetime from cli.SparkTTS import SparkTTS from sparktts.utils.token_parser import LEVELS_MAP_UI from huggingface_hub import snapshot_download import spaces MODEL = None def initialize_model(model_dir=None, device="cpu"): """Load the model once at the beginning.""" if model_dir is None: logging.info(f"Downloading model to: {model_dir}") model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") logging.info(f"Loading model from: {model_dir}") device = torch.device(device) model = SparkTTS(model_dir, device) return model @spaces.GPU def generate(text, prompt_speech, prompt_text, gender, pitch, speed, ): """Generate audio from text.""" global MODEL # Initialize model if not already done if MODEL is None: MODEL = initialize_model(device="cuda" if torch.cuda.is_available() else "cpu") model = MODEL # if gpu available, move model to gpu if torch.cuda.is_available(): print("Moving model to GPU") model.to("cuda") with torch.no_grad(): wav = model.inference( text, prompt_speech, prompt_text, gender, pitch, speed, ) return wav def run_tts( text, prompt_text=None, prompt_speech=None, gender=None, pitch=None, speed=None, save_dir="example/results", ): """Perform TTS inference and save the generated audio.""" logging.info(f"Saving audio to: {save_dir}") if prompt_text is not None: prompt_text = None if len(prompt_text) <= 1 else prompt_text # Ensure the save directory exists os.makedirs(save_dir, exist_ok=True) # Generate unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d%H%M%S") save_path = os.path.join(save_dir, f"{timestamp}.wav") logging.info("Starting inference...") # Perform inference and save the output audio wav = generate(text, prompt_speech, prompt_text, gender, pitch, speed,) sf.write(save_path, wav, samplerate=16000) logging.info(f"Audio saved at: {save_path}") return save_path def build_ui(model_dir, device=0): global MODEL # Initialize model with proper device handling device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu" if MODEL is None: MODEL = initialize_model(model_dir, device=device) if device == "cuda": MODEL = MODEL.to(device) # Define callback function for voice cloning def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): """ Gradio callback to clone voice using text and optional prompt speech. - text: The input text to be synthesised. - prompt_text: Additional textual info for the prompt (optional). - prompt_wav_upload/prompt_wav_record: Audio files used as reference. """ prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record prompt_text_clean = None if len(prompt_text) < 2 else prompt_text audio_output_path = run_tts( text, prompt_text=prompt_text_clean, prompt_speech=prompt_speech ) return audio_output_path # Define callback function for creating new voices def voice_creation(text, gender, pitch, speed): """ Gradio callback to create a synthetic voice with adjustable parameters. - text: The input text for synthesis. - gender: 'male' or 'female'. - pitch/speed: Ranges mapped by LEVELS_MAP_UI. """ pitch_val = LEVELS_MAP_UI[int(pitch)] speed_val = LEVELS_MAP_UI[int(speed)] audio_output_path = run_tts( text, gender=gender, pitch=pitch_val, speed=speed_val ) return audio_output_path with gr.Blocks() as demo: # Use HTML for centered title gr.HTML('

(Official) Spark-TTS by SparkAudio

') with gr.Tabs(): # Voice Clone Tab   with gr.TabItem("Voice Clone"): gr.Markdown( "### Upload reference audio or recording (上传参考音频或者录音)" ) with gr.Row(): prompt_wav_upload = gr.Audio( sources="upload", type="filepath", label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", ) prompt_wav_record = gr.Audio( sources="microphone", type="filepath", label="Record the prompt audio file.", ) with gr.Row(): text_input = gr.Textbox( label="Text", lines=3, placeholder="Enter text here" ) prompt_text_input = gr.Textbox( label="Text of prompt speech (Optional; recommended for cloning in the same language.)", lines=3, placeholder="Enter text of the prompt speech.", ) audio_output = gr.Audio( label="Generated Audio", autoplay=True, streaming=True ) generate_buttom_clone = gr.Button("Generate") generate_buttom_clone.click( voice_clone, inputs=[ text_input, prompt_text_input, prompt_wav_upload, prompt_wav_record, ], outputs=[audio_output], ) # Voice Creation Tab with gr.TabItem("Voice Creation"): gr.Markdown( "### Create your own voice based on the following parameters" ) with gr.Row(): with gr.Column(): gender = gr.Radio( choices=["male", "female"], value="male", label="Gender" ) pitch = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Pitch" ) speed = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Speed" ) with gr.Column(): text_input_creation = gr.Textbox( label="Input Text", lines=3, placeholder="Enter text here", value="You can generate a customized voice by adjusting parameters such as pitch and speed.", ) create_button = gr.Button("Create Voice") audio_output = gr.Audio( label="Generated Audio", autoplay=True, streaming=True ) create_button.click( voice_creation, inputs=[text_input_creation, gender, pitch, speed], outputs=[audio_output], ) return demo def parse_arguments(): """ Parse command-line arguments such as model directory and device ID. """ parser = argparse.ArgumentParser(description="Spark TTS Gradio server.") parser.add_argument( "--model_dir", type=str, default=None, help="Path to the model directory." ) parser.add_argument( "--device", type=str, default="cpu", help="Device to use (e.g., 'cpu' or 'cuda:0')." ) parser.add_argument( "--server_name", type=str, default=None, help="Server host/IP for Gradio app." ) parser.add_argument( "--server_port", type=int, default=None, help="Server port for Gradio app." ) return parser.parse_args() if __name__ == "__main__": # Parse command-line arguments args = parse_arguments() # Build the Gradio demo by specifying the model directory and GPU device demo = build_ui( model_dir=args.model_dir, device=args.device ) # Launch Gradio with the specified server name and port demo.launch( server_name=args.server_name, server_port=args.server_port )