Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
@@ -8,6 +8,7 @@ import cv2
|
|
8 |
from models.mtcnn import MTCNN
|
9 |
from models.inception_resnet_v1 import InceptionResnetV1
|
10 |
|
|
|
11 |
device = 'cpu'
|
12 |
|
13 |
SINGLE_FACE_PATH = "data/test_images"
|
@@ -17,11 +18,11 @@ THRESHOLD = 0.8
|
|
17 |
mtcnn = MTCNN(
|
18 |
image_size=160, margin=0, min_face_size=20,
|
19 |
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
|
20 |
-
device=
|
21 |
keep_all=True
|
22 |
)
|
23 |
|
24 |
-
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(
|
25 |
|
26 |
def extract_face(box, img, margin=20):
|
27 |
img_size = (len(img), len(img[0]))
|
@@ -52,9 +53,9 @@ def trans(img):
|
|
52 |
|
53 |
def find_and_compare(model, face):
|
54 |
face = trans(face)
|
55 |
-
face_embedded = model(face.unsqueeze(dim=0).to(
|
56 |
for usr in os.listdir(SINGLE_FACE_PATH):
|
57 |
-
usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt").to(
|
58 |
if (face_embedded - usr_embedding).norm().item() < THRESHOLD:
|
59 |
print("Found " + usr + " in the image!")
|
60 |
return usr
|
|
|
8 |
from models.mtcnn import MTCNN
|
9 |
from models.inception_resnet_v1 import InceptionResnetV1
|
10 |
|
11 |
+
|
12 |
device = 'cpu'
|
13 |
|
14 |
SINGLE_FACE_PATH = "data/test_images"
|
|
|
18 |
mtcnn = MTCNN(
|
19 |
image_size=160, margin=0, min_face_size=20,
|
20 |
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
|
21 |
+
device=device,
|
22 |
keep_all=True
|
23 |
)
|
24 |
|
25 |
+
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
|
26 |
|
27 |
def extract_face(box, img, margin=20):
|
28 |
img_size = (len(img), len(img[0]))
|
|
|
53 |
|
54 |
def find_and_compare(model, face):
|
55 |
face = trans(face)
|
56 |
+
face_embedded = model(face.unsqueeze(dim=0).to(device))
|
57 |
for usr in os.listdir(SINGLE_FACE_PATH):
|
58 |
+
usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt", map_location=torch.device(device)).to(device)
|
59 |
if (face_embedded - usr_embedding).norm().item() < THRESHOLD:
|
60 |
print("Found " + usr + " in the image!")
|
61 |
return usr
|