xiaoqinfeng's picture
Update app.py
f76461a verified
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import soundfile as sf
import logging
import argparse
import gradio as gr
from datetime import datetime
from cli.SparkTTS import SparkTTS
from sparktts.utils.token_parser import LEVELS_MAP_UI
from huggingface_hub import snapshot_download
import spaces
MODEL = None
def initialize_model(model_dir=None, device="cpu"):
"""Load the model once at the beginning."""
if model_dir is None:
logging.info(f"Downloading model to: {model_dir}")
model_dir = snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
logging.info(f"Loading model from: {model_dir}")
device = torch.device(device)
model = SparkTTS(model_dir, device)
return model
@spaces.GPU
def generate(text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,
):
"""Generate audio from text."""
global MODEL
# Initialize model if not already done
if MODEL is None:
MODEL = initialize_model(device="cuda" if torch.cuda.is_available() else "cpu")
model = MODEL
# if gpu available, move model to gpu
if torch.cuda.is_available():
print("Moving model to GPU")
model.to("cuda")
with torch.no_grad():
wav = model.inference(
text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,
)
return wav
def run_tts(
text,
prompt_text=None,
prompt_speech=None,
gender=None,
pitch=None,
speed=None,
save_dir="example/results",
):
"""Perform TTS inference and save the generated audio."""
logging.info(f"Saving audio to: {save_dir}")
if prompt_text is not None:
prompt_text = None if len(prompt_text) <= 1 else prompt_text
# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)
# Generate unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
wav = generate(text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
return save_path
def build_ui(model_dir, device=0):
global MODEL
# Initialize model with proper device handling
device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu"
if MODEL is None:
MODEL = initialize_model(model_dir, device=device)
if device == "cuda":
MODEL = MODEL.to(device)
# Define callback function for voice cloning
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
"""
Gradio callback to clone voice using text and optional prompt speech.
- text: The input text to be synthesised.
- prompt_text: Additional textual info for the prompt (optional).
- prompt_wav_upload/prompt_wav_record: Audio files used as reference.
"""
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
audio_output_path = run_tts(
text,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
)
return audio_output_path
# Define callback function for creating new voices
def voice_creation(text, gender, pitch, speed):
"""
Gradio callback to create a synthetic voice with adjustable parameters.
- text: The input text for synthesis.
- gender: 'male' or 'female'.
- pitch/speed: Ranges mapped by LEVELS_MAP_UI.
"""
pitch_val = LEVELS_MAP_UI[int(pitch)]
speed_val = LEVELS_MAP_UI[int(speed)]
audio_output_path = run_tts(
text,
gender=gender,
pitch=pitch_val,
speed=speed_val
)
return audio_output_path
with gr.Blocks() as demo:
# Use HTML for centered title
gr.HTML('<h1 style="text-align: center;">(Official) Spark-TTS by SparkAudio</h1>')
with gr.Tabs():
# Voice Clone Tab  
with gr.TabItem("Voice Clone"):
gr.Markdown(
"### Upload reference audio or recording (上传参考音频或者录音)"
)
with gr.Row():
prompt_wav_upload = gr.Audio(
sources="upload",
type="filepath",
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
)
prompt_wav_record = gr.Audio(
sources="microphone",
type="filepath",
label="Record the prompt audio file.",
)
with gr.Row():
text_input = gr.Textbox(
label="Text", lines=3, placeholder="Enter text here"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
lines=3,
placeholder="Enter text of the prompt speech.",
)
audio_output = gr.Audio(
label="Generated Audio", autoplay=True, streaming=True
)
generate_buttom_clone = gr.Button("Generate")
generate_buttom_clone.click(
voice_clone,
inputs=[
text_input,
prompt_text_input,
prompt_wav_upload,
prompt_wav_record,
],
outputs=[audio_output],
)
# Voice Creation Tab
with gr.TabItem("Voice Creation"):
gr.Markdown(
"### Create your own voice based on the following parameters"
)
with gr.Row():
with gr.Column():
gender = gr.Radio(
choices=["male", "female"], value="male", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed"
)
with gr.Column():
text_input_creation = gr.Textbox(
label="Input Text",
lines=3,
placeholder="Enter text here",
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
)
create_button = gr.Button("Create Voice")
audio_output = gr.Audio(
label="Generated Audio", autoplay=True, streaming=True
)
create_button.click(
voice_creation,
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output],
)
return demo
def parse_arguments():
"""
Parse command-line arguments such as model directory and device ID.
"""
parser = argparse.ArgumentParser(description="Spark TTS Gradio server.")
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="Path to the model directory."
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use (e.g., 'cpu' or 'cuda:0')."
)
parser.add_argument(
"--server_name",
type=str,
default=None,
help="Server host/IP for Gradio app."
)
parser.add_argument(
"--server_port",
type=int,
default=None,
help="Server port for Gradio app."
)
return parser.parse_args()
if __name__ == "__main__":
# Parse command-line arguments
args = parse_arguments()
# Build the Gradio demo by specifying the model directory and GPU device
demo = build_ui(
model_dir=args.model_dir,
device=args.device
)
# Launch Gradio with the specified server name and port
demo.launch(
server_name=args.server_name,
server_port=args.server_port
)