YoonaAI commited on
Commit
8f4792e
·
1 Parent(s): f7266a6

Upload 5 files

Browse files
lib/pymaf/core/base_trainer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py
2
+ from __future__ import division
3
+ import logging
4
+ from utils import CheckpointSaver
5
+ from tensorboardX import SummaryWriter
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ tqdm.monitor_interval = 0
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BaseTrainer(object):
17
+ """Base class for Trainer objects.
18
+ Takes care of checkpointing/logging/resuming training.
19
+ """
20
+
21
+ def __init__(self, options):
22
+ self.options = options
23
+ if options.multiprocessing_distributed:
24
+ self.device = torch.device('cuda', options.gpu)
25
+ else:
26
+ self.device = torch.device(
27
+ 'cuda' if torch.cuda.is_available() else 'cpu')
28
+ # override this function to define your model, optimizers etc.
29
+ self.saver = CheckpointSaver(save_dir=options.checkpoint_dir,
30
+ overwrite=options.overwrite)
31
+ if options.rank == 0:
32
+ self.summary_writer = SummaryWriter(self.options.summary_dir)
33
+ self.init_fn()
34
+
35
+ self.checkpoint = None
36
+ if options.resume and self.saver.exists_checkpoint():
37
+ self.checkpoint = self.saver.load_checkpoint(
38
+ self.models_dict, self.optimizers_dict)
39
+
40
+ if self.checkpoint is None:
41
+ self.epoch_count = 0
42
+ self.step_count = 0
43
+ else:
44
+ self.epoch_count = self.checkpoint['epoch']
45
+ self.step_count = self.checkpoint['total_step_count']
46
+
47
+ if self.checkpoint is not None:
48
+ self.checkpoint_batch_idx = self.checkpoint['batch_idx']
49
+ else:
50
+ self.checkpoint_batch_idx = 0
51
+
52
+ self.best_performance = float('inf')
53
+
54
+ def load_pretrained(self, checkpoint_file=None):
55
+ """Load a pretrained checkpoint.
56
+ This is different from resuming training using --resume.
57
+ """
58
+ if checkpoint_file is not None:
59
+ checkpoint = torch.load(checkpoint_file)
60
+ for model in self.models_dict:
61
+ if model in checkpoint:
62
+ self.models_dict[model].load_state_dict(checkpoint[model],
63
+ strict=True)
64
+ print(f'Checkpoint {model} loaded')
65
+
66
+ def move_dict_to_device(self, dict, device, tensor2float=False):
67
+ for k, v in dict.items():
68
+ if isinstance(v, torch.Tensor):
69
+ if tensor2float:
70
+ dict[k] = v.float().to(device)
71
+ else:
72
+ dict[k] = v.to(device)
73
+
74
+ # The following methods (with the possible exception of test) have to be implemented in the derived classes
75
+ def train(self, epoch):
76
+ raise NotImplementedError('You need to provide an train method')
77
+
78
+ def init_fn(self):
79
+ raise NotImplementedError('You need to provide an _init_fn method')
80
+
81
+ def train_step(self, input_batch):
82
+ raise NotImplementedError('You need to provide a _train_step method')
83
+
84
+ def train_summaries(self, input_batch):
85
+ raise NotImplementedError(
86
+ 'You need to provide a _train_summaries method')
87
+
88
+ def visualize(self, input_batch):
89
+ raise NotImplementedError('You need to provide a visualize method')
90
+
91
+ def validate(self):
92
+ pass
93
+
94
+ def test(self):
95
+ pass
96
+
97
+ def evaluate(self):
98
+ pass
99
+
100
+ def fit(self):
101
+ # Run training for num_epochs epochs
102
+ for epoch in tqdm(range(self.epoch_count, self.options.num_epochs),
103
+ total=self.options.num_epochs,
104
+ initial=self.epoch_count):
105
+ self.epoch_count = epoch
106
+ self.train(epoch)
107
+ return
lib/pymaf/core/cfgs.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ import os
18
+ import json
19
+ from yacs.config import CfgNode as CN
20
+
21
+ # Configuration variables
22
+ cfg = CN(new_allowed=True)
23
+
24
+ cfg.OUTPUT_DIR = 'results'
25
+ cfg.DEVICE = 'cuda'
26
+ cfg.DEBUG = False
27
+ cfg.LOGDIR = ''
28
+ cfg.VAL_VIS_BATCH_FREQ = 200
29
+ cfg.TRAIN_VIS_ITER_FERQ = 1000
30
+ cfg.SEED_VALUE = -1
31
+
32
+ cfg.TRAIN = CN(new_allowed=True)
33
+
34
+ cfg.LOSS = CN(new_allowed=True)
35
+ cfg.LOSS.KP_2D_W = 300.0
36
+ cfg.LOSS.KP_3D_W = 300.0
37
+ cfg.LOSS.SHAPE_W = 0.06
38
+ cfg.LOSS.POSE_W = 60.0
39
+ cfg.LOSS.VERT_W = 0.0
40
+
41
+ # Loss weights for dense correspondences
42
+ cfg.LOSS.INDEX_WEIGHTS = 2.0
43
+ # Loss weights for surface parts. (24 Parts)
44
+ cfg.LOSS.PART_WEIGHTS = 0.3
45
+ # Loss weights for UV regression.
46
+ cfg.LOSS.POINT_REGRESSION_WEIGHTS = 0.5
47
+
48
+ cfg.MODEL = CN(new_allowed=True)
49
+
50
+ cfg.MODEL.PyMAF = CN(new_allowed=True)
51
+
52
+ # switch
53
+ cfg.TRAIN.VAL_LOOP = True
54
+
55
+ cfg.TEST = CN(new_allowed=True)
56
+
57
+
58
+ def get_cfg_defaults():
59
+ """Get a yacs CfgNode object with default values for my_project."""
60
+ # Return a clone so that the defaults will not be altered
61
+ # This is for the "local variable" use pattern
62
+ # return cfg.clone()
63
+ return cfg
64
+
65
+
66
+ def update_cfg(cfg_file):
67
+ # cfg = get_cfg_defaults()
68
+ cfg.merge_from_file(cfg_file)
69
+ # return cfg.clone()
70
+ return cfg
71
+
72
+
73
+ def parse_args(args):
74
+ cfg_file = args.cfg_file
75
+ if args.cfg_file is not None:
76
+ cfg = update_cfg(args.cfg_file)
77
+ else:
78
+ cfg = get_cfg_defaults()
79
+
80
+ # if args.misc is not None:
81
+ # cfg.merge_from_list(args.misc)
82
+
83
+ return cfg
84
+
85
+
86
+ def parse_args_extend(args):
87
+ if args.resume:
88
+ if not os.path.exists(args.log_dir):
89
+ raise ValueError(
90
+ 'Experiment are set to resume mode, but log directory does not exist.'
91
+ )
92
+
93
+ # load log's cfg
94
+ cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
95
+ cfg = update_cfg(cfg_file)
96
+
97
+ if args.misc is not None:
98
+ cfg.merge_from_list(args.misc)
99
+ else:
100
+ parse_args(args)
lib/pymaf/core/fits_dict.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
3
+ '''
4
+ import os
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis
9
+
10
+ from core import path_config, constants
11
+
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FitsDict():
18
+ """ Dictionary keeping track of the best fit per image in the training set """
19
+
20
+ def __init__(self, options, train_dataset):
21
+ self.options = options
22
+ self.train_dataset = train_dataset
23
+ self.fits_dict = {}
24
+ self.valid_fit_state = {}
25
+ # array used to flip SMPL pose parameters
26
+ self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM,
27
+ dtype=torch.int64)
28
+ # Load dictionary state
29
+ for ds_name, ds in train_dataset.dataset_dict.items():
30
+ if ds_name in ['h36m']:
31
+ dict_file = os.path.join(path_config.FINAL_FITS_DIR,
32
+ ds_name + '.npy')
33
+ self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file))
34
+ self.valid_fit_state[ds_name] = torch.ones(len(
35
+ self.fits_dict[ds_name]),
36
+ dtype=torch.uint8)
37
+ else:
38
+ dict_file = os.path.join(path_config.FINAL_FITS_DIR,
39
+ ds_name + '.npz')
40
+ fits_dict = np.load(dict_file)
41
+ opt_pose = torch.from_numpy(fits_dict['pose'])
42
+ opt_betas = torch.from_numpy(fits_dict['betas'])
43
+ opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to(
44
+ torch.uint8)
45
+ self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas],
46
+ dim=1)
47
+ self.valid_fit_state[ds_name] = opt_valid_fit
48
+
49
+ if not options.single_dataset:
50
+ for ds in train_dataset.datasets:
51
+ if ds.dataset not in ['h36m']:
52
+ ds.pose = self.fits_dict[ds.dataset][:, :72].numpy()
53
+ ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy()
54
+ ds.has_smpl = self.valid_fit_state[ds.dataset].numpy()
55
+
56
+ def save(self):
57
+ """ Save dictionary state to disk """
58
+ for ds_name in self.train_dataset.dataset_dict.keys():
59
+ dict_file = os.path.join(self.options.checkpoint_dir,
60
+ ds_name + '_fits.npy')
61
+ np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
62
+
63
+ def __getitem__(self, x):
64
+ """ Retrieve dictionary entries """
65
+ dataset_name, ind, rot, is_flipped = x
66
+ batch_size = len(dataset_name)
67
+ pose = torch.zeros((batch_size, 72))
68
+ betas = torch.zeros((batch_size, 10))
69
+ for ds, i, n in zip(dataset_name, ind, range(batch_size)):
70
+ params = self.fits_dict[ds][i]
71
+ pose[n, :] = params[:72]
72
+ betas[n, :] = params[72:]
73
+ pose = pose.clone()
74
+ # Apply flipping and rotation
75
+ pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped)
76
+ betas = betas.clone()
77
+ return pose, betas
78
+
79
+ def get_vaild_state(self, dataset_name, ind):
80
+ batch_size = len(dataset_name)
81
+ valid_fit = torch.zeros(batch_size, dtype=torch.uint8)
82
+ for ds, i, n in zip(dataset_name, ind, range(batch_size)):
83
+ valid_fit[n] = self.valid_fit_state[ds][i]
84
+ valid_fit = valid_fit.clone()
85
+ return valid_fit
86
+
87
+ def __setitem__(self, x, val):
88
+ """ Update dictionary entries """
89
+ dataset_name, ind, rot, is_flipped, update = x
90
+ pose, betas = val
91
+ batch_size = len(dataset_name)
92
+ # Undo flipping and rotation
93
+ pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot)
94
+ params = torch.cat((pose, betas), dim=-1).cpu()
95
+ for ds, i, n in zip(dataset_name, ind, range(batch_size)):
96
+ if update[n]:
97
+ self.fits_dict[ds][i] = params[n]
98
+
99
+ def flip_pose(self, pose, is_flipped):
100
+ """flip SMPL pose parameters"""
101
+ is_flipped = is_flipped.byte()
102
+ pose_f = pose.clone()
103
+ pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
104
+ # we also negate the second and the third dimension of the axis-angle representation
105
+ pose_f[is_flipped, 1::3] *= -1
106
+ pose_f[is_flipped, 2::3] *= -1
107
+ return pose_f
108
+
109
+ def rotate_pose(self, pose, rot):
110
+ """Rotate SMPL pose parameters by rot degrees"""
111
+ pose = pose.clone()
112
+ cos = torch.cos(-np.pi * rot / 180.)
113
+ sin = torch.sin(-np.pi * rot / 180.)
114
+ zeros = torch.zeros_like(cos)
115
+ r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
116
+ r3[:, 0, -1] = 1
117
+ R = torch.cat([
118
+ torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
119
+ torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
120
+ ],
121
+ dim=1)
122
+ global_pose = pose[:, :3]
123
+ global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose)
124
+ global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3]
125
+ global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3)
126
+ global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3
127
+ global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy()
128
+ global_pose_np = np.zeros((global_pose.shape[0], 3))
129
+ for i in range(global_pose.shape[0]):
130
+ aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
131
+ global_pose_np[i, :] = aa.squeeze()
132
+ pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
133
+ return pose
lib/pymaf/core/path_config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/path_config.py
3
+ path configuration
4
+ This file contains definitions of useful data stuctures and the paths
5
+ for the datasets and data files necessary to run the code.
6
+ Things you need to change: *_ROOT that indicate the path to each dataset
7
+ """
8
+ import os
9
+ from huggingface_hub import hf_hub_url, cached_download
10
+
11
+ # pymaf
12
+ pymaf_data_dir = hf_hub_url('Yuliang/PyMAF', '')
13
+ smpl_data_dir = hf_hub_url('Yuliang/SMPL', '')
14
+ SMPL_MODEL_DIR = os.path.join(smpl_data_dir, 'models/smpl')
15
+
16
+ SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
17
+ MESH_DOWNSAMPLEING = cached_download(os.path.join(pymaf_data_dir, 'mesh_downsampling.npz'), use_auth_token=os.environ['ICON'])
18
+ CUBE_PARTS_FILE = cached_download(os.path.join(pymaf_data_dir, 'cube_parts.npy'), use_auth_token=os.environ['ICON'])
19
+ JOINT_REGRESSOR_TRAIN_EXTRA = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_extra.npy'), use_auth_token=os.environ['ICON'])
20
+ JOINT_REGRESSOR_H36M = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_h36m.npy'), use_auth_token=os.environ['ICON'])
21
+ VERTEX_TEXTURE_FILE = cached_download(os.path.join(pymaf_data_dir, 'vertex_texture.npy'), use_auth_token=os.environ['ICON'])
22
+ SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
23
+ CHECKPOINT_FILE = cached_download(os.path.join(pymaf_data_dir, 'pretrained_model/PyMAF_model_checkpoint.pt'), use_auth_token=os.environ['ICON'])
lib/pymaf/core/train_options.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ class TrainOptions():
5
+ def __init__(self):
6
+ self.parser = argparse.ArgumentParser()
7
+
8
+ gen = self.parser.add_argument_group('General')
9
+ gen.add_argument(
10
+ '--resume',
11
+ dest='resume',
12
+ default=False,
13
+ action='store_true',
14
+ help='Resume from checkpoint (Use latest checkpoint by default')
15
+
16
+ io = self.parser.add_argument_group('io')
17
+ io.add_argument('--log_dir',
18
+ default='logs',
19
+ help='Directory to store logs')
20
+ io.add_argument(
21
+ '--pretrained_checkpoint',
22
+ default=None,
23
+ help='Load a pretrained checkpoint at the beginning training')
24
+
25
+ train = self.parser.add_argument_group('Training Options')
26
+ train.add_argument('--num_epochs',
27
+ type=int,
28
+ default=200,
29
+ help='Total number of training epochs')
30
+ train.add_argument('--regressor',
31
+ type=str,
32
+ choices=['hmr', 'pymaf_net'],
33
+ default='pymaf_net',
34
+ help='Name of the SMPL regressor.')
35
+ train.add_argument('--cfg_file',
36
+ type=str,
37
+ default='./configs/pymaf_config.yaml',
38
+ help='config file path for PyMAF.')
39
+ train.add_argument(
40
+ '--img_res',
41
+ type=int,
42
+ default=224,
43
+ help='Rescale bounding boxes to size [img_res, img_res] before feeding them in the network'
44
+ )
45
+ train.add_argument(
46
+ '--rot_factor',
47
+ type=float,
48
+ default=30,
49
+ help='Random rotation in the range [-rot_factor, rot_factor]')
50
+ train.add_argument(
51
+ '--noise_factor',
52
+ type=float,
53
+ default=0.4,
54
+ help='Randomly multiply pixel values with factor in the range [1-noise_factor, 1+noise_factor]'
55
+ )
56
+ train.add_argument(
57
+ '--scale_factor',
58
+ type=float,
59
+ default=0.25,
60
+ help='Rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]'
61
+ )
62
+ train.add_argument(
63
+ '--openpose_train_weight',
64
+ default=0.,
65
+ help='Weight for OpenPose keypoints during training')
66
+ train.add_argument('--gt_train_weight',
67
+ default=1.,
68
+ help='Weight for GT keypoints during training')
69
+ train.add_argument('--eval_dataset',
70
+ type=str,
71
+ default='h36m-p2-mosh',
72
+ help='Name of the evaluation dataset.')
73
+ train.add_argument('--single_dataset',
74
+ default=False,
75
+ action='store_true',
76
+ help='Use a single dataset')
77
+ train.add_argument('--single_dataname',
78
+ type=str,
79
+ default='h36m',
80
+ help='Name of the single dataset.')
81
+ train.add_argument('--eval_pve',
82
+ default=False,
83
+ action='store_true',
84
+ help='evaluate PVE')
85
+ train.add_argument('--overwrite',
86
+ default=False,
87
+ action='store_true',
88
+ help='overwrite the latest checkpoint')
89
+
90
+ train.add_argument('--distributed',
91
+ action='store_true',
92
+ help='Use distributed training')
93
+ train.add_argument('--dist_backend',
94
+ default='nccl',
95
+ type=str,
96
+ help='distributed backend')
97
+ train.add_argument('--dist_url',
98
+ default='tcp://127.0.0.1:10356',
99
+ type=str,
100
+ help='url used to set up distributed training')
101
+ train.add_argument('--world_size',
102
+ default=1,
103
+ type=int,
104
+ help='number of nodes for distributed training')
105
+ train.add_argument("--local_rank", default=0, type=int)
106
+ train.add_argument('--rank',
107
+ default=0,
108
+ type=int,
109
+ help='node rank for distributed training')
110
+ train.add_argument(
111
+ '--multiprocessing_distributed',
112
+ action='store_true',
113
+ help='Use multi-processing distributed training to launch '
114
+ 'N processes per node, which has N GPUs. This is the '
115
+ 'fastest way to use PyTorch for either single node or '
116
+ 'multi node data parallel training')
117
+
118
+ misc = self.parser.add_argument_group('Misc Options')
119
+ misc.add_argument('--misc',
120
+ help="Modify config options using the command-line",
121
+ default=None,
122
+ nargs=argparse.REMAINDER)
123
+ return
124
+
125
+ def parse_args(self):
126
+ """Parse input arguments."""
127
+ self.args = self.parser.parse_args()
128
+ self.save_dump()
129
+ return self.args
130
+
131
+ def save_dump(self):
132
+ """Store all argument values to a json file.
133
+ The default location is logs/expname/args.json.
134
+ """
135
+ pass