Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.utils.checkpoint | |
from PIL import Image | |
import numpy as np | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
import cv2 | |
from diffusers import AutoencoderKLTemporalDecoder | |
from diffusers.schedulers import EulerDiscreteScheduler | |
from transformers import WhisperModel, CLIPVisionModelWithProjection, AutoFeatureExtractor | |
from src.utils.util import save_videos_grid, seed_everything | |
from src.dataset.test_preprocess import process_bbox, image_audio_to_tensor | |
from src.models.base.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel, add_ip_adapters | |
from src.pipelines.pipeline_sonic import SonicPipeline | |
from src.models.audio_adapter.audio_proj import AudioProjModel | |
from src.models.audio_adapter.audio_to_bucket import Audio2bucketModel | |
from src.utils.RIFE.RIFE_HDv3 import RIFEModel | |
from src.dataset.face_align.align import AlignImage | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
def test( | |
pipe, | |
config, | |
wav_enc, | |
audio_pe, | |
audio2bucket, | |
image_encoder, | |
width, | |
height, | |
batch | |
): | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.unsqueeze(0).to(pipe.device).float() | |
ref_img = batch['ref_img'] | |
clip_img = batch['clip_images'] | |
face_mask = batch['face_mask'] | |
image_embeds = image_encoder( | |
clip_img | |
).image_embeds | |
audio_feature = batch['audio_feature'] | |
audio_len = batch['audio_len'] | |
step = int(config.step) | |
window = 3000 | |
audio_prompts = [] | |
last_audio_prompts = [] | |
for i in range(0, audio_feature.shape[-1], window): | |
audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states | |
last_audio_prompt = wav_enc.encoder(audio_feature[:,:,i:i+window]).last_hidden_state | |
last_audio_prompt = last_audio_prompt.unsqueeze(-2) | |
audio_prompt = torch.stack(audio_prompt, dim=2) | |
audio_prompts.append(audio_prompt) | |
last_audio_prompts.append(last_audio_prompt) | |
audio_prompts = torch.cat(audio_prompts, dim=1) | |
audio_prompts = audio_prompts[:,:audio_len*2] | |
audio_prompts = torch.cat([torch.zeros_like(audio_prompts[:,:4]), audio_prompts, torch.zeros_like(audio_prompts[:,:6])], 1) | |
last_audio_prompts = torch.cat(last_audio_prompts, dim=1) | |
last_audio_prompts = last_audio_prompts[:,:audio_len*2] | |
last_audio_prompts = torch.cat([torch.zeros_like(last_audio_prompts[:,:24]), last_audio_prompts, torch.zeros_like(last_audio_prompts[:,:26])], 1) | |
ref_tensor_list = [] | |
audio_tensor_list = [] | |
uncond_audio_tensor_list = [] | |
motion_buckets = [] | |
for i in tqdm(range(audio_len//step)): | |
audio_clip = audio_prompts[:,i*2*step:i*2*step+10].unsqueeze(0) | |
audio_clip_for_bucket = last_audio_prompts[:,i*2*step:i*2*step+50].unsqueeze(0) | |
motion_bucket = audio2bucket(audio_clip_for_bucket, image_embeds) | |
motion_bucket = motion_bucket * 16 + 16 | |
motion_buckets.append(motion_bucket[0]) | |
cond_audio_clip = audio_pe(audio_clip).squeeze(0) | |
uncond_audio_clip = audio_pe(torch.zeros_like(audio_clip)).squeeze(0) | |
ref_tensor_list.append(ref_img[0]) | |
audio_tensor_list.append(cond_audio_clip[0]) | |
uncond_audio_tensor_list.append(uncond_audio_clip[0]) | |
video = pipe( | |
ref_img, | |
clip_img, | |
face_mask, | |
audio_tensor_list, | |
uncond_audio_tensor_list, | |
motion_buckets, | |
height=height, | |
width=width, | |
num_frames=len(audio_tensor_list), | |
decode_chunk_size=config.decode_chunk_size, | |
motion_bucket_scale=config.motion_bucket_scale, | |
fps=config.fps, | |
noise_aug_strength=config.noise_aug_strength, | |
min_guidance_scale1=config.min_appearance_guidance_scale, # 1.0, | |
max_guidance_scale1=config.max_appearance_guidance_scale, | |
min_guidance_scale2=config.audio_guidance_scale, # 1.0, | |
max_guidance_scale2=config.audio_guidance_scale, | |
overlap=config.overlap, | |
shift_offset=config.shift_offset, | |
frames_per_batch=config.n_sample_frames, | |
num_inference_steps=config.num_inference_steps, | |
i2i_noise_strength=config.i2i_noise_strength | |
).frames | |
# Concat it with pose tensor | |
# pose_tensor = torch.stack(pose_tensor_list,1).unsqueeze(0) | |
video = (video*0.5 + 0.5).clamp(0, 1) | |
video = torch.cat([video.to(pipe.device)], dim=0).cpu() | |
return video | |
class Sonic(): | |
config_file = os.path.join(BASE_DIR, 'config/inference/sonic.yaml') | |
config = OmegaConf.load(config_file) | |
def __init__(self, | |
device_id=0, | |
enable_interpolate_frame=True, | |
): | |
config = self.config | |
config.use_interframe = enable_interpolate_frame | |
device = 'cuda:{}'.format(device_id) if device_id > -1 else 'cpu' | |
config.pretrained_model_name_or_path = os.path.join(BASE_DIR, config.pretrained_model_name_or_path) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="vae", | |
variant="fp16") | |
val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="scheduler") | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="image_encoder", | |
variant="fp16") | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
config.pretrained_model_name_or_path, | |
subfolder="unet", | |
variant="fp16") | |
add_ip_adapters(unet, [32], [config.ip_audio_scale]) | |
audio2token = AudioProjModel(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=1024, context_tokens=32).to(device) | |
audio2bucket = Audio2bucketModel(seq_len=50, blocks=1, channels=384, clip_channels=1024, intermediate_dim=1024, output_dim=1, context_tokens=2).to(device) | |
unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path) | |
audio2token_checkpoint_path = os.path.join(BASE_DIR, config.audio2token_checkpoint_path) | |
audio2bucket_checkpoint_path = os.path.join(BASE_DIR, config.audio2bucket_checkpoint_path) | |
unet.load_state_dict( | |
torch.load(unet_checkpoint_path, map_location="cpu"), | |
strict=True, | |
) | |
audio2token.load_state_dict( | |
torch.load(audio2token_checkpoint_path, map_location="cpu"), | |
strict=True, | |
) | |
audio2bucket.load_state_dict( | |
torch.load(audio2bucket_checkpoint_path, map_location="cpu"), | |
strict=True, | |
) | |
if config.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
elif config.weight_dtype == "fp32": | |
weight_dtype = torch.float32 | |
elif config.weight_dtype == "bf16": | |
weight_dtype = torch.bfloat16 | |
else: | |
raise ValueError( | |
f"Do not support weight dtype: {config.weight_dtype} during training" | |
) | |
whisper = WhisperModel.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')).to(device).eval() | |
whisper.requires_grad_(False) | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(os.path.join(BASE_DIR, 'checkpoints/whisper-tiny/')) | |
det_path = os.path.join(BASE_DIR, os.path.join(BASE_DIR, 'checkpoints/yoloface_v5m.pt')) | |
self.face_det = AlignImage(device, det_path=det_path) | |
if config.use_interframe: | |
rife = RIFEModel(device=device) | |
rife.load_model(os.path.join(BASE_DIR, 'checkpoints', 'RIFE/')) | |
self.rife = rife | |
image_encoder.to(weight_dtype) | |
vae.to(weight_dtype) | |
unet.to(weight_dtype) | |
pipe = SonicPipeline( | |
unet=unet, | |
image_encoder=image_encoder, | |
vae=vae, | |
scheduler=val_noise_scheduler, | |
) | |
pipe = pipe.to(device=device, dtype=weight_dtype) | |
self.pipe = pipe | |
self.whisper = whisper | |
self.audio2token = audio2token | |
self.audio2bucket = audio2bucket | |
self.image_encoder = image_encoder | |
self.device = device | |
print('init done') | |
def preprocess(self, | |
image_path, expand_ratio=1.0): | |
face_image = cv2.imread(image_path) | |
h, w = face_image.shape[:2] | |
_, _, bboxes = self.face_det(face_image, maxface=True) | |
face_num = len(bboxes) | |
bbox = [] | |
if face_num > 0: | |
x1, y1, ww, hh = bboxes[0] | |
x2, y2 = x1 + ww, y1 + hh | |
bbox = x1, y1, x2, y2 | |
bbox_s = process_bbox(bbox, expand_radio=expand_ratio, height=h, width=w) | |
return { | |
'face_num': face_num, | |
'crop_bbox': bbox_s, | |
} | |
def crop_image(self, | |
input_image_path, | |
output_image_path, | |
crop_bbox): | |
face_image = cv2.imread(input_image_path) | |
crop_image = face_image[crop_bbox[1]:crop_bbox[3], crop_bbox[0]:crop_bbox[2]] | |
cv2.imwrite(output_image_path, crop_image) | |
def process(self, | |
image_path, | |
audio_path, | |
output_path, | |
min_resolution=512, | |
inference_steps=25, | |
dynamic_scale=1.0, | |
keep_resolution=False, | |
seed=None): | |
config = self.config | |
device = self.device | |
pipe = self.pipe | |
whisper = self.whisper | |
audio2token = self.audio2token | |
audio2bucket = self.audio2bucket | |
image_encoder = self.image_encoder | |
# specific parameters | |
if seed: | |
config.seed = seed | |
config.num_inference_steps = inference_steps | |
config.motion_bucket_scale = dynamic_scale | |
seed_everything(config.seed) | |
video_path = output_path.replace('.mp4', '_noaudio.mp4') | |
audio_video_path = output_path | |
imSrc_ = Image.open(image_path).convert('RGB') | |
raw_w, raw_h = imSrc_.size | |
test_data = image_audio_to_tensor(self.face_det, self.feature_extractor, image_path, audio_path, limit=config.frame_num, image_size=min_resolution, area=config.area) | |
if test_data is None: | |
return -1 | |
height, width = test_data['ref_img'].shape[-2:] | |
if keep_resolution: | |
resolution = f'{raw_w//2*2}x{raw_h//2*2}' | |
else: | |
resolution = f'{width}x{height}' | |
video = test( | |
pipe, | |
config, | |
wav_enc=whisper, | |
audio_pe=audio2token, | |
audio2bucket=audio2bucket, | |
image_encoder=image_encoder, | |
width=width, | |
height=height, | |
batch=test_data, | |
) | |
if config.use_interframe: | |
rife = self.rife | |
out = video.to(device) | |
results = [] | |
video_len = out.shape[2] | |
for idx in tqdm(range(video_len-1), ncols=0): | |
I1 = out[:, :, idx] | |
I2 = out[:, :, idx+1] | |
middle = rife.inference(I1, I2).clamp(0, 1).detach() | |
results.append(out[:, :, idx]) | |
results.append(middle) | |
results.append(out[:, :, video_len-1]) | |
video = torch.stack(results, 2).cpu() | |
save_videos_grid(video, video_path, n_rows=video.shape[0], fps=config.fps * 2 if config.use_interframe else config.fps) | |
os.system(f"ffmpeg -i '{video_path}' -i '{audio_path}' -s {resolution} -vcodec libx264 -acodec aac -crf 18 -shortest '{audio_video_path}' -y; rm '{video_path}'") | |
return 0 | |