Spaces:
Sleeping
Sleeping
Upload 44 files
Browse files- .gitattributes +3 -0
- data/.gitignore +3 -0
- data/facenet-pytorch-banner.png +0 -0
- data/multiface.jpg +0 -0
- data/multiface_detected.png +3 -0
- data/onet.pt +3 -0
- data/pnet.pt +3 -0
- data/rnet.pt +3 -0
- data/test_images/angelina_jolie/1.jpg +0 -0
- data/test_images/angelina_jolie/angelina_jolie.pt +3 -0
- data/test_images/bo_vinh/bo_vinh.pt +3 -0
- data/test_images/brad_pitt/brad_pitt.pt +3 -0
- data/test_images/bradley_cooper/1.jpg +3 -0
- data/test_images/bradley_cooper/bradley_cooper.pt +3 -0
- data/test_images/chau_anh/chau_anh.pt +3 -0
- data/test_images/daniel_radcliffe/daniel_radcliffe.pt +3 -0
- data/test_images/hermione_granger/hermione_granger.pt +3 -0
- data/test_images/hien/hien.pt +3 -0
- data/test_images/kate_siegel/1.jpg +0 -0
- data/test_images/kate_siegel/kate_siegel.pt +3 -0
- data/test_images/khanh/khanh.pt +3 -0
- data/test_images/me_hoa/me_hoa.pt +3 -0
- data/test_images/ny_khanh/ny_khanh.pt +3 -0
- data/test_images/paul_rudd/1.jpg +0 -0
- data/test_images/paul_rudd/paul_rudd.pt +3 -0
- data/test_images/ron_weasley/ron_weasley.pt +3 -0
- data/test_images/shea_whigham/1.jpg +3 -0
- data/test_images/shea_whigham/shea_whigham.pt +3 -0
- data/test_images/tu_linh/tu_linh.pt +3 -0
- data/test_images_2/angelina_jolie_brad_pitt/1.jpg +0 -0
- data/test_images_2/bong_chanh/1.jpg +0 -0
- data/test_images_2/bong_chanh/2.jpg +0 -0
- data/test_images_2/khanh_va_ny/1.jpg +0 -0
- data/test_images_2/the_golden_trio/1.jpg +0 -0
- data/test_images_aligned/angelina_jolie/1.png +0 -0
- data/test_images_aligned/bradley_cooper/1.png +0 -0
- data/test_images_aligned/kate_siegel/1.png +0 -0
- data/test_images_aligned/paul_rudd/1.png +0 -0
- data/test_images_aligned/shea_whigham/1.png +0 -0
- models/inception_resnet_v1.py +340 -0
- models/mtcnn.py +519 -0
- models/utils/detect_face.py +378 -0
- models/utils/download.py +102 -0
- models/utils/tensorflow2pytorch.py +416 -0
- models/utils/training.py +144 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/multiface_detected.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/test_images/bradley_cooper/1.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/test_images/shea_whigham/1.jpg filter=lfs diff=lfs merge=lfs -text
|
data/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
2018*
|
2 |
+
*.json
|
3 |
+
profile.txt
|
data/facenet-pytorch-banner.png
ADDED
![]() |
data/multiface.jpg
ADDED
![]() |
data/multiface_detected.png
ADDED
![]() |
Git LFS Details
|
data/onet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:165bfbe42940416ccfb977545cf0e976d5bf321f67083ae2aaaa5c764280118d
|
3 |
+
size 1559269
|
data/pnet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2a71925e0b9996a42f63e47efc1ca19043e69558b5c523b978d611dfae49c8f
|
3 |
+
size 28570
|
data/rnet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbb937de72efc9ef83b186c49f5f558467a1d7e3453a8ece0d71a886633f6a86
|
3 |
+
size 403147
|
data/test_images/angelina_jolie/1.jpg
ADDED
![]() |
data/test_images/angelina_jolie/angelina_jolie.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:729a0aee6c1c8f1b7b7db49011b2a238d0d54fda993d49271f9f39d876e94a66
|
3 |
+
size 2816
|
data/test_images/bo_vinh/bo_vinh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d274aaf7b361cf62066c08e4153601bdf9f8cdc2458e0112bbfc1d6c7191235
|
3 |
+
size 2795
|
data/test_images/brad_pitt/brad_pitt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f70ad8cf619238a4856781ac1ef9a2e7164b486adf29bcb618ac394d5725d289
|
3 |
+
size 2801
|
data/test_images/bradley_cooper/1.jpg
ADDED
![]() |
Git LFS Details
|
data/test_images/bradley_cooper/bradley_cooper.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a22e6311cdd61223c03c955fa48fdb348dd2f607266e1be33cdf245deb0bc2b6
|
3 |
+
size 2816
|
data/test_images/chau_anh/chau_anh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b393e3ee47643ba619cf48bf58d30aaa816f8943112c8ea24e8e682d97d3d22
|
3 |
+
size 2798
|
data/test_images/daniel_radcliffe/daniel_radcliffe.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5bc1e56b1b34386698721020a685b1deea3bb3870b7ea5c8043599ca607881d
|
3 |
+
size 2822
|
data/test_images/hermione_granger/hermione_granger.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c94f592eb34e84957844f545d841c2559e2ed28a4956c1307d570d41c1bc4ec
|
3 |
+
size 2822
|
data/test_images/hien/hien.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d6a1965b22ac8ab4441fe79dc39345612c398afb78ca58072e66e35df47cc67
|
3 |
+
size 2722
|
data/test_images/kate_siegel/1.jpg
ADDED
![]() |
data/test_images/kate_siegel/kate_siegel.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43a64061efec1325ca8d8ab5822a0b12b974f64a610b21c699b0465b5ede7dd1
|
3 |
+
size 2807
|
data/test_images/khanh/khanh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9f31fc909cf9bd67055ca9ec2a375e97c314b944e9568e2e69b1cfcca081242
|
3 |
+
size 2725
|
data/test_images/me_hoa/me_hoa.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64861c898832b1c5d30447c469b228d06c84439f52796fdabae90c15d1e61d03
|
3 |
+
size 2728
|
data/test_images/ny_khanh/ny_khanh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd277690d8a7a10057f3195574e764a5f17472ff1918b4ad8a2b871154778f47
|
3 |
+
size 2798
|
data/test_images/paul_rudd/1.jpg
ADDED
![]() |
data/test_images/paul_rudd/paul_rudd.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92b5ea7bd25a17d2bdceb79a785ecab4944951555469ce27cfdf1709ba6cbdeb
|
3 |
+
size 2801
|
data/test_images/ron_weasley/ron_weasley.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79a4ebf96e0ed3676d0b7ec7c8d017348736097286cef03dda2fe316a25b33c1
|
3 |
+
size 2807
|
data/test_images/shea_whigham/1.jpg
ADDED
![]() |
Git LFS Details
|
data/test_images/shea_whigham/shea_whigham.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab6085300a2227199b57c7eb259c973d785e31ee5e36dbddc936e07019a5eaf9
|
3 |
+
size 2810
|
data/test_images/tu_linh/tu_linh.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9a9afd2eaabb2e7a4808ee168ad7232c51a0bc71a1fc6d50b1ddf6c6a301fe5
|
3 |
+
size 2795
|
data/test_images_2/angelina_jolie_brad_pitt/1.jpg
ADDED
![]() |
data/test_images_2/bong_chanh/1.jpg
ADDED
![]() |
data/test_images_2/bong_chanh/2.jpg
ADDED
![]() |
data/test_images_2/khanh_va_ny/1.jpg
ADDED
![]() |
data/test_images_2/the_golden_trio/1.jpg
ADDED
![]() |
data/test_images_aligned/angelina_jolie/1.png
ADDED
![]() |
data/test_images_aligned/bradley_cooper/1.png
ADDED
![]() |
data/test_images_aligned/kate_siegel/1.png
ADDED
![]() |
data/test_images_aligned/paul_rudd/1.png
ADDED
![]() |
data/test_images_aligned/shea_whigham/1.png
ADDED
![]() |
models/inception_resnet_v1.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from requests.adapters import HTTPAdapter
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from .utils.download import download_url_to_file
|
10 |
+
|
11 |
+
|
12 |
+
class BasicConv2d(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
15 |
+
super().__init__()
|
16 |
+
self.conv = nn.Conv2d(
|
17 |
+
in_planes, out_planes,
|
18 |
+
kernel_size=kernel_size, stride=stride,
|
19 |
+
padding=padding, bias=False
|
20 |
+
) # verify bias false
|
21 |
+
self.bn = nn.BatchNorm2d(
|
22 |
+
out_planes,
|
23 |
+
eps=0.001, # value found in tensorflow
|
24 |
+
momentum=0.1, # default pytorch value
|
25 |
+
affine=True
|
26 |
+
)
|
27 |
+
self.relu = nn.ReLU(inplace=False)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
x = self.bn(x)
|
32 |
+
x = self.relu(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class Block35(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self, scale=1.0):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.scale = scale
|
42 |
+
|
43 |
+
self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1)
|
44 |
+
|
45 |
+
self.branch1 = nn.Sequential(
|
46 |
+
BasicConv2d(256, 32, kernel_size=1, stride=1),
|
47 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
48 |
+
)
|
49 |
+
|
50 |
+
self.branch2 = nn.Sequential(
|
51 |
+
BasicConv2d(256, 32, kernel_size=1, stride=1),
|
52 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
|
53 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
54 |
+
)
|
55 |
+
|
56 |
+
self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
|
57 |
+
self.relu = nn.ReLU(inplace=False)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x0 = self.branch0(x)
|
61 |
+
x1 = self.branch1(x)
|
62 |
+
x2 = self.branch2(x)
|
63 |
+
out = torch.cat((x0, x1, x2), 1)
|
64 |
+
out = self.conv2d(out)
|
65 |
+
out = out * self.scale + x
|
66 |
+
out = self.relu(out)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class Block17(nn.Module):
|
71 |
+
|
72 |
+
def __init__(self, scale=1.0):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
self.scale = scale
|
76 |
+
|
77 |
+
self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1)
|
78 |
+
|
79 |
+
self.branch1 = nn.Sequential(
|
80 |
+
BasicConv2d(896, 128, kernel_size=1, stride=1),
|
81 |
+
BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
|
82 |
+
BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
|
83 |
+
)
|
84 |
+
|
85 |
+
self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
|
86 |
+
self.relu = nn.ReLU(inplace=False)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x0 = self.branch0(x)
|
90 |
+
x1 = self.branch1(x)
|
91 |
+
out = torch.cat((x0, x1), 1)
|
92 |
+
out = self.conv2d(out)
|
93 |
+
out = out * self.scale + x
|
94 |
+
out = self.relu(out)
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
class Block8(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, scale=1.0, noReLU=False):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.scale = scale
|
104 |
+
self.noReLU = noReLU
|
105 |
+
|
106 |
+
self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1)
|
107 |
+
|
108 |
+
self.branch1 = nn.Sequential(
|
109 |
+
BasicConv2d(1792, 192, kernel_size=1, stride=1),
|
110 |
+
BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
|
111 |
+
BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
|
112 |
+
)
|
113 |
+
|
114 |
+
self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
|
115 |
+
if not self.noReLU:
|
116 |
+
self.relu = nn.ReLU(inplace=False)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x0 = self.branch0(x)
|
120 |
+
x1 = self.branch1(x)
|
121 |
+
out = torch.cat((x0, x1), 1)
|
122 |
+
out = self.conv2d(out)
|
123 |
+
out = out * self.scale + x
|
124 |
+
if not self.noReLU:
|
125 |
+
out = self.relu(out)
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
class Mixed_6a(nn.Module):
|
130 |
+
|
131 |
+
def __init__(self):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)
|
135 |
+
|
136 |
+
self.branch1 = nn.Sequential(
|
137 |
+
BasicConv2d(256, 192, kernel_size=1, stride=1),
|
138 |
+
BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
|
139 |
+
BasicConv2d(192, 256, kernel_size=3, stride=2)
|
140 |
+
)
|
141 |
+
|
142 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x0 = self.branch0(x)
|
146 |
+
x1 = self.branch1(x)
|
147 |
+
x2 = self.branch2(x)
|
148 |
+
out = torch.cat((x0, x1, x2), 1)
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class Mixed_7a(nn.Module):
|
153 |
+
|
154 |
+
def __init__(self):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.branch0 = nn.Sequential(
|
158 |
+
BasicConv2d(896, 256, kernel_size=1, stride=1),
|
159 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
160 |
+
)
|
161 |
+
|
162 |
+
self.branch1 = nn.Sequential(
|
163 |
+
BasicConv2d(896, 256, kernel_size=1, stride=1),
|
164 |
+
BasicConv2d(256, 256, kernel_size=3, stride=2)
|
165 |
+
)
|
166 |
+
|
167 |
+
self.branch2 = nn.Sequential(
|
168 |
+
BasicConv2d(896, 256, kernel_size=1, stride=1),
|
169 |
+
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
170 |
+
BasicConv2d(256, 256, kernel_size=3, stride=2)
|
171 |
+
)
|
172 |
+
|
173 |
+
self.branch3 = nn.MaxPool2d(3, stride=2)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x0 = self.branch0(x)
|
177 |
+
x1 = self.branch1(x)
|
178 |
+
x2 = self.branch2(x)
|
179 |
+
x3 = self.branch3(x)
|
180 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
class InceptionResnetV1(nn.Module):
|
185 |
+
"""Inception Resnet V1 model with optional loading of pretrained weights.
|
186 |
+
|
187 |
+
Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface
|
188 |
+
datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
|
189 |
+
requested and cached in the torch cache. Subsequent instantiations use the cache rather than
|
190 |
+
redownloading.
|
191 |
+
|
192 |
+
Keyword Arguments:
|
193 |
+
pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
|
194 |
+
(default: {None})
|
195 |
+
classify {bool} -- Whether the model should output classification probabilities or feature
|
196 |
+
embeddings. (default: {False})
|
197 |
+
num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not
|
198 |
+
equal to that used for the pretrained model, the final linear layer will be randomly
|
199 |
+
initialized. (default: {None})
|
200 |
+
dropout_prob {float} -- Dropout probability. (default: {0.6})
|
201 |
+
"""
|
202 |
+
def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
# Set simple attributes
|
206 |
+
self.pretrained = pretrained
|
207 |
+
self.classify = classify
|
208 |
+
self.num_classes = num_classes
|
209 |
+
|
210 |
+
if pretrained == 'vggface2':
|
211 |
+
tmp_classes = 8631
|
212 |
+
elif pretrained == 'casia-webface':
|
213 |
+
tmp_classes = 10575
|
214 |
+
elif pretrained is None and self.classify and self.num_classes is None:
|
215 |
+
raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')
|
216 |
+
|
217 |
+
|
218 |
+
# Define layers
|
219 |
+
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
220 |
+
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
221 |
+
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
222 |
+
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
223 |
+
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
224 |
+
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
225 |
+
self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2)
|
226 |
+
self.repeat_1 = nn.Sequential(
|
227 |
+
Block35(scale=0.17),
|
228 |
+
Block35(scale=0.17),
|
229 |
+
Block35(scale=0.17),
|
230 |
+
Block35(scale=0.17),
|
231 |
+
Block35(scale=0.17),
|
232 |
+
)
|
233 |
+
self.mixed_6a = Mixed_6a()
|
234 |
+
self.repeat_2 = nn.Sequential(
|
235 |
+
Block17(scale=0.10),
|
236 |
+
Block17(scale=0.10),
|
237 |
+
Block17(scale=0.10),
|
238 |
+
Block17(scale=0.10),
|
239 |
+
Block17(scale=0.10),
|
240 |
+
Block17(scale=0.10),
|
241 |
+
Block17(scale=0.10),
|
242 |
+
Block17(scale=0.10),
|
243 |
+
Block17(scale=0.10),
|
244 |
+
Block17(scale=0.10),
|
245 |
+
)
|
246 |
+
self.mixed_7a = Mixed_7a()
|
247 |
+
self.repeat_3 = nn.Sequential(
|
248 |
+
Block8(scale=0.20),
|
249 |
+
Block8(scale=0.20),
|
250 |
+
Block8(scale=0.20),
|
251 |
+
Block8(scale=0.20),
|
252 |
+
Block8(scale=0.20),
|
253 |
+
)
|
254 |
+
self.block8 = Block8(noReLU=True)
|
255 |
+
self.avgpool_1a = nn.AdaptiveAvgPool2d(1)
|
256 |
+
self.dropout = nn.Dropout(dropout_prob)
|
257 |
+
self.last_linear = nn.Linear(1792, 512, bias=False)
|
258 |
+
self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
|
259 |
+
|
260 |
+
if pretrained is not None:
|
261 |
+
self.logits = nn.Linear(512, tmp_classes)
|
262 |
+
load_weights(self, pretrained)
|
263 |
+
|
264 |
+
if self.classify and self.num_classes is not None:
|
265 |
+
self.logits = nn.Linear(512, self.num_classes)
|
266 |
+
|
267 |
+
self.device = torch.device('cpu')
|
268 |
+
if device is not None:
|
269 |
+
self.device = device
|
270 |
+
self.to(device)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
"""Calculate embeddings or logits given a batch of input image tensors.
|
274 |
+
|
275 |
+
Arguments:
|
276 |
+
x {torch.tensor} -- Batch of image tensors representing faces.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
torch.tensor -- Batch of embedding vectors or multinomial logits.
|
280 |
+
"""
|
281 |
+
x = self.conv2d_1a(x)
|
282 |
+
x = self.conv2d_2a(x)
|
283 |
+
x = self.conv2d_2b(x)
|
284 |
+
x = self.maxpool_3a(x)
|
285 |
+
x = self.conv2d_3b(x)
|
286 |
+
x = self.conv2d_4a(x)
|
287 |
+
x = self.conv2d_4b(x)
|
288 |
+
x = self.repeat_1(x)
|
289 |
+
x = self.mixed_6a(x)
|
290 |
+
x = self.repeat_2(x)
|
291 |
+
x = self.mixed_7a(x)
|
292 |
+
x = self.repeat_3(x)
|
293 |
+
x = self.block8(x)
|
294 |
+
x = self.avgpool_1a(x)
|
295 |
+
x = self.dropout(x)
|
296 |
+
x = self.last_linear(x.view(x.shape[0], -1))
|
297 |
+
x = self.last_bn(x)
|
298 |
+
if self.classify:
|
299 |
+
x = self.logits(x)
|
300 |
+
else:
|
301 |
+
x = F.normalize(x, p=2, dim=1)
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
def load_weights(mdl, name):
|
306 |
+
"""Download pretrained state_dict and load into model.
|
307 |
+
|
308 |
+
Arguments:
|
309 |
+
mdl {torch.nn.Module} -- Pytorch model.
|
310 |
+
name {str} -- Name of dataset that was used to generate pretrained state_dict.
|
311 |
+
|
312 |
+
Raises:
|
313 |
+
ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
|
314 |
+
"""
|
315 |
+
if name == 'vggface2':
|
316 |
+
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
|
317 |
+
elif name == 'casia-webface':
|
318 |
+
path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
|
319 |
+
else:
|
320 |
+
raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')
|
321 |
+
|
322 |
+
model_dir = os.path.join(get_torch_home(), 'checkpoints')
|
323 |
+
os.makedirs(model_dir, exist_ok=True)
|
324 |
+
|
325 |
+
cached_file = os.path.join(model_dir, os.path.basename(path))
|
326 |
+
if not os.path.exists(cached_file):
|
327 |
+
download_url_to_file(path, cached_file)
|
328 |
+
|
329 |
+
state_dict = torch.load(cached_file)
|
330 |
+
mdl.load_state_dict(state_dict)
|
331 |
+
|
332 |
+
|
333 |
+
def get_torch_home():
|
334 |
+
torch_home = os.path.expanduser(
|
335 |
+
os.getenv(
|
336 |
+
'TORCH_HOME',
|
337 |
+
os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
|
338 |
+
)
|
339 |
+
)
|
340 |
+
return torch_home
|
models/mtcnn.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
|
6 |
+
from .utils.detect_face import detect_face, extract_face
|
7 |
+
|
8 |
+
|
9 |
+
class PNet(nn.Module):
|
10 |
+
"""MTCNN PNet.
|
11 |
+
|
12 |
+
Keyword Arguments:
|
13 |
+
pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, pretrained=True):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
|
20 |
+
self.prelu1 = nn.PReLU(10)
|
21 |
+
self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
22 |
+
self.conv2 = nn.Conv2d(10, 16, kernel_size=3)
|
23 |
+
self.prelu2 = nn.PReLU(16)
|
24 |
+
self.conv3 = nn.Conv2d(16, 32, kernel_size=3)
|
25 |
+
self.prelu3 = nn.PReLU(32)
|
26 |
+
self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1)
|
27 |
+
self.softmax4_1 = nn.Softmax(dim=1)
|
28 |
+
self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1)
|
29 |
+
|
30 |
+
self.training = False
|
31 |
+
|
32 |
+
if pretrained:
|
33 |
+
state_dict_path = os.path.join(os.path.dirname(__file__), '../data/pnet.pt')
|
34 |
+
state_dict = torch.load(state_dict_path)
|
35 |
+
self.load_state_dict(state_dict)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.conv1(x)
|
39 |
+
x = self.prelu1(x)
|
40 |
+
x = self.pool1(x)
|
41 |
+
x = self.conv2(x)
|
42 |
+
x = self.prelu2(x)
|
43 |
+
x = self.conv3(x)
|
44 |
+
x = self.prelu3(x)
|
45 |
+
a = self.conv4_1(x)
|
46 |
+
a = self.softmax4_1(a)
|
47 |
+
b = self.conv4_2(x)
|
48 |
+
return b, a
|
49 |
+
|
50 |
+
|
51 |
+
class RNet(nn.Module):
|
52 |
+
"""MTCNN RNet.
|
53 |
+
|
54 |
+
Keyword Arguments:
|
55 |
+
pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, pretrained=True):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.conv1 = nn.Conv2d(3, 28, kernel_size=3)
|
62 |
+
self.prelu1 = nn.PReLU(28)
|
63 |
+
self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
|
64 |
+
self.conv2 = nn.Conv2d(28, 48, kernel_size=3)
|
65 |
+
self.prelu2 = nn.PReLU(48)
|
66 |
+
self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
|
67 |
+
self.conv3 = nn.Conv2d(48, 64, kernel_size=2)
|
68 |
+
self.prelu3 = nn.PReLU(64)
|
69 |
+
self.dense4 = nn.Linear(576, 128)
|
70 |
+
self.prelu4 = nn.PReLU(128)
|
71 |
+
self.dense5_1 = nn.Linear(128, 2)
|
72 |
+
self.softmax5_1 = nn.Softmax(dim=1)
|
73 |
+
self.dense5_2 = nn.Linear(128, 4)
|
74 |
+
|
75 |
+
self.training = False
|
76 |
+
|
77 |
+
if pretrained:
|
78 |
+
state_dict_path = os.path.join(os.path.dirname(__file__), '../data/rnet.pt')
|
79 |
+
state_dict = torch.load(state_dict_path)
|
80 |
+
self.load_state_dict(state_dict)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = self.conv1(x)
|
84 |
+
x = self.prelu1(x)
|
85 |
+
x = self.pool1(x)
|
86 |
+
x = self.conv2(x)
|
87 |
+
x = self.prelu2(x)
|
88 |
+
x = self.pool2(x)
|
89 |
+
x = self.conv3(x)
|
90 |
+
x = self.prelu3(x)
|
91 |
+
x = x.permute(0, 3, 2, 1).contiguous()
|
92 |
+
x = self.dense4(x.view(x.shape[0], -1))
|
93 |
+
x = self.prelu4(x)
|
94 |
+
a = self.dense5_1(x)
|
95 |
+
a = self.softmax5_1(a)
|
96 |
+
b = self.dense5_2(x)
|
97 |
+
return b, a
|
98 |
+
|
99 |
+
|
100 |
+
class ONet(nn.Module):
|
101 |
+
"""MTCNN ONet.
|
102 |
+
|
103 |
+
Keyword Arguments:
|
104 |
+
pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, pretrained=True):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
|
111 |
+
self.prelu1 = nn.PReLU(32)
|
112 |
+
self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
|
113 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
|
114 |
+
self.prelu2 = nn.PReLU(64)
|
115 |
+
self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
|
116 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
|
117 |
+
self.prelu3 = nn.PReLU(64)
|
118 |
+
self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
119 |
+
self.conv4 = nn.Conv2d(64, 128, kernel_size=2)
|
120 |
+
self.prelu4 = nn.PReLU(128)
|
121 |
+
self.dense5 = nn.Linear(1152, 256)
|
122 |
+
self.prelu5 = nn.PReLU(256)
|
123 |
+
self.dense6_1 = nn.Linear(256, 2)
|
124 |
+
self.softmax6_1 = nn.Softmax(dim=1)
|
125 |
+
self.dense6_2 = nn.Linear(256, 4)
|
126 |
+
self.dense6_3 = nn.Linear(256, 10)
|
127 |
+
|
128 |
+
self.training = False
|
129 |
+
|
130 |
+
if pretrained:
|
131 |
+
state_dict_path = os.path.join(os.path.dirname(__file__), '../data/onet.pt')
|
132 |
+
state_dict = torch.load(state_dict_path)
|
133 |
+
self.load_state_dict(state_dict)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
x = self.conv1(x)
|
137 |
+
x = self.prelu1(x)
|
138 |
+
x = self.pool1(x)
|
139 |
+
x = self.conv2(x)
|
140 |
+
x = self.prelu2(x)
|
141 |
+
x = self.pool2(x)
|
142 |
+
x = self.conv3(x)
|
143 |
+
x = self.prelu3(x)
|
144 |
+
x = self.pool3(x)
|
145 |
+
x = self.conv4(x)
|
146 |
+
x = self.prelu4(x)
|
147 |
+
x = x.permute(0, 3, 2, 1).contiguous()
|
148 |
+
x = self.dense5(x.view(x.shape[0], -1))
|
149 |
+
x = self.prelu5(x)
|
150 |
+
a = self.dense6_1(x)
|
151 |
+
a = self.softmax6_1(a)
|
152 |
+
b = self.dense6_2(x)
|
153 |
+
c = self.dense6_3(x)
|
154 |
+
return b, c, a
|
155 |
+
|
156 |
+
|
157 |
+
class MTCNN(nn.Module):
|
158 |
+
"""MTCNN face detection module.
|
159 |
+
|
160 |
+
This class loads pretrained P-, R-, and O-nets and returns images cropped to include the face
|
161 |
+
only, given raw input images of one of the following types:
|
162 |
+
- PIL image or list of PIL images
|
163 |
+
- numpy.ndarray (uint8) representing either a single image (3D) or a batch of images (4D).
|
164 |
+
Cropped faces can optionally be saved to file
|
165 |
+
also.
|
166 |
+
|
167 |
+
Keyword Arguments:
|
168 |
+
image_size {int} -- Output image size in pixels. The image will be square. (default: {160})
|
169 |
+
margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
|
170 |
+
Note that the application of the margin differs slightly from the davidsandberg/facenet
|
171 |
+
repo, which applies the margin to the original image before resizing, making the margin
|
172 |
+
dependent on the original image size (this is a bug in davidsandberg/facenet).
|
173 |
+
(default: {0})
|
174 |
+
min_face_size {int} -- Minimum face size to search for. (default: {20})
|
175 |
+
thresholds {list} -- MTCNN face detection thresholds (default: {[0.6, 0.7, 0.7]})
|
176 |
+
factor {float} -- Factor used to create a scaling pyramid of face sizes. (default: {0.709})
|
177 |
+
post_process {bool} -- Whether or not to post process images tensors before returning.
|
178 |
+
(default: {True})
|
179 |
+
select_largest {bool} -- If True, if multiple faces are detected, the largest is returned.
|
180 |
+
If False, the face with the highest detection probability is returned.
|
181 |
+
(default: {True})
|
182 |
+
selection_method {string} -- Which heuristic to use for selection. Default None. If
|
183 |
+
specified, will override select_largest:
|
184 |
+
"probability": highest probability selected
|
185 |
+
"largest": largest box selected
|
186 |
+
"largest_over_threshold": largest box over a certain probability selected
|
187 |
+
"center_weighted_size": box size minus weighted squared offset from image center
|
188 |
+
(default: {None})
|
189 |
+
keep_all {bool} -- If True, all detected faces are returned, in the order dictated by the
|
190 |
+
select_largest parameter. If a save_path is specified, the first face is saved to that
|
191 |
+
path and the remaining faces are saved to <save_path>1, <save_path>2 etc.
|
192 |
+
(default: {False})
|
193 |
+
device {torch.device} -- The device on which to run neural net passes. Image tensors and
|
194 |
+
models are copied to this device before running forward passes. (default: {None})
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(
|
198 |
+
self, image_size=160, margin=0, min_face_size=20,
|
199 |
+
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
|
200 |
+
select_largest=True, selection_method=None, keep_all=False, device=None
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
|
204 |
+
self.image_size = image_size
|
205 |
+
self.margin = margin
|
206 |
+
self.min_face_size = min_face_size
|
207 |
+
self.thresholds = thresholds
|
208 |
+
self.factor = factor
|
209 |
+
self.post_process = post_process
|
210 |
+
self.select_largest = select_largest
|
211 |
+
self.keep_all = keep_all
|
212 |
+
self.selection_method = selection_method
|
213 |
+
|
214 |
+
self.pnet = PNet()
|
215 |
+
self.rnet = RNet()
|
216 |
+
self.onet = ONet()
|
217 |
+
|
218 |
+
self.device = torch.device('cpu')
|
219 |
+
if device is not None:
|
220 |
+
self.device = device
|
221 |
+
self.to(device)
|
222 |
+
|
223 |
+
if not self.selection_method:
|
224 |
+
self.selection_method = 'largest' if self.select_largest else 'probability'
|
225 |
+
|
226 |
+
def forward(self, img, save_path=None, return_prob=False):
|
227 |
+
"""Run MTCNN face detection on a PIL image or numpy array. This method performs both
|
228 |
+
detection and extraction of faces, returning tensors representing detected faces rather
|
229 |
+
than the bounding boxes. To access bounding boxes, see the MTCNN.detect() method below.
|
230 |
+
|
231 |
+
Arguments:
|
232 |
+
img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
|
233 |
+
|
234 |
+
Keyword Arguments:
|
235 |
+
save_path {str} -- An optional save path for the cropped image. Note that when
|
236 |
+
self.post_process=True, although the returned tensor is post processed, the saved
|
237 |
+
face image is not, so it is a true representation of the face in the input image.
|
238 |
+
If `img` is a list of images, `save_path` should be a list of equal length.
|
239 |
+
(default: {None})
|
240 |
+
return_prob {bool} -- Whether or not to return the detection probability.
|
241 |
+
(default: {False})
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Union[torch.Tensor, tuple(torch.tensor, float)] -- If detected, cropped image of a face
|
245 |
+
with dimensions 3 x image_size x image_size. Optionally, the probability that a
|
246 |
+
face was detected. If self.keep_all is True, n detected faces are returned in an
|
247 |
+
n x 3 x image_size x image_size tensor with an optional list of detection
|
248 |
+
probabilities. If `img` is a list of images, the item(s) returned have an extra
|
249 |
+
dimension (batch) as the first dimension.
|
250 |
+
|
251 |
+
Example:
|
252 |
+
>>> from facenet_pytorch import MTCNN
|
253 |
+
>>> mtcnn = MTCNN()
|
254 |
+
>>> face_tensor, prob = mtcnn(img, save_path='face.png', return_prob=True)
|
255 |
+
"""
|
256 |
+
|
257 |
+
# Detect faces
|
258 |
+
batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True)
|
259 |
+
# Select faces
|
260 |
+
if not self.keep_all:
|
261 |
+
batch_boxes, batch_probs, batch_points = self.select_boxes(
|
262 |
+
batch_boxes, batch_probs, batch_points, img, method=self.selection_method
|
263 |
+
)
|
264 |
+
# Extract faces
|
265 |
+
faces = self.extract(img, batch_boxes, save_path)
|
266 |
+
|
267 |
+
if return_prob:
|
268 |
+
return faces, batch_probs
|
269 |
+
else:
|
270 |
+
return faces
|
271 |
+
|
272 |
+
def detect(self, img, landmarks=False):
|
273 |
+
"""Detect all faces in PIL image and return bounding boxes and optional facial landmarks.
|
274 |
+
|
275 |
+
This method is used by the forward method and is also useful for face detection tasks
|
276 |
+
that require lower-level handling of bounding boxes and facial landmarks (e.g., face
|
277 |
+
tracking). The functionality of the forward function can be emulated by using this method
|
278 |
+
followed by the extract_face() function.
|
279 |
+
|
280 |
+
Arguments:
|
281 |
+
img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
|
282 |
+
|
283 |
+
Keyword Arguments:
|
284 |
+
landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes.
|
285 |
+
(default: {False})
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an
|
289 |
+
Nx4 array of bounding boxes and a length N list of detection probabilities.
|
290 |
+
Returned boxes will be sorted in descending order by detection probability if
|
291 |
+
self.select_largest=False, otherwise the largest face will be returned first.
|
292 |
+
If `img` is a list of images, the items returned have an extra dimension
|
293 |
+
(batch) as the first dimension. Optionally, a third item, the facial landmarks,
|
294 |
+
are returned if `landmarks=True`.
|
295 |
+
|
296 |
+
Example:
|
297 |
+
>>> from PIL import Image, ImageDraw
|
298 |
+
>>> from facenet_pytorch import MTCNN, extract_face
|
299 |
+
>>> mtcnn = MTCNN(keep_all=True)
|
300 |
+
>>> boxes, probs, points = mtcnn.detect(img, landmarks=True)
|
301 |
+
>>> # Draw boxes and save faces
|
302 |
+
>>> img_draw = img.copy()
|
303 |
+
>>> draw = ImageDraw.Draw(img_draw)
|
304 |
+
>>> for i, (box, point) in enumerate(zip(boxes, points)):
|
305 |
+
... draw.rectangle(box.tolist(), width=5)
|
306 |
+
... for p in point:
|
307 |
+
... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10)
|
308 |
+
... extract_face(img, box, save_path='detected_face_{}.png'.format(i))
|
309 |
+
>>> img_draw.save('annotated_faces.png')
|
310 |
+
"""
|
311 |
+
|
312 |
+
with torch.no_grad():
|
313 |
+
batch_boxes, batch_points = detect_face(
|
314 |
+
img, self.min_face_size,
|
315 |
+
self.pnet, self.rnet, self.onet,
|
316 |
+
self.thresholds, self.factor,
|
317 |
+
self.device
|
318 |
+
)
|
319 |
+
|
320 |
+
boxes, probs, points = [], [], []
|
321 |
+
for box, point in zip(batch_boxes, batch_points):
|
322 |
+
box = np.array(box)
|
323 |
+
point = np.array(point)
|
324 |
+
if len(box) == 0:
|
325 |
+
boxes.append(None)
|
326 |
+
probs.append([None])
|
327 |
+
points.append(None)
|
328 |
+
elif self.select_largest:
|
329 |
+
box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]
|
330 |
+
box = box[box_order]
|
331 |
+
point = point[box_order]
|
332 |
+
boxes.append(box[:, :4])
|
333 |
+
probs.append(box[:, 4])
|
334 |
+
points.append(point)
|
335 |
+
else:
|
336 |
+
boxes.append(box[:, :4])
|
337 |
+
probs.append(box[:, 4])
|
338 |
+
points.append(point)
|
339 |
+
boxes = np.array(boxes, dtype=object)
|
340 |
+
probs = np.array(probs, dtype=object)
|
341 |
+
points = np.array(points, dtype=object)
|
342 |
+
|
343 |
+
if (
|
344 |
+
not isinstance(img, (list, tuple)) and
|
345 |
+
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
|
346 |
+
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
|
347 |
+
):
|
348 |
+
boxes = boxes[0]
|
349 |
+
probs = probs[0]
|
350 |
+
points = points[0]
|
351 |
+
|
352 |
+
if landmarks:
|
353 |
+
return boxes, probs, points
|
354 |
+
|
355 |
+
return boxes, probs
|
356 |
+
|
357 |
+
def select_boxes(
|
358 |
+
self, all_boxes, all_probs, all_points, imgs, method='probability', threshold=0.9,
|
359 |
+
center_weight=2.0
|
360 |
+
):
|
361 |
+
"""Selects a single box from multiple for a given image using one of multiple heuristics.
|
362 |
+
|
363 |
+
Arguments:
|
364 |
+
all_boxes {np.ndarray} -- Ix0 ndarray where each element is a Nx4 ndarry of
|
365 |
+
bounding boxes for N detected faces in I images (output from self.detect).
|
366 |
+
all_probs {np.ndarray} -- Ix0 ndarray where each element is a Nx0 ndarry of
|
367 |
+
probabilities for N detected faces in I images (output from self.detect).
|
368 |
+
all_points {np.ndarray} -- Ix0 ndarray where each element is a Nx5x2 array of
|
369 |
+
points for N detected faces. (output from self.detect).
|
370 |
+
imgs {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
|
371 |
+
|
372 |
+
Keyword Arguments:
|
373 |
+
method {str} -- Which heuristic to use for selection:
|
374 |
+
"probability": highest probability selected
|
375 |
+
"largest": largest box selected
|
376 |
+
"largest_over_theshold": largest box over a certain probability selected
|
377 |
+
"center_weighted_size": box size minus weighted squared offset from image center
|
378 |
+
(default: {'probability'})
|
379 |
+
threshold {float} -- theshold for "largest_over_threshold" method. (default: {0.9})
|
380 |
+
center_weight {float} -- weight for squared offset in center weighted size method.
|
381 |
+
(default: {2.0})
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray) -- nx4 ndarray of bounding boxes
|
385 |
+
for n images. Ix0 array of probabilities for each box, array of landmark points.
|
386 |
+
"""
|
387 |
+
|
388 |
+
#copying batch detection from extract, but would be easier to ensure detect creates consistent output.
|
389 |
+
batch_mode = True
|
390 |
+
if (
|
391 |
+
not isinstance(imgs, (list, tuple)) and
|
392 |
+
not (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4) and
|
393 |
+
not (isinstance(imgs, torch.Tensor) and len(imgs.shape) == 4)
|
394 |
+
):
|
395 |
+
imgs = [imgs]
|
396 |
+
all_boxes = [all_boxes]
|
397 |
+
all_probs = [all_probs]
|
398 |
+
all_points = [all_points]
|
399 |
+
batch_mode = False
|
400 |
+
|
401 |
+
selected_boxes, selected_probs, selected_points = [], [], []
|
402 |
+
for boxes, points, probs, img in zip(all_boxes, all_points, all_probs, imgs):
|
403 |
+
|
404 |
+
if boxes is None:
|
405 |
+
selected_boxes.append(None)
|
406 |
+
selected_probs.append([None])
|
407 |
+
selected_points.append(None)
|
408 |
+
continue
|
409 |
+
|
410 |
+
# If at least 1 box found
|
411 |
+
boxes = np.array(boxes)
|
412 |
+
probs = np.array(probs)
|
413 |
+
points = np.array(points)
|
414 |
+
|
415 |
+
if method == 'largest':
|
416 |
+
box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1]
|
417 |
+
elif method == 'probability':
|
418 |
+
box_order = np.argsort(probs)[::-1]
|
419 |
+
elif method == 'center_weighted_size':
|
420 |
+
box_sizes = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
421 |
+
img_center = (img.width / 2, img.height/2)
|
422 |
+
box_centers = np.array(list(zip((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2)))
|
423 |
+
offsets = box_centers - img_center
|
424 |
+
offset_dist_squared = np.sum(np.power(offsets, 2.0), 1)
|
425 |
+
box_order = np.argsort(box_sizes - offset_dist_squared * center_weight)[::-1]
|
426 |
+
elif method == 'largest_over_threshold':
|
427 |
+
box_mask = probs > threshold
|
428 |
+
boxes = boxes[box_mask]
|
429 |
+
box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1]
|
430 |
+
if sum(box_mask) == 0:
|
431 |
+
selected_boxes.append(None)
|
432 |
+
selected_probs.append([None])
|
433 |
+
selected_points.append(None)
|
434 |
+
continue
|
435 |
+
|
436 |
+
box = boxes[box_order][[0]]
|
437 |
+
prob = probs[box_order][[0]]
|
438 |
+
point = points[box_order][[0]]
|
439 |
+
selected_boxes.append(box)
|
440 |
+
selected_probs.append(prob)
|
441 |
+
selected_points.append(point)
|
442 |
+
|
443 |
+
if batch_mode:
|
444 |
+
selected_boxes = np.array(selected_boxes)
|
445 |
+
selected_probs = np.array(selected_probs)
|
446 |
+
selected_points = np.array(selected_points)
|
447 |
+
else:
|
448 |
+
selected_boxes = selected_boxes[0]
|
449 |
+
selected_probs = selected_probs[0][0]
|
450 |
+
selected_points = selected_points[0]
|
451 |
+
|
452 |
+
return selected_boxes, selected_probs, selected_points
|
453 |
+
|
454 |
+
def extract(self, img, batch_boxes, save_path):
|
455 |
+
# Determine if a batch or single image was passed
|
456 |
+
batch_mode = True
|
457 |
+
if (
|
458 |
+
not isinstance(img, (list, tuple)) and
|
459 |
+
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
|
460 |
+
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
|
461 |
+
):
|
462 |
+
img = [img]
|
463 |
+
batch_boxes = [batch_boxes]
|
464 |
+
batch_mode = False
|
465 |
+
|
466 |
+
# Parse save path(s)
|
467 |
+
if save_path is not None:
|
468 |
+
if isinstance(save_path, str):
|
469 |
+
save_path = [save_path]
|
470 |
+
else:
|
471 |
+
save_path = [None for _ in range(len(img))]
|
472 |
+
|
473 |
+
# Process all bounding boxes
|
474 |
+
faces = []
|
475 |
+
for im, box_im, path_im in zip(img, batch_boxes, save_path):
|
476 |
+
if box_im is None:
|
477 |
+
faces.append(None)
|
478 |
+
continue
|
479 |
+
|
480 |
+
if not self.keep_all:
|
481 |
+
box_im = box_im[[0]]
|
482 |
+
|
483 |
+
faces_im = []
|
484 |
+
for i, box in enumerate(box_im):
|
485 |
+
face_path = path_im
|
486 |
+
if path_im is not None and i > 0:
|
487 |
+
save_name, ext = os.path.splitext(path_im)
|
488 |
+
face_path = save_name + '_' + str(i + 1) + ext
|
489 |
+
|
490 |
+
face = extract_face(im, box, self.image_size, self.margin, face_path)
|
491 |
+
if self.post_process:
|
492 |
+
face = fixed_image_standardization(face)
|
493 |
+
faces_im.append(face)
|
494 |
+
|
495 |
+
if self.keep_all:
|
496 |
+
faces_im = torch.stack(faces_im)
|
497 |
+
else:
|
498 |
+
faces_im = faces_im[0]
|
499 |
+
|
500 |
+
faces.append(faces_im)
|
501 |
+
|
502 |
+
if not batch_mode:
|
503 |
+
faces = faces[0]
|
504 |
+
|
505 |
+
return faces
|
506 |
+
|
507 |
+
|
508 |
+
def fixed_image_standardization(image_tensor):
|
509 |
+
processed_tensor = (image_tensor - 127.5) / 128.0
|
510 |
+
return processed_tensor
|
511 |
+
|
512 |
+
|
513 |
+
def prewhiten(x):
|
514 |
+
mean = x.mean()
|
515 |
+
std = x.std()
|
516 |
+
std_adj = std.clamp(min=1.0/(float(x.numel())**0.5))
|
517 |
+
y = (x - mean) / std_adj
|
518 |
+
return y
|
519 |
+
|
models/utils/detect_face.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.functional import interpolate
|
3 |
+
from torchvision.transforms import functional as F
|
4 |
+
from torchvision.ops.boxes import batched_nms
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import math
|
9 |
+
|
10 |
+
# OpenCV is optional, but required if using numpy arrays instead of PIL
|
11 |
+
try:
|
12 |
+
import cv2
|
13 |
+
except:
|
14 |
+
pass
|
15 |
+
|
16 |
+
def fixed_batch_process(im_data, model):
|
17 |
+
batch_size = 512
|
18 |
+
out = []
|
19 |
+
for i in range(0, len(im_data), batch_size):
|
20 |
+
batch = im_data[i:(i+batch_size)]
|
21 |
+
out.append(model(batch))
|
22 |
+
|
23 |
+
return tuple(torch.cat(v, dim=0) for v in zip(*out))
|
24 |
+
|
25 |
+
def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
|
26 |
+
if isinstance(imgs, (np.ndarray, torch.Tensor)):
|
27 |
+
if isinstance(imgs,np.ndarray):
|
28 |
+
imgs = torch.as_tensor(imgs.copy(), device=device)
|
29 |
+
|
30 |
+
if isinstance(imgs,torch.Tensor):
|
31 |
+
imgs = torch.as_tensor(imgs, device=device)
|
32 |
+
|
33 |
+
if len(imgs.shape) == 3:
|
34 |
+
imgs = imgs.unsqueeze(0)
|
35 |
+
else:
|
36 |
+
if not isinstance(imgs, (list, tuple)):
|
37 |
+
imgs = [imgs]
|
38 |
+
if any(img.size != imgs[0].size for img in imgs):
|
39 |
+
raise Exception("MTCNN batch processing only compatible with equal-dimension images.")
|
40 |
+
imgs = np.stack([np.uint8(img) for img in imgs])
|
41 |
+
imgs = torch.as_tensor(imgs.copy(), device=device)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
model_dtype = next(pnet.parameters()).dtype
|
46 |
+
imgs = imgs.permute(0, 3, 1, 2).type(model_dtype)
|
47 |
+
|
48 |
+
batch_size = len(imgs)
|
49 |
+
h, w = imgs.shape[2:4]
|
50 |
+
m = 12.0 / minsize
|
51 |
+
minl = min(h, w)
|
52 |
+
minl = minl * m
|
53 |
+
|
54 |
+
# Create scale pyramid
|
55 |
+
scale_i = m
|
56 |
+
scales = []
|
57 |
+
while minl >= 12:
|
58 |
+
scales.append(scale_i)
|
59 |
+
scale_i = scale_i * factor
|
60 |
+
minl = minl * factor
|
61 |
+
|
62 |
+
# First stage
|
63 |
+
boxes = []
|
64 |
+
image_inds = []
|
65 |
+
|
66 |
+
scale_picks = []
|
67 |
+
|
68 |
+
all_i = 0
|
69 |
+
offset = 0
|
70 |
+
for scale in scales:
|
71 |
+
im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1)))
|
72 |
+
im_data = (im_data - 127.5) * 0.0078125
|
73 |
+
reg, probs = pnet(im_data)
|
74 |
+
|
75 |
+
boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, threshold[0])
|
76 |
+
boxes.append(boxes_scale)
|
77 |
+
image_inds.append(image_inds_scale)
|
78 |
+
|
79 |
+
pick = batched_nms(boxes_scale[:, :4], boxes_scale[:, 4], image_inds_scale, 0.5)
|
80 |
+
scale_picks.append(pick + offset)
|
81 |
+
offset += boxes_scale.shape[0]
|
82 |
+
|
83 |
+
boxes = torch.cat(boxes, dim=0)
|
84 |
+
image_inds = torch.cat(image_inds, dim=0)
|
85 |
+
|
86 |
+
scale_picks = torch.cat(scale_picks, dim=0)
|
87 |
+
|
88 |
+
# NMS within each scale + image
|
89 |
+
boxes, image_inds = boxes[scale_picks], image_inds[scale_picks]
|
90 |
+
|
91 |
+
|
92 |
+
# NMS within each image
|
93 |
+
pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
|
94 |
+
boxes, image_inds = boxes[pick], image_inds[pick]
|
95 |
+
|
96 |
+
regw = boxes[:, 2] - boxes[:, 0]
|
97 |
+
regh = boxes[:, 3] - boxes[:, 1]
|
98 |
+
qq1 = boxes[:, 0] + boxes[:, 5] * regw
|
99 |
+
qq2 = boxes[:, 1] + boxes[:, 6] * regh
|
100 |
+
qq3 = boxes[:, 2] + boxes[:, 7] * regw
|
101 |
+
qq4 = boxes[:, 3] + boxes[:, 8] * regh
|
102 |
+
boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0)
|
103 |
+
boxes = rerec(boxes)
|
104 |
+
y, ey, x, ex = pad(boxes, w, h)
|
105 |
+
|
106 |
+
# Second stage
|
107 |
+
if len(boxes) > 0:
|
108 |
+
im_data = []
|
109 |
+
for k in range(len(y)):
|
110 |
+
if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
|
111 |
+
img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
|
112 |
+
im_data.append(imresample(img_k, (24, 24)))
|
113 |
+
im_data = torch.cat(im_data, dim=0)
|
114 |
+
im_data = (im_data - 127.5) * 0.0078125
|
115 |
+
|
116 |
+
# This is equivalent to out = rnet(im_data) to avoid GPU out of memory.
|
117 |
+
out = fixed_batch_process(im_data, rnet)
|
118 |
+
|
119 |
+
out0 = out[0].permute(1, 0)
|
120 |
+
out1 = out[1].permute(1, 0)
|
121 |
+
score = out1[1, :]
|
122 |
+
ipass = score > threshold[1]
|
123 |
+
boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
|
124 |
+
image_inds = image_inds[ipass]
|
125 |
+
mv = out0[:, ipass].permute(1, 0)
|
126 |
+
|
127 |
+
# NMS within each image
|
128 |
+
pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
|
129 |
+
boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick]
|
130 |
+
boxes = bbreg(boxes, mv)
|
131 |
+
boxes = rerec(boxes)
|
132 |
+
|
133 |
+
# Third stage
|
134 |
+
points = torch.zeros(0, 5, 2, device=device)
|
135 |
+
if len(boxes) > 0:
|
136 |
+
y, ey, x, ex = pad(boxes, w, h)
|
137 |
+
im_data = []
|
138 |
+
for k in range(len(y)):
|
139 |
+
if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
|
140 |
+
img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
|
141 |
+
im_data.append(imresample(img_k, (48, 48)))
|
142 |
+
im_data = torch.cat(im_data, dim=0)
|
143 |
+
im_data = (im_data - 127.5) * 0.0078125
|
144 |
+
|
145 |
+
# This is equivalent to out = onet(im_data) to avoid GPU out of memory.
|
146 |
+
out = fixed_batch_process(im_data, onet)
|
147 |
+
|
148 |
+
out0 = out[0].permute(1, 0)
|
149 |
+
out1 = out[1].permute(1, 0)
|
150 |
+
out2 = out[2].permute(1, 0)
|
151 |
+
score = out2[1, :]
|
152 |
+
points = out1
|
153 |
+
ipass = score > threshold[2]
|
154 |
+
points = points[:, ipass]
|
155 |
+
boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
|
156 |
+
image_inds = image_inds[ipass]
|
157 |
+
mv = out0[:, ipass].permute(1, 0)
|
158 |
+
|
159 |
+
w_i = boxes[:, 2] - boxes[:, 0] + 1
|
160 |
+
h_i = boxes[:, 3] - boxes[:, 1] + 1
|
161 |
+
points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1
|
162 |
+
points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1
|
163 |
+
points = torch.stack((points_x, points_y)).permute(2, 1, 0)
|
164 |
+
boxes = bbreg(boxes, mv)
|
165 |
+
|
166 |
+
# NMS within each image using "Min" strategy
|
167 |
+
# pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
|
168 |
+
pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min')
|
169 |
+
boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick]
|
170 |
+
|
171 |
+
boxes = boxes.cpu().numpy()
|
172 |
+
points = points.cpu().numpy()
|
173 |
+
|
174 |
+
image_inds = image_inds.cpu()
|
175 |
+
|
176 |
+
batch_boxes = []
|
177 |
+
batch_points = []
|
178 |
+
for b_i in range(batch_size):
|
179 |
+
b_i_inds = np.where(image_inds == b_i)
|
180 |
+
batch_boxes.append(boxes[b_i_inds].copy())
|
181 |
+
batch_points.append(points[b_i_inds].copy())
|
182 |
+
|
183 |
+
batch_boxes, batch_points = np.array(batch_boxes, dtype=object), np.array(batch_points, dtype=object)
|
184 |
+
|
185 |
+
return batch_boxes, batch_points
|
186 |
+
|
187 |
+
|
188 |
+
def bbreg(boundingbox, reg):
|
189 |
+
if reg.shape[1] == 1:
|
190 |
+
reg = torch.reshape(reg, (reg.shape[2], reg.shape[3]))
|
191 |
+
|
192 |
+
w = boundingbox[:, 2] - boundingbox[:, 0] + 1
|
193 |
+
h = boundingbox[:, 3] - boundingbox[:, 1] + 1
|
194 |
+
b1 = boundingbox[:, 0] + reg[:, 0] * w
|
195 |
+
b2 = boundingbox[:, 1] + reg[:, 1] * h
|
196 |
+
b3 = boundingbox[:, 2] + reg[:, 2] * w
|
197 |
+
b4 = boundingbox[:, 3] + reg[:, 3] * h
|
198 |
+
boundingbox[:, :4] = torch.stack([b1, b2, b3, b4]).permute(1, 0)
|
199 |
+
|
200 |
+
return boundingbox
|
201 |
+
|
202 |
+
|
203 |
+
def generateBoundingBox(reg, probs, scale, thresh):
|
204 |
+
stride = 2
|
205 |
+
cellsize = 12
|
206 |
+
|
207 |
+
reg = reg.permute(1, 0, 2, 3)
|
208 |
+
|
209 |
+
mask = probs >= thresh
|
210 |
+
mask_inds = mask.nonzero()
|
211 |
+
image_inds = mask_inds[:, 0]
|
212 |
+
score = probs[mask]
|
213 |
+
reg = reg[:, mask].permute(1, 0)
|
214 |
+
bb = mask_inds[:, 1:].type(reg.dtype).flip(1)
|
215 |
+
q1 = ((stride * bb + 1) / scale).floor()
|
216 |
+
q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor()
|
217 |
+
boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1)
|
218 |
+
return boundingbox, image_inds
|
219 |
+
|
220 |
+
|
221 |
+
def nms_numpy(boxes, scores, threshold, method):
|
222 |
+
if boxes.size == 0:
|
223 |
+
return np.empty((0, 3))
|
224 |
+
|
225 |
+
x1 = boxes[:, 0].copy()
|
226 |
+
y1 = boxes[:, 1].copy()
|
227 |
+
x2 = boxes[:, 2].copy()
|
228 |
+
y2 = boxes[:, 3].copy()
|
229 |
+
s = scores
|
230 |
+
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
231 |
+
|
232 |
+
I = np.argsort(s)
|
233 |
+
pick = np.zeros_like(s, dtype=np.int16)
|
234 |
+
counter = 0
|
235 |
+
while I.size > 0:
|
236 |
+
i = I[-1]
|
237 |
+
pick[counter] = i
|
238 |
+
counter += 1
|
239 |
+
idx = I[0:-1]
|
240 |
+
|
241 |
+
xx1 = np.maximum(x1[i], x1[idx]).copy()
|
242 |
+
yy1 = np.maximum(y1[i], y1[idx]).copy()
|
243 |
+
xx2 = np.minimum(x2[i], x2[idx]).copy()
|
244 |
+
yy2 = np.minimum(y2[i], y2[idx]).copy()
|
245 |
+
|
246 |
+
w = np.maximum(0.0, xx2 - xx1 + 1).copy()
|
247 |
+
h = np.maximum(0.0, yy2 - yy1 + 1).copy()
|
248 |
+
|
249 |
+
inter = w * h
|
250 |
+
if method == 'Min':
|
251 |
+
o = inter / np.minimum(area[i], area[idx])
|
252 |
+
else:
|
253 |
+
o = inter / (area[i] + area[idx] - inter)
|
254 |
+
I = I[np.where(o <= threshold)]
|
255 |
+
|
256 |
+
pick = pick[:counter].copy()
|
257 |
+
return pick
|
258 |
+
|
259 |
+
|
260 |
+
def batched_nms_numpy(boxes, scores, idxs, threshold, method):
|
261 |
+
device = boxes.device
|
262 |
+
if boxes.numel() == 0:
|
263 |
+
return torch.empty((0,), dtype=torch.int64, device=device)
|
264 |
+
# strategy: in order to perform NMS independently per class.
|
265 |
+
# we add an offset to all the boxes. The offset is dependent
|
266 |
+
# only on the class idx, and is large enough so that boxes
|
267 |
+
# from different classes do not overlap
|
268 |
+
max_coordinate = boxes.max()
|
269 |
+
offsets = idxs.to(boxes) * (max_coordinate + 1)
|
270 |
+
boxes_for_nms = boxes + offsets[:, None]
|
271 |
+
boxes_for_nms = boxes_for_nms.cpu().numpy()
|
272 |
+
scores = scores.cpu().numpy()
|
273 |
+
keep = nms_numpy(boxes_for_nms, scores, threshold, method)
|
274 |
+
return torch.as_tensor(keep, dtype=torch.long, device=device)
|
275 |
+
|
276 |
+
|
277 |
+
def pad(boxes, w, h):
|
278 |
+
boxes = boxes.trunc().int().cpu().numpy()
|
279 |
+
x = boxes[:, 0]
|
280 |
+
y = boxes[:, 1]
|
281 |
+
ex = boxes[:, 2]
|
282 |
+
ey = boxes[:, 3]
|
283 |
+
|
284 |
+
x[x < 1] = 1
|
285 |
+
y[y < 1] = 1
|
286 |
+
ex[ex > w] = w
|
287 |
+
ey[ey > h] = h
|
288 |
+
|
289 |
+
return y, ey, x, ex
|
290 |
+
|
291 |
+
|
292 |
+
def rerec(bboxA):
|
293 |
+
h = bboxA[:, 3] - bboxA[:, 1]
|
294 |
+
w = bboxA[:, 2] - bboxA[:, 0]
|
295 |
+
|
296 |
+
l = torch.max(w, h)
|
297 |
+
bboxA[:, 0] = bboxA[:, 0] + w * 0.5 - l * 0.5
|
298 |
+
bboxA[:, 1] = bboxA[:, 1] + h * 0.5 - l * 0.5
|
299 |
+
bboxA[:, 2:4] = bboxA[:, :2] + l.repeat(2, 1).permute(1, 0)
|
300 |
+
|
301 |
+
return bboxA
|
302 |
+
|
303 |
+
|
304 |
+
def imresample(img, sz):
|
305 |
+
im_data = interpolate(img, size=sz, mode="area")
|
306 |
+
return im_data
|
307 |
+
|
308 |
+
|
309 |
+
def crop_resize(img, box, image_size):
|
310 |
+
if isinstance(img, np.ndarray):
|
311 |
+
img = img[box[1]:box[3], box[0]:box[2]]
|
312 |
+
out = cv2.resize(
|
313 |
+
img,
|
314 |
+
(image_size, image_size),
|
315 |
+
interpolation=cv2.INTER_AREA
|
316 |
+
).copy()
|
317 |
+
elif isinstance(img, torch.Tensor):
|
318 |
+
img = img[box[1]:box[3], box[0]:box[2]]
|
319 |
+
out = imresample(
|
320 |
+
img.permute(2, 0, 1).unsqueeze(0).float(),
|
321 |
+
(image_size, image_size)
|
322 |
+
).byte().squeeze(0).permute(1, 2, 0)
|
323 |
+
else:
|
324 |
+
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
|
325 |
+
return out
|
326 |
+
|
327 |
+
|
328 |
+
def save_img(img, path):
|
329 |
+
if isinstance(img, np.ndarray):
|
330 |
+
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
331 |
+
else:
|
332 |
+
img.save(path)
|
333 |
+
|
334 |
+
|
335 |
+
def get_size(img):
|
336 |
+
if isinstance(img, (np.ndarray, torch.Tensor)):
|
337 |
+
return img.shape[1::-1]
|
338 |
+
else:
|
339 |
+
return img.size
|
340 |
+
|
341 |
+
|
342 |
+
def extract_face(img, box, image_size=160, margin=0, save_path=None):
|
343 |
+
"""Extract face + margin from PIL Image given bounding box.
|
344 |
+
|
345 |
+
Arguments:
|
346 |
+
img {PIL.Image} -- A PIL Image.
|
347 |
+
box {numpy.ndarray} -- Four-element bounding box.
|
348 |
+
image_size {int} -- Output image size in pixels. The image will be square.
|
349 |
+
margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
|
350 |
+
Note that the application of the margin differs slightly from the davidsandberg/facenet
|
351 |
+
repo, which applies the margin to the original image before resizing, making the margin
|
352 |
+
dependent on the original image size.
|
353 |
+
save_path {str} -- Save path for extracted face image. (default: {None})
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
torch.tensor -- tensor representing the extracted face.
|
357 |
+
"""
|
358 |
+
margin = [
|
359 |
+
margin * (box[2] - box[0]) / (image_size - margin),
|
360 |
+
margin * (box[3] - box[1]) / (image_size - margin),
|
361 |
+
]
|
362 |
+
raw_image_size = get_size(img)
|
363 |
+
box = [
|
364 |
+
int(max(box[0] - margin[0] / 2, 0)),
|
365 |
+
int(max(box[1] - margin[1] / 2, 0)),
|
366 |
+
int(min(box[2] + margin[0] / 2, raw_image_size[0])),
|
367 |
+
int(min(box[3] + margin[1] / 2, raw_image_size[1])),
|
368 |
+
]
|
369 |
+
|
370 |
+
face = crop_resize(img, box, image_size)
|
371 |
+
|
372 |
+
if save_path is not None:
|
373 |
+
os.makedirs(os.path.dirname(save_path) + "/", exist_ok=True)
|
374 |
+
save_img(face, save_path)
|
375 |
+
|
376 |
+
face = F.to_tensor(np.float32(face))
|
377 |
+
|
378 |
+
return face
|
models/utils/download.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import sys
|
5 |
+
import tempfile
|
6 |
+
|
7 |
+
from urllib.request import urlopen, Request
|
8 |
+
|
9 |
+
try:
|
10 |
+
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
|
11 |
+
except ImportError:
|
12 |
+
try:
|
13 |
+
from tqdm import tqdm
|
14 |
+
except ImportError:
|
15 |
+
# fake tqdm if it's not installed
|
16 |
+
class tqdm(object): # type: ignore
|
17 |
+
|
18 |
+
def __init__(self, total=None, disable=False,
|
19 |
+
unit=None, unit_scale=None, unit_divisor=None):
|
20 |
+
self.total = total
|
21 |
+
self.disable = disable
|
22 |
+
self.n = 0
|
23 |
+
# ignore unit, unit_scale, unit_divisor; they're just for real tqdm
|
24 |
+
|
25 |
+
def update(self, n):
|
26 |
+
if self.disable:
|
27 |
+
return
|
28 |
+
|
29 |
+
self.n += n
|
30 |
+
if self.total is None:
|
31 |
+
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
|
32 |
+
else:
|
33 |
+
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
|
34 |
+
sys.stderr.flush()
|
35 |
+
|
36 |
+
def __enter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
40 |
+
if self.disable:
|
41 |
+
return
|
42 |
+
|
43 |
+
sys.stderr.write('\n')
|
44 |
+
|
45 |
+
|
46 |
+
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
47 |
+
r"""Download object at the given URL to a local path.
|
48 |
+
Args:
|
49 |
+
url (string): URL of the object to download
|
50 |
+
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
|
51 |
+
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
|
52 |
+
Default: None
|
53 |
+
progress (bool, optional): whether or not to display a progress bar to stderr
|
54 |
+
Default: True
|
55 |
+
Example:
|
56 |
+
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
|
57 |
+
"""
|
58 |
+
file_size = None
|
59 |
+
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
60 |
+
# certificates in older Python
|
61 |
+
req = Request(url, headers={"User-Agent": "torch.hub"})
|
62 |
+
u = urlopen(req)
|
63 |
+
meta = u.info()
|
64 |
+
if hasattr(meta, 'getheaders'):
|
65 |
+
content_length = meta.getheaders("Content-Length")
|
66 |
+
else:
|
67 |
+
content_length = meta.get_all("Content-Length")
|
68 |
+
if content_length is not None and len(content_length) > 0:
|
69 |
+
file_size = int(content_length[0])
|
70 |
+
|
71 |
+
# We deliberately save it in a temp file and move it after
|
72 |
+
# download is complete. This prevents a local working checkpoint
|
73 |
+
# being overridden by a broken download.
|
74 |
+
dst = os.path.expanduser(dst)
|
75 |
+
dst_dir = os.path.dirname(dst)
|
76 |
+
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
77 |
+
|
78 |
+
try:
|
79 |
+
if hash_prefix is not None:
|
80 |
+
sha256 = hashlib.sha256()
|
81 |
+
with tqdm(total=file_size, disable=not progress,
|
82 |
+
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
83 |
+
while True:
|
84 |
+
buffer = u.read(8192)
|
85 |
+
if len(buffer) == 0:
|
86 |
+
break
|
87 |
+
f.write(buffer)
|
88 |
+
if hash_prefix is not None:
|
89 |
+
sha256.update(buffer)
|
90 |
+
pbar.update(len(buffer))
|
91 |
+
|
92 |
+
f.close()
|
93 |
+
if hash_prefix is not None:
|
94 |
+
digest = sha256.hexdigest()
|
95 |
+
if digest[:len(hash_prefix)] != hash_prefix:
|
96 |
+
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
|
97 |
+
.format(hash_prefix, digest))
|
98 |
+
shutil.move(f.name, dst)
|
99 |
+
finally:
|
100 |
+
f.close()
|
101 |
+
if os.path.exists(f.name):
|
102 |
+
os.remove(f.name)
|
models/utils/tensorflow2pytorch.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import os, sys
|
5 |
+
|
6 |
+
from dependencies.facenet.src import facenet
|
7 |
+
from dependencies.facenet.src.models import inception_resnet_v1 as tf_mdl
|
8 |
+
from dependencies.facenet.src.align import detect_face
|
9 |
+
|
10 |
+
from models.inception_resnet_v1 import InceptionResnetV1
|
11 |
+
from models.mtcnn import PNet, RNet, ONet
|
12 |
+
|
13 |
+
|
14 |
+
def import_tf_params(tf_mdl_dir, sess):
|
15 |
+
"""Import tensorflow model from save directory.
|
16 |
+
|
17 |
+
Arguments:
|
18 |
+
tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
|
19 |
+
sess {tensorflow.Session} -- Tensorflow session object.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
(list, list, list) -- Tuple of lists containing the layer names,
|
23 |
+
parameter arrays as numpy ndarrays, parameter shapes.
|
24 |
+
"""
|
25 |
+
print('\nLoading tensorflow model\n')
|
26 |
+
if callable(tf_mdl_dir):
|
27 |
+
tf_mdl_dir(sess)
|
28 |
+
else:
|
29 |
+
facenet.load_model(tf_mdl_dir)
|
30 |
+
|
31 |
+
print('\nGetting model weights\n')
|
32 |
+
tf_layers = tf.trainable_variables()
|
33 |
+
tf_params = sess.run(tf_layers)
|
34 |
+
|
35 |
+
tf_shapes = [p.shape for p in tf_params]
|
36 |
+
tf_layers = [l.name for l in tf_layers]
|
37 |
+
|
38 |
+
if not callable(tf_mdl_dir):
|
39 |
+
path = os.path.join(tf_mdl_dir, 'layer_description.json')
|
40 |
+
else:
|
41 |
+
path = 'data/layer_description.json'
|
42 |
+
with open(path, 'w') as f:
|
43 |
+
json.dump({l: s for l, s in zip(tf_layers, tf_shapes)}, f)
|
44 |
+
|
45 |
+
return tf_layers, tf_params, tf_shapes
|
46 |
+
|
47 |
+
|
48 |
+
def get_layer_indices(layer_lookup, tf_layers):
|
49 |
+
"""Giving a lookup of model layer attribute names and tensorflow variable names,
|
50 |
+
find matching parameters.
|
51 |
+
|
52 |
+
Arguments:
|
53 |
+
layer_lookup {dict} -- Dictionary mapping pytorch attribute names to (partial)
|
54 |
+
tensorflow variable names. Expects dict of the form {'attr': ['tf_name', ...]}
|
55 |
+
where the '...'s are ignored.
|
56 |
+
tf_layers {list} -- List of tensorflow variable names.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
list -- The input dictionary with the list of matching inds appended to each item.
|
60 |
+
"""
|
61 |
+
layer_inds = {}
|
62 |
+
for name, value in layer_lookup.items():
|
63 |
+
layer_inds[name] = value + [[i for i, n in enumerate(tf_layers) if value[0] in n]]
|
64 |
+
return layer_inds
|
65 |
+
|
66 |
+
|
67 |
+
def load_tf_batchNorm(weights, layer):
|
68 |
+
"""Load tensorflow weights into nn.BatchNorm object.
|
69 |
+
|
70 |
+
Arguments:
|
71 |
+
weights {list} -- Tensorflow parameters.
|
72 |
+
layer {torch.nn.Module} -- nn.BatchNorm.
|
73 |
+
"""
|
74 |
+
layer.bias.data = torch.tensor(weights[0]).view(layer.bias.data.shape)
|
75 |
+
layer.weight.data = torch.ones_like(layer.weight.data)
|
76 |
+
layer.running_mean = torch.tensor(weights[1]).view(layer.running_mean.shape)
|
77 |
+
layer.running_var = torch.tensor(weights[2]).view(layer.running_var.shape)
|
78 |
+
|
79 |
+
|
80 |
+
def load_tf_conv2d(weights, layer, transpose=False):
|
81 |
+
"""Load tensorflow weights into nn.Conv2d object.
|
82 |
+
|
83 |
+
Arguments:
|
84 |
+
weights {list} -- Tensorflow parameters.
|
85 |
+
layer {torch.nn.Module} -- nn.Conv2d.
|
86 |
+
"""
|
87 |
+
if isinstance(weights, list):
|
88 |
+
if len(weights) == 2:
|
89 |
+
layer.bias.data = (
|
90 |
+
torch.tensor(weights[1])
|
91 |
+
.view(layer.bias.data.shape)
|
92 |
+
)
|
93 |
+
weights = weights[0]
|
94 |
+
|
95 |
+
if transpose:
|
96 |
+
dim_order = (3, 2, 1, 0)
|
97 |
+
else:
|
98 |
+
dim_order = (3, 2, 0, 1)
|
99 |
+
|
100 |
+
layer.weight.data = (
|
101 |
+
torch.tensor(weights)
|
102 |
+
.permute(dim_order)
|
103 |
+
.view(layer.weight.data.shape)
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def load_tf_conv2d_trans(weights, layer):
|
108 |
+
return load_tf_conv2d(weights, layer, transpose=True)
|
109 |
+
|
110 |
+
|
111 |
+
def load_tf_basicConv2d(weights, layer):
|
112 |
+
"""Load tensorflow weights into grouped Conv2d+BatchNorm object.
|
113 |
+
|
114 |
+
Arguments:
|
115 |
+
weights {list} -- Tensorflow parameters.
|
116 |
+
layer {torch.nn.Module} -- Object containing Conv2d+BatchNorm.
|
117 |
+
"""
|
118 |
+
load_tf_conv2d(weights[0], layer.conv)
|
119 |
+
load_tf_batchNorm(weights[1:], layer.bn)
|
120 |
+
|
121 |
+
|
122 |
+
def load_tf_linear(weights, layer):
|
123 |
+
"""Load tensorflow weights into nn.Linear object.
|
124 |
+
|
125 |
+
Arguments:
|
126 |
+
weights {list} -- Tensorflow parameters.
|
127 |
+
layer {torch.nn.Module} -- nn.Linear.
|
128 |
+
"""
|
129 |
+
if isinstance(weights, list):
|
130 |
+
if len(weights) == 2:
|
131 |
+
layer.bias.data = (
|
132 |
+
torch.tensor(weights[1])
|
133 |
+
.view(layer.bias.data.shape)
|
134 |
+
)
|
135 |
+
weights = weights[0]
|
136 |
+
layer.weight.data = (
|
137 |
+
torch.tensor(weights)
|
138 |
+
.transpose(-1, 0)
|
139 |
+
.view(layer.weight.data.shape)
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
# High-level parameter-loading functions:
|
144 |
+
|
145 |
+
def load_tf_block35(weights, layer):
|
146 |
+
load_tf_basicConv2d(weights[:4], layer.branch0)
|
147 |
+
load_tf_basicConv2d(weights[4:8], layer.branch1[0])
|
148 |
+
load_tf_basicConv2d(weights[8:12], layer.branch1[1])
|
149 |
+
load_tf_basicConv2d(weights[12:16], layer.branch2[0])
|
150 |
+
load_tf_basicConv2d(weights[16:20], layer.branch2[1])
|
151 |
+
load_tf_basicConv2d(weights[20:24], layer.branch2[2])
|
152 |
+
load_tf_conv2d(weights[24:26], layer.conv2d)
|
153 |
+
|
154 |
+
|
155 |
+
def load_tf_block17_8(weights, layer):
|
156 |
+
load_tf_basicConv2d(weights[:4], layer.branch0)
|
157 |
+
load_tf_basicConv2d(weights[4:8], layer.branch1[0])
|
158 |
+
load_tf_basicConv2d(weights[8:12], layer.branch1[1])
|
159 |
+
load_tf_basicConv2d(weights[12:16], layer.branch1[2])
|
160 |
+
load_tf_conv2d(weights[16:18], layer.conv2d)
|
161 |
+
|
162 |
+
|
163 |
+
def load_tf_mixed6a(weights, layer):
|
164 |
+
if len(weights) != 16:
|
165 |
+
raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 16')
|
166 |
+
load_tf_basicConv2d(weights[:4], layer.branch0)
|
167 |
+
load_tf_basicConv2d(weights[4:8], layer.branch1[0])
|
168 |
+
load_tf_basicConv2d(weights[8:12], layer.branch1[1])
|
169 |
+
load_tf_basicConv2d(weights[12:16], layer.branch1[2])
|
170 |
+
|
171 |
+
|
172 |
+
def load_tf_mixed7a(weights, layer):
|
173 |
+
if len(weights) != 28:
|
174 |
+
raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 28')
|
175 |
+
load_tf_basicConv2d(weights[:4], layer.branch0[0])
|
176 |
+
load_tf_basicConv2d(weights[4:8], layer.branch0[1])
|
177 |
+
load_tf_basicConv2d(weights[8:12], layer.branch1[0])
|
178 |
+
load_tf_basicConv2d(weights[12:16], layer.branch1[1])
|
179 |
+
load_tf_basicConv2d(weights[16:20], layer.branch2[0])
|
180 |
+
load_tf_basicConv2d(weights[20:24], layer.branch2[1])
|
181 |
+
load_tf_basicConv2d(weights[24:28], layer.branch2[2])
|
182 |
+
|
183 |
+
|
184 |
+
def load_tf_repeats(weights, layer, rptlen, subfun):
|
185 |
+
if len(weights) % rptlen != 0:
|
186 |
+
raise ValueError(f'Number of weight arrays ({len(weights)}) not divisible by {rptlen}')
|
187 |
+
weights_split = [weights[i:i+rptlen] for i in range(0, len(weights), rptlen)]
|
188 |
+
for i, w in enumerate(weights_split):
|
189 |
+
subfun(w, getattr(layer, str(i)))
|
190 |
+
|
191 |
+
|
192 |
+
def load_tf_repeat_1(weights, layer):
|
193 |
+
load_tf_repeats(weights, layer, 26, load_tf_block35)
|
194 |
+
|
195 |
+
|
196 |
+
def load_tf_repeat_2(weights, layer):
|
197 |
+
load_tf_repeats(weights, layer, 18, load_tf_block17_8)
|
198 |
+
|
199 |
+
|
200 |
+
def load_tf_repeat_3(weights, layer):
|
201 |
+
load_tf_repeats(weights, layer, 18, load_tf_block17_8)
|
202 |
+
|
203 |
+
|
204 |
+
def test_loaded_params(mdl, tf_params, tf_layers):
|
205 |
+
"""Check each parameter in a pytorch model for an equivalent parameter
|
206 |
+
in a list of tensorflow variables.
|
207 |
+
|
208 |
+
Arguments:
|
209 |
+
mdl {torch.nn.Module} -- Pytorch model.
|
210 |
+
tf_params {list} -- List of ndarrays representing tensorflow variables.
|
211 |
+
tf_layers {list} -- Corresponding list of tensorflow variable names.
|
212 |
+
"""
|
213 |
+
tf_means = torch.stack([torch.tensor(p).mean() for p in tf_params])
|
214 |
+
for name, param in mdl.named_parameters():
|
215 |
+
pt_mean = param.data.mean()
|
216 |
+
matching_inds = ((tf_means - pt_mean).abs() < 1e-8).nonzero()
|
217 |
+
print(f'{name} equivalent to {[tf_layers[i] for i in matching_inds]}')
|
218 |
+
|
219 |
+
|
220 |
+
def compare_model_outputs(pt_mdl, sess, test_data):
|
221 |
+
"""Given some testing data, compare the output of pytorch and tensorflow models.
|
222 |
+
|
223 |
+
Arguments:
|
224 |
+
pt_mdl {torch.nn.Module} -- Pytorch model.
|
225 |
+
sess {tensorflow.Session} -- Tensorflow session object.
|
226 |
+
test_data {torch.Tensor} -- Pytorch tensor.
|
227 |
+
"""
|
228 |
+
print('\nPassing test data through TF model\n')
|
229 |
+
if isinstance(sess, tf.Session):
|
230 |
+
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
|
231 |
+
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
|
232 |
+
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
|
233 |
+
feed_dict = {images_placeholder: test_data.numpy(), phase_train_placeholder: False}
|
234 |
+
tf_output = torch.tensor(sess.run(embeddings, feed_dict=feed_dict))
|
235 |
+
else:
|
236 |
+
tf_output = sess(test_data)
|
237 |
+
|
238 |
+
print(tf_output)
|
239 |
+
|
240 |
+
print('\nPassing test data through PT model\n')
|
241 |
+
pt_output = pt_mdl(test_data.permute(0, 3, 1, 2))
|
242 |
+
print(pt_output)
|
243 |
+
|
244 |
+
distance = (tf_output - pt_output).norm()
|
245 |
+
print(f'\nDistance {distance}\n')
|
246 |
+
|
247 |
+
|
248 |
+
def compare_mtcnn(pt_mdl, tf_fun, sess, ind, test_data):
|
249 |
+
tf_mdls = tf_fun(sess)
|
250 |
+
tf_mdl = tf_mdls[ind]
|
251 |
+
|
252 |
+
print('\nPassing test data through TF model\n')
|
253 |
+
tf_output = tf_mdl(test_data.numpy())
|
254 |
+
tf_output = [torch.tensor(out) for out in tf_output]
|
255 |
+
print('\n'.join([str(o.view(-1)[:10]) for o in tf_output]))
|
256 |
+
|
257 |
+
print('\nPassing test data through PT model\n')
|
258 |
+
with torch.no_grad():
|
259 |
+
pt_output = pt_mdl(test_data.permute(0, 3, 2, 1))
|
260 |
+
pt_output = [torch.tensor(out) for out in pt_output]
|
261 |
+
for i in range(len(pt_output)):
|
262 |
+
if len(pt_output[i].shape) == 4:
|
263 |
+
pt_output[i] = pt_output[i].permute(0, 3, 2, 1).contiguous()
|
264 |
+
print('\n'.join([str(o.view(-1)[:10]) for o in pt_output]))
|
265 |
+
|
266 |
+
distance = [(tf_o - pt_o).norm() for tf_o, pt_o in zip(tf_output, pt_output)]
|
267 |
+
print(f'\nDistance {distance}\n')
|
268 |
+
|
269 |
+
|
270 |
+
def load_tf_model_weights(mdl, layer_lookup, tf_mdl_dir, is_resnet=True, arg_num=None):
|
271 |
+
"""Load tensorflow parameters into a pytorch model.
|
272 |
+
|
273 |
+
Arguments:
|
274 |
+
mdl {torch.nn.Module} -- Pytorch model.
|
275 |
+
layer_lookup {[type]} -- Dictionary mapping pytorch attribute names to (partial)
|
276 |
+
tensorflow variable names, and a function suitable for loading weights.
|
277 |
+
Expects dict of the form {'attr': ['tf_name', function]}.
|
278 |
+
tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
|
279 |
+
"""
|
280 |
+
tf.reset_default_graph()
|
281 |
+
with tf.Session() as sess:
|
282 |
+
tf_layers, tf_params, tf_shapes = import_tf_params(tf_mdl_dir, sess)
|
283 |
+
layer_info = get_layer_indices(layer_lookup, tf_layers)
|
284 |
+
|
285 |
+
for layer_name, info in layer_info.items():
|
286 |
+
print(f'Loading {info[0]}/* into {layer_name}')
|
287 |
+
weights = [tf_params[i] for i in info[2]]
|
288 |
+
layer = getattr(mdl, layer_name)
|
289 |
+
info[1](weights, layer)
|
290 |
+
|
291 |
+
test_loaded_params(mdl, tf_params, tf_layers)
|
292 |
+
|
293 |
+
if is_resnet:
|
294 |
+
compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach())
|
295 |
+
|
296 |
+
|
297 |
+
def tensorflow2pytorch():
|
298 |
+
lookup_inception_resnet_v1 = {
|
299 |
+
'conv2d_1a': ['InceptionResnetV1/Conv2d_1a_3x3', load_tf_basicConv2d],
|
300 |
+
'conv2d_2a': ['InceptionResnetV1/Conv2d_2a_3x3', load_tf_basicConv2d],
|
301 |
+
'conv2d_2b': ['InceptionResnetV1/Conv2d_2b_3x3', load_tf_basicConv2d],
|
302 |
+
'conv2d_3b': ['InceptionResnetV1/Conv2d_3b_1x1', load_tf_basicConv2d],
|
303 |
+
'conv2d_4a': ['InceptionResnetV1/Conv2d_4a_3x3', load_tf_basicConv2d],
|
304 |
+
'conv2d_4b': ['InceptionResnetV1/Conv2d_4b_3x3', load_tf_basicConv2d],
|
305 |
+
'repeat_1': ['InceptionResnetV1/Repeat/block35', load_tf_repeat_1],
|
306 |
+
'mixed_6a': ['InceptionResnetV1/Mixed_6a', load_tf_mixed6a],
|
307 |
+
'repeat_2': ['InceptionResnetV1/Repeat_1/block17', load_tf_repeat_2],
|
308 |
+
'mixed_7a': ['InceptionResnetV1/Mixed_7a', load_tf_mixed7a],
|
309 |
+
'repeat_3': ['InceptionResnetV1/Repeat_2/block8', load_tf_repeat_3],
|
310 |
+
'block8': ['InceptionResnetV1/Block8', load_tf_block17_8],
|
311 |
+
'last_linear': ['InceptionResnetV1/Bottleneck/weights', load_tf_linear],
|
312 |
+
'last_bn': ['InceptionResnetV1/Bottleneck/BatchNorm', load_tf_batchNorm],
|
313 |
+
'logits': ['Logits', load_tf_linear],
|
314 |
+
}
|
315 |
+
|
316 |
+
print('\nLoad VGGFace2-trained weights and save\n')
|
317 |
+
mdl = InceptionResnetV1(num_classes=8631).eval()
|
318 |
+
tf_mdl_dir = 'data/20180402-114759'
|
319 |
+
data_name = 'vggface2'
|
320 |
+
load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
|
321 |
+
state_dict = mdl.state_dict()
|
322 |
+
torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
|
323 |
+
torch.save(
|
324 |
+
{
|
325 |
+
'logits.weight': state_dict['logits.weight'],
|
326 |
+
'logits.bias': state_dict['logits.bias'],
|
327 |
+
},
|
328 |
+
f'{tf_mdl_dir}-{data_name}-logits.pt'
|
329 |
+
)
|
330 |
+
state_dict.pop('logits.weight')
|
331 |
+
state_dict.pop('logits.bias')
|
332 |
+
torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
|
333 |
+
|
334 |
+
print('\nLoad CASIA-Webface-trained weights and save\n')
|
335 |
+
mdl = InceptionResnetV1(num_classes=10575).eval()
|
336 |
+
tf_mdl_dir = 'data/20180408-102900'
|
337 |
+
data_name = 'casia-webface'
|
338 |
+
load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
|
339 |
+
state_dict = mdl.state_dict()
|
340 |
+
torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
|
341 |
+
torch.save(
|
342 |
+
{
|
343 |
+
'logits.weight': state_dict['logits.weight'],
|
344 |
+
'logits.bias': state_dict['logits.bias'],
|
345 |
+
},
|
346 |
+
f'{tf_mdl_dir}-{data_name}-logits.pt'
|
347 |
+
)
|
348 |
+
state_dict.pop('logits.weight')
|
349 |
+
state_dict.pop('logits.bias')
|
350 |
+
torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
|
351 |
+
|
352 |
+
lookup_pnet = {
|
353 |
+
'conv1': ['pnet/conv1', load_tf_conv2d_trans],
|
354 |
+
'prelu1': ['pnet/PReLU1', load_tf_linear],
|
355 |
+
'conv2': ['pnet/conv2', load_tf_conv2d_trans],
|
356 |
+
'prelu2': ['pnet/PReLU2', load_tf_linear],
|
357 |
+
'conv3': ['pnet/conv3', load_tf_conv2d_trans],
|
358 |
+
'prelu3': ['pnet/PReLU3', load_tf_linear],
|
359 |
+
'conv4_1': ['pnet/conv4-1', load_tf_conv2d_trans],
|
360 |
+
'conv4_2': ['pnet/conv4-2', load_tf_conv2d_trans],
|
361 |
+
}
|
362 |
+
lookup_rnet = {
|
363 |
+
'conv1': ['rnet/conv1', load_tf_conv2d_trans],
|
364 |
+
'prelu1': ['rnet/prelu1', load_tf_linear],
|
365 |
+
'conv2': ['rnet/conv2', load_tf_conv2d_trans],
|
366 |
+
'prelu2': ['rnet/prelu2', load_tf_linear],
|
367 |
+
'conv3': ['rnet/conv3', load_tf_conv2d_trans],
|
368 |
+
'prelu3': ['rnet/prelu3', load_tf_linear],
|
369 |
+
'dense4': ['rnet/conv4', load_tf_linear],
|
370 |
+
'prelu4': ['rnet/prelu4', load_tf_linear],
|
371 |
+
'dense5_1': ['rnet/conv5-1', load_tf_linear],
|
372 |
+
'dense5_2': ['rnet/conv5-2', load_tf_linear],
|
373 |
+
}
|
374 |
+
lookup_onet = {
|
375 |
+
'conv1': ['onet/conv1', load_tf_conv2d_trans],
|
376 |
+
'prelu1': ['onet/prelu1', load_tf_linear],
|
377 |
+
'conv2': ['onet/conv2', load_tf_conv2d_trans],
|
378 |
+
'prelu2': ['onet/prelu2', load_tf_linear],
|
379 |
+
'conv3': ['onet/conv3', load_tf_conv2d_trans],
|
380 |
+
'prelu3': ['onet/prelu3', load_tf_linear],
|
381 |
+
'conv4': ['onet/conv4', load_tf_conv2d_trans],
|
382 |
+
'prelu4': ['onet/prelu4', load_tf_linear],
|
383 |
+
'dense5': ['onet/conv5', load_tf_linear],
|
384 |
+
'prelu5': ['onet/prelu5', load_tf_linear],
|
385 |
+
'dense6_1': ['onet/conv6-1', load_tf_linear],
|
386 |
+
'dense6_2': ['onet/conv6-2', load_tf_linear],
|
387 |
+
'dense6_3': ['onet/conv6-3', load_tf_linear],
|
388 |
+
}
|
389 |
+
|
390 |
+
print('\nLoad PNet weights and save\n')
|
391 |
+
tf_mdl_dir = lambda sess: detect_face.create_mtcnn(sess, None)
|
392 |
+
mdl = PNet()
|
393 |
+
data_name = 'pnet'
|
394 |
+
load_tf_model_weights(mdl, lookup_pnet, tf_mdl_dir, is_resnet=False, arg_num=0)
|
395 |
+
torch.save(mdl.state_dict(), f'data/{data_name}.pt')
|
396 |
+
tf.reset_default_graph()
|
397 |
+
with tf.Session() as sess:
|
398 |
+
compare_mtcnn(mdl, tf_mdl_dir, sess, 0, torch.randn(1, 256, 256, 3).detach())
|
399 |
+
|
400 |
+
print('\nLoad RNet weights and save\n')
|
401 |
+
mdl = RNet()
|
402 |
+
data_name = 'rnet'
|
403 |
+
load_tf_model_weights(mdl, lookup_rnet, tf_mdl_dir, is_resnet=False, arg_num=1)
|
404 |
+
torch.save(mdl.state_dict(), f'data/{data_name}.pt')
|
405 |
+
tf.reset_default_graph()
|
406 |
+
with tf.Session() as sess:
|
407 |
+
compare_mtcnn(mdl, tf_mdl_dir, sess, 1, torch.randn(1, 24, 24, 3).detach())
|
408 |
+
|
409 |
+
print('\nLoad ONet weights and save\n')
|
410 |
+
mdl = ONet()
|
411 |
+
data_name = 'onet'
|
412 |
+
load_tf_model_weights(mdl, lookup_onet, tf_mdl_dir, is_resnet=False, arg_num=2)
|
413 |
+
torch.save(mdl.state_dict(), f'data/{data_name}.pt')
|
414 |
+
tf.reset_default_graph()
|
415 |
+
with tf.Session() as sess:
|
416 |
+
compare_mtcnn(mdl, tf_mdl_dir, sess, 2, torch.randn(1, 48, 48, 3).detach())
|
models/utils/training.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class Logger(object):
|
7 |
+
|
8 |
+
def __init__(self, mode, length, calculate_mean=False):
|
9 |
+
self.mode = mode
|
10 |
+
self.length = length
|
11 |
+
self.calculate_mean = calculate_mean
|
12 |
+
if self.calculate_mean:
|
13 |
+
self.fn = lambda x, i: x / (i + 1)
|
14 |
+
else:
|
15 |
+
self.fn = lambda x, i: x
|
16 |
+
|
17 |
+
def __call__(self, loss, metrics, i):
|
18 |
+
track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
|
19 |
+
loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
|
20 |
+
metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
|
21 |
+
print(track_str + loss_str + metric_str + ' ', end='')
|
22 |
+
if i + 1 == self.length:
|
23 |
+
print('')
|
24 |
+
|
25 |
+
|
26 |
+
class BatchTimer(object):
|
27 |
+
"""Batch timing class.
|
28 |
+
Use this class for tracking training and testing time/rate per batch or per sample.
|
29 |
+
|
30 |
+
Keyword Arguments:
|
31 |
+
rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
|
32 |
+
per batch or sample). (default: {True})
|
33 |
+
per_sample {bool} -- Whether to report times or rates per sample or per batch.
|
34 |
+
(default: {True})
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, rate=True, per_sample=True):
|
38 |
+
self.start = time.time()
|
39 |
+
self.end = None
|
40 |
+
self.rate = rate
|
41 |
+
self.per_sample = per_sample
|
42 |
+
|
43 |
+
def __call__(self, y_pred, y):
|
44 |
+
self.end = time.time()
|
45 |
+
elapsed = self.end - self.start
|
46 |
+
self.start = self.end
|
47 |
+
self.end = None
|
48 |
+
|
49 |
+
if self.per_sample:
|
50 |
+
elapsed /= len(y_pred)
|
51 |
+
if self.rate:
|
52 |
+
elapsed = 1 / elapsed
|
53 |
+
|
54 |
+
return torch.tensor(elapsed)
|
55 |
+
|
56 |
+
|
57 |
+
def accuracy(logits, y):
|
58 |
+
_, preds = torch.max(logits, 1)
|
59 |
+
return (preds == y).float().mean()
|
60 |
+
|
61 |
+
|
62 |
+
def pass_epoch(
|
63 |
+
model, loss_fn, loader, optimizer=None, scheduler=None,
|
64 |
+
batch_metrics={'time': BatchTimer()}, show_running=True,
|
65 |
+
device='cpu', writer=None
|
66 |
+
):
|
67 |
+
"""Train or evaluate over a data epoch.
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
model {torch.nn.Module} -- Pytorch model.
|
71 |
+
loss_fn {callable} -- A function to compute (scalar) loss.
|
72 |
+
loader {torch.utils.data.DataLoader} -- A pytorch data loader.
|
73 |
+
|
74 |
+
Keyword Arguments:
|
75 |
+
optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
|
76 |
+
scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
|
77 |
+
batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
|
78 |
+
is a simple timer. A progressive average of these metrics, along with the average
|
79 |
+
loss, is printed every batch. (default: {{'time': iter_timer()}})
|
80 |
+
show_running {bool} -- Whether or not to print losses and metrics for the current batch
|
81 |
+
or rolling averages. (default: {False})
|
82 |
+
device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
|
83 |
+
writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
|
87 |
+
metric values across the epoch.
|
88 |
+
"""
|
89 |
+
|
90 |
+
mode = 'Train' if model.training else 'Valid'
|
91 |
+
logger = Logger(mode, length=len(loader), calculate_mean=show_running)
|
92 |
+
loss = 0
|
93 |
+
metrics = {}
|
94 |
+
|
95 |
+
for i_batch, (x, y) in enumerate(loader):
|
96 |
+
x = x.to(device)
|
97 |
+
y = y.to(device)
|
98 |
+
y_pred = model(x)
|
99 |
+
loss_batch = loss_fn(y_pred, y)
|
100 |
+
|
101 |
+
if model.training:
|
102 |
+
loss_batch.backward()
|
103 |
+
optimizer.step()
|
104 |
+
optimizer.zero_grad()
|
105 |
+
|
106 |
+
metrics_batch = {}
|
107 |
+
for metric_name, metric_fn in batch_metrics.items():
|
108 |
+
metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
|
109 |
+
metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
|
110 |
+
|
111 |
+
if writer is not None and model.training:
|
112 |
+
if writer.iteration % writer.interval == 0:
|
113 |
+
writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
|
114 |
+
for metric_name, metric_batch in metrics_batch.items():
|
115 |
+
writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
|
116 |
+
writer.iteration += 1
|
117 |
+
|
118 |
+
loss_batch = loss_batch.detach().cpu()
|
119 |
+
loss += loss_batch
|
120 |
+
if show_running:
|
121 |
+
logger(loss, metrics, i_batch)
|
122 |
+
else:
|
123 |
+
logger(loss_batch, metrics_batch, i_batch)
|
124 |
+
|
125 |
+
if model.training and scheduler is not None:
|
126 |
+
scheduler.step()
|
127 |
+
|
128 |
+
loss = loss / (i_batch + 1)
|
129 |
+
metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
|
130 |
+
|
131 |
+
if writer is not None and not model.training:
|
132 |
+
writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
|
133 |
+
for metric_name, metric in metrics.items():
|
134 |
+
writer.add_scalars(metric_name, {mode: metric})
|
135 |
+
|
136 |
+
return loss, metrics
|
137 |
+
|
138 |
+
|
139 |
+
def collate_pil(x):
|
140 |
+
out_x, out_y = [], []
|
141 |
+
for xx, yy in x:
|
142 |
+
out_x.append(xx)
|
143 |
+
out_y.append(yy)
|
144 |
+
return out_x, out_y
|