YoonaAI commited on
Commit
be7be52
·
1 Parent(s): 5bea2a2

Update apps/ICON.py

Browse files
Files changed (1) hide show
  1. apps/ICON.py +246 -216
apps/ICON.py CHANGED
@@ -14,24 +14,35 @@
14
  #
15
  # Contact: [email protected]
16
 
 
 
 
17
  from lib.common.seg3d_lossless import Seg3dLossless
18
  from lib.dataset.Evaluator import Evaluator
19
  from lib.net import HGPIFuNet
20
  from lib.common.train_util import *
21
  from lib.common.render import Render
22
  from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
 
 
23
  import torch
24
  import lib.smplx as smplx
25
  import numpy as np
26
  from torch import nn
 
 
27
  from skimage.transform import resize
28
  import pytorch_lightning as pl
 
29
 
30
  torch.backends.cudnn.benchmark = True
31
 
 
32
 
33
- class ICON(pl.LightningModule):
34
 
 
 
35
  def __init__(self, cfg):
36
  super(ICON, self).__init__()
37
 
@@ -50,31 +61,25 @@ class ICON(pl.LightningModule):
50
  error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
51
  )
52
 
 
53
  self.evaluator = Evaluator(
54
  device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
55
 
56
- self.resolutions = (np.logspace(
57
- start=5,
58
- stop=np.log2(self.mcube_res),
59
- base=2,
60
- num=int(np.log2(self.mcube_res) - 4),
61
- endpoint=True,
62
- ) + 1.0)
 
 
 
63
  self.resolutions = self.resolutions.astype(np.int16).tolist()
64
 
65
- self.base_keys = ["smpl_verts", "smpl_faces"]
66
- self.feat_names = self.cfg.net.smpl_feats
67
-
68
- self.icon_keys = self.base_keys + [
69
- f"smpl_{feat_name}" for feat_name in self.feat_names
70
- ]
71
- self.keypoint_keys = self.base_keys + [
72
- f"smpl_{feat_name}" for feat_name in self.feat_names
73
- ]
74
- self.pamir_keys = [
75
- "voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num"
76
- ]
77
- self.pifu_keys = []
78
 
79
  self.reconEngine = Seg3dLossless(
80
  query_func=query_func,
@@ -91,15 +96,14 @@ class ICON(pl.LightningModule):
91
  )
92
 
93
  self.render = Render(
94
- size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"))
 
95
  self.smpl_data = SMPLX()
96
 
97
  self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
98
  self.smpl_data.model_dir,
99
- kid_template_path=osp.join(
100
- osp.realpath(self.smpl_data.model_dir),
101
- f"{smpl_type}/{smpl_type}_kid_template.npy",
102
- ),
103
  model_type=smpl_type,
104
  gender=gender,
105
  age=age,
@@ -130,34 +134,31 @@ class ICON(pl.LightningModule):
130
  weight_decay = self.cfg.weight_decay
131
  momentum = self.cfg.momentum
132
 
133
- optim_params_G = [{
134
- "params": self.netG.if_regressor.parameters(),
135
- "lr": self.lr_G
136
- }]
137
 
138
  if self.cfg.net.use_filter:
139
- optim_params_G.append({
140
- "params": self.netG.F_filter.parameters(),
141
- "lr": self.lr_G
142
- })
143
 
144
  if self.cfg.net.prior_type == "pamir":
145
- optim_params_G.append({
146
- "params": self.netG.ve.parameters(),
147
- "lr": self.lr_G
148
- })
149
 
150
  if self.cfg.optim == "Adadelta":
151
 
152
- optimizer_G = torch.optim.Adadelta(optim_params_G,
153
- lr=self.lr_G,
154
- weight_decay=weight_decay)
155
 
156
  elif self.cfg.optim == "Adam":
157
 
158
- optimizer_G = torch.optim.Adam(optim_params_G,
159
- lr=self.lr_G,
160
- weight_decay=weight_decay)
161
 
162
  elif self.cfg.optim == "RMSprop":
163
 
@@ -173,7 +174,8 @@ class ICON(pl.LightningModule):
173
 
174
  # set scheduler
175
  scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
176
- optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
 
177
 
178
  return [optimizer_G], [scheduler_G]
179
 
@@ -193,10 +195,14 @@ class ICON(pl.LightningModule):
193
  for name in self.in_total:
194
  in_tensor_dict.update({name: batch[name]})
195
 
196
- in_tensor_dict.update({
197
- k: batch[k] if k in batch.keys() else None
198
- for k in getattr(self, f"{self.prior_type}_keys")
199
- })
 
 
 
 
200
 
201
  preds_G, error_G = self.netG(in_tensor_dict)
202
 
@@ -225,15 +231,11 @@ class ICON(pl.LightningModule):
225
  self.render_func(in_tensor_dict, dataset="train")
226
 
227
  metrics_return = {
228
- k.replace("train_", ""): torch.tensor(v)
229
- for k, v in metrics_log.items()
230
  }
231
 
232
- metrics_return.update({
233
- "loss": error_G,
234
- "log": tf_log,
235
- "progress_bar": bar_log
236
- })
237
 
238
  return metrics_return
239
 
@@ -269,11 +271,15 @@ class ICON(pl.LightningModule):
269
  for name in self.in_total:
270
  in_tensor_dict.update({name: batch[name]})
271
 
272
- in_tensor_dict.update({
273
- k: batch[k] if k in batch.keys() else None
274
- for k in getattr(self, f"{self.prior_type}_keys")
275
- })
276
-
 
 
 
 
277
  preds_G, error_G = self.netG(in_tensor_dict)
278
 
279
  acc, iou, prec, recall = self.evaluator.calc_acc(
@@ -316,7 +322,11 @@ class ICON(pl.LightningModule):
316
 
317
  (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
318
  smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
319
- smpl_cmap = self.smpl_data.cmap_smpl_vids(smpl_type)
 
 
 
 
320
 
321
  return {
322
  "smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
@@ -327,25 +337,29 @@ class ICON(pl.LightningModule):
327
  @torch.enable_grad()
328
  def optim_body(self, in_tensor_dict, batch):
329
 
330
- smpl_model = self.get_smpl_model(batch["type"][0], batch["gender"][0],
331
- batch["age"][0], None).to(self.device)
332
- in_tensor_dict["smpl_faces"] = (torch.tensor(
333
- smpl_model.faces.astype(np.int)).long().unsqueeze(0).to(
334
- self.device))
 
 
 
 
335
 
336
  # The optimizer and variables
337
- optimed_pose = torch.tensor(batch["body_pose"][0],
338
- device=self.device,
339
- requires_grad=True) # [1,23,3,3]
340
- optimed_trans = torch.tensor(batch["transl"][0],
341
- device=self.device,
342
- requires_grad=True) # [3]
343
- optimed_betas = torch.tensor(batch["betas"][0],
344
- device=self.device,
345
- requires_grad=True) # [1,10]
346
- optimed_orient = torch.tensor(batch["global_orient"][0],
347
- device=self.device,
348
- requires_grad=True) # [1,1,3,3]
349
 
350
  optimizer_smpl = torch.optim.SGD(
351
  [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
@@ -353,12 +367,8 @@ class ICON(pl.LightningModule):
353
  momentum=0.9,
354
  )
355
  scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
356
- optimizer_smpl,
357
- mode="min",
358
- factor=0.5,
359
- verbose=0,
360
- min_lr=1e-5,
361
- patience=5)
362
  loop_smpl = range(50)
363
  for i in loop_smpl:
364
 
@@ -374,12 +384,12 @@ class ICON(pl.LightningModule):
374
  )
375
 
376
  smpl_verts = smpl_out.vertices[0] * 100.0
377
- smpl_verts = projection(smpl_verts,
378
- batch["calib"][0],
379
- format="tensor")
380
  smpl_verts[:, 1] *= -1
381
  # render optimized mesh (normal, T_normal, image [-1,1])
382
- self.render.load_meshes(smpl_verts, in_tensor_dict["smpl_faces"])
 
383
  (
384
  in_tensor_dict["T_normal_F"],
385
  in_tensor_dict["T_normal_B"],
@@ -394,20 +404,24 @@ class ICON(pl.LightningModule):
394
  ) = self.netG.normal_filter(in_tensor_dict)
395
 
396
  # mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
397
- diff_F_smpl = torch.abs(in_tensor_dict["T_normal_F"] -
398
- in_tensor_dict["normal_F"])
399
- diff_B_smpl = torch.abs(in_tensor_dict["T_normal_B"] -
400
- in_tensor_dict["normal_B"])
 
 
401
  loss = (diff_F_smpl + diff_B_smpl).mean()
402
 
403
  # silhouette loss
404
  smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
405
  gt_arr = torch.cat(
406
- [in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]],
407
- dim=2).permute(1, 2, 0)
408
  gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
409
- bg_color = (torch.Tensor(
410
- [0.5, 0.5, 0.5]).unsqueeze(0).unsqueeze(0).to(self.device))
 
 
411
  gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
412
  loss += torch.abs(smpl_arr - gt_arr).mean()
413
 
@@ -425,7 +439,8 @@ class ICON(pl.LightningModule):
425
  batch["type"][0],
426
  in_tensor_dict["smpl_verts"][0],
427
  in_tensor_dict["smpl_faces"][0],
428
- ))
 
429
 
430
  features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
431
 
@@ -439,46 +454,22 @@ class ICON(pl.LightningModule):
439
  verts_pr /= (self.resolutions[-1] - 1) / 2.0
440
 
441
  losses = {
442
- "cloth": {
443
- "weight": 5.0,
444
- "value": 0.0
445
- },
446
- "edge": {
447
- "weight": 100.0,
448
- "value": 0.0
449
- },
450
- "normal": {
451
- "weight": 0.2,
452
- "value": 0.0
453
- },
454
- "laplacian": {
455
- "weight": 100.0,
456
- "value": 0.0
457
- },
458
- "smpl": {
459
- "weight": 1.0,
460
- "value": 0.0
461
- },
462
- "deform": {
463
- "weight": 20.0,
464
- "value": 0.0
465
- },
466
  }
467
 
468
- deform_verts = torch.full(verts_pr.shape,
469
- 0.0,
470
- device=self.device,
471
- requires_grad=True)
472
- optimizer_cloth = torch.optim.SGD([deform_verts],
473
- lr=1e-1,
474
- momentum=0.9)
475
  scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
476
- optimizer_cloth,
477
- mode="min",
478
- factor=0.1,
479
- verbose=0,
480
- min_lr=1e-3,
481
- patience=5)
482
  # cloth optimization
483
  loop_cloth = range(100)
484
 
@@ -498,7 +489,8 @@ class ICON(pl.LightningModule):
498
  diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
499
  losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
500
  losses["deform"]["value"] = torch.topk(
501
- torch.abs(deform_verts.flatten()), 30)[0].mean()
 
502
 
503
  # Weighted sum of the losses
504
  cloth_loss = torch.tensor(0.0, device=self.device)
@@ -518,8 +510,8 @@ class ICON(pl.LightningModule):
518
 
519
  # convert from GT to SDF
520
  deform_verts = deform_verts.flatten().detach()
521
- deform_verts[torch.topk(torch.abs(deform_verts),
522
- 30)[1]] = deform_verts.mean()
523
  deform_verts = deform_verts.view(-1, 3).cpu()
524
 
525
  verts_pr += deform_verts
@@ -530,6 +522,15 @@ class ICON(pl.LightningModule):
530
 
531
  def test_step(self, batch, batch_idx):
532
 
 
 
 
 
 
 
 
 
 
533
  self.netG.eval()
534
  self.netG.training = False
535
  in_tensor_dict = {}
@@ -537,78 +538,111 @@ class ICON(pl.LightningModule):
537
  # export paths
538
  mesh_name = batch["subject"][0]
539
  mesh_rot = batch["rotation"][0].item()
 
 
 
 
540
 
541
- self.export_dir = osp.join(self.cfg.results_path, self.cfg.name,
542
- "-".join(self.cfg.dataset.types), mesh_name)
 
 
543
 
 
544
  os.makedirs(self.export_dir, exist_ok=True)
545
 
546
  for name in self.in_total:
547
  if name in batch.keys():
548
  in_tensor_dict.update({name: batch[name]})
549
 
550
- in_tensor_dict.update({
551
- k: batch[k] if k in batch.keys() else None
552
- for k in getattr(self, f"{self.prior_type}_keys")
553
- })
554
-
555
- if "T_normal_F" not in in_tensor_dict.keys(
556
- ) or "T_normal_B" not in in_tensor_dict.keys():
557
 
558
- # update the new T_normal_F/B
559
- self.render.load_meshes(
560
- batch["smpl_verts"] *
561
- torch.tensor([1.0, -1.0, 1.0]).to(self.device),
562
- batch["smpl_faces"])
563
- T_normal_F, T_noraml_B = self.render.get_rgb_image()
564
- in_tensor_dict.update({
565
- 'T_normal_F': T_normal_F,
566
- 'T_normal_B': T_noraml_B
567
- })
 
 
 
 
 
 
 
 
568
 
569
  with torch.no_grad():
570
- features, inter = self.netG.filter(in_tensor_dict,
571
- return_inter=True)
572
- sdf = self.reconEngine(opt=self.cfg,
573
- netG=self.netG,
574
- features=features,
575
- proj_matrix=None)
576
-
577
- def tensor2arr(x):
578
- return (x[0].permute(1, 2, 0).detach().cpu().numpy() +
579
- 1.0) * 0.5 * 255.0
580
 
581
  # save inter results
582
- image = tensor2arr(in_tensor_dict["image"])
583
- smpl_F = tensor2arr(in_tensor_dict["T_normal_F"])
584
- smpl_B = tensor2arr(in_tensor_dict["T_normal_B"])
585
- image_inter = np.concatenate(self.tensor2image(512, inter[0]) +
586
- [smpl_F, smpl_B, image],
587
- axis=1)
588
- Image.fromarray((image_inter).astype(np.uint8)).save(
589
- osp.join(self.export_dir, f"{mesh_rot}_inter.png"))
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
592
 
593
  if self.clean_mesh_flag:
594
  verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
595
 
 
 
 
596
  verts_gt = batch["verts"][0]
597
  faces_gt = batch["faces"][0]
598
 
599
- self.result_eval.update({
600
- "verts_gt": verts_gt,
601
- "faces_gt": faces_gt,
602
- "verts_pr": verts_pr,
603
- "faces_pr": faces_pr,
604
- "recon_size": (self.resolutions[-1] - 1.0),
605
- "calib": batch["calib"][0],
606
- })
607
-
608
- self.evaluator.set_mesh(self.result_eval)
609
- chamfer, p2s = self.evaluator.calculate_chamfer_p2s(num_samples=1000)
 
 
 
 
 
610
  normal_consist = self.evaluator.calculate_normal_consist(
611
- osp.join(self.export_dir, f"{mesh_rot}_nc.png"))
 
612
 
613
  test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
614
 
@@ -622,8 +656,7 @@ class ICON(pl.LightningModule):
622
  outputs,
623
  rot_num=3,
624
  split={
625
- "cape-easy": (0, 50),
626
- "cape-hard": (50, 100)
627
  },
628
  )
629
 
@@ -631,10 +664,7 @@ class ICON(pl.LightningModule):
631
  print(colored(self.cfg.dataset.noise_scale, "green"))
632
 
633
  self.logger.experiment.add_hparams(
634
- hparam_dict={
635
- "lr_G": self.lr_G,
636
- "bsize": self.batch_size
637
- },
638
  metric_dict=accu_outputs,
639
  )
640
 
@@ -652,8 +682,8 @@ class ICON(pl.LightningModule):
652
  for dim in self.in_geo_dim:
653
  img = resize(
654
  np.tile(
655
- ((inter[:dim].cpu().numpy() + 1.0) / 2.0 *
656
- 255.0).transpose(1, 2, 0),
657
  (1, 1, int(3 / dim)),
658
  ),
659
  (height, height),
@@ -668,15 +698,13 @@ class ICON(pl.LightningModule):
668
  def render_func(self, in_tensor_dict, dataset="title", idx=0):
669
 
670
  for name in in_tensor_dict.keys():
671
- if in_tensor_dict[name] is not None:
672
- in_tensor_dict[name] = in_tensor_dict[name][0:1]
673
 
674
  self.netG.eval()
675
  features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
676
- sdf = self.reconEngine(opt=self.cfg,
677
- netG=self.netG,
678
- features=features,
679
- proj_matrix=None)
680
 
681
  if sdf is not None:
682
  render = self.reconEngine.display(sdf)
@@ -685,14 +713,15 @@ class ICON(pl.LightningModule):
685
  height = image_pred.shape[0]
686
 
687
  image_gt = resize(
688
- ((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0 *
689
- 255.0).transpose(1, 2, 0),
 
690
  (height, height),
691
  anti_aliasing=True,
692
  )
693
  image_inter = self.tensor2image(height, inter[0])
694
- image = np.concatenate([image_pred, image_gt] + image_inter,
695
- axis=1)
696
 
697
  step_id = self.global_step if dataset == "train" else self.global_step + idx
698
  self.logger.experiment.add_image(
@@ -711,18 +740,19 @@ class ICON(pl.LightningModule):
711
  if name in batch.keys():
712
  in_tensor_dict.update({name: batch[name]})
713
 
714
- in_tensor_dict.update({
715
- k: batch[k] if k in batch.keys() else None
716
- for k in getattr(self, f"{self.prior_type}_keys")
717
- })
 
 
 
 
718
 
719
- with torch.no_grad():
720
- features, inter = self.netG.filter(in_tensor_dict,
721
- return_inter=True)
722
- sdf = self.reconEngine(opt=self.cfg,
723
- netG=self.netG,
724
- features=features,
725
- proj_matrix=None)
726
 
727
  verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
728
 
 
14
  #
15
  # Contact: [email protected]
16
 
17
+
18
+ import os
19
+
20
  from lib.common.seg3d_lossless import Seg3dLossless
21
  from lib.dataset.Evaluator import Evaluator
22
  from lib.net import HGPIFuNet
23
  from lib.common.train_util import *
24
  from lib.common.render import Render
25
  from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
26
+ import warnings
27
+ import logging
28
  import torch
29
  import lib.smplx as smplx
30
  import numpy as np
31
  from torch import nn
32
+ import os.path as osp
33
+
34
  from skimage.transform import resize
35
  import pytorch_lightning as pl
36
+ from huggingface_hub import cached_download
37
 
38
  torch.backends.cudnn.benchmark = True
39
 
40
+ logging.getLogger("lightning").setLevel(logging.ERROR)
41
 
42
+ warnings.filterwarnings("ignore")
43
 
44
+
45
+ class ICON(pl.LightningModule):
46
  def __init__(self, cfg):
47
  super(ICON, self).__init__()
48
 
 
61
  error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
62
  )
63
 
64
+ # TODO: replace the renderer from opengl to pytorch3d
65
  self.evaluator = Evaluator(
66
  device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
67
 
68
+ self.resolutions = (
69
+ np.logspace(
70
+ start=5,
71
+ stop=np.log2(self.mcube_res),
72
+ base=2,
73
+ num=int(np.log2(self.mcube_res) - 4),
74
+ endpoint=True,
75
+ )
76
+ + 1.0
77
+ )
78
  self.resolutions = self.resolutions.astype(np.int16).tolist()
79
 
80
+ self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
81
+ self.pamir_keys = ["voxel_verts",
82
+ "voxel_faces", "pad_v_num", "pad_f_num"]
 
 
 
 
 
 
 
 
 
 
83
 
84
  self.reconEngine = Seg3dLossless(
85
  query_func=query_func,
 
96
  )
97
 
98
  self.render = Render(
99
+ size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}")
100
+ )
101
  self.smpl_data = SMPLX()
102
 
103
  self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
104
  self.smpl_data.model_dir,
105
+ kid_template_path=cached_download(osp.join(self.smpl_data.model_dir,
106
+ f"{smpl_type}/{smpl_type}_kid_template.npy"), use_auth_token=os.environ['ICON']),
 
 
107
  model_type=smpl_type,
108
  gender=gender,
109
  age=age,
 
134
  weight_decay = self.cfg.weight_decay
135
  momentum = self.cfg.momentum
136
 
137
+ optim_params_G = [
138
+ {"params": self.netG.if_regressor.parameters(), "lr": self.lr_G}
139
+ ]
 
140
 
141
  if self.cfg.net.use_filter:
142
+ optim_params_G.append(
143
+ {"params": self.netG.F_filter.parameters(), "lr": self.lr_G}
144
+ )
 
145
 
146
  if self.cfg.net.prior_type == "pamir":
147
+ optim_params_G.append(
148
+ {"params": self.netG.ve.parameters(), "lr": self.lr_G}
149
+ )
 
150
 
151
  if self.cfg.optim == "Adadelta":
152
 
153
+ optimizer_G = torch.optim.Adadelta(
154
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
155
+ )
156
 
157
  elif self.cfg.optim == "Adam":
158
 
159
+ optimizer_G = torch.optim.Adam(
160
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
161
+ )
162
 
163
  elif self.cfg.optim == "RMSprop":
164
 
 
174
 
175
  # set scheduler
176
  scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
177
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
178
+ )
179
 
180
  return [optimizer_G], [scheduler_G]
181
 
 
195
  for name in self.in_total:
196
  in_tensor_dict.update({name: batch[name]})
197
 
198
+ if self.prior_type == "icon":
199
+ for key in self.icon_keys:
200
+ in_tensor_dict.update({key: batch[key]})
201
+ elif self.prior_type == "pamir":
202
+ for key in self.pamir_keys:
203
+ in_tensor_dict.update({key: batch[key]})
204
+ else:
205
+ pass
206
 
207
  preds_G, error_G = self.netG(in_tensor_dict)
208
 
 
231
  self.render_func(in_tensor_dict, dataset="train")
232
 
233
  metrics_return = {
234
+ k.replace("train_", ""): torch.tensor(v) for k, v in metrics_log.items()
 
235
  }
236
 
237
+ metrics_return.update(
238
+ {"loss": error_G, "log": tf_log, "progress_bar": bar_log})
 
 
 
239
 
240
  return metrics_return
241
 
 
271
  for name in self.in_total:
272
  in_tensor_dict.update({name: batch[name]})
273
 
274
+ if self.prior_type == "icon":
275
+ for key in self.icon_keys:
276
+ in_tensor_dict.update({key: batch[key]})
277
+ elif self.prior_type == "pamir":
278
+ for key in self.pamir_keys:
279
+ in_tensor_dict.update({key: batch[key]})
280
+ else:
281
+ pass
282
+
283
  preds_G, error_G = self.netG(in_tensor_dict)
284
 
285
  acc, iou, prec, recall = self.evaluator.calc_acc(
 
322
 
323
  (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
324
  smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
325
+ if smpl_type == "smpl":
326
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
327
+ else:
328
+ smplx_ind = np.arange(smpl_vis.shape[0])
329
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
330
 
331
  return {
332
  "smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
 
337
  @torch.enable_grad()
338
  def optim_body(self, in_tensor_dict, batch):
339
 
340
+ smpl_model = self.get_smpl_model(
341
+ batch["type"][0], batch["gender"][0], batch["age"][0], None
342
+ ).to(self.device)
343
+ in_tensor_dict["smpl_faces"] = (
344
+ torch.tensor(smpl_model.faces.astype(np.int))
345
+ .long()
346
+ .unsqueeze(0)
347
+ .to(self.device)
348
+ )
349
 
350
  # The optimizer and variables
351
+ optimed_pose = torch.tensor(
352
+ batch["body_pose"][0], device=self.device, requires_grad=True
353
+ ) # [1,23,3,3]
354
+ optimed_trans = torch.tensor(
355
+ batch["transl"][0], device=self.device, requires_grad=True
356
+ ) # [3]
357
+ optimed_betas = torch.tensor(
358
+ batch["betas"][0], device=self.device, requires_grad=True
359
+ ) # [1,10]
360
+ optimed_orient = torch.tensor(
361
+ batch["global_orient"][0], device=self.device, requires_grad=True
362
+ ) # [1,1,3,3]
363
 
364
  optimizer_smpl = torch.optim.SGD(
365
  [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
 
367
  momentum=0.9,
368
  )
369
  scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
370
+ optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=5
371
+ )
 
 
 
 
372
  loop_smpl = range(50)
373
  for i in loop_smpl:
374
 
 
384
  )
385
 
386
  smpl_verts = smpl_out.vertices[0] * 100.0
387
+ smpl_verts = projection(
388
+ smpl_verts, batch["calib"][0], format="tensor")
 
389
  smpl_verts[:, 1] *= -1
390
  # render optimized mesh (normal, T_normal, image [-1,1])
391
+ self.render.load_meshes(
392
+ smpl_verts, in_tensor_dict["smpl_faces"])
393
  (
394
  in_tensor_dict["T_normal_F"],
395
  in_tensor_dict["T_normal_B"],
 
404
  ) = self.netG.normal_filter(in_tensor_dict)
405
 
406
  # mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
407
+ diff_F_smpl = torch.abs(
408
+ in_tensor_dict["T_normal_F"] - in_tensor_dict["normal_F"]
409
+ )
410
+ diff_B_smpl = torch.abs(
411
+ in_tensor_dict["T_normal_B"] - in_tensor_dict["normal_B"]
412
+ )
413
  loss = (diff_F_smpl + diff_B_smpl).mean()
414
 
415
  # silhouette loss
416
  smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
417
  gt_arr = torch.cat(
418
+ [in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]], dim=2
419
+ ).permute(1, 2, 0)
420
  gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
421
+ bg_color = (
422
+ torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
423
+ 0).unsqueeze(0).to(self.device)
424
+ )
425
  gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
426
  loss += torch.abs(smpl_arr - gt_arr).mean()
427
 
 
439
  batch["type"][0],
440
  in_tensor_dict["smpl_verts"][0],
441
  in_tensor_dict["smpl_faces"][0],
442
+ )
443
+ )
444
 
445
  features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
446
 
 
454
  verts_pr /= (self.resolutions[-1] - 1) / 2.0
455
 
456
  losses = {
457
+ "cloth": {"weight": 5.0, "value": 0.0},
458
+ "edge": {"weight": 100.0, "value": 0.0},
459
+ "normal": {"weight": 0.2, "value": 0.0},
460
+ "laplacian": {"weight": 100.0, "value": 0.0},
461
+ "smpl": {"weight": 1.0, "value": 0.0},
462
+ "deform": {"weight": 20.0, "value": 0.0},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  }
464
 
465
+ deform_verts = torch.full(
466
+ verts_pr.shape, 0.0, device=self.device, requires_grad=True
467
+ )
468
+ optimizer_cloth = torch.optim.SGD(
469
+ [deform_verts], lr=1e-1, momentum=0.9)
 
 
470
  scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
471
+ optimizer_cloth, mode="min", factor=0.1, verbose=0, min_lr=1e-3, patience=5
472
+ )
 
 
 
 
473
  # cloth optimization
474
  loop_cloth = range(100)
475
 
 
489
  diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
490
  losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
491
  losses["deform"]["value"] = torch.topk(
492
+ torch.abs(deform_verts.flatten()), 30
493
+ )[0].mean()
494
 
495
  # Weighted sum of the losses
496
  cloth_loss = torch.tensor(0.0, device=self.device)
 
510
 
511
  # convert from GT to SDF
512
  deform_verts = deform_verts.flatten().detach()
513
+ deform_verts[torch.topk(torch.abs(deform_verts), 30)[
514
+ 1]] = deform_verts.mean()
515
  deform_verts = deform_verts.view(-1, 3).cpu()
516
 
517
  verts_pr += deform_verts
 
522
 
523
  def test_step(self, batch, batch_idx):
524
 
525
+ # dict_keys(['dataset', 'subject', 'rotation', 'scale', 'calib',
526
+ # 'normal_F', 'normal_B', 'image', 'T_normal_F', 'T_normal_B',
527
+ # 'z-trans', 'verts', 'faces', 'samples_geo', 'labels_geo',
528
+ # 'smpl_verts', 'smpl_faces', 'smpl_vis', 'smpl_cmap', 'pts_signs',
529
+ # 'type', 'gender', 'age', 'body_pose', 'global_orient', 'betas', 'transl'])
530
+
531
+ if self.evaluator._normal_render is None:
532
+ self.evaluator.init_gl()
533
+
534
  self.netG.eval()
535
  self.netG.training = False
536
  in_tensor_dict = {}
 
538
  # export paths
539
  mesh_name = batch["subject"][0]
540
  mesh_rot = batch["rotation"][0].item()
541
+ ckpt_dir = self.cfg.name
542
+
543
+ for kid, key in enumerate(self.cfg.dataset.noise_type):
544
+ ckpt_dir += f"_{key}_{self.cfg.dataset.noise_scale[kid]}"
545
 
546
+ if self.cfg.optim_cloth:
547
+ ckpt_dir += "_optim_cloth"
548
+ if self.cfg.optim_body:
549
+ ckpt_dir += "_optim_body"
550
 
551
+ self.export_dir = osp.join(self.cfg.results_path, ckpt_dir, mesh_name)
552
  os.makedirs(self.export_dir, exist_ok=True)
553
 
554
  for name in self.in_total:
555
  if name in batch.keys():
556
  in_tensor_dict.update({name: batch[name]})
557
 
558
+ # update the new T_normal_F/B
559
+ in_tensor_dict.update(
560
+ self.evaluator.render_normal(
561
+ batch["smpl_verts"], batch["smpl_faces"])
562
+ )
 
 
563
 
564
+ # update the new smpl_vis
565
+ (xy, z) = batch["smpl_verts"][0].split([2, 1], dim=1)
566
+ smpl_vis = get_visibility(
567
+ xy,
568
+ z,
569
+ torch.as_tensor(self.smpl_data.faces).type_as(
570
+ batch["smpl_verts"]).long(),
571
+ )
572
+ in_tensor_dict.update({"smpl_vis": smpl_vis.unsqueeze(0)})
573
+
574
+ if self.prior_type == "icon":
575
+ for key in self.icon_keys:
576
+ in_tensor_dict.update({key: batch[key]})
577
+ elif self.prior_type == "pamir":
578
+ for key in self.pamir_keys:
579
+ in_tensor_dict.update({key: batch[key]})
580
+ else:
581
+ pass
582
 
583
  with torch.no_grad():
584
+ if self.cfg.optim_body:
585
+ features, inter, in_tensor_dict = self.optim_body(
586
+ in_tensor_dict, batch)
587
+ else:
588
+ features, inter = self.netG.filter(
589
+ in_tensor_dict, return_inter=True)
590
+ sdf = self.reconEngine(
591
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
592
+ )
 
593
 
594
  # save inter results
595
+ image = (
596
+ in_tensor_dict["image"][0].permute(
597
+ 1, 2, 0).detach().cpu().numpy() + 1.0
598
+ ) * 0.5
599
+ smpl_F = (
600
+ in_tensor_dict["T_normal_F"][0].permute(
601
+ 1, 2, 0).detach().cpu().numpy()
602
+ + 1.0
603
+ ) * 0.5
604
+ smpl_B = (
605
+ in_tensor_dict["T_normal_B"][0].permute(
606
+ 1, 2, 0).detach().cpu().numpy()
607
+ + 1.0
608
+ ) * 0.5
609
+ image_inter = np.concatenate(
610
+ self.tensor2image(512, inter[0]) + [smpl_F, smpl_B, image], axis=1
611
+ )
612
+ Image.fromarray((image_inter * 255.0).astype(np.uint8)).save(
613
+ osp.join(self.export_dir, f"{mesh_rot}_inter.png")
614
+ )
615
 
616
  verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
617
 
618
  if self.clean_mesh_flag:
619
  verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
620
 
621
+ if self.cfg.optim_cloth:
622
+ verts_pr = self.optim_cloth(verts_pr, faces_pr, inter[0].detach())
623
+
624
  verts_gt = batch["verts"][0]
625
  faces_gt = batch["faces"][0]
626
 
627
+ self.result_eval.update(
628
+ {
629
+ "verts_gt": verts_gt,
630
+ "faces_gt": faces_gt,
631
+ "verts_pr": verts_pr,
632
+ "faces_pr": faces_pr,
633
+ "recon_size": (self.resolutions[-1] - 1.0),
634
+ "calib": batch["calib"][0],
635
+ }
636
+ )
637
+
638
+ self.evaluator.set_mesh(self.result_eval, scale_factor=1.0)
639
+ self.evaluator.space_transfer()
640
+
641
+ chamfer, p2s = self.evaluator.calculate_chamfer_p2s(
642
+ sampled_points=1000)
643
  normal_consist = self.evaluator.calculate_normal_consist(
644
+ save_demo_img=osp.join(self.export_dir, f"{mesh_rot}_nc.png")
645
+ )
646
 
647
  test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
648
 
 
656
  outputs,
657
  rot_num=3,
658
  split={
659
+ "thuman2": (0, 5),
 
660
  },
661
  )
662
 
 
664
  print(colored(self.cfg.dataset.noise_scale, "green"))
665
 
666
  self.logger.experiment.add_hparams(
667
+ hparam_dict={"lr_G": self.lr_G, "bsize": self.batch_size},
 
 
 
668
  metric_dict=accu_outputs,
669
  )
670
 
 
682
  for dim in self.in_geo_dim:
683
  img = resize(
684
  np.tile(
685
+ ((inter[:dim].cpu().numpy() + 1.0) /
686
+ 2.0).transpose(1, 2, 0),
687
  (1, 1, int(3 / dim)),
688
  ),
689
  (height, height),
 
698
  def render_func(self, in_tensor_dict, dataset="title", idx=0):
699
 
700
  for name in in_tensor_dict.keys():
701
+ in_tensor_dict[name] = in_tensor_dict[name][0:1]
 
702
 
703
  self.netG.eval()
704
  features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
705
+ sdf = self.reconEngine(
706
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
707
+ )
 
708
 
709
  if sdf is not None:
710
  render = self.reconEngine.display(sdf)
 
713
  height = image_pred.shape[0]
714
 
715
  image_gt = resize(
716
+ ((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0).transpose(
717
+ 1, 2, 0
718
+ ),
719
  (height, height),
720
  anti_aliasing=True,
721
  )
722
  image_inter = self.tensor2image(height, inter[0])
723
+ image = np.concatenate(
724
+ [image_pred, image_gt] + image_inter, axis=1)
725
 
726
  step_id = self.global_step if dataset == "train" else self.global_step + idx
727
  self.logger.experiment.add_image(
 
740
  if name in batch.keys():
741
  in_tensor_dict.update({name: batch[name]})
742
 
743
+ if self.prior_type == "icon":
744
+ for key in self.icon_keys:
745
+ in_tensor_dict.update({key: batch[key]})
746
+ elif self.prior_type == "pamir":
747
+ for key in self.pamir_keys:
748
+ in_tensor_dict.update({key: batch[key]})
749
+ else:
750
+ pass
751
 
752
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
753
+ sdf = self.reconEngine(
754
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
755
+ )
 
 
 
756
 
757
  verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
758