bornet commited on
Commit
726c8e7
·
verified ·
1 Parent(s): 528e972

Sample usage with the Hugging Face model

Browse files
Files changed (1) hide show
  1. edgeface_sample_use.py +83 -0
edgeface_sample_use.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env -S uv run
2
+
3
+ # /// script
4
+ # requires-python = "<= 3.12"
5
+ # dependencies = [
6
+ # "torchvision",
7
+ # "huggingface_hub",
8
+ # "timm",
9
+ # "opencv-python",
10
+ # "mediapipe",
11
+ # "timm",
12
+ # ]
13
+ # ///
14
+
15
+ import os
16
+ import sys
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torchvision.transforms as transforms
21
+ from huggingface_hub import hf_hub_download
22
+ import cv2
23
+
24
+ from utils import align_crop
25
+ from timmfrv2 import TimmFRWrapperV2, model_configs
26
+
27
+ model_name = sys.argv[1]
28
+ image_1 = sys.argv[2]
29
+ image_2 = sys.argv[3]
30
+
31
+
32
+ def load_and_crop(image_filename_1, image_filename_2):
33
+ img_1 = cv2.imread(image_filename_1)
34
+ print(img_1)
35
+ print(type(img_1))
36
+ print(img_1.shape)
37
+ img_1 = cv2.cvtColor(img_1, cv2.COLOR_RGB2BGR)
38
+ print(img_1)
39
+ print(type(img_1))
40
+ print(img_1.shape)
41
+ img_2 = cv2.imread(image_filename_2)
42
+ crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR)
43
+ crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR)
44
+ return crop_1, crop_2
45
+
46
+
47
+ transform = transforms.Compose(
48
+ [
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
51
+ ]
52
+ )
53
+
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+ print("Download model")
57
+ model_path = hf_hub_download(
58
+ repo_id=model_configs[model_name]["repo"],
59
+ filename=model_configs[model_name]["filename"],
60
+ local_dir="models",
61
+ )
62
+ print(f"Model downloaded in {model_path}")
63
+
64
+ print("Create model")
65
+ model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False)
66
+ model = model_configs[model_name]["post_setup"](model)
67
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
68
+ model = model.eval()
69
+ model.to(device)
70
+
71
+ crop_a, crop_b = load_and_crop(image_1, image_2)
72
+
73
+ with torch.no_grad():
74
+ ea = model(transform(crop_a)[None].to(device))[0][None]
75
+ eb = model(transform(crop_b)[None].to(device))[0][None]
76
+ pct = float(F.cosine_similarity(ea, eb).item() * 100)
77
+ pct = max(0, min(100, pct))
78
+ print(f"{pct:.2f}% match")
79
+
80
+ # cv2.imshow("crop left", crop_a)
81
+ # cv2.imshow("crop right", crop_b)
82
+ # cv2.waitKey(0)
83
+ # cv2.destroyAllWindows()