alexnasa commited on
Commit
9d79572
·
verified ·
1 Parent(s): d431838

Update src/pixel3dmm/preprocessing/MICA/demo.py

Browse files
src/pixel3dmm/preprocessing/MICA/demo.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # -*- coding: utf-8 -*-
2
 
3
  # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
@@ -131,18 +134,11 @@ def load_checkpoint(args, mica):
131
  mica.flameModel.load_state_dict(checkpoint['flameModel'])
132
 
133
 
134
- def main(cfg, args):
135
- device = 'cuda:0'
136
- cfg.model.testing = True
137
- mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
138
- load_checkpoint(args, mica)
139
- mica.eval()
140
 
141
  faces = mica.flameModel.generator.faces_tensor.cpu()
142
  Path(args.o).mkdir(exist_ok=True, parents=True)
143
 
144
- app = LandmarksDetector(model=detectors.RETINAFACE)
145
-
146
  with torch.no_grad():
147
  logger.info(f'Processing has started...')
148
  paths = process(args, app, draw_bbox=False)
@@ -186,4 +182,14 @@ if __name__ == '__main__':
186
  <<<<<<<< ALREADY COMPLETE MICA PREDICTION FOR {args.video_name}, SKIPPING >>>>>>>>
187
  ''')
188
  exit()
189
- main(cfg, args)
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
  # -*- coding: utf-8 -*-
5
 
6
  # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
 
134
  mica.flameModel.load_state_dict(checkpoint['flameModel'])
135
 
136
 
137
+ def main(args, mica, app):
 
 
 
 
 
138
 
139
  faces = mica.flameModel.generator.faces_tensor.cpu()
140
  Path(args.o).mkdir(exist_ok=True, parents=True)
141
 
 
 
142
  with torch.no_grad():
143
  logger.info(f'Processing has started...')
144
  paths = process(args, app, draw_bbox=False)
 
182
  <<<<<<<< ALREADY COMPLETE MICA PREDICTION FOR {args.video_name}, SKIPPING >>>>>>>>
183
  ''')
184
  exit()
185
+
186
+ # instantiate models outside main
187
+ device = 'cuda'
188
+ cfg.model.testing = True
189
+ mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
190
+ load_checkpoint(args, mica)
191
+ mica.eval()
192
+
193
+ app = LandmarksDetector(model=detectors.RETINAFACE)
194
+
195
+ main(args, mica, app)