Spaces:
Runtime error
Runtime error
# Copyright Alibaba Inc. All Rights Reserved. | |
import argparse | |
import os | |
import subprocess | |
from datetime import datetime | |
from pathlib import Path | |
import cv2 | |
import librosa | |
import torch | |
from PIL import Image | |
from transformers import Wav2Vec2Model, Wav2Vec2Processor | |
from FantasyTalking.Diffsynth import ModelManager, WanVideoPipeline | |
from FantasyTalking.model import FantasyTalkingAudioConditionModel | |
from FantasyTalking.utils import get_audio_features, resize_image_by_longest_edge, save_video | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--wan_model_dir", | |
type=str, | |
default="./models/Wan2.1-I2V-14B-720P", | |
required=False, | |
help="The dir of the Wan I2V 14B model.", | |
) | |
parser.add_argument( | |
"--fantasytalking_model_path", | |
type=str, | |
default="./models/fantasytalking_model.ckpt", | |
required=False, | |
help="The .ckpt path of fantasytalking model.", | |
) | |
parser.add_argument( | |
"--wav2vec_model_dir", | |
type=str, | |
default="./models/wav2vec2-base-960h", | |
required=False, | |
help="The dir of wav2vec model.", | |
) | |
parser.add_argument( | |
"--image_path", | |
type=str, | |
default="./assets/images/woman.png", | |
required=False, | |
help="The path of the image.", | |
) | |
parser.add_argument( | |
"--audio_path", | |
type=str, | |
default="./assets/audios/woman.wav", | |
required=False, | |
help="The path of the audio.", | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="A woman is talking.", | |
required=False, | |
help="prompt.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./output", | |
help="Dir to save the model.", | |
) | |
parser.add_argument( | |
"--image_size", | |
type=int, | |
default=512, | |
help="The image will be resized proportionally to this size.", | |
) | |
parser.add_argument( | |
"--audio_scale", | |
type=float, | |
default=1.0, | |
help="Audio condition injection weight", | |
) | |
parser.add_argument( | |
"--prompt_cfg_scale", | |
type=float, | |
default=5.0, | |
required=False, | |
help="Prompt cfg scale", | |
) | |
parser.add_argument( | |
"--audio_cfg_scale", | |
type=float, | |
default=5.0, | |
required=False, | |
help="Audio cfg scale", | |
) | |
parser.add_argument( | |
"--max_num_frames", | |
type=int, | |
default=81, | |
required=False, | |
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.", | |
) | |
parser.add_argument( | |
"--fps", | |
type=int, | |
default=23, | |
required=False, | |
) | |
parser.add_argument( | |
"--num_persistent_param_in_dit", | |
type=int, | |
default=None, | |
required=False, | |
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=1111, | |
required=False, | |
) | |
args = parser.parse_args() | |
return args | |
def load_models(args): | |
print("🔄 Loading Wan I2V models...") | |
model_manager = ModelManager(device="cpu") | |
model_manager.load_models( | |
[ | |
[ | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors", | |
f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors", | |
], | |
f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth", | |
f"{args.wan_model_dir}/Wan2.1_VAE.pth", | |
], | |
torch_dtype=torch.bfloat16, | |
) | |
print("✅ Wan I2V models loaded.") | |
pipe = WanVideoPipeline.from_model_manager( | |
model_manager, torch_dtype=torch.bfloat16, device="cuda" | |
) | |
print("🔄 Loading FantasyTalking model...") | |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda") | |
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit) | |
print("✅ FantasyTalking model loaded.") | |
print("🧠 Enabling VRAM management...") | |
pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit) | |
print("🔄 Loading Wav2Vec2 processor and model...") | |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir) | |
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda") | |
print("✅ Wav2Vec2 loaded.") | |
return pipe, fantasytalking, wav2vec_processor, wav2vec | |
def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec): | |
print("📁 Creating output directory...") | |
os.makedirs(args.output_dir, exist_ok=True) | |
print(f"🔊 Getting duration of audio: {args.audio_path}") | |
duration = librosa.get_duration(filename=args.audio_path) | |
print(f"🎞️ Duration: {duration:.2f}s") | |
num_frames = min(int(args.fps * duration), args.max_num_frames) | |
print(f"📽️ Calculated number of frames: {num_frames}") | |
print("🎧 Extracting audio features...") | |
audio_wav2vec_fea = get_audio_features( | |
wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames | |
) | |
print("✅ Audio features extracted.") | |
print("🖼️ Loading and resizing image...") | |
image = resize_image_by_longest_edge(args.image_path, args.image_size) | |
width, height = image.size | |
print(f"✅ Image resized to: {width}x{height}") | |
print("🔄 Projecting audio features...") | |
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea) | |
pos_idx_ranges = fantasytalking.split_audio_sequence( | |
audio_proj_fea.size(1), num_frames=num_frames | |
) | |
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( | |
audio_proj_fea, pos_idx_ranges, expand_length=4 | |
) | |
print("✅ Audio features projected and split.") | |
print("🚀 Generating video from image + audio...") | |
video_audio = pipe( | |
prompt=args.prompt, | |
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | |
input_image=image, | |
width=width, | |
height=height, | |
num_frames=num_frames, | |
num_inference_steps=30, | |
seed=args.seed, | |
tiled=True, | |
audio_scale=args.audio_scale, | |
cfg_scale=args.prompt_cfg_scale, | |
audio_cfg_scale=args.audio_cfg_scale, | |
audio_proj=audio_proj_split, | |
audio_context_lens=audio_context_lens, | |
latents_num_frames=(num_frames - 1) // 4 + 1, | |
) | |
print("✅ Video frames generated.") | |
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" | |
print(f"💾 Saving temporary video without audio to: {save_path_tmp}") | |
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5) | |
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" | |
print(f"🔊 Merging video with audio using FFmpeg...") | |
final_command = [ | |
"ffmpeg", "-y", "-i", save_path_tmp, "-i", args.audio_path, | |
"-c:v", "libx264", "-c:a", "aac", "-shortest", save_path, | |
] | |
subprocess.run(final_command, check=True) | |
print(f"✅ Final video saved to: {save_path}") | |
print("🧹 Removing temporary video file...") | |
os.remove(save_path_tmp) | |
return save_path | |
if __name__ == "__main__": | |
print("🚦 Starting main script...") | |
args = parse_args() | |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args) | |
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec) | |
print(f"🎉 Done! Final video path: {video_path}") | |