Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import uuid | |
import argparse | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from mutagen.mp3 import MP3 | |
from pydub import AudioSegment | |
from PIL import Image | |
import ffmpeg | |
# Set working directory | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
from scripts.inference import inference_process | |
# Constants | |
AUDIO_MAX_DURATION = 4000 | |
is_shared_ui = "fffiloni/tts-hallo-talking-portrait" in os.environ.get('SPACE_ID', '') | |
hallo_dir = snapshot_download(repo_id="fudan-generative-ai/hallo", local_dir="pretrained_models") | |
# Utility Functions | |
def is_mp3(file_path): | |
try: | |
MP3(file_path) | |
return True | |
except Exception: | |
return False | |
def convert_mp3_to_wav(mp3_file_path, wav_file_path): | |
audio = AudioSegment.from_mp3(mp3_file_path) | |
audio.export(wav_file_path, format="wav") | |
return wav_file_path | |
def trim_audio(file_path, output_path, max_duration): | |
audio = AudioSegment.from_wav(file_path) | |
if len(audio) > max_duration: | |
audio = audio[:max_duration] | |
audio.export(output_path, format="wav") | |
return output_path | |
def add_silence_to_wav(wav_file_path, duration_s=1): | |
audio = AudioSegment.from_wav(wav_file_path) | |
silence = AudioSegment.silent(duration=duration_s * 1000) | |
(audio + silence).export(wav_file_path, format="wav") | |
return wav_file_path | |
def check_mp3(file_path): | |
if is_mp3(file_path): | |
unique_id = uuid.uuid4() | |
wav_file_path = f"{os.path.splitext(file_path)[0]}-{unique_id}.wav" | |
converted_audio = convert_mp3_to_wav(file_path, wav_file_path) | |
print(f"File converted to {wav_file_path}") | |
return converted_audio, gr.update(value=converted_audio, visible=True) | |
else: | |
print("The file is not an MP3 file.") | |
return file_path, gr.update(value=file_path, visible=True) | |
def check_and_convert_webp_to_png(input_path, output_path): | |
try: | |
with Image.open(input_path) as img: | |
if img.format == 'WEBP': | |
img.save(output_path, 'PNG') | |
print(f"Converted {input_path} to {output_path}") | |
return output_path | |
else: | |
print(f"The file {input_path} is not in WebP format.") | |
return input_path | |
except IOError: | |
print(f"Cannot open {input_path}. The file might not exist or is not an image.") | |
def convert_user_uploaded_webp(input_path): | |
unique_id = uuid.uuid4() | |
output_file = f"converted_to_png_portrait-{unique_id}.png" | |
ready_png = check_and_convert_webp_to_png(input_path, output_file) | |
print(f"PORTRAIT PNG FILE: {ready_png}") | |
return ready_png | |
def clear_audio_elms(): | |
return gr.update(value=None, visible=False) | |
def change_video_codec(input_file, output_file, codec='libx264', audio_codec='aac'): | |
try: | |
ffmpeg.input(input_file).output(output_file, vcodec=codec, acodec=audio_codec).run(overwrite_output=True) | |
print(f'Successfully changed codec of {input_file} and saved as {output_file}') | |
except ffmpeg.Error as e: | |
print(f'Error occurred: {e.stderr.decode()}') | |
# Gradio APIs | |
def generate_portrait(prompt_image): | |
if not prompt_image: | |
raise gr.Error("Can't generate a portrait without a prompt!") | |
try: | |
client = Client("ByteDance/SDXL-Lightning") | |
except Exception: | |
raise gr.Error("ByteDance/SDXL-Lightning space's API might not be ready, please wait, or upload an image instead.") | |
result = client.predict(prompt=prompt_image, ckpt="4-Step", api_name="/generate_image") | |
return convert_user_uploaded_webp(result) | |
def generate_voice_with_parler(prompt_audio, voice_description): | |
if not prompt_audio: | |
raise gr.Error("Can't generate a voice without text to synthesize!") | |
if not voice_description: | |
gr.Info("For better control, you may want to provide a voice character description next time.", duration=10, visible=True) | |
try: | |
client = Client("parler-tts/parler_tts_mini") | |
except Exception: | |
raise gr.Error("parler-tts/parler_tts_mini space's API might not be ready, please wait, or upload an audio instead.") | |
result = client.predict(text=prompt_audio, description=voice_description, api_name="/gen_tts") | |
return result, gr.update(value=result, visible=True) | |
def get_whisperspeech(prompt_audio_whisperspeech, audio_to_clone): | |
try: | |
client = Client("collabora/WhisperSpeech") | |
except Exception: | |
raise gr.Error("collabora/WhisperSpeech space's API might not be ready, please wait, or upload an audio instead.") | |
result = client.predict(multilingual_text=prompt_audio_whisperspeech, speaker_audio=handle_file(audio_to_clone), speaker_url="", cps=14, api_name="/whisper_speech_demo") | |
return result, gr.update(value=result, visible=True) | |
def get_maskGCT_TTS(prompt_audio_maskGCT, audio_to_clone): | |
try: | |
client = Client("amphion/maskgct") | |
except Exception: | |
raise gr.Error("amphion/maskgct space's API might not be ready, please wait, or upload an audio instead.") | |
result = client.predict(prompt_wav=handle_file(audio_to_clone), target_text=prompt_audio_maskGCT, target_len=-1, n_timesteps=25, api_name="/predict") | |
return result, gr.update(value=result, visible=True) | |
# Talking Portrait Generation | |
def run_hallo(source_image, driving_audio, progress=gr.Progress(track_tqdm=True)): | |
unique_id = uuid.uuid4() | |
args = argparse.Namespace( | |
config='configs/inference/default.yaml', | |
source_image=source_image, | |
driving_audio=driving_audio, | |
output=f'output-{unique_id}.mp4', | |
pose_weight=1.0, | |
face_weight=1.0, | |
lip_weight=1.0, | |
face_expand_ratio=1.2, | |
checkpoint=None | |
) | |
inference_process(args) | |
return f'output-{unique_id}.mp4' | |
def generate_talking_portrait(portrait, voice, progress=gr.Progress(track_tqdm=True)): | |
if not portrait: | |
raise gr.Error("Please provide a portrait to animate.") | |
if not voice: | |
raise gr.Error("Please provide audio (4 seconds max).") | |
if is_shared_ui: | |
unique_id = uuid.uuid4() | |
trimmed_output_file = f"-{unique_id}.wav" | |
voice = trim_audio(voice, trimmed_output_file, AUDIO_MAX_DURATION) | |
ready_audio = add_silence_to_wav(voice) | |
print(f"1 second of silence added to {voice}") | |
talking_portrait_vid = run_hallo(portrait, ready_audio) | |
final_output_file = f"converted_{talking_portrait_vid}" | |
change_video_codec(talking_portrait_vid, final_output_file) | |
return final_output_file | |
# Gradio Interface | |
css = ''' | |
/* Your CSS here */ | |
''' | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# TTS x Hallo Talking Portrait Generator") | |
with gr.Row(elem_id="column-names"): | |
gr.Markdown("## 1. Load Portrait") | |
gr.Markdown("## 2. Load Voice") | |
gr.Markdown("## 3. Result") | |
with gr.Group(elem_id="main-group"): | |
with gr.Row(): | |
with gr.Column(): | |
portrait = gr.Image(sources=["upload"], type="filepath", format="png", elem_id="image-block") | |
prompt_image = gr.Textbox(label="Generate image", lines=2, max_lines=2) | |
gen_image_btn = gr.Button("Generate portrait (optional)") | |
with gr.Column(elem_id="audio-column"): | |
voice = gr.Audio(type="filepath", elem_id="audio-block") | |
preprocess_audio_file = gr.File(visible=False) | |
with gr.Tab("Parler TTS", elem_id="parler-tab"): | |
prompt_audio = gr.Textbox(label="Text to synthesize", lines=3, max_lines=3, elem_id="text-synth") | |
voice_description = gr.Textbox(label="Voice description", lines=3, max_lines=3, elem_id="voice-desc") | |
gen_voice_btn = gr.Button("Generate voice (optional)") | |
with gr.Tab("WhisperSpeech", elem_id="whisperspeech-tab"): | |
prompt_audio_whisperspeech = gr.Textbox(label="Text to synthesize", lines=2, max_lines=2, elem_id="text-synth-wsp") | |
audio_to_clone = gr.Audio(label="Voice to clone", type="filepath", elem_id="audio-clone-elm") | |
gen_wsp_voice_btn = gr.Button("Generate voice clone (optional)") | |
with gr.Tab("MaskGCT TTS", elem_id="maskGCT-tab"): | |
prompt_audio_maskGCT = gr.Textbox(label="Text to synthesize", lines=2, max_lines=2, elem_id="text-synth-maskGCT") | |
audio_to_clone_maskGCT = gr.Audio(label="Voice to clone", type="filepath", elem_id="audio-clone-elm-maskGCT") | |
gen_maskGCT_voice_btn = gr.Button("Generate voice clone (optional)") | |
with gr.Column(elem_id="result-column"): | |
result = gr.Video(elem_id="video-block") | |
submit_btn = gr.Button("Go talking Portrait !", elem_id="main-submit") | |
with gr.Row(elem_id="pro-tips"): | |
gr.Markdown("# Hallo Pro Tips:") | |
gr.Markdown("# TTS Pro Tips:") | |
portrait.upload(convert_user_uploaded_webp, inputs=[portrait], outputs=[portrait], queue=False, show_api=False) | |
voice.upload(check_mp3, inputs=[voice], outputs=[voice, preprocess_audio_file], queue=False, show_api=False) | |
voice.clear(clear_audio_elms, inputs=None, outputs=[preprocess_audio_file], queue=False, show_api=False) | |
gen_image_btn.click(generate_portrait, inputs=[prompt_image], outputs=[portrait], queue=False, show_api=False) | |
gen_voice_btn.click(generate_voice_with_parler, inputs=[prompt_audio, voice_description], outputs=[voice, preprocess_audio_file], queue=False, show_api=False) | |
gen_wsp_voice_btn.click(get_whisperspeech, inputs=[prompt_audio_whisperspeech, audio_to_clone], outputs=[voice, preprocess_audio_file], queue=False, show_api=False) | |
gen_maskGCT_voice_btn.click(get_maskGCT_TTS, inputs=[prompt_audio_maskGCT, audio_to_clone_maskGCT], outputs=[voice, preprocess_audio_file], queue=False, show_api=False) | |
submit_btn.click(generate_talking_portrait, inputs=[portrait, voice], outputs=[result], show_api=False) | |
demo.queue(max_size=2).launch(show_error=True, show_api=False) |