alexnasa commited on
Commit
cf92dec
·
verified ·
1 Parent(s): f90ed46

Upload 66 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. .gitignore +21 -0
  3. LICENSE +407 -0
  4. assets/flame_mirror_index.npy +3 -0
  5. assets/flame_uv_coords.npy +3 -0
  6. assets/flame_vertex_weights.npy +3 -0
  7. assets/head_template.obj +0 -0
  8. assets/head_template_color.obj +0 -0
  9. assets/head_template_color_tex.png +3 -0
  10. assets/test_rigid.ply +3 -0
  11. assets/uv_mask_eyes.png +0 -0
  12. assets/uv_valid_verty.npy +3 -0
  13. assets/uv_valid_verty_noEyes.npy +3 -0
  14. assets/uv_valid_verty_noEyes_debug.npy +3 -0
  15. assets/uv_valid_verty_noEyes_noEyeRegion_debug_wEars.npy +3 -0
  16. bin/release.sh +21 -0
  17. configs/base.yaml +234 -0
  18. configs/tracking.yaml +110 -0
  19. environment.yml +76 -0
  20. example_videos/ex1.mp4 +3 -0
  21. example_videos/ex2.mp4 +3 -0
  22. example_videos/ex3.mp4 +3 -0
  23. example_videos/ex4.mp4 +3 -0
  24. example_videos/ex5.mp4 +3 -0
  25. install_preprocessing_pipeline.sh +42 -0
  26. media/banner.gif +3 -0
  27. pyproject.toml +35 -0
  28. requirements.txt +26 -0
  29. scripts/.gitkeep +0 -0
  30. scripts/network_inference.py +229 -0
  31. scripts/run_cropping.py +107 -0
  32. scripts/run_facer_segmentation.py +221 -0
  33. scripts/run_preprocessing.py +23 -0
  34. scripts/track.py +27 -0
  35. scripts/viz_head_centric_cameras.py +103 -0
  36. setup.py +7 -0
  37. src/__init__.py +0 -0
  38. src/pixel3dmm/__init__.py +0 -0
  39. src/pixel3dmm/env_paths.py +34 -0
  40. src/pixel3dmm/lightning/p3dmm_network.py +0 -0
  41. src/pixel3dmm/lightning/p3dmm_system.py +491 -0
  42. src/pixel3dmm/lightning/utils.py +119 -0
  43. src/pixel3dmm/preprocessing/__init__.py +0 -0
  44. src/pixel3dmm/preprocessing/pipnet_utils.py +348 -0
  45. src/pixel3dmm/preprocessing/replacement_code/__init__.py +0 -0
  46. src/pixel3dmm/preprocessing/replacement_code/facer_transform.py +397 -0
  47. src/pixel3dmm/preprocessing/replacement_code/farl.py +94 -0
  48. src/pixel3dmm/preprocessing/replacement_code/mica.py +120 -0
  49. src/pixel3dmm/preprocessing/replacement_code/mica_demo.py +188 -0
  50. src/pixel3dmm/preprocessing/replacement_code/pipnet_demo.py +401 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/head_template_color_tex.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/test_rigid.ply filter=lfs diff=lfs merge=lfs -text
38
+ example_videos/ex1.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ example_videos/ex2.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ example_videos/ex3.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ example_videos/ex4.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ example_videos/ex5.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ media/banner.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ _pycache__
3
+ *.py[cod]
4
+
5
+ # PyCharm/Jupyter
6
+ .idea
7
+ .ipynb_checkpoints
8
+
9
+ # build artifacts
10
+ *.egg-info
11
+ /docs/build/
12
+ dist
13
+
14
+
15
+ /src/pixel3dmm/preprocessing/facer/
16
+ /src/pixel3dmm/preprocessing/MICA/
17
+ /src/pixel3dmm/preprocessing/PIPNet/
18
+ /pretrained_weights/
19
+ /assets/
20
+
21
+ *.mp4
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the "Licensor." The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
assets/flame_mirror_index.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:374636351ce484f36a3524af05cb00ed5a4f4a38ab759ea9bffce6e3fb761153
3
+ size 40312
assets/flame_uv_coords.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:194adb9a1c1bc76a85c4e8c446d5dfaca93f0f42f157eb82659457c057c71527
3
+ size 80496
assets/flame_vertex_weights.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e860e97be4209f3ee7351a7359cb01b9870d56fc93be3b76cdd7e657eb9346d
3
+ size 60404
assets/head_template.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/head_template_color.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/head_template_color_tex.png ADDED

Git LFS Details

  • SHA256: da8bb9232572dbc286aae59cc56be1a8ac4d5c97b58c0275eda8b7b63e2828bf
  • Pointer size: 131 Bytes
  • Size of remote file: 192 kB
assets/test_rigid.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:427556b014087cf1aea6bbcddec815dbcc5c55a3334de1786a362068e716cfe6
3
+ size 210354
assets/uv_mask_eyes.png ADDED
assets/uv_valid_verty.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f36cdf5ebe1cb4db88f3126ef82c5a59a48ca45ea396ffa23e4b50ac0ce06b
3
+ size 14424
assets/uv_valid_verty_noEyes.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7428896b5b1d4c9db6c0c1e2a4e98c412e95cb62243e5463311cb714c8f3820c
3
+ size 9832
assets/uv_valid_verty_noEyes_debug.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bea26d06186de1f51d0d9687e4eb723a816911fe224c46ff328c0664f7e68bd5
3
+ size 15096
assets/uv_valid_verty_noEyes_noEyeRegion_debug_wEars.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:379a44c27fcc3e05818588f88073d6dd53a48c3ef11bb56afe7d03657f256fb6
3
+ size 19912
bin/release.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ while [[ "$#" -gt 0 ]]; do
4
+ case $1 in
5
+ -t|--test) test=1 ;;
6
+ *) echo "Unknown parameter: $1"; exit 1 ;;
7
+ esac
8
+ shift
9
+ done
10
+
11
+ if [[ $test ]]; then
12
+ twine_params="--repository testpypi"
13
+ else
14
+ twine_params=""
15
+ fi
16
+
17
+ cd "${0%/*}/.."
18
+ rm -r dist/*
19
+ python -m build
20
+ twine upload $twine_params dist/*
21
+ # Username: tobias.kirschstein
configs/base.yaml ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu_id: [0] #[4,5,6,7]
2
+
3
+ exp_name: LaRa/release-test-head-cluster
4
+
5
+ n_views: 2
6
+
7
+ reconstruction_folder: recs
8
+
9
+ flame_folder : /home/giebenhain/face-tracking
10
+ flame_folder_assets : /home/giebenhain/face-tracking/flame/
11
+ flame_base_mesh: /home/giebenhain/PycharmProjects/non-rigid-registration/template/test_rigid.ply
12
+ exp_tag : _
13
+
14
+ viz_uv_mesh : False
15
+
16
+ model:
17
+ attn_drop : 0.2
18
+ model_type: 'flame_params'
19
+ prediction_type: 'normals'
20
+ network_type: 'transformer'
21
+ encoder_backbone: 'vit_small_patch16_224.dino' #'vit_base_patch16_224.dino' #'vit_base_patch16_224.dino' # ['vit_small_patch16_224.dino','vit_base_patch16_224.dino']
22
+
23
+ n_groups: [16] # n_groups for local attention
24
+ n_offset_groups: 32 # offset radius of 1/n_offset_groups of the scene size
25
+
26
+ K: 2 # primitives per-voxel
27
+ sh_degree: 1 # view dependent color
28
+
29
+ num_layers: 6 #6 #12
30
+ num_heads: 8 #16
31
+
32
+ view_embed_dim: 16 #32
33
+ embedding_dim: 256 #128 #256
34
+
35
+ vol_feat_reso: 16
36
+ vol_embedding_reso: 32
37
+
38
+ vol_embedding_out_dim: 40 #80
39
+
40
+ ckpt_path: null # specify a ckpt path if you want to continue training
41
+
42
+ flame_dim: 101
43
+
44
+ finetune_backbone: False
45
+
46
+ feature_map_type: DINO
47
+
48
+ pred_conf: False
49
+
50
+ pred_disentangled : False
51
+
52
+ nocs : True
53
+
54
+
55
+
56
+ use_pos_enc : False
57
+
58
+ conv_dec : True
59
+
60
+ use_plucker : False
61
+ use_uv_enc : True
62
+
63
+ n_facial_components : 0
64
+
65
+ render_super : False
66
+
67
+ flame_shape_dim : 300
68
+ flame_expr_dim : 100
69
+
70
+ prior_input : False
71
+ use_neutral : True
72
+
73
+ reg_inner : True
74
+ n_inner_steps : 0 #20
75
+
76
+ corresp_align : False
77
+
78
+ pred_dim : 4
79
+
80
+ outer_vertex_mask : False
81
+
82
+ downsample_inps : False
83
+
84
+ flame2020 : True
85
+ use_mica : False
86
+
87
+ branched : True
88
+
89
+
90
+
91
+ train_dataset:
92
+ dataset_name: gobjeverse
93
+ data_root: /mnt/rohan/cluster/andram/sgiebenhain/objaverse_imposter8_cropped_prepped_00.hdf5 #cluster/andram/sgiebenhain/objaverse_imposter3_prepped_00.hdf5 #/mnt/hdd/dataset/gobjaverse/gobjaverse_part_01.h5 #/mnt/rohan /home/giebenhain/proj4/objaverse_imposter2_prepped_00.hdf5 #dataset/gobjaverse/gobjaverse.h5
94
+
95
+ split: train
96
+ img_size: [512,512] # image resolution
97
+ n_group: ${n_views} # image resolution
98
+ n_scenes: 3000000
99
+
100
+ itl:
101
+ lr_expr: 0.1
102
+ lr_id: 0.05
103
+ lr_cam_pos: 0.0001 #0.005
104
+ lr_cam_rot: 0.001 #0.01
105
+ lr_fl: 0.01 #0.03
106
+ lr_pp: 0.00001 #0.002
107
+ lr_jaw : 0.0001
108
+
109
+ lr_expr_outer : 0.00001
110
+ lr_shape_outer : 0.00001
111
+ lr_cam_pos_outer : 0.000001
112
+ lr_cam_rot_outer : 0.000001
113
+ lr_fl_outer : 0.000001
114
+ lr_pp_outer : 0.000001
115
+
116
+ noise_strenght : 0.5
117
+
118
+ ffwd_init : True
119
+ ffwd_init_flame : True
120
+ ffwd_flame_weight : 0.01
121
+
122
+ scale_reg_id : 10
123
+ scale_reg_ex : 10
124
+ scale_confidence : 10
125
+
126
+ n_steps_cam : 0
127
+
128
+ use_uv : True
129
+ use_n : True
130
+ use_ncan : False
131
+ use_disp : False
132
+
133
+ reg_conf: 0.01
134
+ totvar_conf : 1.0
135
+
136
+ uv_loss_mult : 3
137
+ n_loss_mult : 0.0
138
+
139
+ const_conf : False
140
+ uv_l2 : False
141
+ n_mask_new : False
142
+
143
+ reg_shape: 0.01
144
+ reg_shape_ffwd: 0.01
145
+ reg_expr: 0.01
146
+ reg_expr_ffwd: 0.01
147
+
148
+ rnd_warmup : False
149
+ use_outer_normals : True
150
+ normal_inp : True
151
+ rnd_n_inner : False
152
+ n_inner_min : 20
153
+ n_inner_max : 100
154
+ fov_mult : 1.0
155
+
156
+ outer_l2 : True
157
+
158
+ pred_face_region : False
159
+ sup_back_more : True
160
+
161
+
162
+ data:
163
+ load_normal: False
164
+ load_flame: False
165
+ load_uv : False
166
+ load_pos_map : False
167
+ load_depth : False
168
+ load_verts : False
169
+ load_arcface : False
170
+ load_albedo : False
171
+ load_nocs : False
172
+ mirror_aug : False
173
+ disable_aug: False
174
+ disable_color_aug: False
175
+ use_nphm : True
176
+ use_ava : True
177
+ use_facescape : True
178
+ use_celeba : False
179
+ use_lyhm : True
180
+ use_stirling : True
181
+ use_video : False
182
+ use_cafca : True
183
+ use_now : False
184
+ use_mimicme : True
185
+
186
+ add_occ : False
187
+
188
+ use_p3dmm : True
189
+
190
+
191
+ load_consist : False
192
+ load_prior : False
193
+
194
+ overfit : False
195
+
196
+ more_verts: False
197
+
198
+ load_facer: False
199
+
200
+
201
+
202
+ test_dataset:
203
+ dataset_name: gobjeverse
204
+ data_root: /mnt/rohan/cluster/andram/sgiebenhain/objaverse_imposter8_cropped_prepped_00.hdf5 #cluster/andram/sgiebenhain/objaverse_imposter3_prepped_00.hdf5 #/mnt/hdd/dataset/gobjaverse/gobjaverse_part_01.h5 #/mnt/rohan /home/giebenhain/proj4/objaverse_imposter2_prepped_00.hdf5 #dataset/gobjaverse/gobjaverse.h5
205
+
206
+ split: test
207
+ img_size: [512,512]
208
+ n_group: ${n_views}
209
+ n_scenes: 3000000
210
+
211
+ train:
212
+ batch_size: 8 #3
213
+ lr: 4e-4 #1e-2 #4e-4
214
+ lr_backbone: 1e-5 #4e-4
215
+ beta1: 0.9
216
+ beta2: 0.95 #0.95
217
+ weight_decay: 0.05
218
+ warmup_iters: 200
219
+ n_epoch: 3000 #3000
220
+ limit_train_batches: 0.05 #0.2 #1.0 #0.1 #1.0 #0.2
221
+ limit_val_batches: 0.02 #0.05 #1 #0.02
222
+ check_val_every_n_epoch: 1
223
+ start_fine: 5000
224
+ use_rand_views: False
225
+ duster_loss: False
226
+ start_2d_vertex_loss : 500 #2500
227
+ start_normal_render_loss : 1000 #5000
228
+
229
+ test:
230
+ batch_size: 8 #3
231
+
232
+ logger:
233
+ name: wandb #tensorboard
234
+ dir: logs/${exp_name}
configs/tracking.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config_name : test
2
+
3
+ batch_size : 16
4
+ num_views : 1
5
+ size : 256
6
+ image_size : [256, 256] # use this instead hardcoding a bunch of 512 and 256
7
+
8
+ data_folder : none
9
+ p3dmm_folder : none
10
+
11
+ extra_cam_steps : False
12
+ big_normal_mask : False
13
+
14
+ start_frame : 0
15
+
16
+ num_shape_params : 300
17
+ num_exp_params : 100
18
+ tex_params : 140
19
+ iters : 200 #800 #400
20
+
21
+ no_lm : False
22
+ use_eyebrows : False
23
+ use_mouth_lmk : True
24
+ no_pho : True
25
+ no_sh : True
26
+ disable_edge : False
27
+
28
+ keyframes : []
29
+
30
+ ignore_mica : False
31
+ flame2023 : False
32
+
33
+ uv_map_super : 2000.0 #500.0 #100 #2001.0 #5000.0 #2000.0
34
+ normal_super : 1000.0 #202.0
35
+ normal_super_can : 0.0
36
+ sil_super : 500
37
+
38
+
39
+ uv_loss:
40
+ stricter_uv_mask : False
41
+ delta_uv : 0.00005 #0.00005 #0.0005 #0.00005
42
+ delta_uv_fine : 0.00005 #0.00005 #0.0005 #0.00005
43
+ dist_uv : 20 #20 #15
44
+ dist_uv_fine : 20 #35 #20 #15
45
+
46
+
47
+ occ_filter : True
48
+
49
+
50
+ lr_id : 0.002 #0.003 #0.006 #0.003
51
+ lr_exp : 0.005 # 0.005 #0.01 #0.01 #0.005
52
+ lr_jaw : 0.005 #0.003
53
+ lr_neck : 0.001 #0.0005
54
+ lr_R : 0.005 #0.005 #0.002 #0.01# 0.0001
55
+ lr_t : 0.001 #0.002 #0.001 #0.0005 #0.0005
56
+ lr_f : 0.1 #0.05 #0.01 #0.001
57
+ lr_pp : 0.00005
58
+
59
+ w_pho : 150
60
+ w_lmks : 3000
61
+ w_lmks_mouth : 1000
62
+ w_lmks_68 : 1000
63
+ w_lmks_lid : 1000
64
+ w_lmks_iris : 1000
65
+ w_lmks_oval : 2000
66
+ w_lmks_star : 0
67
+
68
+ include_neck : True
69
+
70
+ w_shape : 0.2
71
+ w_shape_general : 0.05
72
+ w_exp : 0.05
73
+ w_jaw : 0.01
74
+ w_neck : 0.1
75
+
76
+ n_fine : False
77
+ low_overhead : False
78
+
79
+ delta_n : 0.33
80
+
81
+ global_camera : True
82
+ global_iters : 5000
83
+
84
+ reg_smooth_exp : 50.0
85
+ reg_smooth_eyes : 10.0
86
+ reg_smooth_eyelids : 2.0
87
+ reg_smooth_jaw : 50.0
88
+ reg_smooth_neck : 1000.0
89
+ reg_smooth_R : 2000.0
90
+ reg_smooth_t : 15200.0
91
+ reg_smooth_pp : 420.0
92
+ reg_smooth_fl : 420.0
93
+
94
+ reg_smooth_mult : 1.0
95
+
96
+ uv_l2 : True
97
+ normal_l2 : False
98
+ smooth : True
99
+ normal_mask_ksize : 13
100
+
101
+ early_stopping_delta : 5.0
102
+
103
+ early_exit : False
104
+
105
+ draw_uv_corresp : False
106
+
107
+ save_landmarks : False
108
+
109
+ save_meshes : True
110
+ delete_preprocessing : False
environment.yml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: conda dependencies have only 1 "="
2
+ # pip dependencies have 2 "=="
3
+ # Fuse pip dependencies together under one " - pip" item
4
+ # Otherwise, only some of the are installed, because conda creates a temporary requirements.txt file
5
+ # only the last -pip section
6
+
7
+ name: p3dmm
8
+
9
+ channels:
10
+ - pytorch
11
+ - conda-forge
12
+ - defaults
13
+
14
+ dependencies:
15
+ - python=3.9
16
+ - pip
17
+ - jupyter
18
+
19
+ # CUDA and PyTorch
20
+ - gcc<12 # Needs to be <12 because nvcc does not like gcc>11
21
+ - gxx
22
+ - torchvision
23
+
24
+ -
25
+ - nvidia/label/cuda-11.8.0::cuda-nvcc # for nvcc
26
+ - nvidia/label/cuda-11.8.0::cuda-cccl
27
+ - nvidia/label/cuda-11.8.0::cuda-cudart
28
+ - nvidia/label/cuda-11.8.0::cuda-cudart-dev # for cuda_runtime.h
29
+ - nvidia/label/cuda-11.8.0::libcusparse
30
+ - nvidia/label/cuda-11.8.0::libcusparse-dev
31
+ - nvidia/label/cuda-11.8.0::libcublas
32
+ - nvidia/label/cuda-11.8.0::libcublas-dev
33
+ - nvidia/label/cuda-11.8.0::libcurand
34
+ - nvidia/label/cuda-11.8.0::libcurand-dev
35
+ - nvidia/label/cuda-11.8.0::libcusolver
36
+ - nvidia/label/cuda-11.8.0::libcusolver-dev
37
+ - pip:
38
+
39
+ - pip:
40
+ - --extra-index-url https://download.pytorch.org/whl/cu118
41
+ - torch==2.7+cu118
42
+ - torchvision==0.22+cu118
43
+ - tyro
44
+ - environs
45
+
46
+ - omegaconf
47
+ - dreifus
48
+ - wandb
49
+ - pytorch_lightning
50
+ - opencv-python
51
+ - tensorboard
52
+ - wandb
53
+ - scikit-image
54
+ - pyvista
55
+ - chumpy
56
+ - h5py
57
+ - einops
58
+ - ninja
59
+ - mediapy
60
+ - face-alignment==1.3.3
61
+ - numpy==1.23
62
+
63
+
64
+ - git+https://github.com/facebookresearch/pytorch3d.git@stable
65
+ - git+https://github.com/NVlabs/nvdiffrast.git
66
+
67
+ # for MICA
68
+ - insightface
69
+ - onnxruntime
70
+ - loguru
71
+ - yacs
72
+
73
+ # facer
74
+ - distinctipy
75
+ - validators
76
+ - timm
example_videos/ex1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bcf05f5d3ff2dabaad3ec3562b1ea463bdc2324ffa1cb5875f4468f5341e5f4
3
+ size 662545
example_videos/ex2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71f5bc28eb0bc3fb23dfe4079e303c382e1036b25553c12a8dda208b5ebb9a44
3
+ size 822778
example_videos/ex3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e5127eb860778a01b0b33ff0a5760f604a29232f1cdd695fdc8499300d607a6
3
+ size 326767
example_videos/ex4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8f71ee7d60490725cb463b9da247c2b3d08f9d01a8dbd566726b599cee53199
3
+ size 375763
example_videos/ex5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05a0c8807a31740243d9a1e5ae34f3fd4990701202ffae256e33e70e1f5fa5a9
3
+ size 587737
install_preprocessing_pipeline.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ cd src/pixel3dmm/preprocessing/
4
+
5
+ # facer repository
6
+ git clone [email protected]:FacePerceiver/facer.git
7
+ cd facer
8
+ cp ../replacement_code/farl.py facer/face_parsing/farl.py
9
+ cp ../replacement_code/facer_transform.py facer/transform.py
10
+ pip install -e .
11
+ cd ..
12
+
13
+ # MICA
14
+ git clone [email protected]:Zielon/MICA.git
15
+ cd MICA
16
+ cp ../replacement_code/mica_demo.py demo.py
17
+ cp ../replacement_code/mica.py micalib/models/mica.py
18
+ ./install.sh
19
+ cd ..
20
+
21
+ #TODO: Maybe need to copy these flame weights to trackign/FLAME as well, or ideally adjust some paths instead
22
+
23
+
24
+ # PIPnet
25
+ git clone https://github.com/jhb86253817/PIPNet.git
26
+ cd PIPNet
27
+ cd FaceBoxesV2/utils
28
+ sh make.sh
29
+ cd ../..
30
+ mkdir snapshots
31
+ mkdir snapshots/WFLW/
32
+ mkdir snapshots/WFLW/pip_32_16_60_r18_l2_l1_10_1_nb10/
33
+ gdown --id 1nVkaSbxy3NeqblwMTGvLg4nF49cI_99C -O snapshots/WFLW/pip_32_16_60_r18_l2_l1_10_1_nb10/epoch59.pth
34
+ #mkdir snapshots/WFLW/pip_32_16_60_r101_l2_l1_10_1_nb10/
35
+ #gdown --id 1Jb97z5Z5ca61-6W2RDOK0e2w_RlbeWgS -O snapshots/WFLW/pip_32_16_60_r101_l2_l1_10_1_nb10/epoch59.pth
36
+
37
+
38
+ cd ../../../../
39
+ mkdir pretrained_weights
40
+ cd pretrained_weights
41
+ gdown --id 1SDV_8_qWTe__rX_8e4Fi-BE3aES0YzJY -O ./uv.ckpt
42
+ gdown --id 1KYYlpN-KGrYMVcAOT22NkVQC0UAfycMD -O ./normals.ckpt
media/banner.gif ADDED

Git LFS Details

  • SHA256: 7a8efa82d3b64240743c3b5870f04bce8def66e8ee2021d315dfa649f6837ae2
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "pixel3dmm" # DON'T FORGET TO REMOVE empty FROM git remote!!!
7
+ version = "0.0.1"
8
+ description = "<<<ENTER_DESCRIPTION>>>"
9
+ authors = [
10
+ { name = "Simon Giebenhain", email = "[email protected]" },
11
+ ]
12
+ readme = "README.md"
13
+ license = { text = "CC BY-NC 4.0" }
14
+ requires-python = ">=3.9.0"
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Programming Language :: Python :: 3",
18
+ "Operating System :: OS Independent"
19
+ ]
20
+ # urls = { Documentation = "<<<ENTER_LINK_TO_DOCUMENTATION>>>" }
21
+ # Main dependencies
22
+ dependencies = [
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ # Development packages, install via <<<PROJECT_NAME>>>[dev]
27
+ dev = [
28
+ ]
29
+
30
+ [project.scripts]
31
+ # E.g., ns-download-data = "scripts.downloads.download_data:entrypoint"
32
+
33
+ [tool.setuptools.packages.find]
34
+ where = ["src"]
35
+ include = ["pixel3dmm*"] # Keep the '*', otherwise submodules are not found
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23
2
+ omegaconf
3
+ opencv-python
4
+ tensorboard
5
+ wandb
6
+ scikit-image
7
+ pyvista
8
+ dreifus
9
+ chumpy
10
+ h5py
11
+ pytorch_lightning
12
+ einops
13
+ mediapy
14
+ face-alignment==1.3.3
15
+ ninja
16
+
17
+ insightface
18
+ onnxruntime
19
+ loguru
20
+ yacs
21
+
22
+ distinctipy
23
+ validators
24
+ timm
25
+ tyro
26
+ environs
scripts/.gitkeep ADDED
File without changes
scripts/network_inference.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ from tqdm import tqdm
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+ from time import time
10
+
11
+ from pixel3dmm.utils.uv import uv_pred_to_mesh
12
+ from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
13
+ #from pixel3dmm.lightning.system_flame_params_legacy import system as system_flame_params_legacy
14
+ from pixel3dmm import env_paths
15
+
16
+
17
+
18
+ def pad_to_3_channels(img):
19
+ if img.shape[-1] == 3:
20
+ return img
21
+ elif img.shape[-1] == 1:
22
+ return np.concatenate([img, np.zeros_like(img[..., :1]), np.zeros_like(img[..., :1])], axis=-1)
23
+ elif img.shape[-1] == 2:
24
+ return np.concatenate([img, np.zeros_like(img[..., :1])], axis=-1)
25
+ else:
26
+ raise ValueError('too many dimensions in prediction type!')
27
+
28
+ def gaussian_fn(M, std):
29
+ n = torch.arange(0, M) - (M - 1.0) / 2.0
30
+ sig2 = 2 * std * std
31
+ w = torch.exp(-n ** 2 / sig2)
32
+ return w
33
+
34
+ def gkern(kernlen=256, std=128):
35
+ """Returns a 2D Gaussian kernel array."""
36
+ gkern1d_x = gaussian_fn(kernlen, std=std * 5)
37
+ gkern1d_y = gaussian_fn(kernlen, std=std)
38
+ gkern2d = torch.outer(gkern1d_y, gkern1d_x)
39
+ return gkern2d
40
+
41
+
42
+ valid_verts = np.load(f'{env_paths.VALID_VERTICES_WIDE_REGION}')
43
+
44
+ def main(cfg):
45
+
46
+ if cfg.model.prediction_type == 'flame_params':
47
+ cfg.data.mirror_aug = False
48
+
49
+ # data loader
50
+ if cfg.model.feature_map_type == 'DINO':
51
+ feature_map_size = 32
52
+ elif cfg.model.feature_map_type == 'sapiens':
53
+ feature_map_size = 64
54
+
55
+ batch_size = 1 #cfg.inference_batch_size
56
+
57
+ checkpoints = {
58
+ 'uv_map': f"{env_paths.CKPT_UV_PRED}",
59
+ 'normals': f"{env_paths.CKPT_N_PRED}",
60
+ }
61
+
62
+
63
+ model_checkpoint = checkpoints[cfg.model.prediction_type]
64
+
65
+ model = None
66
+
67
+
68
+ prediction_types = cfg.model.prediction_type.split(',')
69
+
70
+
71
+ conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=11, bias=False, padding='same')
72
+ g_weights = gkern(11, 2)
73
+ g_weights /= torch.sum(g_weights)
74
+ conv.weight = torch.nn.Parameter(g_weights.unsqueeze(0).unsqueeze(0))
75
+
76
+ OUT_NAMES = str(cfg.video_name).split(',')
77
+
78
+ print(f'''
79
+ <<<<<<<< STARTING PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
80
+ ''')
81
+
82
+ for OUT_NAME in OUT_NAMES:
83
+ folder = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/'
84
+ IMAGE_FOLDER = f'{folder}/cropped'
85
+ SEGEMNTATION_FOLDER = f'{folder}/seg_og/'
86
+
87
+ out_folders = {}
88
+ out_folders_wGT = {}
89
+ out_folders_viz = {}
90
+
91
+ for prediction_type in prediction_types:
92
+ out_folders[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm/{prediction_type}/'
93
+ out_folders_wGT[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_wGT/{prediction_type}/'
94
+ os.makedirs(out_folders[prediction_type], exist_ok=True)
95
+ os.makedirs(out_folders_wGT[prediction_type], exist_ok=True)
96
+ out_folders_viz[prediction_type] = f'{env_paths.PREPROCESSED_DATA}/{OUT_NAME}/p3dmm_extraViz/{prediction_type}/'
97
+ os.makedirs(out_folders_viz[prediction_type], exist_ok=True)
98
+
99
+
100
+ image_names = os.listdir(f'{IMAGE_FOLDER}')
101
+ image_names.sort()
102
+
103
+ if os.path.exists(out_folders[prediction_type]):
104
+ if len(os.listdir(out_folders[prediction_type])) == len(image_names):
105
+ return
106
+
107
+ if model is None:
108
+ model = p3dmm_system.load_from_checkpoint(model_checkpoint, strict=False)
109
+ # TODO: disable randomness, dropout, etc...
110
+ # model.eval()
111
+ model = model.cuda()
112
+
113
+
114
+
115
+ for i in tqdm(range(len(image_names))):
116
+ #if not int(image_names[i].split('_')[0]) in [17, 175, 226, 279]:
117
+ # continue
118
+ try:
119
+
120
+ for i_batch in range(batch_size):
121
+ img = np.array(Image.open(f'{IMAGE_FOLDER}/{image_names[i]}').resize((512, 512))) / 255 # need 512,512 images as input; normalize to [0, 1] range
122
+ img = torch.from_numpy(img)[None, None].float().cuda() # 1,1,512,512,3
123
+ img_seg = np.array(Image.open(f'{SEGEMNTATION_FOLDER}/{image_names[i][:-4]}.png').resize((512, 512), Image.NEAREST))
124
+ if len(img_seg.shape) == 3:
125
+ img_seg = img_seg[..., 0]
126
+ #img_seg = np.array(Image.open(f'{SEGEMNTATION_FOLDER}/{int(image_names[i][:-4])*3:05d}.png').resize((512, 512), Image.NEAREST))
127
+ mask = ((img_seg == 2) | ((img_seg > 3) & (img_seg < 14)) ) & ~(img_seg==11)
128
+ mask = torch.from_numpy(mask).long().cuda()[None, None] # 1, 1, 512, 512
129
+ #mask = torch.ones_like(img[..., 0]).cuda().bool()
130
+ batch = {
131
+ 'tar_msk': mask,
132
+ 'tar_rgb': img,
133
+ }
134
+ batch_mirrored = {
135
+ 'tar_rgb': torch.flip(batch['tar_rgb'], dims=[3]).cuda(),
136
+ 'tar_msk': torch.flip(batch['tar_msk'], dims=[3]).cuda(),
137
+ }
138
+
139
+
140
+ # execute model twice, once with original image, once with mirrored original image,
141
+ # and then average results after undoing the mirroring operation on the prediction
142
+ with torch.no_grad():
143
+ output, conf = model.net(batch)
144
+ output_mirrored, conf = model.net(batch_mirrored)
145
+
146
+ if 'uv_map' in output:
147
+ fliped_uv_pred = torch.flip(output_mirrored['uv_map'], dims=[4])
148
+ fliped_uv_pred[:, :, 0, :, :] *= -1
149
+ fliped_uv_pred[:, :, 0, :, :] += 2*0.0075
150
+ output['uv_map'] = (output['uv_map'] + fliped_uv_pred)/2
151
+ if 'normals' in output:
152
+ fliped_uv_pred = torch.flip(output_mirrored['normals'], dims=[4])
153
+ fliped_uv_pred[:, :, 0, :, :] *= -1
154
+ output['normals'] = (output['normals'] + fliped_uv_pred)/2
155
+ if 'disps' in output:
156
+ fliped_uv_pred = torch.flip(output_mirrored['disps'], dims=[4])
157
+ fliped_uv_pred[:, :, 0, :, :] *= -1
158
+ output['disps'] = (output['disps'] + fliped_uv_pred)/2
159
+
160
+
161
+
162
+ for prediction_type in prediction_types:
163
+ for i_batch in range(batch_size):
164
+
165
+ i_view = 0
166
+ gt_rgb = batch['tar_rgb']
167
+
168
+ # normalize to [0,1] range
169
+ if prediction_type == 'uv_map':
170
+ tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 1) / 2, 0, 1)
171
+ elif prediction_type == 'disps':
172
+ tmp_output = torch.clamp((output[prediction_type][i_batch, i_view] + 50) / 100, 0, 1)
173
+ elif prediction_type in ['normals', 'normals_can']:
174
+ tmp_output = output[prediction_type][i_batch, i_view]
175
+ tmp_output = tmp_output / torch.norm(tmp_output, dim=0).unsqueeze(0)
176
+ tmp_output = torch.clamp((tmp_output + 1) / 2, 0, 1)
177
+ # undo "weird" convention of normals that I used for preprocessing
178
+ tmp_output = torch.stack(
179
+ [tmp_output[0, ...], 1 - tmp_output[2, ...], 1 - tmp_output[1, ...]],
180
+ dim=0)
181
+
182
+
183
+ content = [
184
+ gt_rgb[i_batch, i_view].detach().cpu().numpy(),
185
+ pad_to_3_channels(tmp_output.permute(1, 2, 0).detach().cpu().float().numpy()),
186
+ ]
187
+
188
+ catted = (np.concatenate(content, axis=1) * 255).astype(np.uint8)
189
+ Image.fromarray(catted).save(f'{out_folders_wGT[prediction_type]}/{image_names[i]}')
190
+
191
+
192
+ Image.fromarray(
193
+ pad_to_3_channels(
194
+ tmp_output.permute(1, 2, 0).detach().cpu().float().numpy() * 255).astype(
195
+ np.uint8)).save(
196
+ f'{out_folders[prediction_type]}/{image_names[i][:-4]}.png')
197
+
198
+
199
+ # this visulization is quite slow, therefore disable it per default
200
+ if prediction_type == 'uv_map' and cfg.viz_uv_mesh:
201
+ to_show_non_mirr = uv_pred_to_mesh(
202
+ output[prediction_type][i_batch:i_batch + 1, ...],
203
+ batch['tar_msk'][i_batch:i_batch + 1, ...],
204
+ batch['tar_rgb'][i_batch:i_batch + 1, ...],
205
+ right_ear = [537, 1334, 857, 554, 941],
206
+ left_ear = [541, 476, 237, 502, 286],
207
+ )
208
+
209
+ Image.fromarray(to_show_non_mirr).save(f'{out_folders_viz[prediction_type]}/{image_names[i]}')
210
+
211
+ except Exception as exx:
212
+ traceback.print_exc()
213
+ pass
214
+
215
+ print(f'''
216
+ <<<<<<<< FINISHED PIXEL3DMM INFERENCE for {cfg.video_name} in {prediction_types} MODE >>>>>>>>
217
+ ''')
218
+
219
+
220
+
221
+
222
+
223
+ if __name__ == '__main__':
224
+ base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/base.yaml')
225
+
226
+ cli_conf = OmegaConf.from_cli()
227
+ cfg = OmegaConf.merge(base_conf, cli_conf)
228
+
229
+ main(cfg)
scripts/run_cropping.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import os
3
+ import sys
4
+ import importlib
5
+
6
+ import mediapy
7
+ from PIL import Image
8
+ import tyro
9
+
10
+ import torchvision.transforms as transforms
11
+
12
+
13
+ from pixel3dmm import env_paths
14
+ sys.path.append(f'{env_paths.CODE_BASE}/src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/')
15
+ from pixel3dmm.preprocessing.pipnet_utils import demo_image
16
+ from pixel3dmm import env_paths
17
+
18
+
19
+
20
+
21
+ def run(exp_path, image_dir, start_frame = 0,
22
+ vertical_crop : bool = False,
23
+ static_crop : bool = False,
24
+ max_bbox : bool = False,
25
+ disable_cropping : bool = False,
26
+ ):
27
+ experiment_name = exp_path.split('/')[-1][:-3]
28
+ data_name = exp_path.split('/')[-2]
29
+ config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
30
+
31
+ my_config = importlib.import_module(config_path, package='pixel3dmm.preprocessing.PIPNet')
32
+ Config = getattr(my_config, 'Config')
33
+ cfg = Config()
34
+ cfg.experiment_name = experiment_name
35
+ cfg.data_name = data_name
36
+
37
+ save_dir = os.path.join(f'{env_paths.CODE_BASE}/src/pixel3dmm/preprocessing/PIPNet/snapshots', cfg.data_name, cfg.experiment_name)
38
+
39
+
40
+
41
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
42
+ std=[0.229, 0.224, 0.225])
43
+ preprocess = transforms.Compose(
44
+ [transforms.Resize((cfg.input_size, cfg.input_size)), transforms.ToTensor(), normalize])
45
+
46
+
47
+ #for pid in pids:
48
+ pid = "FaMoS_180424_03335_TA_selfie_IMG_0092.jpg"
49
+ pid = "FaMoS_180426_03336_TA_selfie_IMG_0152.jpg"
50
+
51
+
52
+
53
+ demo_image(image_dir, pid, save_dir, preprocess, cfg, cfg.input_size, cfg.net_stride, cfg.num_nb,
54
+ cfg.use_gpu,
55
+ start_frame=start_frame, vertical_crop=vertical_crop, static_crop=static_crop, max_bbox=max_bbox,
56
+ disable_cropping=disable_cropping)
57
+
58
+
59
+ def unpack_images(base_path, video_or_images_path):
60
+ if not os.path.exists(base_path):
61
+ os.makedirs(base_path, exist_ok=True)
62
+ if os.path.isdir(video_or_images_path):
63
+ files = os.listdir(f'{video_or_images_path}')
64
+ files.sort()
65
+ if len(os.listdir(base_path)) == len(files):
66
+ print(f'''
67
+ <<<<<<<< ALREADY COMPLETED IMAGE CROPPING for {video_or_images_path}, SKIPPING! >>>>>>>>
68
+ ''')
69
+ return
70
+ for i, file in enumerate(files):
71
+ I = Image.open(f'{video_or_images_path}/{file}')
72
+ I.save(f'{base_path}/{i:05d}.jpg', quality=95)
73
+ elif video_or_images_path.endswith('.jpg') or video_or_images_path.endswith('.jpeg') or video_or_images_path.endswith('.png'):
74
+ Image.open(video_or_images_path).save(f'{base_path}/{0:05d}.jpg', quality=95)
75
+ else:
76
+ frames = mediapy.read_video(f'{video_or_images_path}')
77
+ if len(frames) == len(os.listdir(base_path)):
78
+ return
79
+ for i, frame in enumerate(frames):
80
+ Image.fromarray(frame).save(f'{base_path}/{i:05d}.jpg', quality=95)
81
+
82
+ def main(video_or_images_path : str,
83
+ max_bbox : bool = True, # not used
84
+ disable_cropping : bool = False):
85
+ if os.path.isdir(video_or_images_path):
86
+ video_name = video_or_images_path.split('/')[-1]
87
+ else:
88
+ video_name = video_or_images_path.split('/')[-1][:-4]
89
+
90
+ base_path = f'{env_paths.PREPROCESSED_DATA}/{video_name}/rgb/'
91
+
92
+ unpack_images(base_path, video_or_images_path)
93
+
94
+ if os.path.exists(f'{env_paths.PREPROCESSED_DATA}/{video_name}/cropped/'):
95
+ if len(os.listdir(base_path)) == len(os.listdir(f'{env_paths.PREPROCESSED_DATA}/{video_name}/cropped/')):
96
+ return
97
+
98
+
99
+ start_frame = -1
100
+ run('experiments/WFLW/pip_32_16_60_r18_l2_l1_10_1_nb10.py', base_path, start_frame=start_frame, vertical_crop=False,
101
+ static_crop=True, max_bbox=max_bbox, disable_cropping=disable_cropping)
102
+ # run('experiments/WFLW/pip_32_16_60_r101_l2_l1_10_1_nb10.py', base_path, start_frame=start_frame, vertical_crop=False, static_crop=True)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ tyro.cli(main)
107
+
scripts/run_facer_segmentation.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ from math import ceil
6
+
7
+ import PIL.Image
8
+ import torch
9
+ import distinctipy
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+ import numpy as np
13
+ import facer
14
+ import tyro
15
+
16
+ from pixel3dmm import env_paths
17
+
18
+ colors = distinctipy.get_colors(22, rng=0)
19
+
20
+
21
+ def viz_results(img, seq_classes, n_classes, suppress_plot = False):
22
+
23
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
24
+ #distinctipy.color_swatch(colors)
25
+ bad_indices = [
26
+ 0, # background,
27
+ 1, # neck
28
+ # 2, skin
29
+ 3, # cloth
30
+ 4, # ear_r (images-space r)
31
+ 5, # ear_l
32
+ # 6 brow_r
33
+ # 7 brow_l
34
+ # 8, # eye_r
35
+ # 9, # eye_l
36
+ # 10 noise
37
+ # 11 mouth
38
+ # 12 lower_lip
39
+ # 13 upper_lip
40
+ 14, # hair,
41
+ # 15, glasses
42
+ 16, # ??
43
+ 17, # earring_r
44
+ 18, # ?
45
+ ]
46
+ bad_indices = []
47
+
48
+ for i in range(n_classes):
49
+ if i not in bad_indices:
50
+ seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
51
+
52
+ if not suppress_plot:
53
+ plt.imshow(seg_img.astype(np.uint(8)))
54
+ plt.show()
55
+ return Image.fromarray(seg_img.astype(np.uint8))
56
+
57
+ def get_color_seg(img, seq_classes, n_classes):
58
+
59
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
60
+ colors = distinctipy.get_colors(n_classes+1, rng=0)
61
+ #distinctipy.color_swatch(colors)
62
+ bad_indices = [
63
+ 0, # background,
64
+ 1, # neck
65
+ # 2, skin
66
+ 3, # cloth
67
+ 4, # ear_r (images-space r)
68
+ 5, # ear_l
69
+ # 6 brow_r
70
+ # 7 brow_l
71
+ # 8, # eye_r
72
+ # 9, # eye_l
73
+ # 10 noise
74
+ # 11 mouth
75
+ # 12 lower_lip
76
+ # 13 upper_lip
77
+ 14, # hair,
78
+ # 15, glasses
79
+ 16, # ??
80
+ 17, # earring_r
81
+ 18, # ?
82
+ ]
83
+
84
+ for i in range(n_classes):
85
+ if i not in bad_indices:
86
+ seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
87
+
88
+
89
+ return Image.fromarray(seg_img.astype(np.uint8))
90
+
91
+
92
+ def crop_gt_img(img, seq_classes, n_classes):
93
+
94
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
95
+ colors = distinctipy.get_colors(n_classes+1, rng=0)
96
+ #distinctipy.color_swatch(colors)
97
+ bad_indices = [
98
+ 0, # background,
99
+ 1, # neck
100
+ # 2, skin
101
+ 3, # cloth
102
+ 4, #ear_r (images-space r)
103
+ 5, #ear_l
104
+ # 6 brow_r
105
+ # 7 brow_l
106
+ #8, # eye_r
107
+ #9, # eye_l
108
+ # 10 noise
109
+ # 11 mouth
110
+ # 12 lower_lip
111
+ # 13 upper_lip
112
+ 14, # hair,
113
+ # 15, glasses
114
+ 16, # ??
115
+ 17, # earring_r
116
+ 18, # ?
117
+ ]
118
+
119
+ for i in range(n_classes):
120
+ if i in bad_indices:
121
+ img[seq_classes[0, :, :] == i] = 0
122
+
123
+
124
+ #plt.imshow(img.astype(np.uint(8)))
125
+ #plt.show()
126
+ return img.astype(np.uint8)
127
+
128
+
129
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
130
+
131
+
132
+
133
+ face_detector = facer.face_detector('retinaface/mobilenet', device=device)
134
+ face_parser = facer.face_parser('farl/celebm/448', device=device) # optional "farl/lapa/448"
135
+
136
+
137
+ def main(video_name : str):
138
+
139
+
140
+ out = f'{env_paths.PREPROCESSED_DATA}/{video_name}'
141
+ out_seg = f'{out}/seg_og/'
142
+ out_seg_annot = f'{out}/seg_non_crop_annotations/'
143
+ os.makedirs(out_seg, exist_ok=True)
144
+ os.makedirs(out_seg_annot, exist_ok=True)
145
+ folder = f'{out}/cropped/' # '/home/giebenhain/GTA/data_kinect/color/'
146
+
147
+
148
+
149
+
150
+
151
+ frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')]
152
+
153
+ frames.sort()
154
+
155
+ if len(os.listdir(out_seg)) == len(frames):
156
+ print(f'''
157
+ <<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>>
158
+ ''')
159
+ return
160
+
161
+ #for file in frames:
162
+ batch_size = 1
163
+
164
+ for i in range(len(frames)//batch_size):
165
+ image_stack = []
166
+ frame_stack = []
167
+ original_shapes = []
168
+ for j in range(batch_size):
169
+ file = frames[i * batch_size + j]
170
+
171
+ if os.path.exists(f'{out_seg_annot}/color_{file}.png'):
172
+ print('DONE')
173
+ continue
174
+ img = Image.open(f'{folder}/{file}')#.resize((512, 512))
175
+
176
+ og_size = img.size
177
+
178
+ image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device=device) # image: 1 x 3 x h x w
179
+ image_stack.append(image)
180
+ frame_stack.append(file[:-4])
181
+
182
+ for batch_idx in range(ceil(len(image_stack)/batch_size)):
183
+ image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0)
184
+ frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size]
185
+ og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size]
186
+
187
+ #if True:
188
+ try:
189
+ with torch.inference_mode():
190
+ faces = face_detector(image_batch)
191
+ torch.cuda.empty_cache()
192
+ faces = face_parser(image_batch, faces, bbox_scale_factor=1.25)
193
+ torch.cuda.empty_cache()
194
+
195
+ seg_logits = faces['seg']['logits']
196
+ back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy()
197
+ seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
198
+ seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8)
199
+ seg_classes[back_ground] = seg_probs.shape[1] + 1
200
+
201
+
202
+ for _iidx in range(seg_probs.shape[0]):
203
+ frame = frame_idx_batch[_iidx]
204
+ iidx = faces['image_ids'][_iidx].item()
205
+ try:
206
+ I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True)
207
+ I_color.save(f'{out_seg_annot}/color_{frame}.png')
208
+ except Exception as ex:
209
+ pass
210
+ I = Image.fromarray(seg_classes[_iidx])
211
+ I.save(f'{out_seg}/{frame}.png')
212
+ torch.cuda.empty_cache()
213
+ except Exception as exx:
214
+ traceback.print_exc()
215
+ continue
216
+
217
+
218
+ if __name__ == '__main__':
219
+
220
+ tyro.cli(main)
221
+
scripts/run_preprocessing.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tyro
3
+
4
+ from pixel3dmm import env_paths
5
+
6
+
7
+ def main(video_or_images_path : str):
8
+
9
+ if os.path.isdir(video_or_images_path):
10
+ vid_name = video_or_images_path.split('/')[-1]
11
+ else:
12
+ vid_name = video_or_images_path.split('/')[-1][:-4]
13
+
14
+ os.system(f'cd {env_paths.CODE_BASE}/scripts/ ; python run_cropping.py --video_or_images_path {video_or_images_path}')
15
+
16
+ os.system(f'cd {env_paths.CODE_BASE}/src/pixel3dmm/preprocessing/MICA ; python demo.py -video_name {vid_name}')
17
+
18
+ os.system(f'cd {env_paths.CODE_BASE}/scripts/ ; python run_facer_segmentation.py --video_name {vid_name}')
19
+
20
+
21
+
22
+ if __name__ == '__main__':
23
+ tyro.cli(main)
scripts/track.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+
4
+ from omegaconf import OmegaConf
5
+ from pixel3dmm.tracking.tracker import Tracker
6
+ from pixel3dmm import env_paths
7
+
8
+ def main(cfg):
9
+ tracker = Tracker(cfg)
10
+ tracker.run()
11
+
12
+ if __name__ == '__main__':
13
+ base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
14
+
15
+ cli_conf = OmegaConf.from_cli()
16
+ cfg = OmegaConf.merge(base_conf, cli_conf)
17
+
18
+ #os.makedirs('/home/giebenhain/debug_wandb_p3dmm/', exist_ok=True)
19
+ #wandb.init(
20
+ # dir='/home/giebenhain/debug_wandb_p3dmm/',
21
+ # #config=config,
22
+ # project='face-tracking-p3dmm',
23
+ # #tags=wandb_tags,
24
+ # #name=cfg.config_name,
25
+ #
26
+ #)
27
+ main(cfg)
scripts/viz_head_centric_cameras.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tyro
3
+ import mediapy
4
+ import torch
5
+ import numpy as np
6
+ import pyvista as pv
7
+ import trimesh
8
+ from PIL import Image
9
+
10
+ from dreifus.matrix import Intrinsics, Pose, CameraCoordinateConvention, PoseType
11
+ from dreifus.pyvista import add_camera_frustum, render_from_camera
12
+
13
+ from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix
14
+ from pixel3dmm.env_paths import PREPROCESSED_DATA, TRACKING_OUTPUT
15
+
16
+
17
+ def main(vid_name : str,
18
+ HEAD_CENTRIC : bool = True,
19
+ DO_PROJECTION_TEST : bool = False,
20
+ ):
21
+ tracking_dir = f'{TRACKING_OUTPUT}/{vid_name}_nV1_noPho_uv2000.0_n1000.0'
22
+
23
+ meshes = [f for f in os.listdir(f'{tracking_dir}/mesh/') if f.endswith('.ply') and not 'canonical' in f]
24
+ meshes.sort()
25
+
26
+ ckpts = [f for f in os.listdir(f'{tracking_dir}/checkpoint/') if f.endswith('.frame')]
27
+ ckpts.sort()
28
+
29
+ N_STEPS = len(meshes)
30
+
31
+ pl = pv.Plotter()
32
+ vid_frames = []
33
+ for i in range(N_STEPS):
34
+ ckpt = torch.load(f'{tracking_dir}/checkpoint/{ckpts[i]}', weights_only=False)
35
+
36
+ mesh = trimesh.load(f'{tracking_dir}/mesh/{meshes[i]}', process=False)
37
+
38
+ head_rot = rotation_6d_to_matrix(torch.from_numpy(ckpt['flame']['R'])).numpy()[0]
39
+
40
+ if not HEAD_CENTRIC:
41
+ # move mesh from FLAME Space into World Space
42
+ mesh.vertices = mesh.vertices @ head_rot.T + (ckpt['flame']['t'])
43
+ else:
44
+ # undo neck rotation
45
+ verts_hom = np.concatenate([mesh.vertices, np.ones_like(mesh.vertices[..., :1])], axis=-1)
46
+ verts_hom = verts_hom @ np.linalg.inv(ckpt['joint_transforms'][0, 1, :, :]).T
47
+ mesh.vertices = verts_hom[..., :3]
48
+
49
+
50
+
51
+ extr_open_gl_world_to_cam = np.eye(4)
52
+ extr_open_gl_world_to_cam[:3, :3] = ckpt['camera']['R_base_0'][0]
53
+ extr_open_gl_world_to_cam[:3, 3] = ckpt['camera']['t_base_0'][0]
54
+ if HEAD_CENTRIC:
55
+ flame2world = np.eye(4)
56
+ flame2world[:3, :3] = head_rot
57
+ flame2world[:3, 3] = np.squeeze(ckpt['flame']['t'])
58
+ #TODO include neck transform as well
59
+ extr_open_gl_world_to_cam = extr_open_gl_world_to_cam @ flame2world @ ckpt['joint_transforms'][0, 1, :, :]
60
+
61
+
62
+
63
+
64
+ extr_open_gl_world_to_cam = Pose(extr_open_gl_world_to_cam,
65
+ camera_coordinate_convention=CameraCoordinateConvention.OPEN_GL,
66
+ pose_type=PoseType.WORLD_2_CAM)
67
+
68
+ intr = np.eye(3)
69
+ intr[0, 0] = ckpt['camera']['fl'][0, 0] * 256
70
+ intr[1, 1] = ckpt['camera']['fl'][0, 0] * 256
71
+ intr[:2, 2] = ckpt['camera']['pp'][0] * (256/2+0.5) + 256/2 + 0.5
72
+
73
+ intr = Intrinsics(intr)
74
+
75
+
76
+
77
+ pl.add_mesh(mesh, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)])
78
+ add_camera_frustum(pl, extr_open_gl_world_to_cam, intr, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)])
79
+
80
+ if DO_PROJECTION_TEST:
81
+ pll = pv.Plotter(off_screen=True, window_size=(256, 256))
82
+ pll.add_mesh(mesh)
83
+ img = render_from_camera(pll, extr_open_gl_world_to_cam, intr)
84
+
85
+ gt_img = np.array(Image.open(f'{PREPROCESSED_DATA}/{vid_name}/cropped/{i:05d}.jpg').resize((256, 256)))
86
+
87
+ alpha = img[..., 3]
88
+
89
+ overlay = (gt_img *0.5 + img[..., :3]*0.5).astype(np.uint8)
90
+ vid_frames.append(overlay)
91
+
92
+
93
+
94
+
95
+ pl.show()
96
+
97
+ if DO_PROJECTION_TEST:
98
+ mediapy.write_video(f'{tracking_dir}/projection_test.mp4', images=vid_frames)
99
+
100
+
101
+
102
+ if __name__ == '__main__':
103
+ tyro.cli(main)
setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import setuptools
4
+
5
+ if __name__ == "__main__":
6
+ # Still necessary, otherwise we get a pip error
7
+ setuptools.setup()
src/__init__.py ADDED
File without changes
src/pixel3dmm/__init__.py ADDED
File without changes
src/pixel3dmm/env_paths.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from environs import Env
4
+
5
+
6
+ env = Env(expand_vars=True)
7
+ env_file_path = Path(f"{Path.home()}/.config/pixel3dmm/.env")
8
+ if env_file_path.exists():
9
+ env.read_env(str(env_file_path), recurse=False)
10
+
11
+
12
+ with env.prefixed("PIXEL3DMM_"):
13
+ CODE_BASE = env("CODE_BASE")
14
+ PREPROCESSED_DATA = env("PREPROCESSED_DATA")
15
+ TRACKING_OUTPUT = env("TRACKING_OUTPUT")
16
+
17
+
18
+
19
+ head_template = f'{CODE_BASE}/assets/head_template.obj'
20
+ head_template_color = f'{CODE_BASE}/assets/head_template_color.obj'
21
+ head_template_ply = f'{CODE_BASE}/assets/test_rigid.ply'
22
+ VALID_VERTICES_WIDE_REGION = f'{CODE_BASE}/assets/uv_valid_verty_noEyes_debug.npy'
23
+ VALID_VERTS_UV_MESH = f'{CODE_BASE}/assets/uv_valid_verty.npy'
24
+ VERTEX_WEIGHT_MASK = f'{CODE_BASE}/assets/flame_vertex_weights.npy'
25
+ MIRROR_INDEX = f'{CODE_BASE}/assets/flame_mirror_index.npy'
26
+ EYE_MASK = f'{CODE_BASE}/assets/uv_mask_eyes.png'
27
+ FLAME_UV_COORDS = f'{CODE_BASE}/assets/flame_uv_coords.npy'
28
+ VALID_VERTS_NARROW = f'{CODE_BASE}/assets/uv_valid_verty_noEyes.npy'
29
+ VALID_VERTS = f'{CODE_BASE}/assets/uv_valid_verty_noEyes_noEyeRegion_debug_wEars.npy'
30
+ FLAME_ASSETS = f'{CODE_BASE}/src/pixel3dmm/preprocessing/MICA/data/'
31
+
32
+ # paths to pretrained pixel3dmm checkpoints
33
+ CKPT_UV_PRED = f'{CODE_BASE}/pretrained_weights/uv.ckpt'
34
+ CKPT_N_PRED = f'{CODE_BASE}/pretrained_weights/normals.ckpt'
src/pixel3dmm/lightning/p3dmm_network.py ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/lightning/p3dmm_system.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import pytorch_lightning as L
6
+ import torch.nn as nn
7
+
8
+ from pixel3dmm.lightning.utils import CosineWarmupScheduler, WarmupScheduler
9
+ from pixel3dmm.lightning.p3dmm_network import Network
10
+ from pixel3dmm import env_paths
11
+
12
+
13
+ def fov_to_ixt(fov, reso=512):
14
+ ixt = torch.eye(3).float().unsqueeze(0).repeat(fov.shape[0], 1, 1).to(fov.device)
15
+ ixt[:, 0, 2] = reso / 2
16
+ ixt[:, 1, 2] = reso / 2
17
+ focal = .5 * reso / torch.tan(.5 * fov)
18
+ ixt[:, 0, 0] = focal
19
+ ixt[:, 1, 1] = focal
20
+ return ixt
21
+
22
+
23
+ def batch_rodrigues(
24
+ rot_vecs: torch.Tensor,
25
+ epsilon: float = 1e-8,
26
+ ) -> torch.Tensor:
27
+ ''' Calculates the rotation matrices for a batch of rotation vectors
28
+ Parameters
29
+ ----------
30
+ rot_vecs: torch.tensor Nx3
31
+ array of N axis-angle vectors
32
+ Returns
33
+ -------
34
+ R: torch.tensor Nx3x3
35
+ The rotation matrices for the given axis-angle parameters
36
+ '''
37
+
38
+ batch_size = rot_vecs.shape[0]
39
+ device, dtype = rot_vecs.device, rot_vecs.dtype
40
+
41
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
42
+ rot_dir = rot_vecs / angle
43
+
44
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
45
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
46
+
47
+ # Bx1 arrays
48
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
49
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
50
+
51
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
52
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
53
+ .view((batch_size, 3, 3))
54
+
55
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
56
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
57
+ return rot_mat
58
+
59
+
60
+ def pad_to_3_channels(img):
61
+ if img.shape[-1] == 3:
62
+ return img
63
+ elif img.shape[-1] == 1:
64
+ return np.concatenate([img, np.zeros_like(img[..., :1]), np.zeros_like(img[..., :1])], axis=-1)
65
+ elif img.shape[-1] == 2:
66
+ return np.concatenate([img, np.zeros_like(img[..., :1])], axis=-1)
67
+ else:
68
+ raise ValueError('too many dimensions in prediction type!')
69
+
70
+
71
+ class system(L.LightningModule):
72
+ def __init__(self, cfg):
73
+ super().__init__()
74
+
75
+ self.glctx = None
76
+ self.cfg = cfg
77
+ self.net = Network(cfg)
78
+
79
+ vertex_weight_mask = np.load(f'{env_paths.VERTEX_WEIGHT_MASK}')
80
+
81
+ self.register_buffer('vertex_weight_mask', torch.from_numpy(vertex_weight_mask).float())
82
+
83
+
84
+
85
+ self.validation_step_outputs = []
86
+ self.validation_step_outputs_per_dataset = []
87
+
88
+ self.dataset_types = [
89
+ 'facescape',
90
+ 'nphm',
91
+ 'ava',
92
+ ]
93
+
94
+
95
+ self.do_eval = True
96
+
97
+ self.alpha = 1.0
98
+
99
+ self.save_hyperparameters()
100
+
101
+ self.loss_weights = {
102
+ 'albedo': 1.0, # 1.0/0.13,
103
+ 'depth': 1.0,
104
+ 'pos_map': 1.0, # 1.0/0.0006,
105
+ 'pos_map_can': 1.0, # 1.0/0.0006,
106
+ 'normals': 0.1, # TODO achtung #1.0/0.03,
107
+ 'normals_can': 1.0, # 1.0/0.03,
108
+ 'uv_map': 10.0, # 1.0/0.001,
109
+ 'nocs': 1.0, # 1.0/0.0006,
110
+ }
111
+
112
+
113
+ def training_step(self, batch, batch_idx):
114
+
115
+
116
+ output, conf = self.net(batch)
117
+
118
+ B = output[list(output.keys())[0]].shape[0]
119
+ V = output[list(output.keys())[0]].shape[1]
120
+
121
+ c_map = None
122
+
123
+
124
+
125
+ losses = {}
126
+
127
+
128
+ if 'normals' in self.cfg.model.prediction_type:
129
+
130
+ gt_normals = batch['normals'].permute(0, 1, 4, 2, 3)
131
+ if conf is None:
132
+ losses['normals'] = (batch['tar_msk'].unsqueeze(2) * (gt_normals - output['normals'])).abs().mean()
133
+ else:
134
+ losses['normals'] = (batch['tar_msk'].unsqueeze(2) * (
135
+ c_map * (gt_normals - output['normals']) - self.alpha * torch.log(c_map))).abs().mean()
136
+
137
+ if self.cfg.model.pred_disentangled:
138
+ gt_normals_can = batch['normals_can'].permute(0, 1, 4, 2, 3)
139
+ if conf is None:
140
+ losses['normals_can'] = (
141
+ batch['tar_msk'].unsqueeze(2) * (gt_normals_can - output['normals_can'])).abs().mean()
142
+ else:
143
+ losses['normals_can'] = (batch['tar_msk'].unsqueeze(2) * (
144
+ c_map * (gt_normals_can - output['normals_can']) - self.alpha * torch.log(
145
+ c_map))).abs().mean()
146
+
147
+
148
+ for prediction_type in ['uv_map', 'depth', 'nocs']:
149
+ if prediction_type in self.cfg.model.prediction_type:
150
+ weight_mask = torch.ones_like(output[prediction_type])
151
+ if prediction_type == 'uv_map' or (prediction_type == 'nocs'): # ATTENTION: only for nocs?
152
+ weight_mask = batch['uv_masks'].unsqueeze(2).float() + 0.2
153
+ gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
154
+ if conf is None:
155
+ losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
156
+ gt_pos_map - output[prediction_type])).abs().mean()
157
+ else:
158
+ losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
159
+ c_map * (gt_pos_map - output[prediction_type]) - self.alpha * torch.log(
160
+ c_map))).abs().mean()
161
+
162
+ total_loss = 0
163
+
164
+ loss = 0
165
+ for k in losses.keys():
166
+ if k in self.loss_weights:
167
+ loss += self.loss_weights[k] * losses[k]
168
+ else:
169
+ loss += losses[k]
170
+
171
+
172
+
173
+ self.log(f'train/loss', loss.item(), prog_bar=False)
174
+ # for prediction_type in self.cfg.model.prediction_type:
175
+ for k in losses.keys():
176
+ if k in self.cfg.model.prediction_type:
177
+ self.log(f'train/loss_{k}', losses[k])
178
+ if self.cfg.model.pred_disentangled:
179
+ for k in losses.keys():
180
+ if k[:-4] in self.cfg.model.prediction_type:
181
+ self.log(f'train/loss_{k}', losses[k])
182
+
183
+
184
+ self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
185
+
186
+ do_vis = (0 == self.trainer.global_step % 300) if os.path.exists('/mnt/rohan') else (
187
+ 0 == self.trainer.global_step % 3000)
188
+ if do_vis and (self.trainer.local_rank == 0):
189
+ output, conf = self.net(batch)
190
+
191
+
192
+ self.vis_results({k: v.detach() for (k, v) in output.items()}, conf, batch, prex='train')
193
+ self.do_eval = True
194
+ torch.cuda.empty_cache()
195
+
196
+
197
+ return loss
198
+
199
+
200
+
201
+ def optimizer_step(
202
+ self,
203
+ *args, **kwargs
204
+ ):
205
+ """
206
+ Skipping updates in case of unstable gradients
207
+ https://github.com/Lightning-AI/lightning/issues/4956
208
+ """
209
+ valid_gradients = True
210
+ grads = [
211
+ param.grad.detach().flatten()
212
+ for param in self.parameters()
213
+ if param.grad is not None
214
+ ]
215
+ if len(grads) > 0:
216
+ norm = torch.cat(grads).norm()
217
+ self.log(f'grad/norm', norm.item(), prog_bar=False) # , sync_dist=True)
218
+
219
+ if (norm > 10000 and self.global_step > 20 or torch.isnan(norm)):
220
+ valid_gradients = False
221
+
222
+ if not valid_gradients:
223
+ print(
224
+ f'detected inf or nan values in gradients. not updating model parameters, OTHER FUNCTION threshold: {10000}, value: {norm.item()}')
225
+ self.zero_grad()
226
+ for param in self.parameters():
227
+ param.grad = None
228
+
229
+ L.LightningModule.optimizer_step(self, *args, **kwargs)
230
+
231
+
232
+ def validation_step(self, batch, batch_idx):
233
+
234
+
235
+ self.net.eval()
236
+ output, conf = self.net(batch)
237
+
238
+ B = output[list(output.keys())[0]].shape[0]
239
+ V = output[list(output.keys())[0]].shape[1]
240
+
241
+
242
+
243
+ loss_dict = {}
244
+
245
+ dataset_indices = {}
246
+
247
+
248
+
249
+ val_losses = {}
250
+ for prediction_type in ['uv_map', 'depth', 'nocs']:
251
+ if prediction_type in self.cfg.model.prediction_type:
252
+ gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
253
+ weight_mask = torch.ones_like(output[prediction_type])
254
+ if prediction_type == 'uv_map' or (prediction_type == 'nocs'): # ATTENTION: only for nocs?
255
+ weight_mask = batch['uv_masks'].unsqueeze(2).float() + 0.2
256
+
257
+ val_losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
258
+ gt_pos_map - output[prediction_type])).abs().mean()
259
+ loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
260
+
261
+ if 'normals' in self.cfg.model.prediction_type:
262
+ prediction_type = 'normals'
263
+ gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
264
+
265
+ val_losses[prediction_type] = (
266
+ batch['tar_msk'].unsqueeze(2) * (gt_pos_map - output[prediction_type])).abs().mean()
267
+
268
+ loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
269
+
270
+ if self.cfg.model.pred_disentangled:
271
+ prediction_type = 'normals_can'
272
+ gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
273
+
274
+ val_losses[prediction_type] = (
275
+ batch['tar_msk'].unsqueeze(2) * (gt_pos_map - output[prediction_type])).abs().mean()
276
+
277
+ loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
278
+
279
+ # if self.cfg.model.prediction_type == 'depth_si':
280
+ # loss, pred_scale, target_scale = simae2_loss(output, batch['depth'].permute(0, 1, 4, 2, 3), batch['tar_msk'].unsqueeze(2), c_map=c_map, alpha=self.alpha)
281
+ # self.validation_step_outputs.append({'loss': loss.item()})
282
+
283
+ val_loss = 0
284
+
285
+ for prediction_type in self.cfg.model.prediction_type:
286
+ val_loss += self.loss_weights[prediction_type] * val_losses[prediction_type]
287
+
288
+
289
+ loss_dict['loss/total'] = val_loss.item()
290
+ self.validation_step_outputs.append(loss_dict)
291
+
292
+ #print('GLOBAL_STEP:', self.trainer.global_step)
293
+ if self.do_eval and self.trainer.local_rank == 0:
294
+ output, conf = self.net(batch)
295
+ if conf is not None:
296
+ conf = conf.detach()
297
+ tmp_dict = {k: v.detach() for (k, v) in output.items()}
298
+ self.vis_results(tmp_dict, conf, batch, prex='val')
299
+ self.do_eval = False
300
+ torch.cuda.empty_cache()
301
+
302
+ return val_loss
303
+
304
+ def on_validation_epoch_end(self):
305
+ # for key in keys:
306
+ # prog_bar = True if key in ['psnr','mask','depth'] else False
307
+ metric_mean = np.stack([np.array(x['loss/total']) for x in self.validation_step_outputs]).mean()
308
+ self.log(f'val/loss', metric_mean, prog_bar=False, sync_dist=True)
309
+ if self.net.n_facial_components == 0:
310
+
311
+ for prediction_type in self.cfg.model.prediction_type:
312
+ metric_mean_pred_type = np.stack(
313
+ [np.array(x[f'loss/{prediction_type}']) for x in self.validation_step_outputs]).mean()
314
+ self.log(f'val/loss_{prediction_type}', metric_mean_pred_type, sync_dist=True)
315
+
316
+ for dataset_type in self.dataset_types:
317
+ for loss_type in self.validation_step_outputs[0].keys():
318
+ content = [np.array(x[dataset_type][loss_type]) for x in self.validation_step_outputs_per_dataset if loss_type in x[dataset_type]]
319
+ if len(content) > 0:
320
+ metric_mean = np.nanmean(np.stack(content))
321
+ self.log(f'val_{dataset_type}/{loss_type}', metric_mean, sync_dist=True)
322
+
323
+ self.validation_step_outputs.clear() # free memory
324
+ torch.cuda.empty_cache()
325
+
326
+ def vis_results(self, output, conf, batch, prex):
327
+ out_folder = f'{self.cfg.reconstruction_folder}/{prex}_{self.trainer.global_step}/'
328
+ os.makedirs(out_folder, exist_ok=True)
329
+ output_gpu = {k: v for k, v in output.items()}
330
+ output = {k: v.cpu() for k, v in output.items()}
331
+ if self.net.n_facial_components == 0:
332
+ output_rows = {}
333
+
334
+ for predictiont_type in ['normals', 'albedo', 'uv_map', 'nocs']:
335
+ if predictiont_type in self.cfg.model.prediction_type:
336
+ output_rows[predictiont_type] = (batch['tar_msk'][..., None].float() * batch[predictiont_type]).permute(0, 1, 4, 2, 3).detach().cpu()
337
+ if predictiont_type in self.cfg.model.prediction_type and predictiont_type == 'normals' and self.cfg.model.pred_disentangled:
338
+ output_rows['normals_can'] = (batch['tar_msk'][..., None].float() * batch['normals_can']).permute(0, 1, 4, 2, 3).detach().cpu()
339
+
340
+ gt_rgb = batch['tar_rgb'].permute(0, 1, 4, 2, 3).detach().cpu()
341
+
342
+
343
+ for i_batch in range(output_rows[self.cfg.model.prediction_type[0]].shape[0]):
344
+
345
+ modalities = []
346
+ prediction_types = self.cfg.model.prediction_type.copy() # ['pos_map', 'normals', 'albedo', 'uv_map']
347
+ if self.cfg.model.pred_disentangled and "pos_map" in prediction_types:
348
+ prediction_types.append('pos_map_can')
349
+ if self.cfg.model.pred_disentangled and "normals" in prediction_types:
350
+ prediction_types.append('normals_can')
351
+ if self.cfg.model.pred_disentangled and "uv_map" in prediction_types:
352
+ prediction_types.append('disps')
353
+
354
+ for prediction_type in prediction_types:
355
+ rows = []
356
+ for i_view in range(output_rows[prediction_type].shape[1]):
357
+ with torch.no_grad():
358
+ mini = min(output_rows[prediction_type][i_batch, i_view].min().item(),
359
+ output[prediction_type][i_batch, i_view].min().item())
360
+ tmp_gt_pos_map = output_rows[prediction_type][i_batch, i_view].clone() - mini
361
+ tmp_output = output[prediction_type][i_batch, i_view].clone() - mini
362
+ maxi = max(tmp_gt_pos_map.max().item(), tmp_output.max().item())
363
+ tmp_gt_pos_map = tmp_gt_pos_map / maxi
364
+ tmp_output = tmp_output / maxi
365
+
366
+ catted = [
367
+ gt_rgb[i_batch, i_view].permute(1, 2, 0).detach().cpu().numpy(),
368
+ pad_to_3_channels(
369
+ (batch['tar_msk'][i_batch, i_view].cpu() * tmp_gt_pos_map.cpu()).permute(1, 2,
370
+ 0).detach().cpu().numpy()),
371
+ pad_to_3_channels(tmp_output.permute(1, 2, 0).detach().cpu().float().numpy()),
372
+ ]
373
+
374
+ if conf is not None:
375
+ mini_conf = conf[i_batch, i_view].min()
376
+ tmp_conf = conf[i_batch, i_view].clone() - mini_conf
377
+ maxi_conf = tmp_conf.max()
378
+ tmp_conf = tmp_conf / maxi_conf
379
+ catted.append(
380
+ pad_to_3_channels(tmp_conf.permute(1, 2, 0).detach().cpu().float().numpy()))
381
+
382
+ catted = (np.concatenate(catted, axis=1) * 255).astype(np.uint8)
383
+
384
+ rows.append(catted)
385
+ modalities.append(np.concatenate(rows, axis=0))
386
+
387
+ catted = Image.fromarray(np.concatenate(modalities, axis=0))
388
+ scene_name = batch['meta']['scene'][i_batch]
389
+ catted.save(f'{out_folder}/{scene_name}.png') # , quality=90)
390
+
391
+
392
+
393
+
394
+ keys = list(output.keys())
395
+ for k in keys:
396
+ del output[k]
397
+ del output
398
+ del gt_rgb
399
+ keys = list(output_rows.keys())
400
+ for k in keys:
401
+ del output_rows[k]
402
+ del output_rows
403
+
404
+ torch.cuda.empty_cache()
405
+ # pll.show()
406
+
407
+ def num_steps(self) -> int:
408
+ """Get number of steps"""
409
+ # Accessing _data_source is flaky and might break
410
+ dataset = self.trainer.fit_loop._data_source.dataloader()
411
+ dataset_size = len(dataset)
412
+ num_devices = max(1, self.trainer.num_devices)
413
+ num_steps = dataset_size * self.trainer.max_epochs * self.cfg.train.limit_train_batches // (
414
+ self.trainer.accumulate_grad_batches * num_devices)
415
+ return int(num_steps)
416
+
417
+ def configure_optimizers(self):
418
+ decay_params, no_decay_params = [], []
419
+
420
+ invalid_params = []
421
+ all_backbone_params = []
422
+ all_non_backbone_params = []
423
+ backbone_params = []
424
+ backbone_params_no_decay = []
425
+ # add all bias and LayerNorm params to no_decay_params
426
+ for name, module in self.named_modules():
427
+ if name == 'flame' or name == 'flame_generic':
428
+ invalid_params.extend([p for p in module.parameters()])
429
+ else:
430
+ if isinstance(module, nn.LayerNorm):
431
+ if 'img_encoder' in name:
432
+ backbone_params_no_decay.extend([p for p in module.parameters()])
433
+ else:
434
+ no_decay_params.extend([p for p in module.parameters()])
435
+ elif hasattr(module, 'bias') and module.bias is not None:
436
+ if 'img_encoder' in name:
437
+ backbone_params_no_decay.append(module.bias)
438
+ else:
439
+ no_decay_params.append(module.bias)
440
+
441
+ if 'img_encoder' in name:
442
+ all_backbone_params.extend([p for p in module.parameters()])
443
+ else:
444
+ all_non_backbone_params.extend([p for p in module.parameters()])
445
+
446
+ # add remaining parameters to decay_params
447
+ _no_decay_ids = set(map(id, no_decay_params))
448
+ _all_backbone_ids = set(map(id, all_backbone_params))
449
+ _all_non_backbone_ids = set(map(id, all_non_backbone_params))
450
+ _backbone_no_decay_ids = set(map(id, backbone_params_no_decay))
451
+ _invalid_ids = set(map(id, invalid_params))
452
+ decay_params = [p for p in self.parameters() if
453
+ id(p) not in _no_decay_ids and id(p) not in _all_backbone_ids and id(p) not in _invalid_ids]
454
+ decay_params_backbone = [p for p in self.parameters() if
455
+ id(p) not in _backbone_no_decay_ids and id(p) not in _all_non_backbone_ids and id(
456
+ p) not in _invalid_ids]
457
+ no_decay_params = [p for p in no_decay_params if id(p) not in _invalid_ids]
458
+ no_decay_params_backbone = [p for p in backbone_params_no_decay if id(p) not in _invalid_ids]
459
+
460
+ # filter out parameters with no grad
461
+ decay_params = list(filter(lambda p: p.requires_grad, decay_params))
462
+ no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
463
+ decay_params_backbone = list(filter(lambda p: p.requires_grad, decay_params_backbone))
464
+ no_decay_params_backbone = list(filter(lambda p: p.requires_grad, no_decay_params_backbone))
465
+
466
+ # Optimizer
467
+ opt_groups = [
468
+ {'params': decay_params, 'weight_decay': self.cfg.train.weight_decay, 'lr': self.cfg.train.lr},
469
+ {'params': decay_params_backbone, 'weight_decay': self.cfg.train.weight_decay,
470
+ 'lr': self.cfg.train.lr_backbone},
471
+ {'params': no_decay_params, 'weight_decay': 0.0, 'lr': self.cfg.train.lr},
472
+ {'params': no_decay_params_backbone, 'weight_decay': 0.0, 'lr': self.cfg.train.lr_backbone},
473
+ ]
474
+ optimizer = torch.optim.AdamW(
475
+ opt_groups,
476
+ betas=(self.cfg.train.beta1, self.cfg.train.beta2),
477
+ )
478
+
479
+ total_global_batches = self.num_steps()
480
+
481
+ scheduler = CosineWarmupScheduler(
482
+ optimizer=optimizer,
483
+ warmup_iters=self.cfg.train.warmup_iters,
484
+ max_iters=total_global_batches,
485
+ )
486
+
487
+ return {"optimizer": optimizer,
488
+ "lr_scheduler": {
489
+ 'scheduler': scheduler,
490
+ 'interval': 'step' # or 'epoch' for epoch-level updates
491
+ }}
src/pixel3dmm/lightning/utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, json, math
2
+ import numpy as np
3
+ from torch.optim.lr_scheduler import LRScheduler
4
+
5
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
6
+
7
+ tanHalfFovY = torch.tan((fovY / 2))
8
+ tanHalfFovX = torch.tan((fovX / 2))
9
+
10
+ P = torch.zeros(4, 4)
11
+
12
+ z_sign = 1.0
13
+
14
+ P[0, 0] = 1 / tanHalfFovX
15
+ P[1, 1] = 1 / tanHalfFovY
16
+ P[3, 2] = z_sign
17
+ P[2, 2] = z_sign * zfar / (zfar - znear)
18
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
19
+ return P
20
+
21
+
22
+ class MiniCam:
23
+ def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, device):
24
+ # c2w (pose) should be in NeRF convention.
25
+
26
+ self.image_width = width
27
+ self.image_height = height
28
+ self.FoVy = fovy
29
+ self.FoVx = fovx
30
+ self.znear = znear
31
+ self.zfar = zfar
32
+
33
+ w2c = torch.inverse(c2w)
34
+
35
+ # rectify...
36
+ # w2c[1:3, :3] *= -1
37
+ # w2c[:3, 3] *= -1
38
+
39
+ self.world_view_transform = w2c.transpose(0, 1).to(device)
40
+ self.projection_matrix = (
41
+ getProjectionMatrix(
42
+ znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
43
+ )
44
+ .transpose(0, 1)
45
+ .to(device)
46
+ )
47
+ self.full_proj_transform = (self.world_view_transform @ self.projection_matrix).float()
48
+ self.camera_center = -c2w[:3, 3].to(device)
49
+
50
+
51
+ def rotation_matrix_to_quaternion(R):
52
+ tr = R[0, 0] + R[1, 1] + R[2, 2]
53
+ if tr > 0:
54
+ S = torch.sqrt(tr + 1.0) * 2.0
55
+ qw = 0.25 * S
56
+ qx = (R[2, 1] - R[1, 2]) / S
57
+ qy = (R[0, 2] - R[2, 0]) / S
58
+ qz = (R[1, 0] - R[0, 1]) / S
59
+ elif (R[0, 0] > R[1, 1]) and (R[0, 0] > R[2, 2]):
60
+ S = torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2.0
61
+ qw = (R[2, 1] - R[1, 2]) / S
62
+ qx = 0.25 * S
63
+ qy = (R[0, 1] + R[1, 0]) / S
64
+ qz = (R[0, 2] + R[2, 0]) / S
65
+ elif R[1, 1] > R[2, 2]:
66
+ S = torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2.0
67
+ qw = (R[0, 2] - R[2, 0]) / S
68
+ qx = (R[0, 1] + R[1, 0]) / S
69
+ qy = 0.25 * S
70
+ qz = (R[1, 2] + R[2, 1]) / S
71
+ else:
72
+ S = torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2.0
73
+ qw = (R[1, 0] - R[0, 1]) / S
74
+ qx = (R[0, 2] + R[2, 0]) / S
75
+ qy = (R[1, 2] + R[2, 1]) / S
76
+ qz = 0.25 * S
77
+ return torch.stack([qw, qx, qy, qz], dim=1)
78
+
79
+ def rotate_quaternions(q, R):
80
+ # Convert quaternions to rotation matrices
81
+ q = torch.cat([q[:, :1], -q[:, 1:]], dim=1)
82
+ q = torch.cat([q[:, :3], q[:, 3:] * -1], dim=1)
83
+ rotated_R = torch.matmul(torch.matmul(q, R), q.inverse())
84
+
85
+ # Convert the rotated rotation matrices back to quaternions
86
+ return rotation_matrix_to_quaternion(rotated_R)
87
+
88
+ class WarmupScheduler(LRScheduler):
89
+ def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
90
+ self.warmup_iters = warmup_iters
91
+ self.max_iters = max_iters
92
+ self.initial_lr = initial_lr
93
+ super().__init__(optimizer, last_iter)
94
+
95
+ def get_lr(self):
96
+ return [
97
+ self.initial_lr + (base_lr - self.initial_lr) * min(self._step_count / self.warmup_iters, 1)
98
+ for base_lr in self.base_lrs]
99
+
100
+ # this function is borrowed from OpenLRM
101
+ class CosineWarmupScheduler(LRScheduler):
102
+ def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
103
+ self.warmup_iters = warmup_iters
104
+ self.max_iters = max_iters
105
+ self.initial_lr = initial_lr
106
+ super().__init__(optimizer, last_iter)
107
+
108
+ def get_lr(self):
109
+
110
+ if self._step_count <= self.warmup_iters:
111
+ return [
112
+ self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters
113
+ for base_lr in self.base_lrs]
114
+ else:
115
+ cos_iter = self._step_count - self.warmup_iters
116
+ cos_max_iter = self.max_iters - self.warmup_iters
117
+ cos_theta = cos_iter / cos_max_iter * math.pi
118
+ cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs]
119
+ return cos_lr
src/pixel3dmm/preprocessing/__init__.py ADDED
File without changes
src/pixel3dmm/preprocessing/pipnet_utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import torch.nn.parallel
4
+ import torch.utils.data
5
+ import torchvision.transforms as transforms
6
+
7
+
8
+ from pixel3dmm.preprocessing.PIPNet.FaceBoxesV2.faceboxes_detector import *
9
+ from pixel3dmm.preprocessing.PIPNet.lib.networks import *
10
+ from pixel3dmm.preprocessing.PIPNet.lib.functions import *
11
+ from pixel3dmm.preprocessing.PIPNet.lib.mobilenetv3 import mobilenetv3_large
12
+ from pixel3dmm import env_paths
13
+
14
+ def smooth(x, window_len=11, window='hanning'):
15
+ """smooth the data using a window with requested size.
16
+
17
+ This method is based on the convolution of a scaled window with the signal.
18
+ The signal is prepared by introducing reflected copies of the signal
19
+ (with the window size) in both ends so that transient parts are minimized
20
+ in the begining and end part of the output signal.
21
+
22
+ input:
23
+ x: the input signal
24
+ window_len: the dimension of the smoothing window; should be an odd integer
25
+ window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
26
+ flat window will produce a moving average smoothing.
27
+
28
+ output:
29
+ the smoothed signal
30
+
31
+ example:
32
+
33
+ t=linspace(-2,2,0.1)
34
+ x=sin(t)+randn(len(t))*0.1
35
+ y=smooth(x)
36
+
37
+ see also:
38
+
39
+ numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
40
+ scipy.signal.lfilter
41
+
42
+ TODO: the window parameter could be the window itself if an array instead of a string
43
+ NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y.
44
+ """
45
+
46
+ if x.ndim != 1:
47
+ raise ValueError("smooth only accepts 1 dimension arrays.")
48
+
49
+ if x.size < window_len:
50
+ raise ValueError( "Input vector needs to be bigger than window size.")
51
+
52
+ if window_len < 3:
53
+ return x
54
+
55
+ if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
56
+ raise ValueError( "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")
57
+
58
+ s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
59
+ # print(len(s))
60
+ if window == 'flat': # moving average
61
+ w = np.ones(window_len, 'd')
62
+ else:
63
+ w = eval('np.' + window + '(window_len)')
64
+
65
+ y = np.convolve(w / w.sum(), s, mode='valid')
66
+ return y
67
+
68
+ def get_cstm_crop(image, detections, detections_max, max_bbox : bool = False):
69
+ #Image.fromarray(image).show()
70
+ image_width = image.shape[1]
71
+ image_height = image.shape[0]
72
+
73
+ det_box_scale = 1.42 #2.0#1.42
74
+ if detections[4]*1.42 * detections[5]*1.42 < detections_max[4] * 1.1 * detections_max[5] * 1.1:
75
+ detections = detections_max
76
+ det_box_scale = 1.1
77
+
78
+ det_xmin = detections[2]
79
+ det_ymin = detections[3]
80
+ det_width = detections[4]
81
+ det_height = detections[5]
82
+ if det_width > det_height:
83
+ det_ymin -= (det_width - det_height)//2
84
+ det_height = det_width
85
+ if det_width < det_height:
86
+ det_xmin -= (det_height - det_width)//2
87
+ det_width = det_height
88
+
89
+ det_xmax = det_xmin + det_width - 1
90
+ det_ymax = det_ymin + det_height - 1
91
+
92
+
93
+ det_xmin -= int(det_width * (det_box_scale - 1) / 2)
94
+ det_ymin -= int(det_height * (det_box_scale - 1) / 2)
95
+ det_xmax += int(det_width * (det_box_scale - 1) / 2)
96
+ det_ymax += int(det_height * (det_box_scale - 1) / 2)
97
+ if det_xmin < 0 or det_ymin < 0:
98
+ min_overflow = min(det_xmin, det_ymin)
99
+ det_xmin += -min_overflow
100
+ det_ymin += -min_overflow
101
+ if det_xmax > image_width -1 or det_ymax > image_height - 1:
102
+ max_overflow = max(det_xmax - image_width -1, det_ymax - image_height-1)
103
+ det_xmax -= max_overflow
104
+ det_ymax -= max_overflow
105
+
106
+ det_width = det_xmax - det_xmin + 1
107
+ det_height = det_ymax - det_ymin + 1
108
+ det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :]
109
+ return det_crop, det_ymin, det_ymax, det_xmin, det_xmax
110
+ #Image.fromarray(det_crop).show()
111
+ #exit()
112
+
113
+
114
+ def demo_image(image_dir, pid, save_dir, preprocess, cfg, input_size, net_stride, num_nb, use_gpu, flip=False, start_frame=0,
115
+ vertical_crop : bool = False,
116
+ static_crop : bool = False,
117
+ max_bbox : bool = False,
118
+ disable_cropping : bool = False,
119
+ ):
120
+
121
+ if cfg.use_gpu:
122
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
123
+ else:
124
+ device = torch.device("cpu")
125
+
126
+ detector = FaceBoxesDetector('FaceBoxes', f'{env_paths.CODE_BASE}/src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
127
+ my_thresh = 0.6
128
+ det_box_scale = 1.2
129
+ meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(
130
+ os.path.join(f'{env_paths.CODE_BASE}/src/pixel3dmm/preprocessing/', 'PIPNet', 'data', cfg.data_name, 'meanface.txt'), cfg.num_nb)
131
+
132
+ if cfg.backbone == 'resnet18':
133
+ resnet18 = models.resnet18(pretrained=cfg.pretrained)
134
+ net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
135
+ net_stride=cfg.net_stride)
136
+ elif cfg.backbone == 'resnet50':
137
+ resnet50 = models.resnet50(pretrained=cfg.pretrained)
138
+ net = Pip_resnet50(resnet50, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
139
+ net_stride=cfg.net_stride)
140
+ elif cfg.backbone == 'resnet101':
141
+ resnet101 = models.resnet101(pretrained=cfg.pretrained)
142
+ net = Pip_resnet101(resnet101, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
143
+ net_stride=cfg.net_stride)
144
+ elif cfg.backbone == 'mobilenet_v2':
145
+ mbnet = models.mobilenet_v2(pretrained=cfg.pretrained)
146
+ net = Pip_mbnetv2(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
147
+ elif cfg.backbone == 'mobilenet_v3':
148
+ mbnet = mobilenetv3_large()
149
+ if cfg.pretrained:
150
+ mbnet.load_state_dict(torch.load('lib/mobilenetv3-large-1cd25616.pth'))
151
+ net = Pip_mbnetv3(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
152
+ else:
153
+ print('No such backbone!')
154
+ exit(0)
155
+
156
+
157
+ net = net.to(device)
158
+
159
+ weight_file = os.path.join(save_dir, 'epoch%d.pth' % (cfg.num_epochs - 1))
160
+ state_dict = torch.load(weight_file, map_location=device)
161
+ net.load_state_dict(state_dict)
162
+ net.eval()
163
+
164
+ if start_frame > 0:
165
+ files = [f for f in os.listdir(f'{image_dir}/') if f.endswith('.jpg') or f.endswith('.png') and (((int(f.split('_')[-1].split('.')[0])-start_frame) % 3 )== 0)]
166
+ else:
167
+ files = [f for f in os.listdir(f'{image_dir}/') if f.endswith('.jpg') or f.endswith('.png')]
168
+ files.sort()
169
+
170
+ if not vertical_crop:
171
+ all_detections = []
172
+ all_images = []
173
+ #all_normals = []
174
+ succ_files = []
175
+ for file_name in files:
176
+ image = cv2.imread(f'{image_dir}/{file_name}')
177
+ #normals = cv2.imread(f'{image_dir}/../normals/{file_name[:-4]}.png')
178
+
179
+ if len(image.shape) < 3 or image.shape[-1] != 3:
180
+ continue
181
+
182
+ image_height, image_width, _ = image.shape
183
+
184
+
185
+
186
+ detections, _ = detector.detect(image, my_thresh, 1)
187
+ dets_filtered = [det for det in detections if det[0] == 'face']
188
+ dets_filtered.sort(key=lambda x: -1 * x[1])
189
+ detections = dets_filtered
190
+ if detections[0][1] < 0.75:
191
+ raise ValueError("Found face with too low detections confidence as max confidence")
192
+ all_detections.append(detections[0])
193
+ all_images.append(image)
194
+ #all_normals.append(normals)
195
+ succ_files.append(file_name)
196
+
197
+ assert static_crop, 'Other options currently not supported anymore'
198
+ if static_crop:
199
+ #if max_bbox:
200
+ det1_max = np.min(np.array([x[2] for x in all_detections]), axis=0)
201
+ det2_max = np.min(np.array([x[3] for x in all_detections]), axis=0)
202
+ det3_max = np.max(np.array([x[4]+x[2]-det1_max for x in all_detections]), axis=0)
203
+ det4_max = np.max(np.array([x[5]+x[3]-det2_max for x in all_detections]), axis=0)
204
+ det1 = np.mean(np.array([x[2] for x in all_detections]), axis=0)
205
+ det2 = np.mean(np.array([x[3] for x in all_detections]), axis=0)
206
+ det3 = np.mean(np.array([x[4] for x in all_detections]), axis=0)
207
+ det4 = np.mean(np.array([x[5] for x in all_detections]), axis=0)
208
+
209
+ det_smoothed = np.stack([det1, det2, det3, det4], axis=0).astype(np.int32)
210
+ det_smoothed_max = np.stack([det1_max, det2_max, det3_max, det4_max], axis=0).astype(np.int32)
211
+ all_detections_smoothed = [] # = [[x[0], x[1], x_smoothed[0], x_smoothed[1], x_smoothed[2], x_smoothed[3]] for x, x_smoothed in zip()]
212
+ all_detections_max_smoothed = [] # = [[x[0], x[1], x_smoothed[0], x_smoothed[1], x_smoothed[2], x_smoothed[3]] for x, x_smoothed in zip()]
213
+ for i, det in enumerate(all_detections):
214
+ all_detections_smoothed.append(
215
+ [det[0], det[1], det_smoothed[0], det_smoothed[1], det_smoothed[2], det_smoothed[3]])
216
+ all_detections_max_smoothed.append(
217
+ [det[0], det[1], det_smoothed_max[0], det_smoothed_max[1], det_smoothed_max[2], det_smoothed_max[3]])
218
+ all_detections = all_detections_smoothed
219
+ all_detections_max = all_detections_max_smoothed
220
+ else:
221
+ if len(all_detections) > 11:
222
+ WINDOW_LENGTH = 11
223
+ det1 = smooth(np.array([x[2] for x in all_detections]), window_len=WINDOW_LENGTH)
224
+ det2 = smooth(np.array([x[3] for x in all_detections]), window_len=WINDOW_LENGTH)
225
+ det3 = smooth(np.array([x[4] for x in all_detections]), window_len=WINDOW_LENGTH)
226
+ det4 = smooth(np.array([x[5] for x in all_detections]), window_len=WINDOW_LENGTH)
227
+ det_smoothed = np.stack([det1, det2,det3,det4], axis=1).astype(np.int32)
228
+ all_detections_smoothed = [] #= [[x[0], x[1], x_smoothed[0], x_smoothed[1], x_smoothed[2], x_smoothed[3]] for x, x_smoothed in zip()]
229
+ for i, det in enumerate(all_detections):
230
+ all_detections_smoothed.append([det[0], det[1], det_smoothed[i, 0], det_smoothed[i, 1], det_smoothed[i, 2], det_smoothed[i, 3]])
231
+ all_detections = all_detections_smoothed
232
+ # TODO: smooth detections!!!
233
+ for file_name, detection, detection_max, image in zip(succ_files, all_detections, all_detections_max, all_images):
234
+
235
+ if not disable_cropping:
236
+ img_crop, det_ymin, det_ymax, det_xmin, det_xmax = get_cstm_crop(image, detection, detection_max, max_bbox=max_bbox)
237
+ #n_crop = get_cstm_crop(normals, detection)
238
+ image = img_crop
239
+ # save cropped image
240
+ os.makedirs(f'{image_dir}/../cropped/', exist_ok=True)
241
+ #os.makedirs(f'{image_dir}/../cropped_normals/', exist_ok=True)
242
+ cv2.imwrite(f'{image_dir}/../cropped/{file_name}', cv2.resize(image, (512, 512)))
243
+ #cv2.imwrite(f'{image_dir}/../cropped_normals/{file_name[:-4]}.png', cv2.resize(n_crop, (512, 512)))
244
+
245
+ # store cropping information:
246
+ if not os.path.exists(f'{image_dir}/../crop_ymin_ymax_xmin_xmax.npy'):
247
+ np.save(f'{image_dir}/../crop_ymin_ymax_xmin_xmax.npy', np.array([det_ymin, det_ymax, det_xmin, det_xmax]))
248
+ else:
249
+ for file_name in files:
250
+ image = cv2.imread(f'{image_dir}/{file_name}')
251
+ if image.shape[0] != image.shape[1]:
252
+ image = image[220:-220, 640:-640, :]
253
+ os.makedirs(f'{image_dir}/../cropped/', exist_ok=True)
254
+ cv2.imwrite(f'{image_dir}/../cropped/{file_name}', cv2.resize(image, (512, 512)))
255
+
256
+
257
+ # run landmark detection
258
+ lms = []
259
+ image_dir = f'{image_dir}/../cropped/'
260
+ for file_name in files:
261
+ image = cv2.imread(f'{image_dir}/{file_name}')
262
+
263
+ if len(image.shape) < 3 or image.shape[-1] != 3:
264
+ continue
265
+ if flip:
266
+ image = cv2.transpose(image)
267
+
268
+ image_height, image_width, _ = image.shape
269
+ detections, _ = detector.detect(image, my_thresh, 1)
270
+ pred_export = None
271
+ dets_filtered = [det for det in detections if det[0] == 'face']
272
+ dets_filtered.sort(key=lambda x: -1 * x[1])
273
+ detections = dets_filtered
274
+
275
+
276
+ print(detections)
277
+ for i in range(min(1, len(detections))):
278
+ if detections[i][1] < 0.99:
279
+ continue
280
+ det_xmin = detections[i][2]
281
+ det_ymin = detections[i][3]
282
+ det_width = detections[i][4]
283
+ det_height = detections[i][5]
284
+ det_xmax = det_xmin + det_width - 1
285
+ det_ymax = det_ymin + det_height - 1
286
+
287
+
288
+ det_xmin -= int(det_width * (det_box_scale - 1) / 2)
289
+ # remove a part of top area for alignment, see paper for details
290
+ det_ymin += int(det_height * (det_box_scale - 1) / 2)
291
+ det_xmax += int(det_width * (det_box_scale - 1) / 2)
292
+ det_ymax += int(det_height * (det_box_scale - 1) / 2)
293
+ det_xmin = max(det_xmin, 0)
294
+ det_ymin = max(det_ymin, 0)
295
+ det_xmax = min(det_xmax, image_width - 1)
296
+ det_ymax = min(det_ymax, image_height - 1)
297
+ det_width = det_xmax - det_xmin + 1
298
+ det_height = det_ymax - det_ymin + 1
299
+ cv2.rectangle(image, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2)
300
+ det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :]
301
+ #np.save(f'{CROP_DIR}/{pid[:-4]}.npy', np.array([det_ymin, det_ymax, det_xmin, det_xmax]))
302
+ det_crop = cv2.resize(det_crop, (input_size, input_size))
303
+ inputs = Image.fromarray(det_crop[:, :, ::-1].astype('uint8'), 'RGB')
304
+ #inputs.show()
305
+ inputs = preprocess(inputs).unsqueeze(0)
306
+ inputs = inputs.to(device)
307
+ lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(net,
308
+ inputs,
309
+ preprocess,
310
+ input_size,
311
+ net_stride,
312
+ num_nb)
313
+ lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()
314
+ tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
315
+ tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
316
+ tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1)
317
+ tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1)
318
+ lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()
319
+ lms_pred = lms_pred.cpu().numpy()
320
+ lms_pred_merge = lms_pred_merge.cpu().numpy()
321
+ pred_export = np.zeros([cfg.num_lms, 2])
322
+ for i in range(cfg.num_lms):
323
+ x_pred = lms_pred_merge[i * 2] * det_width
324
+ y_pred = lms_pred_merge[i * 2 + 1] * det_height
325
+ pred_export[i, 0] = (x_pred + det_xmin) / image_width
326
+ pred_export[i, 1] = (y_pred + det_ymin) / image_height
327
+ cv2.circle(image, (int(x_pred) + det_xmin, int(y_pred) + det_ymin), 1, (0, 0, 255), 2)
328
+ if i == 76:
329
+ cv2.circle(image, (int(x_pred) + det_xmin, int(y_pred) + det_ymin), 1, (255, 0, 0), 2)
330
+
331
+ if pred_export is not None:
332
+ print('exporting stuff to ' + image_dir)
333
+ landmakr_dir = f'{image_dir}/../PIPnet_landmarks/'
334
+ os.makedirs(landmakr_dir, exist_ok=True)
335
+ np.save(landmakr_dir + f'/{file_name[:-4]}.npy', pred_export)
336
+ lms.append(pred_export)
337
+ exp_dir = image_dir + '/../PIPnet_annotated_images/'
338
+ os.makedirs(exp_dir, exist_ok=True)
339
+ cv2.imwrite(exp_dir + f'/{file_name}', image)
340
+
341
+ # cv2.imshow('1', image)
342
+ # cv2.waitKey(0)
343
+
344
+ lms = np.stack(lms, axis=0)
345
+ os.makedirs(f'{image_dir}/../pipnet', exist_ok=True)
346
+ np.save(f'{image_dir}/../pipnet/test.npy', lms)
347
+
348
+
src/pixel3dmm/preprocessing/replacement_code/__init__.py ADDED
File without changes
src/pixel3dmm/preprocessing/replacement_code/facer_transform.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Callable, Tuple, Optional
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import functools
5
+ import numpy as np
6
+
7
+
8
+ def get_crop_and_resize_matrix(
9
+ box: torch.Tensor, target_shape: Tuple[int, int],
10
+ target_face_scale: float = 1.0, make_square_crop: bool = True,
11
+ offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
12
+ offset_box_coords: bool = False) -> torch.Tensor:
13
+ """
14
+ Args:
15
+ box: b x 4(x1, y1, x2, y2)
16
+ align_corners (bool): Set this to `True` only if the box you give has coordinates
17
+ ranging from `0` to `h-1` or `w-1`.
18
+
19
+ offset_box_coords (bool): Set this to `True` if the box you give has coordinates
20
+ ranging from `0` to `h` or `w`.
21
+
22
+ Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
23
+
24
+ If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
25
+
26
+ Returns:
27
+ torch.Tensor: b x 3 x 3.
28
+ """
29
+ if offset_xy is None:
30
+ offset_xy = (0.0, 0.0)
31
+
32
+ x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
33
+ cx = (x1 + x2) / 2 + offset_xy[0]
34
+ cy = (y1 + y2) / 2 + offset_xy[1]
35
+ rx = (x2 - x1) / 2 / target_face_scale
36
+ ry = (y2 - y1) / 2 / target_face_scale
37
+ if make_square_crop:
38
+ rx = ry = torch.maximum(rx, ry)
39
+
40
+ x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
41
+
42
+ h, w, *_ = target_shape
43
+
44
+ zeros_pl = torch.zeros_like(x1)
45
+ ones_pl = torch.ones_like(x1)
46
+
47
+ if align_corners:
48
+ # x -> (x - x1) / (x2 - x1) * (w - 1)
49
+ # y -> (y - y1) / (y2 - y1) * (h - 1)
50
+ ax = 1.0 / (x2 - x1) * (w - 1)
51
+ ay = 1.0 / (y2 - y1) * (h - 1)
52
+ matrix = torch.cat([
53
+ ax, zeros_pl, -x1 * ax,
54
+ zeros_pl, ay, -y1 * ay,
55
+ zeros_pl, zeros_pl, ones_pl
56
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
57
+ else:
58
+ if offset_box_coords:
59
+ # x1, x2 \in [0, w], y1, y2 \in [0, h]
60
+ # first we should offset x1, x2, y1, y2 to be ranging in
61
+ # [-0.5, w-0.5] and [-0.5, h-0.5]
62
+ # so to convert these pixel coordinates into boundary coordinates.
63
+ x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
64
+
65
+ # x -> (x - x1) / (x2 - x1) * w - 0.5
66
+ # y -> (y - y1) / (y2 - y1) * h - 0.5
67
+ ax = 1.0 / (x2 - x1) * w
68
+ ay = 1.0 / (y2 - y1) * h
69
+ matrix = torch.cat([
70
+ ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
71
+ zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
72
+ zeros_pl, zeros_pl, ones_pl
73
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
74
+ return matrix
75
+
76
+
77
+ def get_similarity_transform_matrix(
78
+ from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Args:
81
+ from_pts, to_pts: b x n x 2
82
+
83
+ Returns:
84
+ torch.Tensor: b x 3 x 3
85
+ """
86
+ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
87
+ mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
88
+
89
+ a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
90
+ c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
91
+
92
+ to_delta = to_pts - mto
93
+ from_delta = from_pts - mfrom
94
+ c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
95
+ :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
96
+
97
+ a = c1 / a1
98
+ b = c2 / a1
99
+ dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
100
+ dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
101
+
102
+ ones_pl = torch.ones_like(a1)
103
+ zeros_pl = torch.zeros_like(a1)
104
+
105
+ return torch.stack([
106
+ a, b, dx,
107
+ -b, a, dy,
108
+ zeros_pl, zeros_pl, ones_pl,
109
+ ], dim=-1).reshape(-1, 3, 3)
110
+
111
+
112
+ @functools.lru_cache()
113
+ def _standard_face_pts():
114
+ pts = torch.tensor([
115
+ 196.0, 226.0,
116
+ 316.0, 226.0,
117
+ 256.0, 286.0,
118
+ 220.0, 360.4,
119
+ 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
120
+ return torch.reshape(pts, (5, 2))
121
+
122
+
123
+ def get_face_align_matrix(
124
+ face_pts: torch.Tensor, target_shape: Tuple[int, int],
125
+ target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
126
+ target_pts: Optional[torch.Tensor] = None):
127
+
128
+ if target_pts is None:
129
+ with torch.no_grad():
130
+ std_pts = _standard_face_pts().to(face_pts) # [-1 1]
131
+ h, w, *_ = target_shape
132
+ target_pts = (std_pts * target_face_scale + 1) * \
133
+ torch.tensor([w-1, h-1]).to(face_pts) / 2.0
134
+ if offset_xy is not None:
135
+ target_pts[:, 0] += offset_xy[0]
136
+ target_pts[:, 1] += offset_xy[1]
137
+ else:
138
+ target_pts = target_pts.to(face_pts)
139
+
140
+ if target_pts.dim() == 2:
141
+ target_pts = target_pts.unsqueeze(0)
142
+ if target_pts.size(0) == 1:
143
+ target_pts = target_pts.broadcast_to(face_pts.shape)
144
+
145
+ assert target_pts.shape == face_pts.shape
146
+
147
+ return get_similarity_transform_matrix(face_pts, target_pts)
148
+
149
+
150
+ def rot90(v):
151
+ return np.array([-v[1], v[0]])
152
+
153
+
154
+ def get_quad(lm: torch.Tensor):
155
+ # N,2
156
+ lm = lm.detach().cpu().numpy()
157
+ # Choose oriented crop rectangle.
158
+ eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
159
+ mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
160
+ eye_to_eye = lm[1] - lm[0]
161
+ eye_to_mouth = mouth_avg - eye_avg
162
+ x = eye_to_eye - rot90(eye_to_mouth)
163
+ x /= np.hypot(*x)
164
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
165
+ y = rot90(x)
166
+ c = eye_avg + eye_to_mouth * 0.1
167
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
168
+ quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
169
+ return torch.from_numpy(quad_for_coeffs).float()
170
+
171
+
172
+ def get_face_align_matrix_celebm(
173
+ face_pts: torch.Tensor, target_shape: Tuple[int, int], bbox_scale_factor: float = 1.0):
174
+
175
+ face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
176
+ face_mean = face_pts.mean(axis=1).unsqueeze(1)
177
+ diff = face_pts - face_mean
178
+ face_pts = face_mean + torch.tensor([[[1.5, 1.5]]], device=diff.device)*diff
179
+ assert target_shape[0] == target_shape[1]
180
+ diagonal = torch.norm(face_pts[:, 0, :] - face_pts[:, 2, :], dim=-1)
181
+ min_bbox_size = 350
182
+ max_bbox_size = 500
183
+ bbox_scale_factor = bbox_scale_factor + torch.clamp((max_bbox_size-diagonal)/(max_bbox_size-min_bbox_size), 0, 1)
184
+ print(bbox_scale_factor)
185
+ target_size = target_shape[0]/bbox_scale_factor
186
+ #target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
187
+ target_ptss = []
188
+ for tidx in range(target_size.shape[0]):
189
+ target_pts = torch.as_tensor([[0, 0], [target_size[tidx],0], [target_size[tidx], target_size[tidx]], [0, target_size[tidx]]]).to(face_pts)
190
+ target_pts += int( (target_shape[0]-target_size[tidx])/2 )
191
+ target_ptss.append(target_pts)
192
+ target_pts = torch.stack(target_ptss, dim=0)
193
+
194
+ #if target_pts.dim() == 2:
195
+ # target_pts = target_pts.unsqueeze(0)
196
+ #if target_pts.size(0) == 1:
197
+ # target_pts = target_pts.broadcast_to(face_pts.shape)
198
+
199
+ assert target_pts.shape == face_pts.shape
200
+
201
+ return get_similarity_transform_matrix(face_pts, target_pts)
202
+
203
+ @functools.lru_cache(maxsize=128)
204
+ def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ yy, xx = torch.meshgrid(torch.arange(h).float(),
206
+ torch.arange(w).float(),
207
+ indexing='ij')
208
+ return yy, xx
209
+
210
+
211
+ def _forge_grid(batch_size: int, device: torch.device,
212
+ output_shape: Tuple[int, int],
213
+ fn: Callable[[torch.Tensor], torch.Tensor]
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ """ Forge transform maps with a given function `fn`.
216
+
217
+ Args:
218
+ output_shape (tuple): (b, h, w, ...).
219
+ fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
220
+ a bxnx2 array and outputs the transformed bxnx2 array. Both input
221
+ and output store (x, y) coordinates.
222
+
223
+ Note:
224
+ both input and output arrays of `fn` should store (y, x) coordinates.
225
+
226
+ Returns:
227
+ Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
228
+ pixel (y, x) or coordinate (x, y),
229
+ `(X[y, x], Y[y, x]) = fn([x, y])`
230
+ """
231
+ h, w, *_ = output_shape
232
+ yy, xx = _meshgrid(h, w) # h x w
233
+ yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
234
+ xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
235
+
236
+ in_xxyy = torch.stack(
237
+ [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
238
+ out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
239
+ return out_xxyy.reshape(batch_size, h, w, 2)
240
+
241
+
242
+ def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
243
+ return torch.clamp(x, -1+eps, 1-eps).arctanh()
244
+
245
+
246
+ def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
247
+ warp_factor: float, warped_shape: Tuple[int, int]):
248
+ """ Inverted tanh-warp function.
249
+
250
+ Args:
251
+ coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
252
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
253
+ from the original image to the aligned yet not-warped image.
254
+ warp_factor (float): The warp factor.
255
+ 0 means linear transform, 1 means full tanh warp.
256
+ warped_shape (tuple): [height, width].
257
+
258
+ Returns:
259
+ torch.Tensor: b x n x 2 (x, y). The original coordinates.
260
+ """
261
+ h, w, *_ = warped_shape
262
+ # h -= 1
263
+ # w -= 1
264
+
265
+ w_h = torch.tensor([[w, h]]).to(coords)
266
+
267
+ if warp_factor > 0:
268
+ # normalize coordinates to [-1, +1]
269
+ coords = coords / w_h * 2 - 1
270
+
271
+ nl_part1 = coords > 1.0 - warp_factor
272
+ nl_part2 = coords < -1.0 + warp_factor
273
+
274
+ ret_nl_part1 = _safe_arctanh(
275
+ (coords - 1.0 + warp_factor) /
276
+ warp_factor) * warp_factor + \
277
+ 1.0 - warp_factor
278
+ ret_nl_part2 = _safe_arctanh(
279
+ (coords + 1.0 - warp_factor) /
280
+ warp_factor) * warp_factor - \
281
+ 1.0 + warp_factor
282
+
283
+ coords = torch.where(nl_part1, ret_nl_part1,
284
+ torch.where(nl_part2, ret_nl_part2, coords))
285
+
286
+ # denormalize
287
+ coords = (coords + 1) / 2 * w_h
288
+
289
+ coords_homo = torch.cat(
290
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
291
+
292
+ inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
293
+ # inv_matrix = np.linalg.inv(matrix)
294
+ coords_homo = torch.bmm(
295
+ coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
296
+ return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
297
+
298
+
299
+ def tanh_warp_transform(
300
+ coords: torch.Tensor, matrix: torch.Tensor,
301
+ warp_factor: float, warped_shape: Tuple[int, int]):
302
+ """ Tanh-warp function.
303
+
304
+ Args:
305
+ coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
306
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
307
+ from the original image to the aligned yet not-warped image.
308
+ warp_factor (float): The warp factor.
309
+ 0 means linear transform, 1 means full tanh warp.
310
+ warped_shape (tuple): [height, width].
311
+
312
+ Returns:
313
+ torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
314
+ """
315
+ h, w, *_ = warped_shape
316
+ # h -= 1
317
+ # w -= 1
318
+ w_h = torch.tensor([[w, h]]).to(coords)
319
+
320
+ coords_homo = torch.cat(
321
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
322
+
323
+ coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
324
+ coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
325
+
326
+ if warp_factor > 0:
327
+ # normalize coordinates to [-1, +1]
328
+ coords = coords / w_h * 2 - 1
329
+
330
+ nl_part1 = coords > 1.0 - warp_factor
331
+ nl_part2 = coords < -1.0 + warp_factor
332
+
333
+ ret_nl_part1 = torch.tanh(
334
+ (coords - 1.0 + warp_factor) /
335
+ warp_factor) * warp_factor + \
336
+ 1.0 - warp_factor
337
+ ret_nl_part2 = torch.tanh(
338
+ (coords + 1.0 - warp_factor) /
339
+ warp_factor) * warp_factor - \
340
+ 1.0 + warp_factor
341
+
342
+ coords = torch.where(nl_part1, ret_nl_part1,
343
+ torch.where(nl_part2, ret_nl_part2, coords))
344
+
345
+ # denormalize
346
+ coords = (coords + 1) / 2 * w_h
347
+
348
+ return coords
349
+
350
+
351
+ def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
352
+ warped_shape: Tuple[int, int],
353
+ orig_shape: Tuple[int, int]):
354
+ """
355
+ Args:
356
+ matrix: bx3x3 matrix.
357
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
358
+ `warp_factor=0.0` represents a cropping.
359
+ warped_shape: The target image shape to transform to.
360
+
361
+ Returns:
362
+ torch.Tensor: b x h x w x 2 (x, y).
363
+ """
364
+ orig_h, orig_w, *_ = orig_shape
365
+ w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
366
+ return _forge_grid(
367
+ matrix.size(0), matrix.device,
368
+ warped_shape,
369
+ functools.partial(inverted_tanh_warp_transform,
370
+ matrix=matrix,
371
+ warp_factor=warp_factor,
372
+ warped_shape=warped_shape)) / w_h*2-1
373
+
374
+
375
+ def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
376
+ warped_shape: Tuple[int, int],
377
+ orig_shape: Tuple[int, int]):
378
+ """
379
+ Args:
380
+ matrix: bx3x3 matrix.
381
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
382
+ `warp_factor=0.0` represents a cropping.
383
+ warped_shape: The target image shape to transform to.
384
+ orig_shape: The original image shape that is transformed from.
385
+
386
+ Returns:
387
+ torch.Tensor: b x h x w x 2 (x, y).
388
+ """
389
+ h, w, *_ = warped_shape
390
+ w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
391
+ return _forge_grid(
392
+ matrix.size(0), matrix.device,
393
+ orig_shape,
394
+ functools.partial(tanh_warp_transform,
395
+ matrix=matrix,
396
+ warp_factor=warp_factor,
397
+ warped_shape=warped_shape)) / w_h * 2-1
src/pixel3dmm/preprocessing/replacement_code/farl.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ import functools
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ..util import download_jit
7
+ from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
8
+ make_inverted_tanh_warp_grid, make_tanh_warp_grid)
9
+ from .base import FaceParser
10
+
11
+ pretrain_settings = {
12
+ 'lapa/448': {
13
+ 'url': [
14
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
15
+ ],
16
+ 'matrix_src_tag': 'points',
17
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
18
+ target_shape=(448, 448), target_face_scale=1.0),
19
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
20
+ warp_factor=0.8, warped_shape=(448, 448)),
21
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
22
+ warp_factor=0.8, warped_shape=(448, 448)),
23
+ 'label_names': ['background', 'face', 'rb', 'lb', 're',
24
+ 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
25
+ },
26
+ 'celebm/448': {
27
+ 'url': [
28
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
29
+ ],
30
+ 'matrix_src_tag': 'points',
31
+ 'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
32
+ target_shape=(448, 448)),
33
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
34
+ warp_factor=0, warped_shape=(448, 448)),
35
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
36
+ warp_factor=0, warped_shape=(448, 448)),
37
+ 'label_names': [
38
+ 'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
39
+ 'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
40
+ 'eyeg', 'hat', 'earr', 'neck_l']
41
+ }
42
+ }
43
+
44
+
45
+ class FaRLFaceParser(FaceParser):
46
+ """ The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
47
+
48
+ Please consider citing
49
+ ```bibtex
50
+ @article{zheng2021farl,
51
+ title={General Facial Representation Learning in a Visual-Linguistic Manner},
52
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
53
+ Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
54
+ Dong and Zeng, Ming and Wen, Fang},
55
+ journal={arXiv preprint arXiv:2112.03109},
56
+ year={2021}
57
+ }
58
+ ```
59
+ """
60
+
61
+ def __init__(self, conf_name: Optional[str] = None,
62
+ model_path: Optional[str] = None, device=None) -> None:
63
+ super().__init__()
64
+ if conf_name is None:
65
+ conf_name = 'lapa/448'
66
+ if model_path is None:
67
+ model_path = pretrain_settings[conf_name]['url']
68
+ self.conf_name = conf_name
69
+ self.net = download_jit(model_path, map_location=device)
70
+ self.eval()
71
+
72
+ def forward(self, images: torch.Tensor, data: Dict[str, Any], bbox_scale_factor : float = 1.0):
73
+ setting = pretrain_settings[self.conf_name]
74
+ images = images.float() / 255.0
75
+ _, _, h, w = images.shape
76
+
77
+ simages = images[data['image_ids']]
78
+ matrix_fun = functools.partial(get_face_align_matrix_celebm,
79
+ target_shape=(448, 448), bbox_scale_factor=bbox_scale_factor)
80
+ matrix = matrix_fun(data[setting['matrix_src_tag']])
81
+ grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
82
+ inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
83
+
84
+ w_images = F.grid_sample(
85
+ simages, grid, mode='bilinear', align_corners=False)
86
+
87
+ w_seg_logits, _ = self.net(w_images) # (b*n) x c x h x w
88
+
89
+ seg_logits = F.grid_sample(w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
90
+
91
+
92
+ data['seg'] = {'logits': seg_logits,
93
+ 'label_names': setting['label_names']}
94
+ return data
src/pixel3dmm/preprocessing/replacement_code/mica.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2023 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+
18
+ import os
19
+ import sys
20
+
21
+ sys.path.append("./nfclib")
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from models.arcface import Arcface
27
+ from models.generator import Generator
28
+ from micalib.base_model import BaseModel
29
+
30
+ from loguru import logger
31
+
32
+
33
+ class MICA(BaseModel):
34
+ def __init__(self, config=None, device=None, tag='MICA'):
35
+ super(MICA, self).__init__(config, device, tag)
36
+
37
+ self.initialize()
38
+
39
+ def create_model(self, model_cfg):
40
+ mapping_layers = model_cfg.mapping_layers
41
+ pretrained_path = None
42
+ if not model_cfg.use_pretrained:
43
+ pretrained_path = model_cfg.arcface_pretrained_model
44
+ self.arcface = Arcface(pretrained_path=pretrained_path).to(self.device)
45
+ self.flameModel = Generator(512, 300, self.cfg.model.n_shape, mapping_layers, model_cfg, self.device)
46
+
47
+ def load_model(self):
48
+ model_path = os.path.join(self.cfg.output_dir, 'model.tar')
49
+ if os.path.exists(self.cfg.pretrained_model_path) and self.cfg.model.use_pretrained:
50
+ model_path = self.cfg.pretrained_model_path
51
+ if os.path.exists(model_path):
52
+ logger.info(f'[{self.tag}] Trained model found. Path: {model_path} | GPU: {self.device}')
53
+ checkpoint = torch.load(model_path, weights_only=False)
54
+ if 'arcface' in checkpoint:
55
+ self.arcface.load_state_dict(checkpoint['arcface'])
56
+ if 'flameModel' in checkpoint:
57
+ self.flameModel.load_state_dict(checkpoint['flameModel'])
58
+ else:
59
+ logger.info(f'[{self.tag}] Checkpoint not available starting from scratch!')
60
+
61
+ def model_dict(self):
62
+ return {
63
+ 'flameModel': self.flameModel.state_dict(),
64
+ 'arcface': self.arcface.state_dict()
65
+ }
66
+
67
+ def parameters_to_optimize(self):
68
+ return [
69
+ {'params': self.flameModel.parameters(), 'lr': self.cfg.train.lr},
70
+ {'params': self.arcface.parameters(), 'lr': self.cfg.train.arcface_lr},
71
+ ]
72
+
73
+ def encode(self, images, arcface_imgs):
74
+ codedict = {}
75
+
76
+ codedict['arcface'] = F.normalize(self.arcface(arcface_imgs))
77
+ codedict['images'] = images
78
+
79
+ return codedict
80
+
81
+ def decode(self, codedict, epoch=0):
82
+ self.epoch = epoch
83
+
84
+ flame_verts_shape = None
85
+ shapecode = None
86
+
87
+ if not self.testing:
88
+ flame = codedict['flame']
89
+ shapecode = flame['shape_params'].view(-1, flame['shape_params'].shape[2])
90
+ shapecode = shapecode.to(self.device)[:, :self.cfg.model.n_shape]
91
+ with torch.no_grad():
92
+ flame_verts_shape, _, _ = self.flame(shape_params=shapecode)
93
+
94
+ identity_code = codedict['arcface']
95
+ pred_canonical_vertices, pred_shape_code = self.flameModel(identity_code)
96
+
97
+ output = {
98
+ 'flame_verts_shape': flame_verts_shape,
99
+ 'flame_shape_code': shapecode,
100
+ 'pred_canonical_shape_vertices': pred_canonical_vertices,
101
+ 'pred_shape_code': pred_shape_code,
102
+ 'faceid': codedict['arcface']
103
+ }
104
+
105
+ return output
106
+
107
+ def compute_losses(self, input, encoder_output, decoder_output):
108
+ losses = {}
109
+
110
+ pred_verts = decoder_output['pred_canonical_shape_vertices']
111
+ gt_verts = decoder_output['flame_verts_shape'].detach()
112
+
113
+ pred_verts_shape_canonical_diff = (pred_verts - gt_verts).abs()
114
+
115
+ if self.use_mask:
116
+ pred_verts_shape_canonical_diff *= self.vertices_mask
117
+
118
+ losses['pred_verts_shape_canonical_diff'] = torch.mean(pred_verts_shape_canonical_diff) * 1000.0
119
+
120
+ return losses
src/pixel3dmm/preprocessing/replacement_code/mica_demo.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2023 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+
18
+ import argparse
19
+ import os
20
+ import random
21
+ import traceback
22
+ from glob import glob
23
+ from pathlib import Path
24
+ from PIL import Image
25
+
26
+ import cv2
27
+ import numpy as np
28
+ import torch
29
+ import torch.backends.cudnn as cudnn
30
+ import trimesh
31
+ from insightface.app.common import Face
32
+ from insightface.utils import face_align
33
+ from loguru import logger
34
+ from skimage.io import imread
35
+ from tqdm import tqdm
36
+ #from retinaface.pre_trained_models import get_model
37
+ #from retinaface.utils import vis_annotations
38
+ #from matplotlib import pyplot as plt
39
+
40
+
41
+ from pixel3dmm.preprocessing.MICA.configs.config import get_cfg_defaults
42
+ from pixel3dmm.preprocessing.MICA.datasets.creation.util import get_arcface_input, get_center, draw_on
43
+ from pixel3dmm.preprocessing.MICA.utils import util
44
+ from pixel3dmm.preprocessing.MICA.utils.landmark_detector import LandmarksDetector, detectors
45
+ from pixel3dmm import env_paths
46
+
47
+
48
+ #model = get_model("resnet50_2020-07-20", max_size=512)
49
+ #model.eval()
50
+
51
+
52
+ def deterministic(rank):
53
+ torch.manual_seed(rank)
54
+ torch.cuda.manual_seed(rank)
55
+ np.random.seed(rank)
56
+ random.seed(rank)
57
+
58
+ cudnn.deterministic = True
59
+ cudnn.benchmark = False
60
+
61
+
62
+ def process(args, app, image_size=224, draw_bbox=False):
63
+ dst = Path(args.a)
64
+ dst.mkdir(parents=True, exist_ok=True)
65
+ processes = []
66
+ image_paths = sorted(glob(args.i + '/*.*'))#[:1]
67
+ image_paths = image_paths[::max(1, len(image_paths)//10)]
68
+ for image_path in tqdm(image_paths):
69
+ name = Path(image_path).stem
70
+ img = cv2.imread(image_path)
71
+
72
+
73
+ # FOR pytorch retinaface use this: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ # I had issues with onnxruntime!
75
+ bboxes, kpss = app.detect(img)
76
+
77
+ #annotation = model.predict_jsons(img)
78
+ #Image.fromarray(vis_annotations(img, annotation)).show()
79
+
80
+ #bboxes = np.stack([np.array( annotation[0]['bbox'] + [annotation[0]['score']] ) for i in range(len(annotation))], axis=0)
81
+ #kpss = np.stack([np.array( annotation[0]['landmarks'] ) for i in range(len(annotation))], axis=0)
82
+ if bboxes.shape[0] == 0:
83
+ logger.error(f'[ERROR] Face not detected for {image_path}')
84
+ continue
85
+ i = get_center(bboxes, img)
86
+ bbox = bboxes[i, 0:4]
87
+ det_score = bboxes[i, 4]
88
+ kps = None
89
+ if kpss is not None:
90
+ kps = kpss[i]
91
+
92
+ ##for ikp in range(kps.shape[0]):
93
+ # img[int(kps[ikp][1]), int(kps[ikp][0]), 0] = 255
94
+ # img[int(kpss_[0][ikp][1]), int(kpss_[0][ikp][0]), 1] = 255
95
+ #Image.fromarray(img).show()
96
+ face = Face(bbox=bbox, kps=kps, det_score=det_score)
97
+ blob, aimg = get_arcface_input(face, img)
98
+ file = str(Path(dst, name))
99
+ np.save(file, blob)
100
+ processes.append(file + '.npy')
101
+ cv2.imwrite(file + '.jpg', face_align.norm_crop(img, landmark=face.kps, image_size=image_size))
102
+ if draw_bbox:
103
+ dimg = draw_on(img, [face])
104
+ cv2.imwrite(file + '_bbox.jpg', dimg)
105
+
106
+ return processes
107
+
108
+
109
+ def to_batch(path):
110
+ src = path.replace('npy', 'jpg')
111
+ if not os.path.exists(src):
112
+ src = path.replace('npy', 'png')
113
+
114
+ image = imread(src)[:, :, :3]
115
+ image = image / 255.
116
+ image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
117
+ image = torch.tensor(image).cuda()[None]
118
+
119
+ arcface = np.load(path)
120
+ arcface = torch.tensor(arcface).cuda()[None]
121
+
122
+ return image, arcface
123
+
124
+
125
+ def load_checkpoint(args, mica):
126
+ checkpoint = torch.load(args.m, weights_only=False)
127
+ if 'arcface' in checkpoint:
128
+ mica.arcface.load_state_dict(checkpoint['arcface'])
129
+ if 'flameModel' in checkpoint:
130
+ mica.flameModel.load_state_dict(checkpoint['flameModel'])
131
+
132
+
133
+ def main(cfg, args):
134
+ device = 'cuda:0'
135
+ cfg.model.testing = True
136
+ mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
137
+ load_checkpoint(args, mica)
138
+ mica.eval()
139
+
140
+ faces = mica.flameModel.generator.faces_tensor.cpu()
141
+ Path(args.o).mkdir(exist_ok=True, parents=True)
142
+
143
+ app = LandmarksDetector(model=detectors.RETINAFACE)
144
+
145
+ with torch.no_grad():
146
+ logger.info(f'Processing has started...')
147
+ paths = process(args, app, draw_bbox=False)
148
+ for path in tqdm(paths):
149
+ name = Path(path).stem
150
+ images, arcface = to_batch(path)
151
+ codedict = mica.encode(images, arcface)
152
+ opdict = mica.decode(codedict)
153
+ meshes = opdict['pred_canonical_shape_vertices']
154
+ code = opdict['pred_shape_code']
155
+ lmk = mica.flameModel.generator.compute_landmarks(meshes)
156
+
157
+ mesh = meshes[0]
158
+ landmark_51 = lmk[0, 17:]
159
+ landmark_7 = landmark_51[[19, 22, 25, 28, 16, 31, 37]]
160
+
161
+ dst = Path(args.o, name)
162
+ dst.mkdir(parents=True, exist_ok=True)
163
+ trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.ply') # save in millimeters
164
+ trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.obj')
165
+ np.save(f'{dst}/identity', code[0].cpu().numpy())
166
+ np.save(f'{dst}/kpt7', landmark_7.cpu().numpy() * 1000.0)
167
+ np.save(f'{dst}/kpt68', lmk.cpu().numpy() * 1000.0)
168
+
169
+ logger.info(f'Processing finished. Results has been saved in {args.o}')
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser(description='MICA - Towards Metrical Reconstruction of Human Faces')
174
+ parser.add_argument('-video_name', required=True, type=str)
175
+ parser.add_argument('-a', default='demo/arcface', type=str, help='Processed images for MICA input')
176
+ parser.add_argument('-m', default='data/pretrained/mica.tar', type=str, help='Pretrained model path')
177
+
178
+ args = parser.parse_args()
179
+ cfg = get_cfg_defaults()
180
+ args.i = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/cropped/'
181
+ args.o = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'
182
+ if os.path.exists(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'):
183
+ if len(os.listdir(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/')) >= 10:
184
+ print(f'''
185
+ <<<<<<<< ALREADY COMPLETE MICA PREDICTION FOR {args.video_name}, SKIPPING >>>>>>>>
186
+ ''')
187
+ exit()
188
+ main(cfg, args)
src/pixel3dmm/preprocessing/replacement_code/pipnet_demo.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ import cv2, os
4
+ import sys
5
+ sys.path.insert(0, 'FaceBoxesV2')
6
+ sys.path.insert(0, '../..')
7
+ import numpy as np
8
+ import pickle
9
+ import importlib
10
+ from math import floor
11
+ from faceboxes_detector import *
12
+ import time
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.parallel
17
+ import torch.optim as optim
18
+ import torch.utils.data
19
+ import torch.nn.functional as F
20
+ import torchvision.transforms as transforms
21
+ import torchvision.datasets as datasets
22
+ import torchvision.models as models
23
+
24
+ from networks import *
25
+ import data_utils
26
+ from functions import *
27
+ from mobilenetv3 import mobilenetv3_large
28
+
29
+
30
+ def smooth(x, window_len=11, window='hanning'):
31
+ """smooth the data using a window with requested size.
32
+
33
+ This method is based on the convolution of a scaled window with the signal.
34
+ The signal is prepared by introducing reflected copies of the signal
35
+ (with the window size) in both ends so that transient parts are minimized
36
+ in the begining and end part of the output signal.
37
+
38
+ input:
39
+ x: the input signal
40
+ window_len: the dimension of the smoothing window; should be an odd integer
41
+ window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
42
+ flat window will produce a moving average smoothing.
43
+
44
+ output:
45
+ the smoothed signal
46
+
47
+ example:
48
+
49
+ t=linspace(-2,2,0.1)
50
+ x=sin(t)+randn(len(t))*0.1
51
+ y=smooth(x)
52
+
53
+ see also:
54
+
55
+ numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
56
+ scipy.signal.lfilter
57
+
58
+ TODO: the window parameter could be the window itself if an array instead of a string
59
+ NOTE: length(output) != length(input), to correct this: return y[(window_len/2-1):-(window_len/2)] instead of just y.
60
+ """
61
+
62
+ if x.ndim != 1:
63
+ raise ValueError("smooth only accepts 1 dimension arrays.")
64
+
65
+ if x.size < window_len:
66
+ raise ValueError( "Input vector needs to be bigger than window size.")
67
+
68
+ if window_len < 3:
69
+ return x
70
+
71
+ if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
72
+ raise ValueError( "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")
73
+
74
+ s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
75
+ # print(len(s))
76
+ if window == 'flat': # moving average
77
+ w = np.ones(window_len, 'd')
78
+ else:
79
+ w = eval('np.' + window + '(window_len)')
80
+
81
+ y = np.convolve(w / w.sum(), s, mode='valid')
82
+ return y
83
+
84
+ if not len(sys.argv) == 3:
85
+ print('Format:')
86
+ print('python lib/demo.py config_file image_file')
87
+ exit(0)
88
+
89
+
90
+ experiment_name = sys.argv[1].split('/')[-1][:-3]
91
+ data_name = sys.argv[1].split('/')[-2]
92
+ config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
93
+
94
+ def get_cstm_crop(image, detections):
95
+ #Image.fromarray(image).show()
96
+ image_width = image.shape[1]
97
+ image_height = image.shape[0]
98
+ det_box_scale = 1.42 #2.0#1.42
99
+ det_xmin = detections[2]
100
+ det_ymin = detections[3]
101
+ det_width = detections[4]
102
+ det_height = detections[5]
103
+ if det_width > det_height:
104
+ det_ymin -= (det_width - det_height)//2
105
+ det_height = det_width
106
+ if det_width < det_height:
107
+ det_xmin -= (det_height - det_width)//2
108
+ det_width = det_height
109
+
110
+ det_xmax = det_xmin + det_width - 1
111
+ det_ymax = det_ymin + det_height - 1
112
+
113
+
114
+ det_xmin -= int(det_width * (det_box_scale - 1) / 2)
115
+ det_ymin -= int(det_height * (det_box_scale - 1) / 2)
116
+ det_xmax += int(det_width * (det_box_scale - 1) / 2)
117
+ det_ymax += int(det_height * (det_box_scale - 1) / 2)
118
+ if det_xmin < 0 or det_ymin < 0:
119
+ min_overflow = min(det_xmin, det_ymin)
120
+ det_xmin += -min_overflow
121
+ det_ymin += -min_overflow
122
+ if det_xmax > image_width -1 or det_ymax > image_height - 1:
123
+ max_overflow = max(det_xmax - image_width -1, det_ymax - image_height-1)
124
+ det_xmax -= max_overflow
125
+ det_ymax -= max_overflow
126
+
127
+ det_width = det_xmax - det_xmin + 1
128
+ det_height = det_ymax - det_ymin + 1
129
+ det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :]
130
+ return det_crop
131
+ #Image.fromarray(det_crop).show()
132
+ #exit()
133
+
134
+ def demo_image(image_dir, pid, cam_dir, net, preprocess, cfg, input_size, net_stride, num_nb, use_gpu, device, flip=False, start_frame=0,
135
+ vertical_crop : bool = False,
136
+ static_crop : bool = False,
137
+ ):
138
+ detector = FaceBoxesDetector('FaceBoxes', '../PIPNet/FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
139
+ my_thresh = 0.6
140
+ det_box_scale = 1.2
141
+ meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(
142
+ os.path.join('../..', 'PIPNet', 'data', cfg.data_name, 'meanface.txt'), cfg.num_nb)
143
+
144
+ net.eval()
145
+
146
+ #CROP_DIR = '/mnt/rohan/cluster/angmar/sgiebenhain/now_crops_pipnet/'
147
+ #os.makedirs(CROP_DIR, exist_ok=True)
148
+
149
+
150
+ if start_frame > 0:
151
+ files = [f for f in os.listdir(f'{image_dir}/') if f.endswith('.jpg') or f.endswith('.png') and (((int(f.split('_')[-1].split('.')[0])-start_frame) % 3 )== 0)]
152
+ else:
153
+ files = [f for f in os.listdir(f'{image_dir}/') if f.endswith('.jpg') or f.endswith('.png')]
154
+ files.sort()
155
+
156
+ if not vertical_crop:
157
+ all_detections = []
158
+ all_images = []
159
+ #all_normals = []
160
+ succ_files = []
161
+ for file_name in files:
162
+ image = cv2.imread(f'{image_dir}/{file_name}')
163
+ #normals = cv2.imread(f'{image_dir}/../normals/{file_name[:-4]}.png')
164
+
165
+ if len(image.shape) < 3 or image.shape[-1] != 3:
166
+ continue
167
+
168
+ image_height, image_width, _ = image.shape
169
+ detections, _ = detector.detect(image, my_thresh, 1)
170
+ dets_filtered = [det for det in detections if det[0] == 'face']
171
+ dets_filtered.sort(key=lambda x: -1 * x[1])
172
+ detections = dets_filtered
173
+ if detections[0][1] < 0.75:
174
+ raise ValueError("Found face with too low detections confidence as max confidence")
175
+ all_detections.append(detections[0])
176
+ all_images.append(image)
177
+ #all_normals.append(normals)
178
+ succ_files.append(file_name)
179
+
180
+ if static_crop:
181
+ det1 = np.mean(np.array([x[2] for x in all_detections]), axis=0)
182
+ det2 = np.mean(np.array([x[3] for x in all_detections]), axis=0)
183
+ det3 = np.mean(np.array([x[4] for x in all_detections]), axis=0)
184
+ det4 = np.mean(np.array([x[5] for x in all_detections]), axis=0)
185
+ det_smoothed = np.stack([det1, det2, det3, det4], axis=0).astype(np.int32)
186
+ all_detections_smoothed = [] # = [[x[0], x[1], x_smoothed[0], x_smoothed[1], x_smoothed[2], x_smoothed[3]] for x, x_smoothed in zip()]
187
+ for i, det in enumerate(all_detections):
188
+ all_detections_smoothed.append(
189
+ [det[0], det[1], det_smoothed[0], det_smoothed[1], det_smoothed[2], det_smoothed[3]])
190
+ all_detections = all_detections_smoothed
191
+ else:
192
+ if len(all_detections) > 11:
193
+ WINDOW_LENGTH = 11
194
+ det1 = smooth(np.array([x[2] for x in all_detections]), window_len=WINDOW_LENGTH)
195
+ det2 = smooth(np.array([x[3] for x in all_detections]), window_len=WINDOW_LENGTH)
196
+ det3 = smooth(np.array([x[4] for x in all_detections]), window_len=WINDOW_LENGTH)
197
+ det4 = smooth(np.array([x[5] for x in all_detections]), window_len=WINDOW_LENGTH)
198
+ det_smoothed = np.stack([det1, det2,det3,det4], axis=1).astype(np.int32)
199
+ all_detections_smoothed = [] #= [[x[0], x[1], x_smoothed[0], x_smoothed[1], x_smoothed[2], x_smoothed[3]] for x, x_smoothed in zip()]
200
+ for i, det in enumerate(all_detections):
201
+ all_detections_smoothed.append([det[0], det[1], det_smoothed[i, 0], det_smoothed[i, 1], det_smoothed[i, 2], det_smoothed[i, 3]])
202
+ all_detections = all_detections_smoothed
203
+ # TODO: smooth detections!!!
204
+ for file_name, detection, image in zip(succ_files, all_detections, all_images):
205
+
206
+ img_crop = get_cstm_crop(image, detection)
207
+ #n_crop = get_cstm_crop(normals, detection)
208
+ image = img_crop
209
+ # save cropped image
210
+ os.makedirs(f'{image_dir}/../cropped/', exist_ok=True)
211
+ #os.makedirs(f'{image_dir}/../cropped_normals/', exist_ok=True)
212
+ cv2.imwrite(f'{image_dir}/../cropped/{file_name}', cv2.resize(image, (512, 512)))
213
+ #cv2.imwrite(f'{image_dir}/../cropped_normals/{file_name[:-4]}.png', cv2.resize(n_crop, (512, 512)))
214
+ else:
215
+ for file_name in files:
216
+ image = cv2.imread(f'{image_dir}/{file_name}')
217
+ if image.shape[0] != image.shape[1]:
218
+ image = image[220:-220, 640:-640, :]
219
+ os.makedirs(f'{image_dir}/../cropped/', exist_ok=True)
220
+ cv2.imwrite(f'{image_dir}/../cropped/{file_name}', cv2.resize(image, (512, 512)))
221
+
222
+
223
+ lms = []
224
+ image_dir = f'{image_dir}/../cropped/'
225
+ for file_name in files:
226
+ image = cv2.imread(f'{image_dir}/{file_name}')
227
+
228
+ if len(image.shape) < 3 or image.shape[-1] != 3:
229
+ continue
230
+ if flip:
231
+ image = cv2.transpose(image)
232
+
233
+ image_height, image_width, _ = image.shape
234
+ detections, _ = detector.detect(image, my_thresh, 1)
235
+ pred_export = None
236
+ dets_filtered = [det for det in detections if det[0] == 'face']
237
+ dets_filtered.sort(key=lambda x: -1 * x[1])
238
+ detections = dets_filtered
239
+
240
+
241
+ print(detections)
242
+ for i in range(min(1, len(detections))):
243
+ if detections[i][1] < 0.99:
244
+ continue
245
+ det_xmin = detections[i][2]
246
+ det_ymin = detections[i][3]
247
+ det_width = detections[i][4]
248
+ det_height = detections[i][5]
249
+ det_xmax = det_xmin + det_width - 1
250
+ det_ymax = det_ymin + det_height - 1
251
+
252
+
253
+ det_xmin -= int(det_width * (det_box_scale - 1) / 2)
254
+ # remove a part of top area for alignment, see paper for details
255
+ det_ymin += int(det_height * (det_box_scale - 1) / 2)
256
+ det_xmax += int(det_width * (det_box_scale - 1) / 2)
257
+ det_ymax += int(det_height * (det_box_scale - 1) / 2)
258
+ det_xmin = max(det_xmin, 0)
259
+ det_ymin = max(det_ymin, 0)
260
+ det_xmax = min(det_xmax, image_width - 1)
261
+ det_ymax = min(det_ymax, image_height - 1)
262
+ det_width = det_xmax - det_xmin + 1
263
+ det_height = det_ymax - det_ymin + 1
264
+ cv2.rectangle(image, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2)
265
+ det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :]
266
+ #np.save(f'{CROP_DIR}/{pid[:-4]}.npy', np.array([det_ymin, det_ymax, det_xmin, det_xmax]))
267
+ det_crop = cv2.resize(det_crop, (input_size, input_size))
268
+ inputs = Image.fromarray(det_crop[:, :, ::-1].astype('uint8'), 'RGB')
269
+ #inputs.show()
270
+ inputs = preprocess(inputs).unsqueeze(0)
271
+ inputs = inputs.to(device)
272
+ lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(net,
273
+ inputs,
274
+ preprocess,
275
+ input_size,
276
+ net_stride,
277
+ num_nb)
278
+ lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()
279
+ tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
280
+ tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
281
+ tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1)
282
+ tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1)
283
+ lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()
284
+ lms_pred = lms_pred.cpu().numpy()
285
+ lms_pred_merge = lms_pred_merge.cpu().numpy()
286
+ pred_export = np.zeros([cfg.num_lms, 2])
287
+ for i in range(cfg.num_lms):
288
+ x_pred = lms_pred_merge[i * 2] * det_width
289
+ y_pred = lms_pred_merge[i * 2 + 1] * det_height
290
+ pred_export[i, 0] = (x_pred + det_xmin) / image_width
291
+ pred_export[i, 1] = (y_pred + det_ymin) / image_height
292
+ cv2.circle(image, (int(x_pred) + det_xmin, int(y_pred) + det_ymin), 1, (0, 0, 255), 2)
293
+ if i == 76:
294
+ cv2.circle(image, (int(x_pred) + det_xmin, int(y_pred) + det_ymin), 1, (255, 0, 0), 2)
295
+
296
+ if pred_export is not None:
297
+ print('exporting stuff to ' + image_dir)
298
+ landmakr_dir = f'{image_dir}/../PIPnet_landmarks/'
299
+ os.makedirs(landmakr_dir, exist_ok=True)
300
+ np.save(landmakr_dir + f'/{file_name[:-4]}.npy', pred_export)
301
+ lms.append(pred_export)
302
+ exp_dir = image_dir + '/../PIPnet_annotated_images/'
303
+ os.makedirs(exp_dir, exist_ok=True)
304
+ cv2.imwrite(exp_dir + f'/{file_name}', image)
305
+
306
+ # cv2.imshow('1', image)
307
+ # cv2.waitKey(0)
308
+
309
+ lms = np.stack(lms, axis=0)
310
+ os.makedirs(f'{image_dir}/../pipnet', exist_ok=True)
311
+ np.save(f'{image_dir}/../pipnet/test.npy', lms)
312
+
313
+
314
+ def run(exp_path, image_dir, start_frame = 0,
315
+ vertical_crop : bool = False,
316
+ static_crop : bool = False
317
+ ):
318
+ experiment_name = exp_path.split('/')[-1][:-3]
319
+ data_name = exp_path.split('/')[-2]
320
+ config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
321
+
322
+ my_config = importlib.import_module(config_path, package='PIPNet')
323
+ Config = getattr(my_config, 'Config')
324
+ cfg = Config()
325
+ cfg.experiment_name = experiment_name
326
+ cfg.data_name = data_name
327
+
328
+ save_dir = os.path.join('../PIPNet/snapshots', cfg.data_name, cfg.experiment_name)
329
+
330
+ if cfg.backbone == 'resnet18':
331
+ resnet18 = models.resnet18(pretrained=cfg.pretrained)
332
+ net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
333
+ net_stride=cfg.net_stride)
334
+ elif cfg.backbone == 'resnet50':
335
+ resnet50 = models.resnet50(pretrained=cfg.pretrained)
336
+ net = Pip_resnet50(resnet50, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
337
+ net_stride=cfg.net_stride)
338
+ elif cfg.backbone == 'resnet101':
339
+ resnet101 = models.resnet101(pretrained=cfg.pretrained)
340
+ net = Pip_resnet101(resnet101, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size,
341
+ net_stride=cfg.net_stride)
342
+ elif cfg.backbone == 'mobilenet_v2':
343
+ mbnet = models.mobilenet_v2(pretrained=cfg.pretrained)
344
+ net = Pip_mbnetv2(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
345
+ elif cfg.backbone == 'mobilenet_v3':
346
+ mbnet = mobilenetv3_large()
347
+ if cfg.pretrained:
348
+ mbnet.load_state_dict(torch.load('lib/mobilenetv3-large-1cd25616.pth'))
349
+ net = Pip_mbnetv3(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
350
+ else:
351
+ print('No such backbone!')
352
+ exit(0)
353
+
354
+ if cfg.use_gpu:
355
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
356
+ else:
357
+ device = torch.device("cpu")
358
+ net = net.to(device)
359
+
360
+ weight_file = os.path.join(save_dir, 'epoch%d.pth' % (cfg.num_epochs - 1))
361
+ state_dict = torch.load(weight_file, map_location=device)
362
+ net.load_state_dict(state_dict)
363
+
364
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
365
+ std=[0.229, 0.224, 0.225])
366
+ preprocess = transforms.Compose(
367
+ [transforms.Resize((cfg.input_size, cfg.input_size)), transforms.ToTensor(), normalize])
368
+
369
+
370
+ #for pid in pids:
371
+ pid = "FaMoS_180424_03335_TA_selfie_IMG_0092.jpg"
372
+ pid = "FaMoS_180426_03336_TA_selfie_IMG_0152.jpg"
373
+
374
+
375
+
376
+ demo_image(image_dir, pid, None, net, preprocess, cfg, cfg.input_size, cfg.net_stride, cfg.num_nb,
377
+ cfg.use_gpu,
378
+ device, start_frame=start_frame, vertical_crop=vertical_crop, static_crop=static_crop)
379
+
380
+
381
+
382
+ if __name__ == '__main__':
383
+ base_path = '/mnt/rohan/cluster/valinor/jschmidt/becominglit/1015/HEADROT/img_cc_4/cam_220700191/'
384
+ base_path = '/home/giebenhain/try_tracking_obama2/rgb'
385
+ #base_base_path = '/home/giebenhain/test_videos_p3dmm_full/'
386
+ base_base_path = '/mnt/rohan/cluster/andram/sgiebenhain/test_video_p3dmm_full/'
387
+ v_names = [f for f in os.listdir(base_base_path) if f.startswith('th1k')]
388
+ print(v_names)
389
+ #v_names = list(range(800, 813))
390
+ #v_names = ['yu', 'marc', 'karla', 'karla_light', 'karla_glasses_hat', 'karla_glasses'] #['merlin', 'haoxuan']
391
+ for video_name in v_names:
392
+ base_path = f'{base_base_path}/{video_name}/rgb/'
393
+ #if os.path.exists(f'{base_path}/../cropped/'):
394
+ # print('SKIP', base_path)
395
+ # continue
396
+ start_frame = -1
397
+ vertical_crop=True
398
+ try:
399
+ run('experiments/WFLW/pip_32_16_60_r18_l2_l1_10_1_nb10.py', base_path, start_frame=start_frame, vertical_crop=False, static_crop=True)
400
+ except Exception as ex:
401
+ traceback.print_exc()