YoonaAI commited on
Commit
3fb8682
·
1 Parent(s): c31e128

Upload 5 files

Browse files
lib/pymaf/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .hmr import hmr
2
+ from .pymaf_net import pymaf_net
3
+ from .smpl import SMPL
lib/pymaf/models/hmr.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/hmr.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models.resnet as resnet
6
+ import numpy as np
7
+ import math
8
+ from lib.pymaf.utils.geometry import rot6d_to_rotmat
9
+
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ BN_MOMENTUM = 0.1
15
+
16
+
17
+ class Bottleneck(nn.Module):
18
+ """ Redefinition of Bottleneck residual block
19
+ Adapted from the official PyTorch implementation
20
+ """
21
+ expansion = 4
22
+
23
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
24
+ super().__init__()
25
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
26
+ self.bn1 = nn.BatchNorm2d(planes)
27
+ self.conv2 = nn.Conv2d(planes,
28
+ planes,
29
+ kernel_size=3,
30
+ stride=stride,
31
+ padding=1,
32
+ bias=False)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
35
+ self.bn3 = nn.BatchNorm2d(planes * 4)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.downsample = downsample
38
+ self.stride = stride
39
+
40
+ def forward(self, x):
41
+ residual = x
42
+
43
+ out = self.conv1(x)
44
+ out = self.bn1(out)
45
+ out = self.relu(out)
46
+
47
+ out = self.conv2(out)
48
+ out = self.bn2(out)
49
+ out = self.relu(out)
50
+
51
+ out = self.conv3(out)
52
+ out = self.bn3(out)
53
+
54
+ if self.downsample is not None:
55
+ residual = self.downsample(x)
56
+
57
+ out += residual
58
+ out = self.relu(out)
59
+
60
+ return out
61
+
62
+
63
+ class ResNet_Backbone(nn.Module):
64
+ """ Feature Extrator with ResNet backbone
65
+ """
66
+
67
+ def __init__(self, model='res50', pretrained=True):
68
+ if model == 'res50':
69
+ block, layers = Bottleneck, [3, 4, 6, 3]
70
+ else:
71
+ pass # TODO
72
+
73
+ self.inplanes = 64
74
+ super().__init__()
75
+ npose = 24 * 6
76
+ self.conv1 = nn.Conv2d(3,
77
+ 64,
78
+ kernel_size=7,
79
+ stride=2,
80
+ padding=3,
81
+ bias=False)
82
+ self.bn1 = nn.BatchNorm2d(64)
83
+ self.relu = nn.ReLU(inplace=True)
84
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
85
+ self.layer1 = self._make_layer(block, 64, layers[0])
86
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
87
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
88
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
89
+ self.avgpool = nn.AvgPool2d(7, stride=1)
90
+
91
+ if pretrained:
92
+ resnet_imagenet = resnet.resnet50(pretrained=True)
93
+ self.load_state_dict(resnet_imagenet.state_dict(), strict=False)
94
+ logger.info('loaded resnet50 imagenet pretrained model')
95
+
96
+ def _make_layer(self, block, planes, blocks, stride=1):
97
+ downsample = None
98
+ if stride != 1 or self.inplanes != planes * block.expansion:
99
+ downsample = nn.Sequential(
100
+ nn.Conv2d(self.inplanes,
101
+ planes * block.expansion,
102
+ kernel_size=1,
103
+ stride=stride,
104
+ bias=False),
105
+ nn.BatchNorm2d(planes * block.expansion),
106
+ )
107
+
108
+ layers = []
109
+ layers.append(block(self.inplanes, planes, stride, downsample))
110
+ self.inplanes = planes * block.expansion
111
+ for i in range(1, blocks):
112
+ layers.append(block(self.inplanes, planes))
113
+
114
+ return nn.Sequential(*layers)
115
+
116
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
117
+ assert num_layers == len(num_filters), \
118
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
119
+ assert num_layers == len(num_kernels), \
120
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
121
+
122
+ def _get_deconv_cfg(deconv_kernel, index):
123
+ if deconv_kernel == 4:
124
+ padding = 1
125
+ output_padding = 0
126
+ elif deconv_kernel == 3:
127
+ padding = 1
128
+ output_padding = 1
129
+ elif deconv_kernel == 2:
130
+ padding = 0
131
+ output_padding = 0
132
+
133
+ return deconv_kernel, padding, output_padding
134
+
135
+ layers = []
136
+ for i in range(num_layers):
137
+ kernel, padding, output_padding = _get_deconv_cfg(
138
+ num_kernels[i], i)
139
+
140
+ planes = num_filters[i]
141
+ layers.append(
142
+ nn.ConvTranspose2d(in_channels=self.inplanes,
143
+ out_channels=planes,
144
+ kernel_size=kernel,
145
+ stride=2,
146
+ padding=padding,
147
+ output_padding=output_padding,
148
+ bias=self.deconv_with_bias))
149
+ layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
150
+ layers.append(nn.ReLU(inplace=True))
151
+ self.inplanes = planes
152
+
153
+ return nn.Sequential(*layers)
154
+
155
+ def forward(self, x):
156
+
157
+ batch_size = x.shape[0]
158
+
159
+ x = self.conv1(x)
160
+ x = self.bn1(x)
161
+ x = self.relu(x)
162
+ x = self.maxpool(x)
163
+
164
+ x1 = self.layer1(x)
165
+ x2 = self.layer2(x1)
166
+ x3 = self.layer3(x2)
167
+ x4 = self.layer4(x3)
168
+
169
+ xf = self.avgpool(x4)
170
+ xf = xf.view(xf.size(0), -1)
171
+
172
+ x_featmap = x4
173
+
174
+ return x_featmap, xf
175
+
176
+
177
+ class HMR(nn.Module):
178
+ """ SMPL Iterative Regressor with ResNet50 backbone
179
+ """
180
+
181
+ def __init__(self, block, layers, smpl_mean_params):
182
+ self.inplanes = 64
183
+ super().__init__()
184
+ npose = 24 * 6
185
+ self.conv1 = nn.Conv2d(3,
186
+ 64,
187
+ kernel_size=7,
188
+ stride=2,
189
+ padding=3,
190
+ bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.relu = nn.ReLU(inplace=True)
193
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.avgpool = nn.AvgPool2d(7, stride=1)
199
+ self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024)
200
+ self.drop1 = nn.Dropout()
201
+ self.fc2 = nn.Linear(1024, 1024)
202
+ self.drop2 = nn.Dropout()
203
+ self.decpose = nn.Linear(1024, npose)
204
+ self.decshape = nn.Linear(1024, 10)
205
+ self.deccam = nn.Linear(1024, 3)
206
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
207
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
208
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
209
+
210
+ for m in self.modules():
211
+ if isinstance(m, nn.Conv2d):
212
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
213
+ m.weight.data.normal_(0, math.sqrt(2. / n))
214
+ elif isinstance(m, nn.BatchNorm2d):
215
+ m.weight.data.fill_(1)
216
+ m.bias.data.zero_()
217
+
218
+ mean_params = np.load(smpl_mean_params)
219
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
220
+ init_shape = torch.from_numpy(
221
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
222
+ init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
223
+ self.register_buffer('init_pose', init_pose)
224
+ self.register_buffer('init_shape', init_shape)
225
+ self.register_buffer('init_cam', init_cam)
226
+
227
+ def _make_layer(self, block, planes, blocks, stride=1):
228
+ downsample = None
229
+ if stride != 1 or self.inplanes != planes * block.expansion:
230
+ downsample = nn.Sequential(
231
+ nn.Conv2d(self.inplanes,
232
+ planes * block.expansion,
233
+ kernel_size=1,
234
+ stride=stride,
235
+ bias=False),
236
+ nn.BatchNorm2d(planes * block.expansion),
237
+ )
238
+
239
+ layers = []
240
+ layers.append(block(self.inplanes, planes, stride, downsample))
241
+ self.inplanes = planes * block.expansion
242
+ for i in range(1, blocks):
243
+ layers.append(block(self.inplanes, planes))
244
+
245
+ return nn.Sequential(*layers)
246
+
247
+ def forward(self,
248
+ x,
249
+ init_pose=None,
250
+ init_shape=None,
251
+ init_cam=None,
252
+ n_iter=3):
253
+
254
+ batch_size = x.shape[0]
255
+
256
+ if init_pose is None:
257
+ init_pose = self.init_pose.expand(batch_size, -1)
258
+ if init_shape is None:
259
+ init_shape = self.init_shape.expand(batch_size, -1)
260
+ if init_cam is None:
261
+ init_cam = self.init_cam.expand(batch_size, -1)
262
+
263
+ x = self.conv1(x)
264
+ x = self.bn1(x)
265
+ x = self.relu(x)
266
+ x = self.maxpool(x)
267
+
268
+ x1 = self.layer1(x)
269
+ x2 = self.layer2(x1)
270
+ x3 = self.layer3(x2)
271
+ x4 = self.layer4(x3)
272
+
273
+ xf = self.avgpool(x4)
274
+ xf = xf.view(xf.size(0), -1)
275
+
276
+ pred_pose = init_pose
277
+ pred_shape = init_shape
278
+ pred_cam = init_cam
279
+ for i in range(n_iter):
280
+ xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
281
+ xc = self.fc1(xc)
282
+ xc = self.drop1(xc)
283
+ xc = self.fc2(xc)
284
+ xc = self.drop2(xc)
285
+ pred_pose = self.decpose(xc) + pred_pose
286
+ pred_shape = self.decshape(xc) + pred_shape
287
+ pred_cam = self.deccam(xc) + pred_cam
288
+
289
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
290
+
291
+ return pred_rotmat, pred_shape, pred_cam
292
+
293
+
294
+ def hmr(smpl_mean_params, pretrained=True, **kwargs):
295
+ """ Constructs an HMR model with ResNet50 backbone.
296
+ Args:
297
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
298
+ """
299
+ model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
300
+ if pretrained:
301
+ resnet_imagenet = resnet.resnet50(pretrained=True)
302
+ model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
303
+ return model
lib/pymaf/models/maf_extractor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is borrowed and extended from https://github.com/shunsukesaito/PIFu/blob/master/lib/model/SurfaceClassifier.py
2
+
3
+ from packaging import version
4
+ import torch
5
+ import scipy
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from lib.common.config import cfg
11
+ from lib.pymaf.utils.geometry import projection
12
+ from lib.pymaf.core.path_config import MESH_DOWNSAMPLEING
13
+
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class MAF_Extractor(nn.Module):
20
+ ''' Mesh-aligned Feature Extrator
21
+
22
+ As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices.
23
+ The features extrated from spatial feature maps will go through a MLP for dimension reduction.
24
+ '''
25
+
26
+ def __init__(self, device=torch.device('cuda')):
27
+ super().__init__()
28
+
29
+ self.device = device
30
+ self.filters = []
31
+ self.num_views = 1
32
+ filter_channels = cfg.MODEL.PyMAF.MLP_DIM
33
+ self.last_op = nn.ReLU(True)
34
+
35
+ for l in range(0, len(filter_channels) - 1):
36
+ if 0 != l:
37
+ self.filters.append(
38
+ nn.Conv1d(filter_channels[l] + filter_channels[0],
39
+ filter_channels[l + 1], 1))
40
+ else:
41
+ self.filters.append(
42
+ nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
43
+
44
+ self.add_module("conv%d" % l, self.filters[l])
45
+
46
+ self.im_feat = None
47
+ self.cam = None
48
+
49
+ # downsample SMPL mesh and assign part labels
50
+ # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz
51
+ smpl_mesh_graph = np.load(MESH_DOWNSAMPLEING,
52
+ allow_pickle=True,
53
+ encoding='latin1')
54
+
55
+ A = smpl_mesh_graph['A']
56
+ U = smpl_mesh_graph['U']
57
+ D = smpl_mesh_graph['D'] # shape: (2,)
58
+
59
+ # downsampling
60
+ ptD = []
61
+ for i in range(len(D)):
62
+ d = scipy.sparse.coo_matrix(D[i])
63
+ i = torch.LongTensor(np.array([d.row, d.col]))
64
+ v = torch.FloatTensor(d.data)
65
+ ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
66
+
67
+ # downsampling mapping from 6890 points to 431 points
68
+ # ptD[0].to_dense() - Size: [1723, 6890]
69
+ # ptD[1].to_dense() - Size: [431. 1723]
70
+ Dmap = torch.matmul(ptD[1].to_dense(),
71
+ ptD[0].to_dense()) # 6890 -> 431
72
+ self.register_buffer('Dmap', Dmap)
73
+
74
+ def reduce_dim(self, feature):
75
+ '''
76
+ Dimension reduction by multi-layer perceptrons
77
+ :param feature: list of [B, C_s, N] point-wise features before dimension reduction
78
+ :return: [B, C_p x N] concatantion of point-wise features after dimension reduction
79
+ '''
80
+ y = feature
81
+ tmpy = feature
82
+ for i, f in enumerate(self.filters):
83
+ y = self._modules['conv' +
84
+ str(i)](y if i == 0 else torch.cat([y, tmpy], 1))
85
+ if i != len(self.filters) - 1:
86
+ y = F.leaky_relu(y)
87
+ if self.num_views > 1 and i == len(self.filters) // 2:
88
+ y = y.view(-1, self.num_views, y.shape[1],
89
+ y.shape[2]).mean(dim=1)
90
+ tmpy = feature.view(-1, self.num_views, feature.shape[1],
91
+ feature.shape[2]).mean(dim=1)
92
+
93
+ y = self.last_op(y)
94
+
95
+ y = y.view(y.shape[0], -1)
96
+ return y
97
+
98
+ def sampling(self, points, im_feat=None, z_feat=None):
99
+ '''
100
+ Given 2D points, sample the point-wise features for each point,
101
+ the dimension of point-wise features will be reduced from C_s to C_p by MLP.
102
+ Image features should be pre-computed before this call.
103
+ :param points: [B, N, 2] image coordinates of points
104
+ :im_feat: [B, C_s, H_s, W_s] spatial feature maps
105
+ :return: [B, C_p x N] concatantion of point-wise features after dimension reduction
106
+ '''
107
+ if im_feat is None:
108
+ im_feat = self.im_feat
109
+
110
+ batch_size = im_feat.shape[0]
111
+
112
+ if version.parse(torch.__version__) >= version.parse('1.3.0'):
113
+ # Default grid_sample behavior has changed to align_corners=False since 1.3.0.
114
+ point_feat = torch.nn.functional.grid_sample(
115
+ im_feat, points.unsqueeze(2), align_corners=True)[..., 0]
116
+ else:
117
+ point_feat = torch.nn.functional.grid_sample(
118
+ im_feat, points.unsqueeze(2))[..., 0]
119
+
120
+ mesh_align_feat = self.reduce_dim(point_feat)
121
+ return mesh_align_feat
122
+
123
+ def forward(self, p, s_feat=None, cam=None, **kwargs):
124
+ ''' Returns mesh-aligned features for the 3D mesh points.
125
+
126
+ Args:
127
+ p (tensor): [B, N_m, 3] mesh vertices
128
+ s_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps
129
+ cam (tensor): [B, 3] camera
130
+ Return:
131
+ mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features
132
+ '''
133
+ if cam is None:
134
+ cam = self.cam
135
+ p_proj_2d = projection(p, cam, retain_z=False)
136
+ mesh_align_feat = self.sampling(p_proj_2d, s_feat)
137
+ return mesh_align_feat
lib/pymaf/models/res_module.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code brought in part from https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py
2
+
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from collections import OrderedDict
11
+ import os
12
+ from lib.pymaf.core.cfgs import cfg
13
+
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ BN_MOMENTUM = 0.1
19
+
20
+
21
+ def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1):
22
+ """3x3 convolution with padding"""
23
+ return nn.Conv2d(in_planes * groups,
24
+ out_planes * groups,
25
+ kernel_size=3,
26
+ stride=stride,
27
+ padding=1,
28
+ bias=bias,
29
+ groups=groups)
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
36
+ super().__init__()
37
+ self.conv1 = conv3x3(inplanes, planes, stride, groups=groups)
38
+ self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
39
+ self.relu = nn.ReLU(inplace=True)
40
+ self.conv2 = conv3x3(planes, planes, groups=groups)
41
+ self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
42
+ self.downsample = downsample
43
+ self.stride = stride
44
+
45
+ def forward(self, x):
46
+ residual = x
47
+
48
+ out = self.conv1(x)
49
+ out = self.bn1(out)
50
+ out = self.relu(out)
51
+
52
+ out = self.conv2(out)
53
+ out = self.bn2(out)
54
+
55
+ if self.downsample is not None:
56
+ residual = self.downsample(x)
57
+
58
+ out += residual
59
+ out = self.relu(out)
60
+
61
+ return out
62
+
63
+
64
+ class Bottleneck(nn.Module):
65
+ expansion = 4
66
+
67
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
68
+ super().__init__()
69
+ self.conv1 = nn.Conv2d(inplanes * groups,
70
+ planes * groups,
71
+ kernel_size=1,
72
+ bias=False,
73
+ groups=groups)
74
+ self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
75
+ self.conv2 = nn.Conv2d(planes * groups,
76
+ planes * groups,
77
+ kernel_size=3,
78
+ stride=stride,
79
+ padding=1,
80
+ bias=False,
81
+ groups=groups)
82
+ self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM)
83
+ self.conv3 = nn.Conv2d(planes * groups,
84
+ planes * self.expansion * groups,
85
+ kernel_size=1,
86
+ bias=False,
87
+ groups=groups)
88
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups,
89
+ momentum=BN_MOMENTUM)
90
+ self.relu = nn.ReLU(inplace=True)
91
+ self.downsample = downsample
92
+ self.stride = stride
93
+
94
+ def forward(self, x):
95
+ residual = x
96
+
97
+ out = self.conv1(x)
98
+ out = self.bn1(out)
99
+ out = self.relu(out)
100
+
101
+ out = self.conv2(out)
102
+ out = self.bn2(out)
103
+ out = self.relu(out)
104
+
105
+ out = self.conv3(out)
106
+ out = self.bn3(out)
107
+
108
+ if self.downsample is not None:
109
+ residual = self.downsample(x)
110
+
111
+ out += residual
112
+ out = self.relu(out)
113
+
114
+ return out
115
+
116
+
117
+ resnet_spec = {
118
+ 18: (BasicBlock, [2, 2, 2, 2]),
119
+ 34: (BasicBlock, [3, 4, 6, 3]),
120
+ 50: (Bottleneck, [3, 4, 6, 3]),
121
+ 101: (Bottleneck, [3, 4, 23, 3]),
122
+ 152: (Bottleneck, [3, 8, 36, 3])
123
+ }
124
+
125
+
126
+ class IUV_predict_layer(nn.Module):
127
+ def __init__(self,
128
+ feat_dim=256,
129
+ final_cov_k=3,
130
+ part_out_dim=25,
131
+ with_uv=True):
132
+ super().__init__()
133
+
134
+ self.with_uv = with_uv
135
+ if self.with_uv:
136
+ self.predict_u = nn.Conv2d(in_channels=feat_dim,
137
+ out_channels=25,
138
+ kernel_size=final_cov_k,
139
+ stride=1,
140
+ padding=1 if final_cov_k == 3 else 0)
141
+
142
+ self.predict_v = nn.Conv2d(in_channels=feat_dim,
143
+ out_channels=25,
144
+ kernel_size=final_cov_k,
145
+ stride=1,
146
+ padding=1 if final_cov_k == 3 else 0)
147
+
148
+ self.predict_ann_index = nn.Conv2d(
149
+ in_channels=feat_dim,
150
+ out_channels=15,
151
+ kernel_size=final_cov_k,
152
+ stride=1,
153
+ padding=1 if final_cov_k == 3 else 0)
154
+
155
+ self.predict_uv_index = nn.Conv2d(in_channels=feat_dim,
156
+ out_channels=25,
157
+ kernel_size=final_cov_k,
158
+ stride=1,
159
+ padding=1 if final_cov_k == 3 else 0)
160
+
161
+ self.inplanes = feat_dim
162
+
163
+ def _make_layer(self, block, planes, blocks, stride=1):
164
+ downsample = None
165
+ if stride != 1 or self.inplanes != planes * block.expansion:
166
+ downsample = nn.Sequential(
167
+ nn.Conv2d(self.inplanes,
168
+ planes * block.expansion,
169
+ kernel_size=1,
170
+ stride=stride,
171
+ bias=False),
172
+ nn.BatchNorm2d(planes * block.expansion),
173
+ )
174
+
175
+ layers = []
176
+ layers.append(block(self.inplanes, planes, stride, downsample))
177
+ self.inplanes = planes * block.expansion
178
+ for i in range(1, blocks):
179
+ layers.append(block(self.inplanes, planes))
180
+
181
+ return nn.Sequential(*layers)
182
+
183
+ def forward(self, x):
184
+ return_dict = {}
185
+
186
+ predict_uv_index = self.predict_uv_index(x)
187
+ predict_ann_index = self.predict_ann_index(x)
188
+
189
+ return_dict['predict_uv_index'] = predict_uv_index
190
+ return_dict['predict_ann_index'] = predict_ann_index
191
+
192
+ if self.with_uv:
193
+ predict_u = self.predict_u(x)
194
+ predict_v = self.predict_v(x)
195
+ return_dict['predict_u'] = predict_u
196
+ return_dict['predict_v'] = predict_v
197
+ else:
198
+ return_dict['predict_u'] = None
199
+ return_dict['predict_v'] = None
200
+ # return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
201
+ # return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device)
202
+
203
+ return return_dict
204
+
205
+
206
+ class SmplResNet(nn.Module):
207
+ def __init__(self,
208
+ resnet_nums,
209
+ in_channels=3,
210
+ num_classes=229,
211
+ last_stride=2,
212
+ n_extra_feat=0,
213
+ truncate=0,
214
+ **kwargs):
215
+ super().__init__()
216
+
217
+ self.inplanes = 64
218
+ self.truncate = truncate
219
+ # extra = cfg.MODEL.EXTRA
220
+ # self.deconv_with_bias = extra.DECONV_WITH_BIAS
221
+ block, layers = resnet_spec[resnet_nums]
222
+
223
+ self.conv1 = nn.Conv2d(in_channels,
224
+ 64,
225
+ kernel_size=7,
226
+ stride=2,
227
+ padding=3,
228
+ bias=False)
229
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
230
+ self.relu = nn.ReLU(inplace=True)
231
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
232
+ self.layer1 = self._make_layer(block, 64, layers[0])
233
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
234
+ self.layer3 = self._make_layer(block, 256, layers[2],
235
+ stride=2) if truncate < 2 else None
236
+ self.layer4 = self._make_layer(
237
+ block, 512, layers[3],
238
+ stride=last_stride) if truncate < 1 else None
239
+
240
+ self.avg_pooling = nn.AdaptiveAvgPool2d(1)
241
+
242
+ self.num_classes = num_classes
243
+ if num_classes > 0:
244
+ self.final_layer = nn.Linear(512 * block.expansion, num_classes)
245
+ nn.init.xavier_uniform_(self.final_layer.weight, gain=0.01)
246
+
247
+ self.n_extra_feat = n_extra_feat
248
+ if n_extra_feat > 0:
249
+ self.trans_conv = nn.Sequential(
250
+ nn.Conv2d(n_extra_feat + 512 * block.expansion,
251
+ 512 * block.expansion,
252
+ kernel_size=1,
253
+ bias=False),
254
+ nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM),
255
+ nn.ReLU(True))
256
+
257
+ def _make_layer(self, block, planes, blocks, stride=1):
258
+ downsample = None
259
+ if stride != 1 or self.inplanes != planes * block.expansion:
260
+ downsample = nn.Sequential(
261
+ nn.Conv2d(self.inplanes,
262
+ planes * block.expansion,
263
+ kernel_size=1,
264
+ stride=stride,
265
+ bias=False),
266
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
267
+ )
268
+
269
+ layers = []
270
+ layers.append(block(self.inplanes, planes, stride, downsample))
271
+ self.inplanes = planes * block.expansion
272
+ for i in range(1, blocks):
273
+ layers.append(block(self.inplanes, planes))
274
+
275
+ return nn.Sequential(*layers)
276
+
277
+ def forward(self, x, infeat=None):
278
+ x = self.conv1(x)
279
+ x = self.bn1(x)
280
+ x = self.relu(x)
281
+ x = self.maxpool(x)
282
+
283
+ x1 = self.layer1(x)
284
+ x2 = self.layer2(x1)
285
+ x3 = self.layer3(x2) if self.truncate < 2 else x2
286
+ x4 = self.layer4(x3) if self.truncate < 1 else x3
287
+
288
+ if infeat is not None:
289
+ x4 = self.trans_conv(torch.cat([infeat, x4], 1))
290
+
291
+ if self.num_classes > 0:
292
+ xp = self.avg_pooling(x4)
293
+ cls = self.final_layer(xp.view(xp.size(0), -1))
294
+ if not cfg.DANET.USE_MEAN_PARA:
295
+ # for non-negative scale
296
+ scale = F.relu(cls[:, 0]).unsqueeze(1)
297
+ cls = torch.cat((scale, cls[:, 1:]), dim=1)
298
+ else:
299
+ cls = None
300
+
301
+ return cls, {'x4': x4}
302
+
303
+ def init_weights(self, pretrained=''):
304
+ if os.path.isfile(pretrained):
305
+ logger.info('=> loading pretrained model {}'.format(pretrained))
306
+ # self.load_state_dict(pretrained_state_dict, strict=False)
307
+ checkpoint = torch.load(pretrained)
308
+ if isinstance(checkpoint, OrderedDict):
309
+ # state_dict = checkpoint
310
+ state_dict_old = self.state_dict()
311
+ for key in state_dict_old.keys():
312
+ if key in checkpoint.keys():
313
+ if state_dict_old[key].shape != checkpoint[key].shape:
314
+ del checkpoint[key]
315
+ state_dict = checkpoint
316
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
317
+ state_dict_old = checkpoint['state_dict']
318
+ state_dict = OrderedDict()
319
+ # delete 'module.' because it is saved from DataParallel module
320
+ for key in state_dict_old.keys():
321
+ if key.startswith('module.'):
322
+ # state_dict[key[7:]] = state_dict[key]
323
+ # state_dict.pop(key)
324
+ state_dict[key[7:]] = state_dict_old[key]
325
+ else:
326
+ state_dict[key] = state_dict_old[key]
327
+ else:
328
+ raise RuntimeError(
329
+ 'No state_dict found in checkpoint file {}'.format(
330
+ pretrained))
331
+ self.load_state_dict(state_dict, strict=False)
332
+ else:
333
+ logger.error('=> imagenet pretrained model dose not exist')
334
+ logger.error('=> please download it first')
335
+ raise ValueError('imagenet pretrained model does not exist')
336
+
337
+
338
+ class LimbResLayers(nn.Module):
339
+ def __init__(self,
340
+ resnet_nums,
341
+ inplanes,
342
+ outplanes=None,
343
+ groups=1,
344
+ **kwargs):
345
+ super().__init__()
346
+
347
+ self.inplanes = inplanes
348
+ block, layers = resnet_spec[resnet_nums]
349
+ self.outplanes = 512 if outplanes == None else outplanes
350
+ self.layer4 = self._make_layer(block,
351
+ self.outplanes,
352
+ layers[3],
353
+ stride=2,
354
+ groups=groups)
355
+
356
+ self.avg_pooling = nn.AdaptiveAvgPool2d(1)
357
+
358
+ def _make_layer(self, block, planes, blocks, stride=1, groups=1):
359
+ downsample = None
360
+ if stride != 1 or self.inplanes != planes * block.expansion:
361
+ downsample = nn.Sequential(
362
+ nn.Conv2d(self.inplanes * groups,
363
+ planes * block.expansion * groups,
364
+ kernel_size=1,
365
+ stride=stride,
366
+ bias=False,
367
+ groups=groups),
368
+ nn.BatchNorm2d(planes * block.expansion * groups,
369
+ momentum=BN_MOMENTUM),
370
+ )
371
+
372
+ layers = []
373
+ layers.append(
374
+ block(self.inplanes, planes, stride, downsample, groups=groups))
375
+ self.inplanes = planes * block.expansion
376
+ for i in range(1, blocks):
377
+ layers.append(block(self.inplanes, planes, groups=groups))
378
+
379
+ return nn.Sequential(*layers)
380
+
381
+ def forward(self, x):
382
+ x = self.layer4(x)
383
+ x = self.avg_pooling(x)
384
+
385
+ return x
lib/pymaf/models/smpl.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is borrowed from https://github.com/nkolot/SPIN/blob/master/models/smpl.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ from lib.smplx import SMPL as _SMPL
6
+ from lib.smplx.body_models import ModelOutput
7
+ from lib.smplx.lbs import vertices2joints
8
+ from collections import namedtuple
9
+
10
+ from lib.pymaf.core import path_config, constants
11
+
12
+ SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
13
+ SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR
14
+
15
+ # Indices to get the 14 LSP joints from the 17 H36M joints
16
+ H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
17
+ H36M_TO_J14 = H36M_TO_J17[:14]
18
+
19
+
20
+ class SMPL(_SMPL):
21
+ """ Extension of the official SMPL implementation to support more joints """
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
26
+ J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
27
+ self.register_buffer(
28
+ 'J_regressor_extra',
29
+ torch.tensor(J_regressor_extra, dtype=torch.float32))
30
+ self.joint_map = torch.tensor(joints, dtype=torch.long)
31
+ self.ModelOutput = namedtuple(
32
+ 'ModelOutput_', ModelOutput._fields + (
33
+ 'smpl_joints',
34
+ 'joints_J19',
35
+ ))
36
+ self.ModelOutput.__new__.__defaults__ = (None, ) * len(
37
+ self.ModelOutput._fields)
38
+
39
+ def forward(self, *args, **kwargs):
40
+ kwargs['get_skin'] = True
41
+ smpl_output = super().forward(*args, **kwargs)
42
+ extra_joints = vertices2joints(self.J_regressor_extra,
43
+ smpl_output.vertices)
44
+ # smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3]
45
+ vertices = smpl_output.vertices
46
+ joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
47
+ smpl_joints = smpl_output.joints[:, :24]
48
+ joints = joints[:, self.joint_map, :] # [B, 49, 3]
49
+ joints_J24 = joints[:, -24:, :]
50
+ joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
51
+ output = self.ModelOutput(vertices=vertices,
52
+ global_orient=smpl_output.global_orient,
53
+ body_pose=smpl_output.body_pose,
54
+ joints=joints,
55
+ joints_J19=joints_J19,
56
+ smpl_joints=smpl_joints,
57
+ betas=smpl_output.betas,
58
+ full_pose=smpl_output.full_pose)
59
+ return output
60
+
61
+
62
+ def get_smpl_faces():
63
+ smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
64
+ return smpl.faces
65
+
66
+
67
+ def get_part_joints(smpl_joints):
68
+ batch_size = smpl_joints.shape[0]
69
+
70
+ # part_joints = torch.zeros().to(smpl_joints.device)
71
+
72
+ one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14),
73
+ (12, 15), (13, 16), (14, 17)]
74
+ two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19),
75
+ (18, 20), (19, 21)]
76
+
77
+ one_seg_pairs.extend(two_seg_pairs)
78
+
79
+ single_joints = [(10), (11), (15), (22), (23)]
80
+
81
+ part_joints = []
82
+
83
+ for j_p in one_seg_pairs:
84
+ new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True)
85
+ part_joints.append(new_joint)
86
+
87
+ for j_p in single_joints:
88
+ part_joints.append(smpl_joints[:, j_p:j_p + 1])
89
+
90
+ part_joints = torch.cat(part_joints, dim=1)
91
+
92
+ return part_joints