HienK64BKHN commited on
Commit
e066bcd
·
verified ·
1 Parent(s): f4788e4

Upload 44 files

Browse files
Files changed (45) hide show
  1. .gitattributes +3 -0
  2. data/.gitignore +3 -0
  3. data/facenet-pytorch-banner.png +0 -0
  4. data/multiface.jpg +0 -0
  5. data/multiface_detected.png +3 -0
  6. data/onet.pt +3 -0
  7. data/pnet.pt +3 -0
  8. data/rnet.pt +3 -0
  9. data/test_images/angelina_jolie/1.jpg +0 -0
  10. data/test_images/angelina_jolie/angelina_jolie.pt +3 -0
  11. data/test_images/bo_vinh/bo_vinh.pt +3 -0
  12. data/test_images/brad_pitt/brad_pitt.pt +3 -0
  13. data/test_images/bradley_cooper/1.jpg +3 -0
  14. data/test_images/bradley_cooper/bradley_cooper.pt +3 -0
  15. data/test_images/chau_anh/chau_anh.pt +3 -0
  16. data/test_images/daniel_radcliffe/daniel_radcliffe.pt +3 -0
  17. data/test_images/hermione_granger/hermione_granger.pt +3 -0
  18. data/test_images/hien/hien.pt +3 -0
  19. data/test_images/kate_siegel/1.jpg +0 -0
  20. data/test_images/kate_siegel/kate_siegel.pt +3 -0
  21. data/test_images/khanh/khanh.pt +3 -0
  22. data/test_images/me_hoa/me_hoa.pt +3 -0
  23. data/test_images/ny_khanh/ny_khanh.pt +3 -0
  24. data/test_images/paul_rudd/1.jpg +0 -0
  25. data/test_images/paul_rudd/paul_rudd.pt +3 -0
  26. data/test_images/ron_weasley/ron_weasley.pt +3 -0
  27. data/test_images/shea_whigham/1.jpg +3 -0
  28. data/test_images/shea_whigham/shea_whigham.pt +3 -0
  29. data/test_images/tu_linh/tu_linh.pt +3 -0
  30. data/test_images_2/angelina_jolie_brad_pitt/1.jpg +0 -0
  31. data/test_images_2/bong_chanh/1.jpg +0 -0
  32. data/test_images_2/bong_chanh/2.jpg +0 -0
  33. data/test_images_2/khanh_va_ny/1.jpg +0 -0
  34. data/test_images_2/the_golden_trio/1.jpg +0 -0
  35. data/test_images_aligned/angelina_jolie/1.png +0 -0
  36. data/test_images_aligned/bradley_cooper/1.png +0 -0
  37. data/test_images_aligned/kate_siegel/1.png +0 -0
  38. data/test_images_aligned/paul_rudd/1.png +0 -0
  39. data/test_images_aligned/shea_whigham/1.png +0 -0
  40. models/inception_resnet_v1.py +340 -0
  41. models/mtcnn.py +519 -0
  42. models/utils/detect_face.py +378 -0
  43. models/utils/download.py +102 -0
  44. models/utils/tensorflow2pytorch.py +416 -0
  45. models/utils/training.py +144 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/multiface_detected.png filter=lfs diff=lfs merge=lfs -text
37
+ data/test_images/bradley_cooper/1.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/test_images/shea_whigham/1.jpg filter=lfs diff=lfs merge=lfs -text
data/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2018*
2
+ *.json
3
+ profile.txt
data/facenet-pytorch-banner.png ADDED
data/multiface.jpg ADDED
data/multiface_detected.png ADDED

Git LFS Details

  • SHA256: 0cef0063ee90e397abf4875070d876bb9f7b0871b7b8c761bfc2fa5fa8bc06d6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
data/onet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:165bfbe42940416ccfb977545cf0e976d5bf321f67083ae2aaaa5c764280118d
3
+ size 1559269
data/pnet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2a71925e0b9996a42f63e47efc1ca19043e69558b5c523b978d611dfae49c8f
3
+ size 28570
data/rnet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbb937de72efc9ef83b186c49f5f558467a1d7e3453a8ece0d71a886633f6a86
3
+ size 403147
data/test_images/angelina_jolie/1.jpg ADDED
data/test_images/angelina_jolie/angelina_jolie.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:729a0aee6c1c8f1b7b7db49011b2a238d0d54fda993d49271f9f39d876e94a66
3
+ size 2816
data/test_images/bo_vinh/bo_vinh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d274aaf7b361cf62066c08e4153601bdf9f8cdc2458e0112bbfc1d6c7191235
3
+ size 2795
data/test_images/brad_pitt/brad_pitt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f70ad8cf619238a4856781ac1ef9a2e7164b486adf29bcb618ac394d5725d289
3
+ size 2801
data/test_images/bradley_cooper/1.jpg ADDED

Git LFS Details

  • SHA256: be3742bdaefbdfa73c29bde065ebb2091fd45df16f6b1d3bcd6c241c8d571f5e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.57 MB
data/test_images/bradley_cooper/bradley_cooper.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22e6311cdd61223c03c955fa48fdb348dd2f607266e1be33cdf245deb0bc2b6
3
+ size 2816
data/test_images/chau_anh/chau_anh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b393e3ee47643ba619cf48bf58d30aaa816f8943112c8ea24e8e682d97d3d22
3
+ size 2798
data/test_images/daniel_radcliffe/daniel_radcliffe.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5bc1e56b1b34386698721020a685b1deea3bb3870b7ea5c8043599ca607881d
3
+ size 2822
data/test_images/hermione_granger/hermione_granger.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c94f592eb34e84957844f545d841c2559e2ed28a4956c1307d570d41c1bc4ec
3
+ size 2822
data/test_images/hien/hien.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d6a1965b22ac8ab4441fe79dc39345612c398afb78ca58072e66e35df47cc67
3
+ size 2722
data/test_images/kate_siegel/1.jpg ADDED
data/test_images/kate_siegel/kate_siegel.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43a64061efec1325ca8d8ab5822a0b12b974f64a610b21c699b0465b5ede7dd1
3
+ size 2807
data/test_images/khanh/khanh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f31fc909cf9bd67055ca9ec2a375e97c314b944e9568e2e69b1cfcca081242
3
+ size 2725
data/test_images/me_hoa/me_hoa.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64861c898832b1c5d30447c469b228d06c84439f52796fdabae90c15d1e61d03
3
+ size 2728
data/test_images/ny_khanh/ny_khanh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd277690d8a7a10057f3195574e764a5f17472ff1918b4ad8a2b871154778f47
3
+ size 2798
data/test_images/paul_rudd/1.jpg ADDED
data/test_images/paul_rudd/paul_rudd.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92b5ea7bd25a17d2bdceb79a785ecab4944951555469ce27cfdf1709ba6cbdeb
3
+ size 2801
data/test_images/ron_weasley/ron_weasley.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79a4ebf96e0ed3676d0b7ec7c8d017348736097286cef03dda2fe316a25b33c1
3
+ size 2807
data/test_images/shea_whigham/1.jpg ADDED

Git LFS Details

  • SHA256: 979b7fa0e662d68b7b22f81a173d03021dddf737d6480b6e500d15a9a01f7aba
  • Pointer size: 132 Bytes
  • Size of remote file: 2.75 MB
data/test_images/shea_whigham/shea_whigham.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab6085300a2227199b57c7eb259c973d785e31ee5e36dbddc936e07019a5eaf9
3
+ size 2810
data/test_images/tu_linh/tu_linh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9a9afd2eaabb2e7a4808ee168ad7232c51a0bc71a1fc6d50b1ddf6c6a301fe5
3
+ size 2795
data/test_images_2/angelina_jolie_brad_pitt/1.jpg ADDED
data/test_images_2/bong_chanh/1.jpg ADDED
data/test_images_2/bong_chanh/2.jpg ADDED
data/test_images_2/khanh_va_ny/1.jpg ADDED
data/test_images_2/the_golden_trio/1.jpg ADDED
data/test_images_aligned/angelina_jolie/1.png ADDED
data/test_images_aligned/bradley_cooper/1.png ADDED
data/test_images_aligned/kate_siegel/1.png ADDED
data/test_images_aligned/paul_rudd/1.png ADDED
data/test_images_aligned/shea_whigham/1.png ADDED
models/inception_resnet_v1.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from requests.adapters import HTTPAdapter
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from .utils.download import download_url_to_file
10
+
11
+
12
+ class BasicConv2d(nn.Module):
13
+
14
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
15
+ super().__init__()
16
+ self.conv = nn.Conv2d(
17
+ in_planes, out_planes,
18
+ kernel_size=kernel_size, stride=stride,
19
+ padding=padding, bias=False
20
+ ) # verify bias false
21
+ self.bn = nn.BatchNorm2d(
22
+ out_planes,
23
+ eps=0.001, # value found in tensorflow
24
+ momentum=0.1, # default pytorch value
25
+ affine=True
26
+ )
27
+ self.relu = nn.ReLU(inplace=False)
28
+
29
+ def forward(self, x):
30
+ x = self.conv(x)
31
+ x = self.bn(x)
32
+ x = self.relu(x)
33
+ return x
34
+
35
+
36
+ class Block35(nn.Module):
37
+
38
+ def __init__(self, scale=1.0):
39
+ super().__init__()
40
+
41
+ self.scale = scale
42
+
43
+ self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1)
44
+
45
+ self.branch1 = nn.Sequential(
46
+ BasicConv2d(256, 32, kernel_size=1, stride=1),
47
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
48
+ )
49
+
50
+ self.branch2 = nn.Sequential(
51
+ BasicConv2d(256, 32, kernel_size=1, stride=1),
52
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
53
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
54
+ )
55
+
56
+ self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
57
+ self.relu = nn.ReLU(inplace=False)
58
+
59
+ def forward(self, x):
60
+ x0 = self.branch0(x)
61
+ x1 = self.branch1(x)
62
+ x2 = self.branch2(x)
63
+ out = torch.cat((x0, x1, x2), 1)
64
+ out = self.conv2d(out)
65
+ out = out * self.scale + x
66
+ out = self.relu(out)
67
+ return out
68
+
69
+
70
+ class Block17(nn.Module):
71
+
72
+ def __init__(self, scale=1.0):
73
+ super().__init__()
74
+
75
+ self.scale = scale
76
+
77
+ self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1)
78
+
79
+ self.branch1 = nn.Sequential(
80
+ BasicConv2d(896, 128, kernel_size=1, stride=1),
81
+ BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
82
+ BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
83
+ )
84
+
85
+ self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
86
+ self.relu = nn.ReLU(inplace=False)
87
+
88
+ def forward(self, x):
89
+ x0 = self.branch0(x)
90
+ x1 = self.branch1(x)
91
+ out = torch.cat((x0, x1), 1)
92
+ out = self.conv2d(out)
93
+ out = out * self.scale + x
94
+ out = self.relu(out)
95
+ return out
96
+
97
+
98
+ class Block8(nn.Module):
99
+
100
+ def __init__(self, scale=1.0, noReLU=False):
101
+ super().__init__()
102
+
103
+ self.scale = scale
104
+ self.noReLU = noReLU
105
+
106
+ self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1)
107
+
108
+ self.branch1 = nn.Sequential(
109
+ BasicConv2d(1792, 192, kernel_size=1, stride=1),
110
+ BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
111
+ BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
112
+ )
113
+
114
+ self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
115
+ if not self.noReLU:
116
+ self.relu = nn.ReLU(inplace=False)
117
+
118
+ def forward(self, x):
119
+ x0 = self.branch0(x)
120
+ x1 = self.branch1(x)
121
+ out = torch.cat((x0, x1), 1)
122
+ out = self.conv2d(out)
123
+ out = out * self.scale + x
124
+ if not self.noReLU:
125
+ out = self.relu(out)
126
+ return out
127
+
128
+
129
+ class Mixed_6a(nn.Module):
130
+
131
+ def __init__(self):
132
+ super().__init__()
133
+
134
+ self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)
135
+
136
+ self.branch1 = nn.Sequential(
137
+ BasicConv2d(256, 192, kernel_size=1, stride=1),
138
+ BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
139
+ BasicConv2d(192, 256, kernel_size=3, stride=2)
140
+ )
141
+
142
+ self.branch2 = nn.MaxPool2d(3, stride=2)
143
+
144
+ def forward(self, x):
145
+ x0 = self.branch0(x)
146
+ x1 = self.branch1(x)
147
+ x2 = self.branch2(x)
148
+ out = torch.cat((x0, x1, x2), 1)
149
+ return out
150
+
151
+
152
+ class Mixed_7a(nn.Module):
153
+
154
+ def __init__(self):
155
+ super().__init__()
156
+
157
+ self.branch0 = nn.Sequential(
158
+ BasicConv2d(896, 256, kernel_size=1, stride=1),
159
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
160
+ )
161
+
162
+ self.branch1 = nn.Sequential(
163
+ BasicConv2d(896, 256, kernel_size=1, stride=1),
164
+ BasicConv2d(256, 256, kernel_size=3, stride=2)
165
+ )
166
+
167
+ self.branch2 = nn.Sequential(
168
+ BasicConv2d(896, 256, kernel_size=1, stride=1),
169
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
170
+ BasicConv2d(256, 256, kernel_size=3, stride=2)
171
+ )
172
+
173
+ self.branch3 = nn.MaxPool2d(3, stride=2)
174
+
175
+ def forward(self, x):
176
+ x0 = self.branch0(x)
177
+ x1 = self.branch1(x)
178
+ x2 = self.branch2(x)
179
+ x3 = self.branch3(x)
180
+ out = torch.cat((x0, x1, x2, x3), 1)
181
+ return out
182
+
183
+
184
+ class InceptionResnetV1(nn.Module):
185
+ """Inception Resnet V1 model with optional loading of pretrained weights.
186
+
187
+ Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface
188
+ datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
189
+ requested and cached in the torch cache. Subsequent instantiations use the cache rather than
190
+ redownloading.
191
+
192
+ Keyword Arguments:
193
+ pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
194
+ (default: {None})
195
+ classify {bool} -- Whether the model should output classification probabilities or feature
196
+ embeddings. (default: {False})
197
+ num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not
198
+ equal to that used for the pretrained model, the final linear layer will be randomly
199
+ initialized. (default: {None})
200
+ dropout_prob {float} -- Dropout probability. (default: {0.6})
201
+ """
202
+ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
203
+ super().__init__()
204
+
205
+ # Set simple attributes
206
+ self.pretrained = pretrained
207
+ self.classify = classify
208
+ self.num_classes = num_classes
209
+
210
+ if pretrained == 'vggface2':
211
+ tmp_classes = 8631
212
+ elif pretrained == 'casia-webface':
213
+ tmp_classes = 10575
214
+ elif pretrained is None and self.classify and self.num_classes is None:
215
+ raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')
216
+
217
+
218
+ # Define layers
219
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
220
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
221
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
222
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
223
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
224
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
225
+ self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2)
226
+ self.repeat_1 = nn.Sequential(
227
+ Block35(scale=0.17),
228
+ Block35(scale=0.17),
229
+ Block35(scale=0.17),
230
+ Block35(scale=0.17),
231
+ Block35(scale=0.17),
232
+ )
233
+ self.mixed_6a = Mixed_6a()
234
+ self.repeat_2 = nn.Sequential(
235
+ Block17(scale=0.10),
236
+ Block17(scale=0.10),
237
+ Block17(scale=0.10),
238
+ Block17(scale=0.10),
239
+ Block17(scale=0.10),
240
+ Block17(scale=0.10),
241
+ Block17(scale=0.10),
242
+ Block17(scale=0.10),
243
+ Block17(scale=0.10),
244
+ Block17(scale=0.10),
245
+ )
246
+ self.mixed_7a = Mixed_7a()
247
+ self.repeat_3 = nn.Sequential(
248
+ Block8(scale=0.20),
249
+ Block8(scale=0.20),
250
+ Block8(scale=0.20),
251
+ Block8(scale=0.20),
252
+ Block8(scale=0.20),
253
+ )
254
+ self.block8 = Block8(noReLU=True)
255
+ self.avgpool_1a = nn.AdaptiveAvgPool2d(1)
256
+ self.dropout = nn.Dropout(dropout_prob)
257
+ self.last_linear = nn.Linear(1792, 512, bias=False)
258
+ self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
259
+
260
+ if pretrained is not None:
261
+ self.logits = nn.Linear(512, tmp_classes)
262
+ load_weights(self, pretrained)
263
+
264
+ if self.classify and self.num_classes is not None:
265
+ self.logits = nn.Linear(512, self.num_classes)
266
+
267
+ self.device = torch.device('cpu')
268
+ if device is not None:
269
+ self.device = device
270
+ self.to(device)
271
+
272
+ def forward(self, x):
273
+ """Calculate embeddings or logits given a batch of input image tensors.
274
+
275
+ Arguments:
276
+ x {torch.tensor} -- Batch of image tensors representing faces.
277
+
278
+ Returns:
279
+ torch.tensor -- Batch of embedding vectors or multinomial logits.
280
+ """
281
+ x = self.conv2d_1a(x)
282
+ x = self.conv2d_2a(x)
283
+ x = self.conv2d_2b(x)
284
+ x = self.maxpool_3a(x)
285
+ x = self.conv2d_3b(x)
286
+ x = self.conv2d_4a(x)
287
+ x = self.conv2d_4b(x)
288
+ x = self.repeat_1(x)
289
+ x = self.mixed_6a(x)
290
+ x = self.repeat_2(x)
291
+ x = self.mixed_7a(x)
292
+ x = self.repeat_3(x)
293
+ x = self.block8(x)
294
+ x = self.avgpool_1a(x)
295
+ x = self.dropout(x)
296
+ x = self.last_linear(x.view(x.shape[0], -1))
297
+ x = self.last_bn(x)
298
+ if self.classify:
299
+ x = self.logits(x)
300
+ else:
301
+ x = F.normalize(x, p=2, dim=1)
302
+ return x
303
+
304
+
305
+ def load_weights(mdl, name):
306
+ """Download pretrained state_dict and load into model.
307
+
308
+ Arguments:
309
+ mdl {torch.nn.Module} -- Pytorch model.
310
+ name {str} -- Name of dataset that was used to generate pretrained state_dict.
311
+
312
+ Raises:
313
+ ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
314
+ """
315
+ if name == 'vggface2':
316
+ path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt'
317
+ elif name == 'casia-webface':
318
+ path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt'
319
+ else:
320
+ raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')
321
+
322
+ model_dir = os.path.join(get_torch_home(), 'checkpoints')
323
+ os.makedirs(model_dir, exist_ok=True)
324
+
325
+ cached_file = os.path.join(model_dir, os.path.basename(path))
326
+ if not os.path.exists(cached_file):
327
+ download_url_to_file(path, cached_file)
328
+
329
+ state_dict = torch.load(cached_file)
330
+ mdl.load_state_dict(state_dict)
331
+
332
+
333
+ def get_torch_home():
334
+ torch_home = os.path.expanduser(
335
+ os.getenv(
336
+ 'TORCH_HOME',
337
+ os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
338
+ )
339
+ )
340
+ return torch_home
models/mtcnn.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import os
5
+
6
+ from .utils.detect_face import detect_face, extract_face
7
+
8
+
9
+ class PNet(nn.Module):
10
+ """MTCNN PNet.
11
+
12
+ Keyword Arguments:
13
+ pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
14
+ """
15
+
16
+ def __init__(self, pretrained=True):
17
+ super().__init__()
18
+
19
+ self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
20
+ self.prelu1 = nn.PReLU(10)
21
+ self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
22
+ self.conv2 = nn.Conv2d(10, 16, kernel_size=3)
23
+ self.prelu2 = nn.PReLU(16)
24
+ self.conv3 = nn.Conv2d(16, 32, kernel_size=3)
25
+ self.prelu3 = nn.PReLU(32)
26
+ self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1)
27
+ self.softmax4_1 = nn.Softmax(dim=1)
28
+ self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1)
29
+
30
+ self.training = False
31
+
32
+ if pretrained:
33
+ state_dict_path = os.path.join(os.path.dirname(__file__), '../data/pnet.pt')
34
+ state_dict = torch.load(state_dict_path)
35
+ self.load_state_dict(state_dict)
36
+
37
+ def forward(self, x):
38
+ x = self.conv1(x)
39
+ x = self.prelu1(x)
40
+ x = self.pool1(x)
41
+ x = self.conv2(x)
42
+ x = self.prelu2(x)
43
+ x = self.conv3(x)
44
+ x = self.prelu3(x)
45
+ a = self.conv4_1(x)
46
+ a = self.softmax4_1(a)
47
+ b = self.conv4_2(x)
48
+ return b, a
49
+
50
+
51
+ class RNet(nn.Module):
52
+ """MTCNN RNet.
53
+
54
+ Keyword Arguments:
55
+ pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
56
+ """
57
+
58
+ def __init__(self, pretrained=True):
59
+ super().__init__()
60
+
61
+ self.conv1 = nn.Conv2d(3, 28, kernel_size=3)
62
+ self.prelu1 = nn.PReLU(28)
63
+ self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
64
+ self.conv2 = nn.Conv2d(28, 48, kernel_size=3)
65
+ self.prelu2 = nn.PReLU(48)
66
+ self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
67
+ self.conv3 = nn.Conv2d(48, 64, kernel_size=2)
68
+ self.prelu3 = nn.PReLU(64)
69
+ self.dense4 = nn.Linear(576, 128)
70
+ self.prelu4 = nn.PReLU(128)
71
+ self.dense5_1 = nn.Linear(128, 2)
72
+ self.softmax5_1 = nn.Softmax(dim=1)
73
+ self.dense5_2 = nn.Linear(128, 4)
74
+
75
+ self.training = False
76
+
77
+ if pretrained:
78
+ state_dict_path = os.path.join(os.path.dirname(__file__), '../data/rnet.pt')
79
+ state_dict = torch.load(state_dict_path)
80
+ self.load_state_dict(state_dict)
81
+
82
+ def forward(self, x):
83
+ x = self.conv1(x)
84
+ x = self.prelu1(x)
85
+ x = self.pool1(x)
86
+ x = self.conv2(x)
87
+ x = self.prelu2(x)
88
+ x = self.pool2(x)
89
+ x = self.conv3(x)
90
+ x = self.prelu3(x)
91
+ x = x.permute(0, 3, 2, 1).contiguous()
92
+ x = self.dense4(x.view(x.shape[0], -1))
93
+ x = self.prelu4(x)
94
+ a = self.dense5_1(x)
95
+ a = self.softmax5_1(a)
96
+ b = self.dense5_2(x)
97
+ return b, a
98
+
99
+
100
+ class ONet(nn.Module):
101
+ """MTCNN ONet.
102
+
103
+ Keyword Arguments:
104
+ pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
105
+ """
106
+
107
+ def __init__(self, pretrained=True):
108
+ super().__init__()
109
+
110
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
111
+ self.prelu1 = nn.PReLU(32)
112
+ self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
113
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
114
+ self.prelu2 = nn.PReLU(64)
115
+ self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
116
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
117
+ self.prelu3 = nn.PReLU(64)
118
+ self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
119
+ self.conv4 = nn.Conv2d(64, 128, kernel_size=2)
120
+ self.prelu4 = nn.PReLU(128)
121
+ self.dense5 = nn.Linear(1152, 256)
122
+ self.prelu5 = nn.PReLU(256)
123
+ self.dense6_1 = nn.Linear(256, 2)
124
+ self.softmax6_1 = nn.Softmax(dim=1)
125
+ self.dense6_2 = nn.Linear(256, 4)
126
+ self.dense6_3 = nn.Linear(256, 10)
127
+
128
+ self.training = False
129
+
130
+ if pretrained:
131
+ state_dict_path = os.path.join(os.path.dirname(__file__), '../data/onet.pt')
132
+ state_dict = torch.load(state_dict_path)
133
+ self.load_state_dict(state_dict)
134
+
135
+ def forward(self, x):
136
+ x = self.conv1(x)
137
+ x = self.prelu1(x)
138
+ x = self.pool1(x)
139
+ x = self.conv2(x)
140
+ x = self.prelu2(x)
141
+ x = self.pool2(x)
142
+ x = self.conv3(x)
143
+ x = self.prelu3(x)
144
+ x = self.pool3(x)
145
+ x = self.conv4(x)
146
+ x = self.prelu4(x)
147
+ x = x.permute(0, 3, 2, 1).contiguous()
148
+ x = self.dense5(x.view(x.shape[0], -1))
149
+ x = self.prelu5(x)
150
+ a = self.dense6_1(x)
151
+ a = self.softmax6_1(a)
152
+ b = self.dense6_2(x)
153
+ c = self.dense6_3(x)
154
+ return b, c, a
155
+
156
+
157
+ class MTCNN(nn.Module):
158
+ """MTCNN face detection module.
159
+
160
+ This class loads pretrained P-, R-, and O-nets and returns images cropped to include the face
161
+ only, given raw input images of one of the following types:
162
+ - PIL image or list of PIL images
163
+ - numpy.ndarray (uint8) representing either a single image (3D) or a batch of images (4D).
164
+ Cropped faces can optionally be saved to file
165
+ also.
166
+
167
+ Keyword Arguments:
168
+ image_size {int} -- Output image size in pixels. The image will be square. (default: {160})
169
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
170
+ Note that the application of the margin differs slightly from the davidsandberg/facenet
171
+ repo, which applies the margin to the original image before resizing, making the margin
172
+ dependent on the original image size (this is a bug in davidsandberg/facenet).
173
+ (default: {0})
174
+ min_face_size {int} -- Minimum face size to search for. (default: {20})
175
+ thresholds {list} -- MTCNN face detection thresholds (default: {[0.6, 0.7, 0.7]})
176
+ factor {float} -- Factor used to create a scaling pyramid of face sizes. (default: {0.709})
177
+ post_process {bool} -- Whether or not to post process images tensors before returning.
178
+ (default: {True})
179
+ select_largest {bool} -- If True, if multiple faces are detected, the largest is returned.
180
+ If False, the face with the highest detection probability is returned.
181
+ (default: {True})
182
+ selection_method {string} -- Which heuristic to use for selection. Default None. If
183
+ specified, will override select_largest:
184
+ "probability": highest probability selected
185
+ "largest": largest box selected
186
+ "largest_over_threshold": largest box over a certain probability selected
187
+ "center_weighted_size": box size minus weighted squared offset from image center
188
+ (default: {None})
189
+ keep_all {bool} -- If True, all detected faces are returned, in the order dictated by the
190
+ select_largest parameter. If a save_path is specified, the first face is saved to that
191
+ path and the remaining faces are saved to <save_path>1, <save_path>2 etc.
192
+ (default: {False})
193
+ device {torch.device} -- The device on which to run neural net passes. Image tensors and
194
+ models are copied to this device before running forward passes. (default: {None})
195
+ """
196
+
197
+ def __init__(
198
+ self, image_size=160, margin=0, min_face_size=20,
199
+ thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
200
+ select_largest=True, selection_method=None, keep_all=False, device=None
201
+ ):
202
+ super().__init__()
203
+
204
+ self.image_size = image_size
205
+ self.margin = margin
206
+ self.min_face_size = min_face_size
207
+ self.thresholds = thresholds
208
+ self.factor = factor
209
+ self.post_process = post_process
210
+ self.select_largest = select_largest
211
+ self.keep_all = keep_all
212
+ self.selection_method = selection_method
213
+
214
+ self.pnet = PNet()
215
+ self.rnet = RNet()
216
+ self.onet = ONet()
217
+
218
+ self.device = torch.device('cpu')
219
+ if device is not None:
220
+ self.device = device
221
+ self.to(device)
222
+
223
+ if not self.selection_method:
224
+ self.selection_method = 'largest' if self.select_largest else 'probability'
225
+
226
+ def forward(self, img, save_path=None, return_prob=False):
227
+ """Run MTCNN face detection on a PIL image or numpy array. This method performs both
228
+ detection and extraction of faces, returning tensors representing detected faces rather
229
+ than the bounding boxes. To access bounding boxes, see the MTCNN.detect() method below.
230
+
231
+ Arguments:
232
+ img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
233
+
234
+ Keyword Arguments:
235
+ save_path {str} -- An optional save path for the cropped image. Note that when
236
+ self.post_process=True, although the returned tensor is post processed, the saved
237
+ face image is not, so it is a true representation of the face in the input image.
238
+ If `img` is a list of images, `save_path` should be a list of equal length.
239
+ (default: {None})
240
+ return_prob {bool} -- Whether or not to return the detection probability.
241
+ (default: {False})
242
+
243
+ Returns:
244
+ Union[torch.Tensor, tuple(torch.tensor, float)] -- If detected, cropped image of a face
245
+ with dimensions 3 x image_size x image_size. Optionally, the probability that a
246
+ face was detected. If self.keep_all is True, n detected faces are returned in an
247
+ n x 3 x image_size x image_size tensor with an optional list of detection
248
+ probabilities. If `img` is a list of images, the item(s) returned have an extra
249
+ dimension (batch) as the first dimension.
250
+
251
+ Example:
252
+ >>> from facenet_pytorch import MTCNN
253
+ >>> mtcnn = MTCNN()
254
+ >>> face_tensor, prob = mtcnn(img, save_path='face.png', return_prob=True)
255
+ """
256
+
257
+ # Detect faces
258
+ batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True)
259
+ # Select faces
260
+ if not self.keep_all:
261
+ batch_boxes, batch_probs, batch_points = self.select_boxes(
262
+ batch_boxes, batch_probs, batch_points, img, method=self.selection_method
263
+ )
264
+ # Extract faces
265
+ faces = self.extract(img, batch_boxes, save_path)
266
+
267
+ if return_prob:
268
+ return faces, batch_probs
269
+ else:
270
+ return faces
271
+
272
+ def detect(self, img, landmarks=False):
273
+ """Detect all faces in PIL image and return bounding boxes and optional facial landmarks.
274
+
275
+ This method is used by the forward method and is also useful for face detection tasks
276
+ that require lower-level handling of bounding boxes and facial landmarks (e.g., face
277
+ tracking). The functionality of the forward function can be emulated by using this method
278
+ followed by the extract_face() function.
279
+
280
+ Arguments:
281
+ img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
282
+
283
+ Keyword Arguments:
284
+ landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes.
285
+ (default: {False})
286
+
287
+ Returns:
288
+ tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an
289
+ Nx4 array of bounding boxes and a length N list of detection probabilities.
290
+ Returned boxes will be sorted in descending order by detection probability if
291
+ self.select_largest=False, otherwise the largest face will be returned first.
292
+ If `img` is a list of images, the items returned have an extra dimension
293
+ (batch) as the first dimension. Optionally, a third item, the facial landmarks,
294
+ are returned if `landmarks=True`.
295
+
296
+ Example:
297
+ >>> from PIL import Image, ImageDraw
298
+ >>> from facenet_pytorch import MTCNN, extract_face
299
+ >>> mtcnn = MTCNN(keep_all=True)
300
+ >>> boxes, probs, points = mtcnn.detect(img, landmarks=True)
301
+ >>> # Draw boxes and save faces
302
+ >>> img_draw = img.copy()
303
+ >>> draw = ImageDraw.Draw(img_draw)
304
+ >>> for i, (box, point) in enumerate(zip(boxes, points)):
305
+ ... draw.rectangle(box.tolist(), width=5)
306
+ ... for p in point:
307
+ ... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10)
308
+ ... extract_face(img, box, save_path='detected_face_{}.png'.format(i))
309
+ >>> img_draw.save('annotated_faces.png')
310
+ """
311
+
312
+ with torch.no_grad():
313
+ batch_boxes, batch_points = detect_face(
314
+ img, self.min_face_size,
315
+ self.pnet, self.rnet, self.onet,
316
+ self.thresholds, self.factor,
317
+ self.device
318
+ )
319
+
320
+ boxes, probs, points = [], [], []
321
+ for box, point in zip(batch_boxes, batch_points):
322
+ box = np.array(box)
323
+ point = np.array(point)
324
+ if len(box) == 0:
325
+ boxes.append(None)
326
+ probs.append([None])
327
+ points.append(None)
328
+ elif self.select_largest:
329
+ box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]
330
+ box = box[box_order]
331
+ point = point[box_order]
332
+ boxes.append(box[:, :4])
333
+ probs.append(box[:, 4])
334
+ points.append(point)
335
+ else:
336
+ boxes.append(box[:, :4])
337
+ probs.append(box[:, 4])
338
+ points.append(point)
339
+ boxes = np.array(boxes, dtype=object)
340
+ probs = np.array(probs, dtype=object)
341
+ points = np.array(points, dtype=object)
342
+
343
+ if (
344
+ not isinstance(img, (list, tuple)) and
345
+ not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
346
+ not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
347
+ ):
348
+ boxes = boxes[0]
349
+ probs = probs[0]
350
+ points = points[0]
351
+
352
+ if landmarks:
353
+ return boxes, probs, points
354
+
355
+ return boxes, probs
356
+
357
+ def select_boxes(
358
+ self, all_boxes, all_probs, all_points, imgs, method='probability', threshold=0.9,
359
+ center_weight=2.0
360
+ ):
361
+ """Selects a single box from multiple for a given image using one of multiple heuristics.
362
+
363
+ Arguments:
364
+ all_boxes {np.ndarray} -- Ix0 ndarray where each element is a Nx4 ndarry of
365
+ bounding boxes for N detected faces in I images (output from self.detect).
366
+ all_probs {np.ndarray} -- Ix0 ndarray where each element is a Nx0 ndarry of
367
+ probabilities for N detected faces in I images (output from self.detect).
368
+ all_points {np.ndarray} -- Ix0 ndarray where each element is a Nx5x2 array of
369
+ points for N detected faces. (output from self.detect).
370
+ imgs {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list.
371
+
372
+ Keyword Arguments:
373
+ method {str} -- Which heuristic to use for selection:
374
+ "probability": highest probability selected
375
+ "largest": largest box selected
376
+ "largest_over_theshold": largest box over a certain probability selected
377
+ "center_weighted_size": box size minus weighted squared offset from image center
378
+ (default: {'probability'})
379
+ threshold {float} -- theshold for "largest_over_threshold" method. (default: {0.9})
380
+ center_weight {float} -- weight for squared offset in center weighted size method.
381
+ (default: {2.0})
382
+
383
+ Returns:
384
+ tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray) -- nx4 ndarray of bounding boxes
385
+ for n images. Ix0 array of probabilities for each box, array of landmark points.
386
+ """
387
+
388
+ #copying batch detection from extract, but would be easier to ensure detect creates consistent output.
389
+ batch_mode = True
390
+ if (
391
+ not isinstance(imgs, (list, tuple)) and
392
+ not (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4) and
393
+ not (isinstance(imgs, torch.Tensor) and len(imgs.shape) == 4)
394
+ ):
395
+ imgs = [imgs]
396
+ all_boxes = [all_boxes]
397
+ all_probs = [all_probs]
398
+ all_points = [all_points]
399
+ batch_mode = False
400
+
401
+ selected_boxes, selected_probs, selected_points = [], [], []
402
+ for boxes, points, probs, img in zip(all_boxes, all_points, all_probs, imgs):
403
+
404
+ if boxes is None:
405
+ selected_boxes.append(None)
406
+ selected_probs.append([None])
407
+ selected_points.append(None)
408
+ continue
409
+
410
+ # If at least 1 box found
411
+ boxes = np.array(boxes)
412
+ probs = np.array(probs)
413
+ points = np.array(points)
414
+
415
+ if method == 'largest':
416
+ box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1]
417
+ elif method == 'probability':
418
+ box_order = np.argsort(probs)[::-1]
419
+ elif method == 'center_weighted_size':
420
+ box_sizes = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
421
+ img_center = (img.width / 2, img.height/2)
422
+ box_centers = np.array(list(zip((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2)))
423
+ offsets = box_centers - img_center
424
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 1)
425
+ box_order = np.argsort(box_sizes - offset_dist_squared * center_weight)[::-1]
426
+ elif method == 'largest_over_threshold':
427
+ box_mask = probs > threshold
428
+ boxes = boxes[box_mask]
429
+ box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1]
430
+ if sum(box_mask) == 0:
431
+ selected_boxes.append(None)
432
+ selected_probs.append([None])
433
+ selected_points.append(None)
434
+ continue
435
+
436
+ box = boxes[box_order][[0]]
437
+ prob = probs[box_order][[0]]
438
+ point = points[box_order][[0]]
439
+ selected_boxes.append(box)
440
+ selected_probs.append(prob)
441
+ selected_points.append(point)
442
+
443
+ if batch_mode:
444
+ selected_boxes = np.array(selected_boxes)
445
+ selected_probs = np.array(selected_probs)
446
+ selected_points = np.array(selected_points)
447
+ else:
448
+ selected_boxes = selected_boxes[0]
449
+ selected_probs = selected_probs[0][0]
450
+ selected_points = selected_points[0]
451
+
452
+ return selected_boxes, selected_probs, selected_points
453
+
454
+ def extract(self, img, batch_boxes, save_path):
455
+ # Determine if a batch or single image was passed
456
+ batch_mode = True
457
+ if (
458
+ not isinstance(img, (list, tuple)) and
459
+ not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
460
+ not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
461
+ ):
462
+ img = [img]
463
+ batch_boxes = [batch_boxes]
464
+ batch_mode = False
465
+
466
+ # Parse save path(s)
467
+ if save_path is not None:
468
+ if isinstance(save_path, str):
469
+ save_path = [save_path]
470
+ else:
471
+ save_path = [None for _ in range(len(img))]
472
+
473
+ # Process all bounding boxes
474
+ faces = []
475
+ for im, box_im, path_im in zip(img, batch_boxes, save_path):
476
+ if box_im is None:
477
+ faces.append(None)
478
+ continue
479
+
480
+ if not self.keep_all:
481
+ box_im = box_im[[0]]
482
+
483
+ faces_im = []
484
+ for i, box in enumerate(box_im):
485
+ face_path = path_im
486
+ if path_im is not None and i > 0:
487
+ save_name, ext = os.path.splitext(path_im)
488
+ face_path = save_name + '_' + str(i + 1) + ext
489
+
490
+ face = extract_face(im, box, self.image_size, self.margin, face_path)
491
+ if self.post_process:
492
+ face = fixed_image_standardization(face)
493
+ faces_im.append(face)
494
+
495
+ if self.keep_all:
496
+ faces_im = torch.stack(faces_im)
497
+ else:
498
+ faces_im = faces_im[0]
499
+
500
+ faces.append(faces_im)
501
+
502
+ if not batch_mode:
503
+ faces = faces[0]
504
+
505
+ return faces
506
+
507
+
508
+ def fixed_image_standardization(image_tensor):
509
+ processed_tensor = (image_tensor - 127.5) / 128.0
510
+ return processed_tensor
511
+
512
+
513
+ def prewhiten(x):
514
+ mean = x.mean()
515
+ std = x.std()
516
+ std_adj = std.clamp(min=1.0/(float(x.numel())**0.5))
517
+ y = (x - mean) / std_adj
518
+ return y
519
+
models/utils/detect_face.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import interpolate
3
+ from torchvision.transforms import functional as F
4
+ from torchvision.ops.boxes import batched_nms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import math
9
+
10
+ # OpenCV is optional, but required if using numpy arrays instead of PIL
11
+ try:
12
+ import cv2
13
+ except:
14
+ pass
15
+
16
+ def fixed_batch_process(im_data, model):
17
+ batch_size = 512
18
+ out = []
19
+ for i in range(0, len(im_data), batch_size):
20
+ batch = im_data[i:(i+batch_size)]
21
+ out.append(model(batch))
22
+
23
+ return tuple(torch.cat(v, dim=0) for v in zip(*out))
24
+
25
+ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
26
+ if isinstance(imgs, (np.ndarray, torch.Tensor)):
27
+ if isinstance(imgs,np.ndarray):
28
+ imgs = torch.as_tensor(imgs.copy(), device=device)
29
+
30
+ if isinstance(imgs,torch.Tensor):
31
+ imgs = torch.as_tensor(imgs, device=device)
32
+
33
+ if len(imgs.shape) == 3:
34
+ imgs = imgs.unsqueeze(0)
35
+ else:
36
+ if not isinstance(imgs, (list, tuple)):
37
+ imgs = [imgs]
38
+ if any(img.size != imgs[0].size for img in imgs):
39
+ raise Exception("MTCNN batch processing only compatible with equal-dimension images.")
40
+ imgs = np.stack([np.uint8(img) for img in imgs])
41
+ imgs = torch.as_tensor(imgs.copy(), device=device)
42
+
43
+
44
+
45
+ model_dtype = next(pnet.parameters()).dtype
46
+ imgs = imgs.permute(0, 3, 1, 2).type(model_dtype)
47
+
48
+ batch_size = len(imgs)
49
+ h, w = imgs.shape[2:4]
50
+ m = 12.0 / minsize
51
+ minl = min(h, w)
52
+ minl = minl * m
53
+
54
+ # Create scale pyramid
55
+ scale_i = m
56
+ scales = []
57
+ while minl >= 12:
58
+ scales.append(scale_i)
59
+ scale_i = scale_i * factor
60
+ minl = minl * factor
61
+
62
+ # First stage
63
+ boxes = []
64
+ image_inds = []
65
+
66
+ scale_picks = []
67
+
68
+ all_i = 0
69
+ offset = 0
70
+ for scale in scales:
71
+ im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1)))
72
+ im_data = (im_data - 127.5) * 0.0078125
73
+ reg, probs = pnet(im_data)
74
+
75
+ boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, threshold[0])
76
+ boxes.append(boxes_scale)
77
+ image_inds.append(image_inds_scale)
78
+
79
+ pick = batched_nms(boxes_scale[:, :4], boxes_scale[:, 4], image_inds_scale, 0.5)
80
+ scale_picks.append(pick + offset)
81
+ offset += boxes_scale.shape[0]
82
+
83
+ boxes = torch.cat(boxes, dim=0)
84
+ image_inds = torch.cat(image_inds, dim=0)
85
+
86
+ scale_picks = torch.cat(scale_picks, dim=0)
87
+
88
+ # NMS within each scale + image
89
+ boxes, image_inds = boxes[scale_picks], image_inds[scale_picks]
90
+
91
+
92
+ # NMS within each image
93
+ pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
94
+ boxes, image_inds = boxes[pick], image_inds[pick]
95
+
96
+ regw = boxes[:, 2] - boxes[:, 0]
97
+ regh = boxes[:, 3] - boxes[:, 1]
98
+ qq1 = boxes[:, 0] + boxes[:, 5] * regw
99
+ qq2 = boxes[:, 1] + boxes[:, 6] * regh
100
+ qq3 = boxes[:, 2] + boxes[:, 7] * regw
101
+ qq4 = boxes[:, 3] + boxes[:, 8] * regh
102
+ boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0)
103
+ boxes = rerec(boxes)
104
+ y, ey, x, ex = pad(boxes, w, h)
105
+
106
+ # Second stage
107
+ if len(boxes) > 0:
108
+ im_data = []
109
+ for k in range(len(y)):
110
+ if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
111
+ img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
112
+ im_data.append(imresample(img_k, (24, 24)))
113
+ im_data = torch.cat(im_data, dim=0)
114
+ im_data = (im_data - 127.5) * 0.0078125
115
+
116
+ # This is equivalent to out = rnet(im_data) to avoid GPU out of memory.
117
+ out = fixed_batch_process(im_data, rnet)
118
+
119
+ out0 = out[0].permute(1, 0)
120
+ out1 = out[1].permute(1, 0)
121
+ score = out1[1, :]
122
+ ipass = score > threshold[1]
123
+ boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
124
+ image_inds = image_inds[ipass]
125
+ mv = out0[:, ipass].permute(1, 0)
126
+
127
+ # NMS within each image
128
+ pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
129
+ boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick]
130
+ boxes = bbreg(boxes, mv)
131
+ boxes = rerec(boxes)
132
+
133
+ # Third stage
134
+ points = torch.zeros(0, 5, 2, device=device)
135
+ if len(boxes) > 0:
136
+ y, ey, x, ex = pad(boxes, w, h)
137
+ im_data = []
138
+ for k in range(len(y)):
139
+ if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
140
+ img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
141
+ im_data.append(imresample(img_k, (48, 48)))
142
+ im_data = torch.cat(im_data, dim=0)
143
+ im_data = (im_data - 127.5) * 0.0078125
144
+
145
+ # This is equivalent to out = onet(im_data) to avoid GPU out of memory.
146
+ out = fixed_batch_process(im_data, onet)
147
+
148
+ out0 = out[0].permute(1, 0)
149
+ out1 = out[1].permute(1, 0)
150
+ out2 = out[2].permute(1, 0)
151
+ score = out2[1, :]
152
+ points = out1
153
+ ipass = score > threshold[2]
154
+ points = points[:, ipass]
155
+ boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
156
+ image_inds = image_inds[ipass]
157
+ mv = out0[:, ipass].permute(1, 0)
158
+
159
+ w_i = boxes[:, 2] - boxes[:, 0] + 1
160
+ h_i = boxes[:, 3] - boxes[:, 1] + 1
161
+ points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1
162
+ points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1
163
+ points = torch.stack((points_x, points_y)).permute(2, 1, 0)
164
+ boxes = bbreg(boxes, mv)
165
+
166
+ # NMS within each image using "Min" strategy
167
+ # pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
168
+ pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min')
169
+ boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick]
170
+
171
+ boxes = boxes.cpu().numpy()
172
+ points = points.cpu().numpy()
173
+
174
+ image_inds = image_inds.cpu()
175
+
176
+ batch_boxes = []
177
+ batch_points = []
178
+ for b_i in range(batch_size):
179
+ b_i_inds = np.where(image_inds == b_i)
180
+ batch_boxes.append(boxes[b_i_inds].copy())
181
+ batch_points.append(points[b_i_inds].copy())
182
+
183
+ batch_boxes, batch_points = np.array(batch_boxes, dtype=object), np.array(batch_points, dtype=object)
184
+
185
+ return batch_boxes, batch_points
186
+
187
+
188
+ def bbreg(boundingbox, reg):
189
+ if reg.shape[1] == 1:
190
+ reg = torch.reshape(reg, (reg.shape[2], reg.shape[3]))
191
+
192
+ w = boundingbox[:, 2] - boundingbox[:, 0] + 1
193
+ h = boundingbox[:, 3] - boundingbox[:, 1] + 1
194
+ b1 = boundingbox[:, 0] + reg[:, 0] * w
195
+ b2 = boundingbox[:, 1] + reg[:, 1] * h
196
+ b3 = boundingbox[:, 2] + reg[:, 2] * w
197
+ b4 = boundingbox[:, 3] + reg[:, 3] * h
198
+ boundingbox[:, :4] = torch.stack([b1, b2, b3, b4]).permute(1, 0)
199
+
200
+ return boundingbox
201
+
202
+
203
+ def generateBoundingBox(reg, probs, scale, thresh):
204
+ stride = 2
205
+ cellsize = 12
206
+
207
+ reg = reg.permute(1, 0, 2, 3)
208
+
209
+ mask = probs >= thresh
210
+ mask_inds = mask.nonzero()
211
+ image_inds = mask_inds[:, 0]
212
+ score = probs[mask]
213
+ reg = reg[:, mask].permute(1, 0)
214
+ bb = mask_inds[:, 1:].type(reg.dtype).flip(1)
215
+ q1 = ((stride * bb + 1) / scale).floor()
216
+ q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor()
217
+ boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1)
218
+ return boundingbox, image_inds
219
+
220
+
221
+ def nms_numpy(boxes, scores, threshold, method):
222
+ if boxes.size == 0:
223
+ return np.empty((0, 3))
224
+
225
+ x1 = boxes[:, 0].copy()
226
+ y1 = boxes[:, 1].copy()
227
+ x2 = boxes[:, 2].copy()
228
+ y2 = boxes[:, 3].copy()
229
+ s = scores
230
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
231
+
232
+ I = np.argsort(s)
233
+ pick = np.zeros_like(s, dtype=np.int16)
234
+ counter = 0
235
+ while I.size > 0:
236
+ i = I[-1]
237
+ pick[counter] = i
238
+ counter += 1
239
+ idx = I[0:-1]
240
+
241
+ xx1 = np.maximum(x1[i], x1[idx]).copy()
242
+ yy1 = np.maximum(y1[i], y1[idx]).copy()
243
+ xx2 = np.minimum(x2[i], x2[idx]).copy()
244
+ yy2 = np.minimum(y2[i], y2[idx]).copy()
245
+
246
+ w = np.maximum(0.0, xx2 - xx1 + 1).copy()
247
+ h = np.maximum(0.0, yy2 - yy1 + 1).copy()
248
+
249
+ inter = w * h
250
+ if method == 'Min':
251
+ o = inter / np.minimum(area[i], area[idx])
252
+ else:
253
+ o = inter / (area[i] + area[idx] - inter)
254
+ I = I[np.where(o <= threshold)]
255
+
256
+ pick = pick[:counter].copy()
257
+ return pick
258
+
259
+
260
+ def batched_nms_numpy(boxes, scores, idxs, threshold, method):
261
+ device = boxes.device
262
+ if boxes.numel() == 0:
263
+ return torch.empty((0,), dtype=torch.int64, device=device)
264
+ # strategy: in order to perform NMS independently per class.
265
+ # we add an offset to all the boxes. The offset is dependent
266
+ # only on the class idx, and is large enough so that boxes
267
+ # from different classes do not overlap
268
+ max_coordinate = boxes.max()
269
+ offsets = idxs.to(boxes) * (max_coordinate + 1)
270
+ boxes_for_nms = boxes + offsets[:, None]
271
+ boxes_for_nms = boxes_for_nms.cpu().numpy()
272
+ scores = scores.cpu().numpy()
273
+ keep = nms_numpy(boxes_for_nms, scores, threshold, method)
274
+ return torch.as_tensor(keep, dtype=torch.long, device=device)
275
+
276
+
277
+ def pad(boxes, w, h):
278
+ boxes = boxes.trunc().int().cpu().numpy()
279
+ x = boxes[:, 0]
280
+ y = boxes[:, 1]
281
+ ex = boxes[:, 2]
282
+ ey = boxes[:, 3]
283
+
284
+ x[x < 1] = 1
285
+ y[y < 1] = 1
286
+ ex[ex > w] = w
287
+ ey[ey > h] = h
288
+
289
+ return y, ey, x, ex
290
+
291
+
292
+ def rerec(bboxA):
293
+ h = bboxA[:, 3] - bboxA[:, 1]
294
+ w = bboxA[:, 2] - bboxA[:, 0]
295
+
296
+ l = torch.max(w, h)
297
+ bboxA[:, 0] = bboxA[:, 0] + w * 0.5 - l * 0.5
298
+ bboxA[:, 1] = bboxA[:, 1] + h * 0.5 - l * 0.5
299
+ bboxA[:, 2:4] = bboxA[:, :2] + l.repeat(2, 1).permute(1, 0)
300
+
301
+ return bboxA
302
+
303
+
304
+ def imresample(img, sz):
305
+ im_data = interpolate(img, size=sz, mode="area")
306
+ return im_data
307
+
308
+
309
+ def crop_resize(img, box, image_size):
310
+ if isinstance(img, np.ndarray):
311
+ img = img[box[1]:box[3], box[0]:box[2]]
312
+ out = cv2.resize(
313
+ img,
314
+ (image_size, image_size),
315
+ interpolation=cv2.INTER_AREA
316
+ ).copy()
317
+ elif isinstance(img, torch.Tensor):
318
+ img = img[box[1]:box[3], box[0]:box[2]]
319
+ out = imresample(
320
+ img.permute(2, 0, 1).unsqueeze(0).float(),
321
+ (image_size, image_size)
322
+ ).byte().squeeze(0).permute(1, 2, 0)
323
+ else:
324
+ out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
325
+ return out
326
+
327
+
328
+ def save_img(img, path):
329
+ if isinstance(img, np.ndarray):
330
+ cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
331
+ else:
332
+ img.save(path)
333
+
334
+
335
+ def get_size(img):
336
+ if isinstance(img, (np.ndarray, torch.Tensor)):
337
+ return img.shape[1::-1]
338
+ else:
339
+ return img.size
340
+
341
+
342
+ def extract_face(img, box, image_size=160, margin=0, save_path=None):
343
+ """Extract face + margin from PIL Image given bounding box.
344
+
345
+ Arguments:
346
+ img {PIL.Image} -- A PIL Image.
347
+ box {numpy.ndarray} -- Four-element bounding box.
348
+ image_size {int} -- Output image size in pixels. The image will be square.
349
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
350
+ Note that the application of the margin differs slightly from the davidsandberg/facenet
351
+ repo, which applies the margin to the original image before resizing, making the margin
352
+ dependent on the original image size.
353
+ save_path {str} -- Save path for extracted face image. (default: {None})
354
+
355
+ Returns:
356
+ torch.tensor -- tensor representing the extracted face.
357
+ """
358
+ margin = [
359
+ margin * (box[2] - box[0]) / (image_size - margin),
360
+ margin * (box[3] - box[1]) / (image_size - margin),
361
+ ]
362
+ raw_image_size = get_size(img)
363
+ box = [
364
+ int(max(box[0] - margin[0] / 2, 0)),
365
+ int(max(box[1] - margin[1] / 2, 0)),
366
+ int(min(box[2] + margin[0] / 2, raw_image_size[0])),
367
+ int(min(box[3] + margin[1] / 2, raw_image_size[1])),
368
+ ]
369
+
370
+ face = crop_resize(img, box, image_size)
371
+
372
+ if save_path is not None:
373
+ os.makedirs(os.path.dirname(save_path) + "/", exist_ok=True)
374
+ save_img(face, save_path)
375
+
376
+ face = F.to_tensor(np.float32(face))
377
+
378
+ return face
models/utils/download.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import tempfile
6
+
7
+ from urllib.request import urlopen, Request
8
+
9
+ try:
10
+ from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
11
+ except ImportError:
12
+ try:
13
+ from tqdm import tqdm
14
+ except ImportError:
15
+ # fake tqdm if it's not installed
16
+ class tqdm(object): # type: ignore
17
+
18
+ def __init__(self, total=None, disable=False,
19
+ unit=None, unit_scale=None, unit_divisor=None):
20
+ self.total = total
21
+ self.disable = disable
22
+ self.n = 0
23
+ # ignore unit, unit_scale, unit_divisor; they're just for real tqdm
24
+
25
+ def update(self, n):
26
+ if self.disable:
27
+ return
28
+
29
+ self.n += n
30
+ if self.total is None:
31
+ sys.stderr.write("\r{0:.1f} bytes".format(self.n))
32
+ else:
33
+ sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
34
+ sys.stderr.flush()
35
+
36
+ def __enter__(self):
37
+ return self
38
+
39
+ def __exit__(self, exc_type, exc_val, exc_tb):
40
+ if self.disable:
41
+ return
42
+
43
+ sys.stderr.write('\n')
44
+
45
+
46
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
47
+ r"""Download object at the given URL to a local path.
48
+ Args:
49
+ url (string): URL of the object to download
50
+ dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
51
+ hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
52
+ Default: None
53
+ progress (bool, optional): whether or not to display a progress bar to stderr
54
+ Default: True
55
+ Example:
56
+ >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
57
+ """
58
+ file_size = None
59
+ # We use a different API for python2 since urllib(2) doesn't recognize the CA
60
+ # certificates in older Python
61
+ req = Request(url, headers={"User-Agent": "torch.hub"})
62
+ u = urlopen(req)
63
+ meta = u.info()
64
+ if hasattr(meta, 'getheaders'):
65
+ content_length = meta.getheaders("Content-Length")
66
+ else:
67
+ content_length = meta.get_all("Content-Length")
68
+ if content_length is not None and len(content_length) > 0:
69
+ file_size = int(content_length[0])
70
+
71
+ # We deliberately save it in a temp file and move it after
72
+ # download is complete. This prevents a local working checkpoint
73
+ # being overridden by a broken download.
74
+ dst = os.path.expanduser(dst)
75
+ dst_dir = os.path.dirname(dst)
76
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
77
+
78
+ try:
79
+ if hash_prefix is not None:
80
+ sha256 = hashlib.sha256()
81
+ with tqdm(total=file_size, disable=not progress,
82
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
83
+ while True:
84
+ buffer = u.read(8192)
85
+ if len(buffer) == 0:
86
+ break
87
+ f.write(buffer)
88
+ if hash_prefix is not None:
89
+ sha256.update(buffer)
90
+ pbar.update(len(buffer))
91
+
92
+ f.close()
93
+ if hash_prefix is not None:
94
+ digest = sha256.hexdigest()
95
+ if digest[:len(hash_prefix)] != hash_prefix:
96
+ raise RuntimeError('invalid hash value (expected "{}", got "{}")'
97
+ .format(hash_prefix, digest))
98
+ shutil.move(f.name, dst)
99
+ finally:
100
+ f.close()
101
+ if os.path.exists(f.name):
102
+ os.remove(f.name)
models/utils/tensorflow2pytorch.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import torch
3
+ import json
4
+ import os, sys
5
+
6
+ from dependencies.facenet.src import facenet
7
+ from dependencies.facenet.src.models import inception_resnet_v1 as tf_mdl
8
+ from dependencies.facenet.src.align import detect_face
9
+
10
+ from models.inception_resnet_v1 import InceptionResnetV1
11
+ from models.mtcnn import PNet, RNet, ONet
12
+
13
+
14
+ def import_tf_params(tf_mdl_dir, sess):
15
+ """Import tensorflow model from save directory.
16
+
17
+ Arguments:
18
+ tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
19
+ sess {tensorflow.Session} -- Tensorflow session object.
20
+
21
+ Returns:
22
+ (list, list, list) -- Tuple of lists containing the layer names,
23
+ parameter arrays as numpy ndarrays, parameter shapes.
24
+ """
25
+ print('\nLoading tensorflow model\n')
26
+ if callable(tf_mdl_dir):
27
+ tf_mdl_dir(sess)
28
+ else:
29
+ facenet.load_model(tf_mdl_dir)
30
+
31
+ print('\nGetting model weights\n')
32
+ tf_layers = tf.trainable_variables()
33
+ tf_params = sess.run(tf_layers)
34
+
35
+ tf_shapes = [p.shape for p in tf_params]
36
+ tf_layers = [l.name for l in tf_layers]
37
+
38
+ if not callable(tf_mdl_dir):
39
+ path = os.path.join(tf_mdl_dir, 'layer_description.json')
40
+ else:
41
+ path = 'data/layer_description.json'
42
+ with open(path, 'w') as f:
43
+ json.dump({l: s for l, s in zip(tf_layers, tf_shapes)}, f)
44
+
45
+ return tf_layers, tf_params, tf_shapes
46
+
47
+
48
+ def get_layer_indices(layer_lookup, tf_layers):
49
+ """Giving a lookup of model layer attribute names and tensorflow variable names,
50
+ find matching parameters.
51
+
52
+ Arguments:
53
+ layer_lookup {dict} -- Dictionary mapping pytorch attribute names to (partial)
54
+ tensorflow variable names. Expects dict of the form {'attr': ['tf_name', ...]}
55
+ where the '...'s are ignored.
56
+ tf_layers {list} -- List of tensorflow variable names.
57
+
58
+ Returns:
59
+ list -- The input dictionary with the list of matching inds appended to each item.
60
+ """
61
+ layer_inds = {}
62
+ for name, value in layer_lookup.items():
63
+ layer_inds[name] = value + [[i for i, n in enumerate(tf_layers) if value[0] in n]]
64
+ return layer_inds
65
+
66
+
67
+ def load_tf_batchNorm(weights, layer):
68
+ """Load tensorflow weights into nn.BatchNorm object.
69
+
70
+ Arguments:
71
+ weights {list} -- Tensorflow parameters.
72
+ layer {torch.nn.Module} -- nn.BatchNorm.
73
+ """
74
+ layer.bias.data = torch.tensor(weights[0]).view(layer.bias.data.shape)
75
+ layer.weight.data = torch.ones_like(layer.weight.data)
76
+ layer.running_mean = torch.tensor(weights[1]).view(layer.running_mean.shape)
77
+ layer.running_var = torch.tensor(weights[2]).view(layer.running_var.shape)
78
+
79
+
80
+ def load_tf_conv2d(weights, layer, transpose=False):
81
+ """Load tensorflow weights into nn.Conv2d object.
82
+
83
+ Arguments:
84
+ weights {list} -- Tensorflow parameters.
85
+ layer {torch.nn.Module} -- nn.Conv2d.
86
+ """
87
+ if isinstance(weights, list):
88
+ if len(weights) == 2:
89
+ layer.bias.data = (
90
+ torch.tensor(weights[1])
91
+ .view(layer.bias.data.shape)
92
+ )
93
+ weights = weights[0]
94
+
95
+ if transpose:
96
+ dim_order = (3, 2, 1, 0)
97
+ else:
98
+ dim_order = (3, 2, 0, 1)
99
+
100
+ layer.weight.data = (
101
+ torch.tensor(weights)
102
+ .permute(dim_order)
103
+ .view(layer.weight.data.shape)
104
+ )
105
+
106
+
107
+ def load_tf_conv2d_trans(weights, layer):
108
+ return load_tf_conv2d(weights, layer, transpose=True)
109
+
110
+
111
+ def load_tf_basicConv2d(weights, layer):
112
+ """Load tensorflow weights into grouped Conv2d+BatchNorm object.
113
+
114
+ Arguments:
115
+ weights {list} -- Tensorflow parameters.
116
+ layer {torch.nn.Module} -- Object containing Conv2d+BatchNorm.
117
+ """
118
+ load_tf_conv2d(weights[0], layer.conv)
119
+ load_tf_batchNorm(weights[1:], layer.bn)
120
+
121
+
122
+ def load_tf_linear(weights, layer):
123
+ """Load tensorflow weights into nn.Linear object.
124
+
125
+ Arguments:
126
+ weights {list} -- Tensorflow parameters.
127
+ layer {torch.nn.Module} -- nn.Linear.
128
+ """
129
+ if isinstance(weights, list):
130
+ if len(weights) == 2:
131
+ layer.bias.data = (
132
+ torch.tensor(weights[1])
133
+ .view(layer.bias.data.shape)
134
+ )
135
+ weights = weights[0]
136
+ layer.weight.data = (
137
+ torch.tensor(weights)
138
+ .transpose(-1, 0)
139
+ .view(layer.weight.data.shape)
140
+ )
141
+
142
+
143
+ # High-level parameter-loading functions:
144
+
145
+ def load_tf_block35(weights, layer):
146
+ load_tf_basicConv2d(weights[:4], layer.branch0)
147
+ load_tf_basicConv2d(weights[4:8], layer.branch1[0])
148
+ load_tf_basicConv2d(weights[8:12], layer.branch1[1])
149
+ load_tf_basicConv2d(weights[12:16], layer.branch2[0])
150
+ load_tf_basicConv2d(weights[16:20], layer.branch2[1])
151
+ load_tf_basicConv2d(weights[20:24], layer.branch2[2])
152
+ load_tf_conv2d(weights[24:26], layer.conv2d)
153
+
154
+
155
+ def load_tf_block17_8(weights, layer):
156
+ load_tf_basicConv2d(weights[:4], layer.branch0)
157
+ load_tf_basicConv2d(weights[4:8], layer.branch1[0])
158
+ load_tf_basicConv2d(weights[8:12], layer.branch1[1])
159
+ load_tf_basicConv2d(weights[12:16], layer.branch1[2])
160
+ load_tf_conv2d(weights[16:18], layer.conv2d)
161
+
162
+
163
+ def load_tf_mixed6a(weights, layer):
164
+ if len(weights) != 16:
165
+ raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 16')
166
+ load_tf_basicConv2d(weights[:4], layer.branch0)
167
+ load_tf_basicConv2d(weights[4:8], layer.branch1[0])
168
+ load_tf_basicConv2d(weights[8:12], layer.branch1[1])
169
+ load_tf_basicConv2d(weights[12:16], layer.branch1[2])
170
+
171
+
172
+ def load_tf_mixed7a(weights, layer):
173
+ if len(weights) != 28:
174
+ raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 28')
175
+ load_tf_basicConv2d(weights[:4], layer.branch0[0])
176
+ load_tf_basicConv2d(weights[4:8], layer.branch0[1])
177
+ load_tf_basicConv2d(weights[8:12], layer.branch1[0])
178
+ load_tf_basicConv2d(weights[12:16], layer.branch1[1])
179
+ load_tf_basicConv2d(weights[16:20], layer.branch2[0])
180
+ load_tf_basicConv2d(weights[20:24], layer.branch2[1])
181
+ load_tf_basicConv2d(weights[24:28], layer.branch2[2])
182
+
183
+
184
+ def load_tf_repeats(weights, layer, rptlen, subfun):
185
+ if len(weights) % rptlen != 0:
186
+ raise ValueError(f'Number of weight arrays ({len(weights)}) not divisible by {rptlen}')
187
+ weights_split = [weights[i:i+rptlen] for i in range(0, len(weights), rptlen)]
188
+ for i, w in enumerate(weights_split):
189
+ subfun(w, getattr(layer, str(i)))
190
+
191
+
192
+ def load_tf_repeat_1(weights, layer):
193
+ load_tf_repeats(weights, layer, 26, load_tf_block35)
194
+
195
+
196
+ def load_tf_repeat_2(weights, layer):
197
+ load_tf_repeats(weights, layer, 18, load_tf_block17_8)
198
+
199
+
200
+ def load_tf_repeat_3(weights, layer):
201
+ load_tf_repeats(weights, layer, 18, load_tf_block17_8)
202
+
203
+
204
+ def test_loaded_params(mdl, tf_params, tf_layers):
205
+ """Check each parameter in a pytorch model for an equivalent parameter
206
+ in a list of tensorflow variables.
207
+
208
+ Arguments:
209
+ mdl {torch.nn.Module} -- Pytorch model.
210
+ tf_params {list} -- List of ndarrays representing tensorflow variables.
211
+ tf_layers {list} -- Corresponding list of tensorflow variable names.
212
+ """
213
+ tf_means = torch.stack([torch.tensor(p).mean() for p in tf_params])
214
+ for name, param in mdl.named_parameters():
215
+ pt_mean = param.data.mean()
216
+ matching_inds = ((tf_means - pt_mean).abs() < 1e-8).nonzero()
217
+ print(f'{name} equivalent to {[tf_layers[i] for i in matching_inds]}')
218
+
219
+
220
+ def compare_model_outputs(pt_mdl, sess, test_data):
221
+ """Given some testing data, compare the output of pytorch and tensorflow models.
222
+
223
+ Arguments:
224
+ pt_mdl {torch.nn.Module} -- Pytorch model.
225
+ sess {tensorflow.Session} -- Tensorflow session object.
226
+ test_data {torch.Tensor} -- Pytorch tensor.
227
+ """
228
+ print('\nPassing test data through TF model\n')
229
+ if isinstance(sess, tf.Session):
230
+ images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
231
+ phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
232
+ embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
233
+ feed_dict = {images_placeholder: test_data.numpy(), phase_train_placeholder: False}
234
+ tf_output = torch.tensor(sess.run(embeddings, feed_dict=feed_dict))
235
+ else:
236
+ tf_output = sess(test_data)
237
+
238
+ print(tf_output)
239
+
240
+ print('\nPassing test data through PT model\n')
241
+ pt_output = pt_mdl(test_data.permute(0, 3, 1, 2))
242
+ print(pt_output)
243
+
244
+ distance = (tf_output - pt_output).norm()
245
+ print(f'\nDistance {distance}\n')
246
+
247
+
248
+ def compare_mtcnn(pt_mdl, tf_fun, sess, ind, test_data):
249
+ tf_mdls = tf_fun(sess)
250
+ tf_mdl = tf_mdls[ind]
251
+
252
+ print('\nPassing test data through TF model\n')
253
+ tf_output = tf_mdl(test_data.numpy())
254
+ tf_output = [torch.tensor(out) for out in tf_output]
255
+ print('\n'.join([str(o.view(-1)[:10]) for o in tf_output]))
256
+
257
+ print('\nPassing test data through PT model\n')
258
+ with torch.no_grad():
259
+ pt_output = pt_mdl(test_data.permute(0, 3, 2, 1))
260
+ pt_output = [torch.tensor(out) for out in pt_output]
261
+ for i in range(len(pt_output)):
262
+ if len(pt_output[i].shape) == 4:
263
+ pt_output[i] = pt_output[i].permute(0, 3, 2, 1).contiguous()
264
+ print('\n'.join([str(o.view(-1)[:10]) for o in pt_output]))
265
+
266
+ distance = [(tf_o - pt_o).norm() for tf_o, pt_o in zip(tf_output, pt_output)]
267
+ print(f'\nDistance {distance}\n')
268
+
269
+
270
+ def load_tf_model_weights(mdl, layer_lookup, tf_mdl_dir, is_resnet=True, arg_num=None):
271
+ """Load tensorflow parameters into a pytorch model.
272
+
273
+ Arguments:
274
+ mdl {torch.nn.Module} -- Pytorch model.
275
+ layer_lookup {[type]} -- Dictionary mapping pytorch attribute names to (partial)
276
+ tensorflow variable names, and a function suitable for loading weights.
277
+ Expects dict of the form {'attr': ['tf_name', function]}.
278
+ tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
279
+ """
280
+ tf.reset_default_graph()
281
+ with tf.Session() as sess:
282
+ tf_layers, tf_params, tf_shapes = import_tf_params(tf_mdl_dir, sess)
283
+ layer_info = get_layer_indices(layer_lookup, tf_layers)
284
+
285
+ for layer_name, info in layer_info.items():
286
+ print(f'Loading {info[0]}/* into {layer_name}')
287
+ weights = [tf_params[i] for i in info[2]]
288
+ layer = getattr(mdl, layer_name)
289
+ info[1](weights, layer)
290
+
291
+ test_loaded_params(mdl, tf_params, tf_layers)
292
+
293
+ if is_resnet:
294
+ compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach())
295
+
296
+
297
+ def tensorflow2pytorch():
298
+ lookup_inception_resnet_v1 = {
299
+ 'conv2d_1a': ['InceptionResnetV1/Conv2d_1a_3x3', load_tf_basicConv2d],
300
+ 'conv2d_2a': ['InceptionResnetV1/Conv2d_2a_3x3', load_tf_basicConv2d],
301
+ 'conv2d_2b': ['InceptionResnetV1/Conv2d_2b_3x3', load_tf_basicConv2d],
302
+ 'conv2d_3b': ['InceptionResnetV1/Conv2d_3b_1x1', load_tf_basicConv2d],
303
+ 'conv2d_4a': ['InceptionResnetV1/Conv2d_4a_3x3', load_tf_basicConv2d],
304
+ 'conv2d_4b': ['InceptionResnetV1/Conv2d_4b_3x3', load_tf_basicConv2d],
305
+ 'repeat_1': ['InceptionResnetV1/Repeat/block35', load_tf_repeat_1],
306
+ 'mixed_6a': ['InceptionResnetV1/Mixed_6a', load_tf_mixed6a],
307
+ 'repeat_2': ['InceptionResnetV1/Repeat_1/block17', load_tf_repeat_2],
308
+ 'mixed_7a': ['InceptionResnetV1/Mixed_7a', load_tf_mixed7a],
309
+ 'repeat_3': ['InceptionResnetV1/Repeat_2/block8', load_tf_repeat_3],
310
+ 'block8': ['InceptionResnetV1/Block8', load_tf_block17_8],
311
+ 'last_linear': ['InceptionResnetV1/Bottleneck/weights', load_tf_linear],
312
+ 'last_bn': ['InceptionResnetV1/Bottleneck/BatchNorm', load_tf_batchNorm],
313
+ 'logits': ['Logits', load_tf_linear],
314
+ }
315
+
316
+ print('\nLoad VGGFace2-trained weights and save\n')
317
+ mdl = InceptionResnetV1(num_classes=8631).eval()
318
+ tf_mdl_dir = 'data/20180402-114759'
319
+ data_name = 'vggface2'
320
+ load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
321
+ state_dict = mdl.state_dict()
322
+ torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
323
+ torch.save(
324
+ {
325
+ 'logits.weight': state_dict['logits.weight'],
326
+ 'logits.bias': state_dict['logits.bias'],
327
+ },
328
+ f'{tf_mdl_dir}-{data_name}-logits.pt'
329
+ )
330
+ state_dict.pop('logits.weight')
331
+ state_dict.pop('logits.bias')
332
+ torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
333
+
334
+ print('\nLoad CASIA-Webface-trained weights and save\n')
335
+ mdl = InceptionResnetV1(num_classes=10575).eval()
336
+ tf_mdl_dir = 'data/20180408-102900'
337
+ data_name = 'casia-webface'
338
+ load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
339
+ state_dict = mdl.state_dict()
340
+ torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
341
+ torch.save(
342
+ {
343
+ 'logits.weight': state_dict['logits.weight'],
344
+ 'logits.bias': state_dict['logits.bias'],
345
+ },
346
+ f'{tf_mdl_dir}-{data_name}-logits.pt'
347
+ )
348
+ state_dict.pop('logits.weight')
349
+ state_dict.pop('logits.bias')
350
+ torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
351
+
352
+ lookup_pnet = {
353
+ 'conv1': ['pnet/conv1', load_tf_conv2d_trans],
354
+ 'prelu1': ['pnet/PReLU1', load_tf_linear],
355
+ 'conv2': ['pnet/conv2', load_tf_conv2d_trans],
356
+ 'prelu2': ['pnet/PReLU2', load_tf_linear],
357
+ 'conv3': ['pnet/conv3', load_tf_conv2d_trans],
358
+ 'prelu3': ['pnet/PReLU3', load_tf_linear],
359
+ 'conv4_1': ['pnet/conv4-1', load_tf_conv2d_trans],
360
+ 'conv4_2': ['pnet/conv4-2', load_tf_conv2d_trans],
361
+ }
362
+ lookup_rnet = {
363
+ 'conv1': ['rnet/conv1', load_tf_conv2d_trans],
364
+ 'prelu1': ['rnet/prelu1', load_tf_linear],
365
+ 'conv2': ['rnet/conv2', load_tf_conv2d_trans],
366
+ 'prelu2': ['rnet/prelu2', load_tf_linear],
367
+ 'conv3': ['rnet/conv3', load_tf_conv2d_trans],
368
+ 'prelu3': ['rnet/prelu3', load_tf_linear],
369
+ 'dense4': ['rnet/conv4', load_tf_linear],
370
+ 'prelu4': ['rnet/prelu4', load_tf_linear],
371
+ 'dense5_1': ['rnet/conv5-1', load_tf_linear],
372
+ 'dense5_2': ['rnet/conv5-2', load_tf_linear],
373
+ }
374
+ lookup_onet = {
375
+ 'conv1': ['onet/conv1', load_tf_conv2d_trans],
376
+ 'prelu1': ['onet/prelu1', load_tf_linear],
377
+ 'conv2': ['onet/conv2', load_tf_conv2d_trans],
378
+ 'prelu2': ['onet/prelu2', load_tf_linear],
379
+ 'conv3': ['onet/conv3', load_tf_conv2d_trans],
380
+ 'prelu3': ['onet/prelu3', load_tf_linear],
381
+ 'conv4': ['onet/conv4', load_tf_conv2d_trans],
382
+ 'prelu4': ['onet/prelu4', load_tf_linear],
383
+ 'dense5': ['onet/conv5', load_tf_linear],
384
+ 'prelu5': ['onet/prelu5', load_tf_linear],
385
+ 'dense6_1': ['onet/conv6-1', load_tf_linear],
386
+ 'dense6_2': ['onet/conv6-2', load_tf_linear],
387
+ 'dense6_3': ['onet/conv6-3', load_tf_linear],
388
+ }
389
+
390
+ print('\nLoad PNet weights and save\n')
391
+ tf_mdl_dir = lambda sess: detect_face.create_mtcnn(sess, None)
392
+ mdl = PNet()
393
+ data_name = 'pnet'
394
+ load_tf_model_weights(mdl, lookup_pnet, tf_mdl_dir, is_resnet=False, arg_num=0)
395
+ torch.save(mdl.state_dict(), f'data/{data_name}.pt')
396
+ tf.reset_default_graph()
397
+ with tf.Session() as sess:
398
+ compare_mtcnn(mdl, tf_mdl_dir, sess, 0, torch.randn(1, 256, 256, 3).detach())
399
+
400
+ print('\nLoad RNet weights and save\n')
401
+ mdl = RNet()
402
+ data_name = 'rnet'
403
+ load_tf_model_weights(mdl, lookup_rnet, tf_mdl_dir, is_resnet=False, arg_num=1)
404
+ torch.save(mdl.state_dict(), f'data/{data_name}.pt')
405
+ tf.reset_default_graph()
406
+ with tf.Session() as sess:
407
+ compare_mtcnn(mdl, tf_mdl_dir, sess, 1, torch.randn(1, 24, 24, 3).detach())
408
+
409
+ print('\nLoad ONet weights and save\n')
410
+ mdl = ONet()
411
+ data_name = 'onet'
412
+ load_tf_model_weights(mdl, lookup_onet, tf_mdl_dir, is_resnet=False, arg_num=2)
413
+ torch.save(mdl.state_dict(), f'data/{data_name}.pt')
414
+ tf.reset_default_graph()
415
+ with tf.Session() as sess:
416
+ compare_mtcnn(mdl, tf_mdl_dir, sess, 2, torch.randn(1, 48, 48, 3).detach())
models/utils/training.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import time
4
+
5
+
6
+ class Logger(object):
7
+
8
+ def __init__(self, mode, length, calculate_mean=False):
9
+ self.mode = mode
10
+ self.length = length
11
+ self.calculate_mean = calculate_mean
12
+ if self.calculate_mean:
13
+ self.fn = lambda x, i: x / (i + 1)
14
+ else:
15
+ self.fn = lambda x, i: x
16
+
17
+ def __call__(self, loss, metrics, i):
18
+ track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
19
+ loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
20
+ metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
21
+ print(track_str + loss_str + metric_str + ' ', end='')
22
+ if i + 1 == self.length:
23
+ print('')
24
+
25
+
26
+ class BatchTimer(object):
27
+ """Batch timing class.
28
+ Use this class for tracking training and testing time/rate per batch or per sample.
29
+
30
+ Keyword Arguments:
31
+ rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
32
+ per batch or sample). (default: {True})
33
+ per_sample {bool} -- Whether to report times or rates per sample or per batch.
34
+ (default: {True})
35
+ """
36
+
37
+ def __init__(self, rate=True, per_sample=True):
38
+ self.start = time.time()
39
+ self.end = None
40
+ self.rate = rate
41
+ self.per_sample = per_sample
42
+
43
+ def __call__(self, y_pred, y):
44
+ self.end = time.time()
45
+ elapsed = self.end - self.start
46
+ self.start = self.end
47
+ self.end = None
48
+
49
+ if self.per_sample:
50
+ elapsed /= len(y_pred)
51
+ if self.rate:
52
+ elapsed = 1 / elapsed
53
+
54
+ return torch.tensor(elapsed)
55
+
56
+
57
+ def accuracy(logits, y):
58
+ _, preds = torch.max(logits, 1)
59
+ return (preds == y).float().mean()
60
+
61
+
62
+ def pass_epoch(
63
+ model, loss_fn, loader, optimizer=None, scheduler=None,
64
+ batch_metrics={'time': BatchTimer()}, show_running=True,
65
+ device='cpu', writer=None
66
+ ):
67
+ """Train or evaluate over a data epoch.
68
+
69
+ Arguments:
70
+ model {torch.nn.Module} -- Pytorch model.
71
+ loss_fn {callable} -- A function to compute (scalar) loss.
72
+ loader {torch.utils.data.DataLoader} -- A pytorch data loader.
73
+
74
+ Keyword Arguments:
75
+ optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
76
+ scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
77
+ batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
78
+ is a simple timer. A progressive average of these metrics, along with the average
79
+ loss, is printed every batch. (default: {{'time': iter_timer()}})
80
+ show_running {bool} -- Whether or not to print losses and metrics for the current batch
81
+ or rolling averages. (default: {False})
82
+ device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
83
+ writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
84
+
85
+ Returns:
86
+ tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
87
+ metric values across the epoch.
88
+ """
89
+
90
+ mode = 'Train' if model.training else 'Valid'
91
+ logger = Logger(mode, length=len(loader), calculate_mean=show_running)
92
+ loss = 0
93
+ metrics = {}
94
+
95
+ for i_batch, (x, y) in enumerate(loader):
96
+ x = x.to(device)
97
+ y = y.to(device)
98
+ y_pred = model(x)
99
+ loss_batch = loss_fn(y_pred, y)
100
+
101
+ if model.training:
102
+ loss_batch.backward()
103
+ optimizer.step()
104
+ optimizer.zero_grad()
105
+
106
+ metrics_batch = {}
107
+ for metric_name, metric_fn in batch_metrics.items():
108
+ metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
109
+ metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
110
+
111
+ if writer is not None and model.training:
112
+ if writer.iteration % writer.interval == 0:
113
+ writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
114
+ for metric_name, metric_batch in metrics_batch.items():
115
+ writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
116
+ writer.iteration += 1
117
+
118
+ loss_batch = loss_batch.detach().cpu()
119
+ loss += loss_batch
120
+ if show_running:
121
+ logger(loss, metrics, i_batch)
122
+ else:
123
+ logger(loss_batch, metrics_batch, i_batch)
124
+
125
+ if model.training and scheduler is not None:
126
+ scheduler.step()
127
+
128
+ loss = loss / (i_batch + 1)
129
+ metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
130
+
131
+ if writer is not None and not model.training:
132
+ writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
133
+ for metric_name, metric in metrics.items():
134
+ writer.add_scalars(metric_name, {mode: metric})
135
+
136
+ return loss, metrics
137
+
138
+
139
+ def collate_pil(x):
140
+ out_x, out_y = [], []
141
+ for xx, yy in x:
142
+ out_x.append(xx)
143
+ out_y.append(yy)
144
+ return out_x, out_y