HienK64BKHN commited on
Commit
2b77606
·
verified ·
1 Parent(s): 2fd2924

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -4
utils.py CHANGED
@@ -8,6 +8,7 @@ import cv2
8
  from models.mtcnn import MTCNN
9
  from models.inception_resnet_v1 import InceptionResnetV1
10
 
 
11
  device = 'cpu'
12
 
13
  SINGLE_FACE_PATH = "data/test_images"
@@ -17,11 +18,11 @@ THRESHOLD = 0.8
17
  mtcnn = MTCNN(
18
  image_size=160, margin=0, min_face_size=20,
19
  thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
20
- device='cpu',
21
  keep_all=True
22
  )
23
 
24
- resnet = InceptionResnetV1(pretrained='vggface2').eval().to('cpu')
25
 
26
  def extract_face(box, img, margin=20):
27
  img_size = (len(img), len(img[0]))
@@ -52,9 +53,9 @@ def trans(img):
52
 
53
  def find_and_compare(model, face):
54
  face = trans(face)
55
- face_embedded = model(face.unsqueeze(dim=0).to('cpu'))
56
  for usr in os.listdir(SINGLE_FACE_PATH):
57
- usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt").to('cpu')
58
  if (face_embedded - usr_embedding).norm().item() < THRESHOLD:
59
  print("Found " + usr + " in the image!")
60
  return usr
 
8
  from models.mtcnn import MTCNN
9
  from models.inception_resnet_v1 import InceptionResnetV1
10
 
11
+
12
  device = 'cpu'
13
 
14
  SINGLE_FACE_PATH = "data/test_images"
 
18
  mtcnn = MTCNN(
19
  image_size=160, margin=0, min_face_size=20,
20
  thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
21
+ device=device,
22
  keep_all=True
23
  )
24
 
25
+ resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
26
 
27
  def extract_face(box, img, margin=20):
28
  img_size = (len(img), len(img[0]))
 
53
 
54
  def find_and_compare(model, face):
55
  face = trans(face)
56
+ face_embedded = model(face.unsqueeze(dim=0).to(device))
57
  for usr in os.listdir(SINGLE_FACE_PATH):
58
+ usr_embedding = torch.load(SINGLE_FACE_PATH + "/" + usr + "/" + usr + ".pt", map_location=torch.device(device)).to(device)
59
  if (face_embedded - usr_embedding).norm().item() < THRESHOLD:
60
  print("Found " + usr + " in the image!")
61
  return usr