import torch import torchvision import numpy as np import pandas as pd import os from PIL import Image import cv2 from models.mtcnn import MTCNN from models.inception_resnet_v1 import InceptionResnetV1 device = 'cpu' SINGLE_FACE_PATH = "data/test_images" MULTI_FACES_PATH = "data/test_images_2" THRESHOLD = 0.8 mtcnn = MTCNN( image_size=160, margin=0, min_face_size=20, thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, device=device, keep_all=True ) resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device) def extract_face(box, img, margin=20): img_size = (len(img), len(img[0])) face_size = 160 margin = [ margin * (box[2] - box[0]) / (face_size - margin), margin * (box[3] - box[1]) / (face_size - margin), ] #tạo margin bao quanh box cũ box = [ int(max(box[0] - margin[0] / 2, 0)), int(max(box[1] - margin[1] / 2, 0)), int(min(box[2] + margin[0] / 2, img_size[0])), int(min(box[3] + margin[1] / 2, img_size[1])), ] img = img[box[1]:box[3], box[0]:box[2]] if img.size != 0: face = cv2.resize(img,(face_size, face_size), interpolation=cv2.INTER_AREA) face = Image.fromarray(face) return face return False def trans(img): transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) return transform(img) def find_and_compare(model, face): face = trans(face) face_embedded = model(face.unsqueeze(dim=0).to(device)) for usr in os.listdir(SINGLE_FACE_PATH): usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt", map_location=torch.device(device)).to(device) if (face_embedded - usr_embedding).norm().item() < THRESHOLD: print("Found " + usr + " in the image!") return usr return "unknown" def recognition(img): if len(img) < 1000 and len(img[0]) < 1000: box_thickness = 3 text_thickness = 1 text_scale = 0.75 else: box_thickness = 6 text_thickness = 2 text_scale = 2 img = np.copy(np.asarray(img)) names = [] boxes, _ = mtcnn.detect(img) for box in boxes: face = extract_face(box, img) if face != False: name = find_and_compare(model=resnet, face=face) img = cv2.rectangle(img, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), (255,0,0), box_thickness) img = cv2.putText(img, name, (int(box[0]),int(box[1])), cv2.FONT_HERSHEY_DUPLEX, text_scale, (0,255,0), text_thickness, cv2.LINE_8) if name != "unknown" and not (name in names): names.append(name) print(names) if len(names) > 0: result = "Found " for name in names: result += name if names[-1] != name: result += ", " else: result = "Found no one!" return result, Image.fromarray(img)