|
|
|
|
|
|
|
from cog import BasePredictor, Input, Path |
|
import os |
|
import time |
|
import subprocess |
|
|
|
MODEL_CACHE = "checkpoints" |
|
MODEL_URL = "https://weights.replicate.delivery/default/chunyu-li/LatentSync/model.tar" |
|
|
|
def download_weights(url, dest): |
|
start = time.time() |
|
print("downloading url: ", url) |
|
print("downloading to: ", dest) |
|
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) |
|
print("downloading took: ", time.time() - start) |
|
|
|
class Predictor(BasePredictor): |
|
def setup(self) -> None: |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
|
|
if not os.path.exists(MODEL_CACHE): |
|
download_weights(MODEL_URL, MODEL_CACHE) |
|
|
|
|
|
os.system("mkdir -p ~/.cache/torch/hub/checkpoints") |
|
os.system("ln -s $(pwd)/checkpoints/auxiliary/2DFAN4-cd938726ad.zip ~/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip") |
|
os.system("ln -s $(pwd)/checkpoints/auxiliary/s3fd-619a316812.pth ~/.cache/torch/hub/checkpoints/s3fd-619a316812.pth") |
|
os.system("ln -s $(pwd)/checkpoints/auxiliary/vgg16-397923af.pth ~/.cache/torch/hub/checkpoints/vgg16-397923af.pth") |
|
|
|
def predict( |
|
self, |
|
video: Path = Input( |
|
description="Input video", default=None |
|
), |
|
audio: Path = Input( |
|
description="Input audio to ", default=None |
|
), |
|
guidance_scale: float = Input( |
|
description="Guidance scale", ge=0, le=10, default=1.0 |
|
), |
|
seed: int = Input( |
|
description="Set to 0 for Random seed", default=0 |
|
) |
|
) -> Path: |
|
"""Run a single prediction on the model""" |
|
if seed <= 0: |
|
seed = int.from_bytes(os.urandom(2), "big") |
|
print(f"Using seed: {seed}") |
|
|
|
video_path = str(video) |
|
audio_path = str(audio) |
|
config_path = "configs/unet/second_stage.yaml" |
|
ckpt_path = "checkpoints/latentsync_unet.pt" |
|
output_path = "/tmp/video_out.mp4" |
|
|
|
|
|
os.system(f"python -m scripts.inference --unet_config_path {config_path} --inference_ckpt_path {ckpt_path} --guidance_scale {str(guidance_scale)} --video_path {video_path} --audio_path {audio_path} --video_out_path {output_path} --seed {seed}") |
|
return Path(output_path) |
|
|