|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision import transforms
|
|
import cv2
|
|
from einops import rearrange
|
|
import mediapipe as mp
|
|
import torch
|
|
import numpy as np
|
|
from typing import Union
|
|
from .affine_transform import AlignRestore, laplacianSmooth
|
|
import face_alignment
|
|
|
|
"""
|
|
If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation.
|
|
https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image
|
|
"""
|
|
|
|
|
|
def load_fixed_mask(resolution: int) -> torch.Tensor:
|
|
mask_image = cv2.imread("latentsync/utils/mask.png")
|
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
|
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
|
|
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
|
return mask_image
|
|
|
|
|
|
class ImageProcessor:
|
|
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
|
|
self.resolution = resolution
|
|
self.resize = transforms.Resize(
|
|
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
|
|
)
|
|
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
|
|
self.mask = mask
|
|
|
|
if mask in ["mouth", "face", "eye"]:
|
|
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True)
|
|
if mask == "fix_mask":
|
|
self.face_mesh = None
|
|
self.smoother = laplacianSmooth()
|
|
self.restorer = AlignRestore()
|
|
|
|
if mask_image is None:
|
|
self.mask_image = load_fixed_mask(resolution)
|
|
else:
|
|
self.mask_image = mask_image
|
|
|
|
if device != "cpu":
|
|
self.fa = face_alignment.FaceAlignment(
|
|
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
|
|
)
|
|
self.face_mesh = None
|
|
else:
|
|
|
|
self.face_mesh = None
|
|
self.fa = None
|
|
|
|
def detect_facial_landmarks(self, image: np.ndarray):
|
|
height, width, _ = image.shape
|
|
results = self.face_mesh.process(image)
|
|
if not results.multi_face_landmarks:
|
|
raise RuntimeError("Face not detected")
|
|
face_landmarks = results.multi_face_landmarks[0]
|
|
landmark_coordinates = [
|
|
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
|
|
]
|
|
return landmark_coordinates
|
|
|
|
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
|
|
image = self.resize(image)
|
|
|
|
if self.mask == "mouth" or self.mask == "face":
|
|
landmark_coordinates = self.detect_facial_landmarks(image)
|
|
if self.mask == "mouth":
|
|
surround_landmarks = mouth_surround_landmarks
|
|
else:
|
|
surround_landmarks = face_surround_landmarks
|
|
|
|
points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
|
|
points = np.array(points)
|
|
mask = np.ones((self.resolution, self.resolution))
|
|
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
|
|
mask = torch.from_numpy(mask)
|
|
mask = mask.unsqueeze(0)
|
|
elif self.mask == "half":
|
|
mask = torch.ones((self.resolution, self.resolution))
|
|
height = mask.shape[0]
|
|
mask[height // 2 :, :] = 0
|
|
mask = mask.unsqueeze(0)
|
|
elif self.mask == "eye":
|
|
mask = torch.ones((self.resolution, self.resolution))
|
|
landmark_coordinates = self.detect_facial_landmarks(image)
|
|
y = landmark_coordinates[195][1]
|
|
mask[y:, :] = 0
|
|
mask = mask.unsqueeze(0)
|
|
else:
|
|
raise ValueError("Invalid mask type")
|
|
|
|
image = image.to(dtype=torch.float32)
|
|
pixel_values = self.normalize(image / 255.0)
|
|
masked_pixel_values = pixel_values * mask
|
|
mask = 1 - mask
|
|
|
|
return pixel_values, masked_pixel_values, mask
|
|
|
|
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
|
|
|
|
if self.fa is None:
|
|
landmark_coordinates = np.array(self.detect_facial_landmarks(image))
|
|
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
|
|
else:
|
|
detected_faces = self.fa.get_landmarks(image)
|
|
if detected_faces is None:
|
|
raise RuntimeError("Face not detected")
|
|
lm68 = detected_faces[0]
|
|
|
|
points = self.smoother.smooth(lm68)
|
|
lmk3_ = np.zeros((3, 2))
|
|
lmk3_[0] = points[17:22].mean(0)
|
|
lmk3_[1] = points[22:27].mean(0)
|
|
lmk3_[2] = points[27:36].mean(0)
|
|
|
|
face, affine_matrix = self.restorer.align_warp_face(
|
|
image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant"
|
|
)
|
|
box = [0, 0, face.shape[1], face.shape[0]]
|
|
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
|
|
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
|
|
return face, box, affine_matrix
|
|
|
|
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
|
|
if affine_transform:
|
|
image, _, _ = self.affine_transform(image)
|
|
else:
|
|
image = self.resize(image)
|
|
pixel_values = self.normalize(image / 255.0)
|
|
masked_pixel_values = pixel_values * self.mask_image
|
|
return pixel_values, masked_pixel_values, self.mask_image[0:1]
|
|
|
|
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
|
|
if isinstance(images, np.ndarray):
|
|
images = torch.from_numpy(images)
|
|
if images.shape[3] == 3:
|
|
images = rearrange(images, "b h w c -> b c h w")
|
|
if self.mask == "fix_mask":
|
|
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
|
|
else:
|
|
results = [self.preprocess_one_masked_image(image) for image in images]
|
|
|
|
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
|
|
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
|
|
|
|
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
|
|
if isinstance(images, np.ndarray):
|
|
images = torch.from_numpy(images)
|
|
if images.shape[3] == 3:
|
|
images = rearrange(images, "b h w c -> b c h w")
|
|
images = self.resize(images)
|
|
pixel_values = self.normalize(images / 255.0)
|
|
return pixel_values
|
|
|
|
def close(self):
|
|
if self.face_mesh is not None:
|
|
self.face_mesh.close()
|
|
|
|
|
|
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
|
|
"""
|
|
lm478: [B, 478, 3] or [478,3]
|
|
"""
|
|
|
|
|
|
landmarks_extracted = []
|
|
for index in landmark_points_68:
|
|
x = lm478[index][0]
|
|
y = lm478[index][1]
|
|
landmarks_extracted.append((x, y))
|
|
return np.array(landmarks_extracted)
|
|
|
|
|
|
landmark_points_68 = [
|
|
162,
|
|
234,
|
|
93,
|
|
58,
|
|
172,
|
|
136,
|
|
149,
|
|
148,
|
|
152,
|
|
377,
|
|
378,
|
|
365,
|
|
397,
|
|
288,
|
|
323,
|
|
454,
|
|
389,
|
|
71,
|
|
63,
|
|
105,
|
|
66,
|
|
107,
|
|
336,
|
|
296,
|
|
334,
|
|
293,
|
|
301,
|
|
168,
|
|
197,
|
|
5,
|
|
4,
|
|
75,
|
|
97,
|
|
2,
|
|
326,
|
|
305,
|
|
33,
|
|
160,
|
|
158,
|
|
133,
|
|
153,
|
|
144,
|
|
362,
|
|
385,
|
|
387,
|
|
263,
|
|
373,
|
|
380,
|
|
61,
|
|
39,
|
|
37,
|
|
0,
|
|
267,
|
|
269,
|
|
291,
|
|
405,
|
|
314,
|
|
17,
|
|
84,
|
|
181,
|
|
78,
|
|
82,
|
|
13,
|
|
312,
|
|
308,
|
|
317,
|
|
14,
|
|
87,
|
|
]
|
|
|
|
|
|
|
|
mouth_surround_landmarks = [
|
|
164,
|
|
165,
|
|
167,
|
|
92,
|
|
186,
|
|
57,
|
|
43,
|
|
106,
|
|
182,
|
|
83,
|
|
18,
|
|
313,
|
|
406,
|
|
335,
|
|
273,
|
|
287,
|
|
410,
|
|
322,
|
|
391,
|
|
393,
|
|
]
|
|
|
|
face_surround_landmarks = [
|
|
152,
|
|
377,
|
|
400,
|
|
378,
|
|
379,
|
|
365,
|
|
397,
|
|
288,
|
|
435,
|
|
433,
|
|
411,
|
|
425,
|
|
423,
|
|
327,
|
|
326,
|
|
94,
|
|
97,
|
|
98,
|
|
203,
|
|
205,
|
|
187,
|
|
213,
|
|
215,
|
|
58,
|
|
172,
|
|
136,
|
|
150,
|
|
149,
|
|
176,
|
|
148,
|
|
]
|
|
|
|
if __name__ == "__main__":
|
|
image_processor = ImageProcessor(512, mask="fix_mask")
|
|
video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4")
|
|
while True:
|
|
ret, frame = video.read()
|
|
|
|
|
|
|
|
|
|
|
|
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
|
|
|
|
face, _, _ = image_processor.affine_transform(frame)
|
|
|
|
break
|
|
|
|
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
|
|
cv2.imwrite("face.jpg", face)
|
|
|
|
|
|
|
|
|