|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import mediapipe as mp
|
|
import cv2
|
|
from decord import VideoReader
|
|
from einops import rearrange
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
from eval.fvd import compute_our_fvd
|
|
|
|
|
|
class FVD:
|
|
def __init__(self, resolution=(224, 224)):
|
|
self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
|
|
self.resolution = resolution
|
|
|
|
def detect_face(self, image):
|
|
height, width = image.shape[:2]
|
|
|
|
results = self.face_detector.process(image)
|
|
|
|
if not results.detections:
|
|
raise Exception("Face not detected")
|
|
|
|
detection = results.detections[0]
|
|
bounding_box = detection.location_data.relative_bounding_box
|
|
xmin = int(bounding_box.xmin * width)
|
|
ymin = int(bounding_box.ymin * height)
|
|
face_width = int(bounding_box.width * width)
|
|
face_height = int(bounding_box.height * height)
|
|
|
|
|
|
xmin = max(0, xmin)
|
|
ymin = max(0, ymin)
|
|
xmax = min(width, xmin + face_width)
|
|
ymax = min(height, ymin + face_height)
|
|
image = image[ymin:ymax, xmin:xmax]
|
|
|
|
return image
|
|
|
|
def detect_video(self, video_path, real: bool = True):
|
|
vr = VideoReader(video_path)
|
|
video_frames = vr[20:36].asnumpy()
|
|
vr.seek(0)
|
|
faces = []
|
|
for frame in video_frames:
|
|
face = self.detect_face(frame)
|
|
face = cv2.resize(face, (self.resolution[1], self.resolution[0]), interpolation=cv2.INTER_AREA)
|
|
faces.append(face)
|
|
|
|
if len(faces) != 16:
|
|
return None
|
|
faces = np.stack(faces, axis=0)
|
|
faces = torch.from_numpy(faces)
|
|
return faces
|
|
|
|
|
|
def eval_fvd(real_videos_dir, fake_videos_dir):
|
|
fvd = FVD()
|
|
real_features_list = []
|
|
fake_features_list = []
|
|
for file in tqdm.tqdm(os.listdir(fake_videos_dir)):
|
|
if file.endswith(".mp4"):
|
|
real_video_path = os.path.join(real_videos_dir, file.replace("_out.mp4", ".mp4"))
|
|
fake_video_path = os.path.join(fake_videos_dir, file)
|
|
real_features = fvd.detect_video(real_video_path, real=True)
|
|
fake_features = fvd.detect_video(fake_video_path, real=False)
|
|
if real_features is None or fake_features is None:
|
|
continue
|
|
real_features_list.append(real_features)
|
|
fake_features_list.append(fake_features)
|
|
|
|
real_features = torch.stack(real_features_list) / 255.0
|
|
fake_features = torch.stack(fake_features_list) / 255.0
|
|
print(compute_our_fvd(real_features, fake_features, device="cpu"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
real_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/cross"
|
|
fake_videos_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/segmented/latentsync_cross"
|
|
|
|
eval_fvd(real_videos_dir, fake_videos_dir)
|
|
|