fffiloni commited on
Commit
9661bf3
·
verified ·
1 Parent(s): 13bc874

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import shutil
5
+ import uuid
6
+ import subprocess
7
+ from glob import glob
8
+ from huggingface_hub import snapshot_download
9
+
10
+ # Download models
11
+ os.makedirs("checkpoints", exist_ok=True)
12
+
13
+ snapshot_download(
14
+ repo_id = "chunyu-li/LatentSync",
15
+ local_dir = "./checkpoints"
16
+ )
17
+
18
+ import argparse
19
+ from omegaconf import OmegaConf
20
+ import torch
21
+ from diffusers import AutoencoderKL, DDIMScheduler
22
+ from latentsync.models.unet import UNet3DConditionModel
23
+ from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
24
+ from diffusers.utils.import_utils import is_xformers_available
25
+ from accelerate.utils import set_seed
26
+ from latentsync.whisper.audio2feature import Audio2Feature
27
+
28
+
29
+ def main(video_path, audio_path, progress=gr.Progress(track_tqdm=True)):
30
+ inference_ckpt_path = "checkpoints/latentsync_unet.pt"
31
+ unet_config_path = "configs/unet/second_stage.yaml"
32
+ config = OmegaConf.load(unet_config_path)
33
+
34
+ print(f"Input video path: {video_path}")
35
+ print(f"Input audio path: {audio_path}")
36
+ print(f"Loaded checkpoint path: {inference_ckpt_path}")
37
+
38
+ scheduler = DDIMScheduler.from_pretrained("configs")
39
+
40
+ if config.model.cross_attention_dim == 768:
41
+ whisper_model_path = "checkpoints/whisper/small.pt"
42
+ elif config.model.cross_attention_dim == 384:
43
+ whisper_model_path = "checkpoints/whisper/tiny.pt"
44
+ else:
45
+ raise NotImplementedError("cross_attention_dim must be 768 or 384")
46
+
47
+ audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames)
48
+
49
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
50
+ vae.config.scaling_factor = 0.18215
51
+ vae.config.shift_factor = 0
52
+
53
+ unet, _ = UNet3DConditionModel.from_pretrained(
54
+ OmegaConf.to_container(config.model),
55
+ inference_ckpt_path, # load checkpoint
56
+ device="cpu",
57
+ )
58
+
59
+ unet = unet.to(dtype=torch.float16)
60
+
61
+ # set xformers
62
+ if is_xformers_available():
63
+ unet.enable_xformers_memory_efficient_attention()
64
+
65
+ pipeline = LipsyncPipeline(
66
+ vae=vae,
67
+ audio_encoder=audio_encoder,
68
+ unet=unet,
69
+ scheduler=scheduler,
70
+ ).to("cuda")
71
+
72
+ if seed != -1:
73
+ set_seed(seed)
74
+ else:
75
+ torch.seed()
76
+
77
+ print(f"Initial seed: {torch.initial_seed()}")
78
+
79
+ unique_id = str(uuid.uuid4())
80
+ video_out_path = f"video_out{unique_id}.mp4"
81
+
82
+ pipeline(
83
+ video_path=video_path,
84
+ audio_path=audio_path,
85
+ video_out_path=video_out_path,
86
+ video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
87
+ num_frames=config.data.num_frames,
88
+ num_inference_steps=config.run.inference_steps,
89
+ guidance_scale=1.0,
90
+ weight_dtype=torch.float16,
91
+ width=config.data.resolution,
92
+ height=config.data.resolution,
93
+ )
94
+
95
+ return video_out_path
96
+
97
+
98
+ css="""
99
+ div#col-container{
100
+ margin: 0 auto;
101
+ max-width: 982px;
102
+ }
103
+ """
104
+ with gr.Blocks(css=css) as demo:
105
+ with gr.Column(elem_id="col-container"):
106
+ gr.Markdown("# LatentSync: Audio Conditioned Latent Diffusion Models for Lip Sync")
107
+ gr.Markdown("LatentSync, an end-to-end lip sync framework based on audio conditioned latent diffusion models without any intermediate motion representation, diverging from previous diffusion-based lip sync methods based on pixel space diffusion or two-stage generation.")
108
+ gr.HTML("""
109
+ <div style="display:flex;column-gap:4px;">
110
+ <a href="https://github.com/bytedance/LatentSync">
111
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
112
+ </a>
113
+ <a href="https://arxiv.org/abs/2412.09262">
114
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
115
+ </a>
116
+ <a href="https://huggingface.co/spaces/fffiloni/LatentSync?duplicate=true">
117
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
118
+ </a>
119
+ <a href="https://huggingface.co/fffiloni">
120
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
121
+ </a>
122
+ </div>
123
+ """)
124
+ with gr.Row():
125
+ with gr.Column():
126
+ video_input = gr.Video(label="Video Control", format="mp4")
127
+ audio_input = gr.Video(label="Audio Inpit", type="filepath")
128
+ submit_btn = gr.Button("Submit")
129
+ with gr.Column():
130
+ video_result = gr.Video(label="Result")
131
+
132
+ gr.Examples(
133
+ examples = [
134
+ ["assets/demo1_video.mp4", "assets/demo1_audio.wav"],
135
+ ["assets/demo2_video.mp4", "assets/demo2_audio.wav"],
136
+ ["assets/demo3_video.mp4", "assets/demo3_audio.wav"],
137
+ ],
138
+ inputs = [video_input, audio_input]
139
+ )
140
+
141
+ submit_btn.click(
142
+ fn = main,
143
+ inputs = [video_input, audio_input],
144
+ outputs = [video_result]
145
+ )
146
+
147
+ demo.queue().launch(show_api=False, show_error=True)