HienK64BKHN commited on
Commit
7a62482
verified
1 Parent(s): e066bcd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +23 -0
  2. utils.py +96 -0
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from utils import recognition
4
+
5
+ workers = 0 if os.name == 'nt' else 4
6
+ device = "cpu"
7
+
8
+
9
+ title = "Compare Faces"
10
+ description = "A FaceNet (MTCNN for the bounding boxes and Inception Resnet V1 for the feature extracting, trainning with VGGFace2 dataset) feature extractor computer vision model to classify images of food as pizza, steak or sushi."
11
+ article = "Created by HienK64BKHN."
12
+
13
+
14
+ demo = gr.Interface(fn=recognition,
15
+ inputs=gr.Image(label="Image"),
16
+ outputs=[gr.Textbox(label="Names"),
17
+ gr.Image(label="Result")],
18
+ title=title,
19
+ description=description,
20
+ article=article
21
+ )
22
+
23
+ demo.launch(debug=False)
utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ import pandas as pd
5
+ import os
6
+ from PIL import Image
7
+ import cv2
8
+ from facenet_pytorch import MTCNN, InceptionResnetV1
9
+
10
+ device = "cpu"
11
+
12
+ SINGLE_FACE_PATH = "facenet-pytorch-master/data/test_images"
13
+ MULTI_FACES_PATH = "facenet-pytorch-master/data/test_images_2"
14
+ THRESHOLD = 0.8
15
+
16
+ mtcnn = MTCNN(
17
+ image_size=160, margin=0, min_face_size=20,
18
+ thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
19
+ device=device,
20
+ keep_all=True
21
+ )
22
+
23
+ resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
24
+
25
+ def extract_face(box, img, margin=20):
26
+ img_size = (len(img), len(img[0]))
27
+ face_size = 160
28
+ margin = [
29
+ margin * (box[2] - box[0]) / (face_size - margin),
30
+ margin * (box[3] - box[1]) / (face_size - margin),
31
+ ] #t岷 margin bao quanh box c农
32
+ box = [
33
+ int(max(box[0] - margin[0] / 2, 0)),
34
+ int(max(box[1] - margin[1] / 2, 0)),
35
+ int(min(box[2] + margin[0] / 2, img_size[0])),
36
+ int(min(box[3] + margin[1] / 2, img_size[1])),
37
+ ]
38
+ img = img[box[1]:box[3], box[0]:box[2]]
39
+ if img.size != 0:
40
+ face = cv2.resize(img,(face_size, face_size), interpolation=cv2.INTER_AREA)
41
+ face = Image.fromarray(face)
42
+ return face
43
+ return False
44
+
45
+ def trans(img):
46
+ transform = torchvision.transforms.Compose([
47
+ torchvision.transforms.ToTensor()
48
+ ])
49
+
50
+ return transform(img)
51
+
52
+ def find_and_compare(model, face):
53
+ face = trans(face)
54
+ face_embedded = model(face.unsqueeze(dim=0).to(device))
55
+ for usr in os.listdir(SINGLE_FACE_PATH):
56
+ usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt").to(device)
57
+ if (face_embedded - usr_embedding).norm().item() < THRESHOLD:
58
+ print("Found " + usr + " in the image!")
59
+ return usr
60
+
61
+ return "unknown"
62
+
63
+ def recognition(img):
64
+ if len(img) < 1000 and len(img[0]) < 1000:
65
+ box_thickness = 3
66
+ text_thickness = 1
67
+ text_scale = 0.75
68
+ else:
69
+ box_thickness = 6
70
+ text_thickness = 2
71
+ text_scale = 2
72
+
73
+ img = np.copy(np.asarray(img))
74
+ names = []
75
+
76
+ boxes, _ = mtcnn.detect(img)
77
+
78
+ for box in boxes:
79
+ face = extract_face(box, img)
80
+ if face != False:
81
+ name = find_and_compare(model=resnet, face=face)
82
+ img = cv2.rectangle(img, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), (255,0,0), box_thickness)
83
+ img = cv2.putText(img, name, (int(box[0]),int(box[1])), cv2.FONT_HERSHEY_DUPLEX, text_scale, (0,255,0), text_thickness, cv2.LINE_8)
84
+ if name != "unknown" and not (name in names):
85
+ names.append(name)
86
+ print(names)
87
+
88
+ if len(names) > 0:
89
+ result = "Found "
90
+ for name in names:
91
+ result += name
92
+ if names[-1] != name:
93
+ result += ", "
94
+ else:
95
+ result = "Found no one!"
96
+ return result, Image.fromarray(img)