Spaces:
Build error
Build error
Upload 5 files
Browse files- lib/pymaf/core/base_trainer.py +107 -0
- lib/pymaf/core/cfgs.py +100 -0
- lib/pymaf/core/fits_dict.py +133 -0
- lib/pymaf/core/path_config.py +23 -0
- lib/pymaf/core/train_options.py +135 -0
lib/pymaf/core/base_trainer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py
|
2 |
+
from __future__ import division
|
3 |
+
import logging
|
4 |
+
from utils import CheckpointSaver
|
5 |
+
from tensorboardX import SummaryWriter
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
tqdm.monitor_interval = 0
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class BaseTrainer(object):
|
17 |
+
"""Base class for Trainer objects.
|
18 |
+
Takes care of checkpointing/logging/resuming training.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, options):
|
22 |
+
self.options = options
|
23 |
+
if options.multiprocessing_distributed:
|
24 |
+
self.device = torch.device('cuda', options.gpu)
|
25 |
+
else:
|
26 |
+
self.device = torch.device(
|
27 |
+
'cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
+
# override this function to define your model, optimizers etc.
|
29 |
+
self.saver = CheckpointSaver(save_dir=options.checkpoint_dir,
|
30 |
+
overwrite=options.overwrite)
|
31 |
+
if options.rank == 0:
|
32 |
+
self.summary_writer = SummaryWriter(self.options.summary_dir)
|
33 |
+
self.init_fn()
|
34 |
+
|
35 |
+
self.checkpoint = None
|
36 |
+
if options.resume and self.saver.exists_checkpoint():
|
37 |
+
self.checkpoint = self.saver.load_checkpoint(
|
38 |
+
self.models_dict, self.optimizers_dict)
|
39 |
+
|
40 |
+
if self.checkpoint is None:
|
41 |
+
self.epoch_count = 0
|
42 |
+
self.step_count = 0
|
43 |
+
else:
|
44 |
+
self.epoch_count = self.checkpoint['epoch']
|
45 |
+
self.step_count = self.checkpoint['total_step_count']
|
46 |
+
|
47 |
+
if self.checkpoint is not None:
|
48 |
+
self.checkpoint_batch_idx = self.checkpoint['batch_idx']
|
49 |
+
else:
|
50 |
+
self.checkpoint_batch_idx = 0
|
51 |
+
|
52 |
+
self.best_performance = float('inf')
|
53 |
+
|
54 |
+
def load_pretrained(self, checkpoint_file=None):
|
55 |
+
"""Load a pretrained checkpoint.
|
56 |
+
This is different from resuming training using --resume.
|
57 |
+
"""
|
58 |
+
if checkpoint_file is not None:
|
59 |
+
checkpoint = torch.load(checkpoint_file)
|
60 |
+
for model in self.models_dict:
|
61 |
+
if model in checkpoint:
|
62 |
+
self.models_dict[model].load_state_dict(checkpoint[model],
|
63 |
+
strict=True)
|
64 |
+
print(f'Checkpoint {model} loaded')
|
65 |
+
|
66 |
+
def move_dict_to_device(self, dict, device, tensor2float=False):
|
67 |
+
for k, v in dict.items():
|
68 |
+
if isinstance(v, torch.Tensor):
|
69 |
+
if tensor2float:
|
70 |
+
dict[k] = v.float().to(device)
|
71 |
+
else:
|
72 |
+
dict[k] = v.to(device)
|
73 |
+
|
74 |
+
# The following methods (with the possible exception of test) have to be implemented in the derived classes
|
75 |
+
def train(self, epoch):
|
76 |
+
raise NotImplementedError('You need to provide an train method')
|
77 |
+
|
78 |
+
def init_fn(self):
|
79 |
+
raise NotImplementedError('You need to provide an _init_fn method')
|
80 |
+
|
81 |
+
def train_step(self, input_batch):
|
82 |
+
raise NotImplementedError('You need to provide a _train_step method')
|
83 |
+
|
84 |
+
def train_summaries(self, input_batch):
|
85 |
+
raise NotImplementedError(
|
86 |
+
'You need to provide a _train_summaries method')
|
87 |
+
|
88 |
+
def visualize(self, input_batch):
|
89 |
+
raise NotImplementedError('You need to provide a visualize method')
|
90 |
+
|
91 |
+
def validate(self):
|
92 |
+
pass
|
93 |
+
|
94 |
+
def test(self):
|
95 |
+
pass
|
96 |
+
|
97 |
+
def evaluate(self):
|
98 |
+
pass
|
99 |
+
|
100 |
+
def fit(self):
|
101 |
+
# Run training for num_epochs epochs
|
102 |
+
for epoch in tqdm(range(self.epoch_count, self.options.num_epochs),
|
103 |
+
total=self.options.num_epochs,
|
104 |
+
initial=self.epoch_count):
|
105 |
+
self.epoch_count = epoch
|
106 |
+
self.train(epoch)
|
107 |
+
return
|
lib/pymaf/core/cfgs.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import os
|
18 |
+
import json
|
19 |
+
from yacs.config import CfgNode as CN
|
20 |
+
|
21 |
+
# Configuration variables
|
22 |
+
cfg = CN(new_allowed=True)
|
23 |
+
|
24 |
+
cfg.OUTPUT_DIR = 'results'
|
25 |
+
cfg.DEVICE = 'cuda'
|
26 |
+
cfg.DEBUG = False
|
27 |
+
cfg.LOGDIR = ''
|
28 |
+
cfg.VAL_VIS_BATCH_FREQ = 200
|
29 |
+
cfg.TRAIN_VIS_ITER_FERQ = 1000
|
30 |
+
cfg.SEED_VALUE = -1
|
31 |
+
|
32 |
+
cfg.TRAIN = CN(new_allowed=True)
|
33 |
+
|
34 |
+
cfg.LOSS = CN(new_allowed=True)
|
35 |
+
cfg.LOSS.KP_2D_W = 300.0
|
36 |
+
cfg.LOSS.KP_3D_W = 300.0
|
37 |
+
cfg.LOSS.SHAPE_W = 0.06
|
38 |
+
cfg.LOSS.POSE_W = 60.0
|
39 |
+
cfg.LOSS.VERT_W = 0.0
|
40 |
+
|
41 |
+
# Loss weights for dense correspondences
|
42 |
+
cfg.LOSS.INDEX_WEIGHTS = 2.0
|
43 |
+
# Loss weights for surface parts. (24 Parts)
|
44 |
+
cfg.LOSS.PART_WEIGHTS = 0.3
|
45 |
+
# Loss weights for UV regression.
|
46 |
+
cfg.LOSS.POINT_REGRESSION_WEIGHTS = 0.5
|
47 |
+
|
48 |
+
cfg.MODEL = CN(new_allowed=True)
|
49 |
+
|
50 |
+
cfg.MODEL.PyMAF = CN(new_allowed=True)
|
51 |
+
|
52 |
+
# switch
|
53 |
+
cfg.TRAIN.VAL_LOOP = True
|
54 |
+
|
55 |
+
cfg.TEST = CN(new_allowed=True)
|
56 |
+
|
57 |
+
|
58 |
+
def get_cfg_defaults():
|
59 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
60 |
+
# Return a clone so that the defaults will not be altered
|
61 |
+
# This is for the "local variable" use pattern
|
62 |
+
# return cfg.clone()
|
63 |
+
return cfg
|
64 |
+
|
65 |
+
|
66 |
+
def update_cfg(cfg_file):
|
67 |
+
# cfg = get_cfg_defaults()
|
68 |
+
cfg.merge_from_file(cfg_file)
|
69 |
+
# return cfg.clone()
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
|
73 |
+
def parse_args(args):
|
74 |
+
cfg_file = args.cfg_file
|
75 |
+
if args.cfg_file is not None:
|
76 |
+
cfg = update_cfg(args.cfg_file)
|
77 |
+
else:
|
78 |
+
cfg = get_cfg_defaults()
|
79 |
+
|
80 |
+
# if args.misc is not None:
|
81 |
+
# cfg.merge_from_list(args.misc)
|
82 |
+
|
83 |
+
return cfg
|
84 |
+
|
85 |
+
|
86 |
+
def parse_args_extend(args):
|
87 |
+
if args.resume:
|
88 |
+
if not os.path.exists(args.log_dir):
|
89 |
+
raise ValueError(
|
90 |
+
'Experiment are set to resume mode, but log directory does not exist.'
|
91 |
+
)
|
92 |
+
|
93 |
+
# load log's cfg
|
94 |
+
cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
|
95 |
+
cfg = update_cfg(cfg_file)
|
96 |
+
|
97 |
+
if args.misc is not None:
|
98 |
+
cfg.merge_from_list(args.misc)
|
99 |
+
else:
|
100 |
+
parse_args(args)
|
lib/pymaf/core/fits_dict.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis
|
9 |
+
|
10 |
+
from core import path_config, constants
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class FitsDict():
|
18 |
+
""" Dictionary keeping track of the best fit per image in the training set """
|
19 |
+
|
20 |
+
def __init__(self, options, train_dataset):
|
21 |
+
self.options = options
|
22 |
+
self.train_dataset = train_dataset
|
23 |
+
self.fits_dict = {}
|
24 |
+
self.valid_fit_state = {}
|
25 |
+
# array used to flip SMPL pose parameters
|
26 |
+
self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM,
|
27 |
+
dtype=torch.int64)
|
28 |
+
# Load dictionary state
|
29 |
+
for ds_name, ds in train_dataset.dataset_dict.items():
|
30 |
+
if ds_name in ['h36m']:
|
31 |
+
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
|
32 |
+
ds_name + '.npy')
|
33 |
+
self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file))
|
34 |
+
self.valid_fit_state[ds_name] = torch.ones(len(
|
35 |
+
self.fits_dict[ds_name]),
|
36 |
+
dtype=torch.uint8)
|
37 |
+
else:
|
38 |
+
dict_file = os.path.join(path_config.FINAL_FITS_DIR,
|
39 |
+
ds_name + '.npz')
|
40 |
+
fits_dict = np.load(dict_file)
|
41 |
+
opt_pose = torch.from_numpy(fits_dict['pose'])
|
42 |
+
opt_betas = torch.from_numpy(fits_dict['betas'])
|
43 |
+
opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to(
|
44 |
+
torch.uint8)
|
45 |
+
self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas],
|
46 |
+
dim=1)
|
47 |
+
self.valid_fit_state[ds_name] = opt_valid_fit
|
48 |
+
|
49 |
+
if not options.single_dataset:
|
50 |
+
for ds in train_dataset.datasets:
|
51 |
+
if ds.dataset not in ['h36m']:
|
52 |
+
ds.pose = self.fits_dict[ds.dataset][:, :72].numpy()
|
53 |
+
ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy()
|
54 |
+
ds.has_smpl = self.valid_fit_state[ds.dataset].numpy()
|
55 |
+
|
56 |
+
def save(self):
|
57 |
+
""" Save dictionary state to disk """
|
58 |
+
for ds_name in self.train_dataset.dataset_dict.keys():
|
59 |
+
dict_file = os.path.join(self.options.checkpoint_dir,
|
60 |
+
ds_name + '_fits.npy')
|
61 |
+
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
|
62 |
+
|
63 |
+
def __getitem__(self, x):
|
64 |
+
""" Retrieve dictionary entries """
|
65 |
+
dataset_name, ind, rot, is_flipped = x
|
66 |
+
batch_size = len(dataset_name)
|
67 |
+
pose = torch.zeros((batch_size, 72))
|
68 |
+
betas = torch.zeros((batch_size, 10))
|
69 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
70 |
+
params = self.fits_dict[ds][i]
|
71 |
+
pose[n, :] = params[:72]
|
72 |
+
betas[n, :] = params[72:]
|
73 |
+
pose = pose.clone()
|
74 |
+
# Apply flipping and rotation
|
75 |
+
pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped)
|
76 |
+
betas = betas.clone()
|
77 |
+
return pose, betas
|
78 |
+
|
79 |
+
def get_vaild_state(self, dataset_name, ind):
|
80 |
+
batch_size = len(dataset_name)
|
81 |
+
valid_fit = torch.zeros(batch_size, dtype=torch.uint8)
|
82 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
83 |
+
valid_fit[n] = self.valid_fit_state[ds][i]
|
84 |
+
valid_fit = valid_fit.clone()
|
85 |
+
return valid_fit
|
86 |
+
|
87 |
+
def __setitem__(self, x, val):
|
88 |
+
""" Update dictionary entries """
|
89 |
+
dataset_name, ind, rot, is_flipped, update = x
|
90 |
+
pose, betas = val
|
91 |
+
batch_size = len(dataset_name)
|
92 |
+
# Undo flipping and rotation
|
93 |
+
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot)
|
94 |
+
params = torch.cat((pose, betas), dim=-1).cpu()
|
95 |
+
for ds, i, n in zip(dataset_name, ind, range(batch_size)):
|
96 |
+
if update[n]:
|
97 |
+
self.fits_dict[ds][i] = params[n]
|
98 |
+
|
99 |
+
def flip_pose(self, pose, is_flipped):
|
100 |
+
"""flip SMPL pose parameters"""
|
101 |
+
is_flipped = is_flipped.byte()
|
102 |
+
pose_f = pose.clone()
|
103 |
+
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
|
104 |
+
# we also negate the second and the third dimension of the axis-angle representation
|
105 |
+
pose_f[is_flipped, 1::3] *= -1
|
106 |
+
pose_f[is_flipped, 2::3] *= -1
|
107 |
+
return pose_f
|
108 |
+
|
109 |
+
def rotate_pose(self, pose, rot):
|
110 |
+
"""Rotate SMPL pose parameters by rot degrees"""
|
111 |
+
pose = pose.clone()
|
112 |
+
cos = torch.cos(-np.pi * rot / 180.)
|
113 |
+
sin = torch.sin(-np.pi * rot / 180.)
|
114 |
+
zeros = torch.zeros_like(cos)
|
115 |
+
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
|
116 |
+
r3[:, 0, -1] = 1
|
117 |
+
R = torch.cat([
|
118 |
+
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
|
119 |
+
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
|
120 |
+
],
|
121 |
+
dim=1)
|
122 |
+
global_pose = pose[:, :3]
|
123 |
+
global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose)
|
124 |
+
global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3]
|
125 |
+
global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3)
|
126 |
+
global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3
|
127 |
+
global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy()
|
128 |
+
global_pose_np = np.zeros((global_pose.shape[0], 3))
|
129 |
+
for i in range(global_pose.shape[0]):
|
130 |
+
aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
|
131 |
+
global_pose_np[i, :] = aa.squeeze()
|
132 |
+
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
|
133 |
+
return pose
|
lib/pymaf/core/path_config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/path_config.py
|
3 |
+
path configuration
|
4 |
+
This file contains definitions of useful data stuctures and the paths
|
5 |
+
for the datasets and data files necessary to run the code.
|
6 |
+
Things you need to change: *_ROOT that indicate the path to each dataset
|
7 |
+
"""
|
8 |
+
import os
|
9 |
+
from huggingface_hub import hf_hub_url, cached_download
|
10 |
+
|
11 |
+
# pymaf
|
12 |
+
pymaf_data_dir = hf_hub_url('Yuliang/PyMAF', '')
|
13 |
+
smpl_data_dir = hf_hub_url('Yuliang/SMPL', '')
|
14 |
+
SMPL_MODEL_DIR = os.path.join(smpl_data_dir, 'models/smpl')
|
15 |
+
|
16 |
+
SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
|
17 |
+
MESH_DOWNSAMPLEING = cached_download(os.path.join(pymaf_data_dir, 'mesh_downsampling.npz'), use_auth_token=os.environ['ICON'])
|
18 |
+
CUBE_PARTS_FILE = cached_download(os.path.join(pymaf_data_dir, 'cube_parts.npy'), use_auth_token=os.environ['ICON'])
|
19 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_extra.npy'), use_auth_token=os.environ['ICON'])
|
20 |
+
JOINT_REGRESSOR_H36M = cached_download(os.path.join(pymaf_data_dir, 'J_regressor_h36m.npy'), use_auth_token=os.environ['ICON'])
|
21 |
+
VERTEX_TEXTURE_FILE = cached_download(os.path.join(pymaf_data_dir, 'vertex_texture.npy'), use_auth_token=os.environ['ICON'])
|
22 |
+
SMPL_MEAN_PARAMS = cached_download(os.path.join(pymaf_data_dir, 'smpl_mean_params.npz'), use_auth_token=os.environ['ICON'])
|
23 |
+
CHECKPOINT_FILE = cached_download(os.path.join(pymaf_data_dir, 'pretrained_model/PyMAF_model_checkpoint.pt'), use_auth_token=os.environ['ICON'])
|
lib/pymaf/core/train_options.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
class TrainOptions():
|
5 |
+
def __init__(self):
|
6 |
+
self.parser = argparse.ArgumentParser()
|
7 |
+
|
8 |
+
gen = self.parser.add_argument_group('General')
|
9 |
+
gen.add_argument(
|
10 |
+
'--resume',
|
11 |
+
dest='resume',
|
12 |
+
default=False,
|
13 |
+
action='store_true',
|
14 |
+
help='Resume from checkpoint (Use latest checkpoint by default')
|
15 |
+
|
16 |
+
io = self.parser.add_argument_group('io')
|
17 |
+
io.add_argument('--log_dir',
|
18 |
+
default='logs',
|
19 |
+
help='Directory to store logs')
|
20 |
+
io.add_argument(
|
21 |
+
'--pretrained_checkpoint',
|
22 |
+
default=None,
|
23 |
+
help='Load a pretrained checkpoint at the beginning training')
|
24 |
+
|
25 |
+
train = self.parser.add_argument_group('Training Options')
|
26 |
+
train.add_argument('--num_epochs',
|
27 |
+
type=int,
|
28 |
+
default=200,
|
29 |
+
help='Total number of training epochs')
|
30 |
+
train.add_argument('--regressor',
|
31 |
+
type=str,
|
32 |
+
choices=['hmr', 'pymaf_net'],
|
33 |
+
default='pymaf_net',
|
34 |
+
help='Name of the SMPL regressor.')
|
35 |
+
train.add_argument('--cfg_file',
|
36 |
+
type=str,
|
37 |
+
default='./configs/pymaf_config.yaml',
|
38 |
+
help='config file path for PyMAF.')
|
39 |
+
train.add_argument(
|
40 |
+
'--img_res',
|
41 |
+
type=int,
|
42 |
+
default=224,
|
43 |
+
help='Rescale bounding boxes to size [img_res, img_res] before feeding them in the network'
|
44 |
+
)
|
45 |
+
train.add_argument(
|
46 |
+
'--rot_factor',
|
47 |
+
type=float,
|
48 |
+
default=30,
|
49 |
+
help='Random rotation in the range [-rot_factor, rot_factor]')
|
50 |
+
train.add_argument(
|
51 |
+
'--noise_factor',
|
52 |
+
type=float,
|
53 |
+
default=0.4,
|
54 |
+
help='Randomly multiply pixel values with factor in the range [1-noise_factor, 1+noise_factor]'
|
55 |
+
)
|
56 |
+
train.add_argument(
|
57 |
+
'--scale_factor',
|
58 |
+
type=float,
|
59 |
+
default=0.25,
|
60 |
+
help='Rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]'
|
61 |
+
)
|
62 |
+
train.add_argument(
|
63 |
+
'--openpose_train_weight',
|
64 |
+
default=0.,
|
65 |
+
help='Weight for OpenPose keypoints during training')
|
66 |
+
train.add_argument('--gt_train_weight',
|
67 |
+
default=1.,
|
68 |
+
help='Weight for GT keypoints during training')
|
69 |
+
train.add_argument('--eval_dataset',
|
70 |
+
type=str,
|
71 |
+
default='h36m-p2-mosh',
|
72 |
+
help='Name of the evaluation dataset.')
|
73 |
+
train.add_argument('--single_dataset',
|
74 |
+
default=False,
|
75 |
+
action='store_true',
|
76 |
+
help='Use a single dataset')
|
77 |
+
train.add_argument('--single_dataname',
|
78 |
+
type=str,
|
79 |
+
default='h36m',
|
80 |
+
help='Name of the single dataset.')
|
81 |
+
train.add_argument('--eval_pve',
|
82 |
+
default=False,
|
83 |
+
action='store_true',
|
84 |
+
help='evaluate PVE')
|
85 |
+
train.add_argument('--overwrite',
|
86 |
+
default=False,
|
87 |
+
action='store_true',
|
88 |
+
help='overwrite the latest checkpoint')
|
89 |
+
|
90 |
+
train.add_argument('--distributed',
|
91 |
+
action='store_true',
|
92 |
+
help='Use distributed training')
|
93 |
+
train.add_argument('--dist_backend',
|
94 |
+
default='nccl',
|
95 |
+
type=str,
|
96 |
+
help='distributed backend')
|
97 |
+
train.add_argument('--dist_url',
|
98 |
+
default='tcp://127.0.0.1:10356',
|
99 |
+
type=str,
|
100 |
+
help='url used to set up distributed training')
|
101 |
+
train.add_argument('--world_size',
|
102 |
+
default=1,
|
103 |
+
type=int,
|
104 |
+
help='number of nodes for distributed training')
|
105 |
+
train.add_argument("--local_rank", default=0, type=int)
|
106 |
+
train.add_argument('--rank',
|
107 |
+
default=0,
|
108 |
+
type=int,
|
109 |
+
help='node rank for distributed training')
|
110 |
+
train.add_argument(
|
111 |
+
'--multiprocessing_distributed',
|
112 |
+
action='store_true',
|
113 |
+
help='Use multi-processing distributed training to launch '
|
114 |
+
'N processes per node, which has N GPUs. This is the '
|
115 |
+
'fastest way to use PyTorch for either single node or '
|
116 |
+
'multi node data parallel training')
|
117 |
+
|
118 |
+
misc = self.parser.add_argument_group('Misc Options')
|
119 |
+
misc.add_argument('--misc',
|
120 |
+
help="Modify config options using the command-line",
|
121 |
+
default=None,
|
122 |
+
nargs=argparse.REMAINDER)
|
123 |
+
return
|
124 |
+
|
125 |
+
def parse_args(self):
|
126 |
+
"""Parse input arguments."""
|
127 |
+
self.args = self.parser.parse_args()
|
128 |
+
self.save_dump()
|
129 |
+
return self.args
|
130 |
+
|
131 |
+
def save_dump(self):
|
132 |
+
"""Store all argument values to a json file.
|
133 |
+
The default location is logs/expname/args.json.
|
134 |
+
"""
|
135 |
+
pass
|