YoonaAI commited on
Commit
bf91b23
·
1 Parent(s): bd58d25

Upload 11 files

Browse files
lib/net/BasePIFuNet.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch.nn as nn
19
+ import pytorch_lightning as pl
20
+
21
+ from .geometry import index, orthogonal, perspective
22
+
23
+
24
+ class BasePIFuNet(pl.LightningModule):
25
+ def __init__(
26
+ self,
27
+ projection_mode='orthogonal',
28
+ error_term=nn.MSELoss(),
29
+ ):
30
+ """
31
+ :param projection_mode:
32
+ Either orthogonal or perspective.
33
+ It will call the corresponding function for projection.
34
+ :param error_term:
35
+ nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
36
+ """
37
+ super(BasePIFuNet, self).__init__()
38
+ self.name = 'base'
39
+
40
+ self.error_term = error_term
41
+
42
+ self.index = index
43
+ self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
44
+
45
+ def forward(self, points, images, calibs, transforms=None):
46
+ '''
47
+ :param points: [B, 3, N] world space coordinates of points
48
+ :param images: [B, C, H, W] input images
49
+ :param calibs: [B, 3, 4] calibration matrices for each image
50
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
51
+ :return: [B, Res, N] predictions for each point
52
+ '''
53
+ features = self.filter(images)
54
+ preds = self.query(features, points, calibs, transforms)
55
+ return preds
56
+
57
+ def filter(self, images):
58
+ '''
59
+ Filter the input images
60
+ store all intermediate features.
61
+ :param images: [B, C, H, W] input images
62
+ '''
63
+ return None
64
+
65
+ def query(self, features, points, calibs, transforms=None):
66
+ '''
67
+ Given 3D points, query the network predictions for each point.
68
+ Image features should be pre-computed before this call.
69
+ store all intermediate features.
70
+ query() function may behave differently during training/testing.
71
+ :param points: [B, 3, N] world space coordinates of points
72
+ :param calibs: [B, 3, 4] calibration matrices for each image
73
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
74
+ :param labels: Optional [B, Res, N] gt labeling
75
+ :return: [B, Res, N] predictions for each point
76
+ '''
77
+ return None
78
+
79
+ def get_error(self, preds, labels):
80
+ '''
81
+ Get the network loss from the last query
82
+ :return: loss term
83
+ '''
84
+ return self.error_term(preds, labels)
lib/net/FBNet.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
3
+ BSD License. All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ * Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ * Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
17
+ IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
18
+ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
19
+ WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
20
+ OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
21
+ '''
22
+ import torch
23
+ import torch.nn as nn
24
+ import functools
25
+ import numpy as np
26
+ import pytorch_lightning as pl
27
+
28
+
29
+ ###############################################################################
30
+ # Functions
31
+ ###############################################################################
32
+ def weights_init(m):
33
+ classname = m.__class__.__name__
34
+ if classname.find('Conv') != -1:
35
+ m.weight.data.normal_(0.0, 0.02)
36
+ elif classname.find('BatchNorm2d') != -1:
37
+ m.weight.data.normal_(1.0, 0.02)
38
+ m.bias.data.fill_(0)
39
+
40
+
41
+ def get_norm_layer(norm_type='instance'):
42
+ if norm_type == 'batch':
43
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
44
+ elif norm_type == 'instance':
45
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
46
+ else:
47
+ raise NotImplementedError('normalization layer [%s] is not found' %
48
+ norm_type)
49
+ return norm_layer
50
+
51
+
52
+ def define_G(input_nc,
53
+ output_nc,
54
+ ngf,
55
+ netG,
56
+ n_downsample_global=3,
57
+ n_blocks_global=9,
58
+ n_local_enhancers=1,
59
+ n_blocks_local=3,
60
+ norm='instance',
61
+ gpu_ids=[],
62
+ last_op=nn.Tanh()):
63
+ norm_layer = get_norm_layer(norm_type=norm)
64
+ if netG == 'global':
65
+ netG = GlobalGenerator(input_nc,
66
+ output_nc,
67
+ ngf,
68
+ n_downsample_global,
69
+ n_blocks_global,
70
+ norm_layer,
71
+ last_op=last_op)
72
+ elif netG == 'local':
73
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
74
+ n_blocks_global, n_local_enhancers,
75
+ n_blocks_local, norm_layer)
76
+ elif netG == 'encoder':
77
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
78
+ norm_layer)
79
+ else:
80
+ raise ('generator not implemented!')
81
+ # print(netG)
82
+ if len(gpu_ids) > 0:
83
+ assert (torch.cuda.is_available())
84
+ device=torch.device(f"cuda:{gpu_ids[0]}")
85
+ netG = netG.to(device)
86
+ netG.apply(weights_init)
87
+ return netG
88
+
89
+
90
+ def print_network(net):
91
+ if isinstance(net, list):
92
+ net = net[0]
93
+ num_params = 0
94
+ for param in net.parameters():
95
+ num_params += param.numel()
96
+ print(net)
97
+ print('Total number of parameters: %d' % num_params)
98
+
99
+
100
+ ##############################################################################
101
+ # Generator
102
+ ##############################################################################
103
+ class LocalEnhancer(pl.LightningModule):
104
+ def __init__(self,
105
+ input_nc,
106
+ output_nc,
107
+ ngf=32,
108
+ n_downsample_global=3,
109
+ n_blocks_global=9,
110
+ n_local_enhancers=1,
111
+ n_blocks_local=3,
112
+ norm_layer=nn.BatchNorm2d,
113
+ padding_type='reflect'):
114
+ super(LocalEnhancer, self).__init__()
115
+ self.n_local_enhancers = n_local_enhancers
116
+
117
+ ###### global generator model #####
118
+ ngf_global = ngf * (2**n_local_enhancers)
119
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
120
+ n_downsample_global, n_blocks_global,
121
+ norm_layer).model
122
+ model_global = [model_global[i] for i in range(len(model_global) - 3)
123
+ ] # get rid of final convolution layers
124
+ self.model = nn.Sequential(*model_global)
125
+
126
+ ###### local enhancer layers #####
127
+ for n in range(1, n_local_enhancers + 1):
128
+ # downsample
129
+ ngf_global = ngf * (2**(n_local_enhancers - n))
130
+ model_downsample = [
131
+ nn.ReflectionPad2d(3),
132
+ nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
133
+ norm_layer(ngf_global),
134
+ nn.ReLU(True),
135
+ nn.Conv2d(ngf_global,
136
+ ngf_global * 2,
137
+ kernel_size=3,
138
+ stride=2,
139
+ padding=1),
140
+ norm_layer(ngf_global * 2),
141
+ nn.ReLU(True)
142
+ ]
143
+ # residual blocks
144
+ model_upsample = []
145
+ for i in range(n_blocks_local):
146
+ model_upsample += [
147
+ ResnetBlock(ngf_global * 2,
148
+ padding_type=padding_type,
149
+ norm_layer=norm_layer)
150
+ ]
151
+
152
+ # upsample
153
+ model_upsample += [
154
+ nn.ConvTranspose2d(ngf_global * 2,
155
+ ngf_global,
156
+ kernel_size=3,
157
+ stride=2,
158
+ padding=1,
159
+ output_padding=1),
160
+ norm_layer(ngf_global),
161
+ nn.ReLU(True)
162
+ ]
163
+
164
+ # final convolution
165
+ if n == n_local_enhancers:
166
+ model_upsample += [
167
+ nn.ReflectionPad2d(3),
168
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
169
+ nn.Tanh()
170
+ ]
171
+
172
+ setattr(self, 'model' + str(n) + '_1',
173
+ nn.Sequential(*model_downsample))
174
+ setattr(self, 'model' + str(n) + '_2',
175
+ nn.Sequential(*model_upsample))
176
+
177
+ self.downsample = nn.AvgPool2d(3,
178
+ stride=2,
179
+ padding=[1, 1],
180
+ count_include_pad=False)
181
+
182
+ def forward(self, input):
183
+ # create input pyramid
184
+ input_downsampled = [input]
185
+ for i in range(self.n_local_enhancers):
186
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
187
+
188
+ # output at coarest level
189
+ output_prev = self.model(input_downsampled[-1])
190
+ # build up one layer at a time
191
+ for n_local_enhancers in range(1, self.n_local_enhancers + 1):
192
+ model_downsample = getattr(self,
193
+ 'model' + str(n_local_enhancers) + '_1')
194
+ model_upsample = getattr(self,
195
+ 'model' + str(n_local_enhancers) + '_2')
196
+ input_i = input_downsampled[self.n_local_enhancers -
197
+ n_local_enhancers]
198
+ output_prev = model_upsample(
199
+ model_downsample(input_i) + output_prev)
200
+ return output_prev
201
+
202
+
203
+ class GlobalGenerator(pl.LightningModule):
204
+ def __init__(self,
205
+ input_nc,
206
+ output_nc,
207
+ ngf=64,
208
+ n_downsampling=3,
209
+ n_blocks=9,
210
+ norm_layer=nn.BatchNorm2d,
211
+ padding_type='reflect',
212
+ last_op=nn.Tanh()):
213
+ assert (n_blocks >= 0)
214
+ super(GlobalGenerator, self).__init__()
215
+ activation = nn.ReLU(True)
216
+
217
+ model = [
218
+ nn.ReflectionPad2d(3),
219
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
220
+ norm_layer(ngf), activation
221
+ ]
222
+ # downsample
223
+ for i in range(n_downsampling):
224
+ mult = 2**i
225
+ model += [
226
+ nn.Conv2d(ngf * mult,
227
+ ngf * mult * 2,
228
+ kernel_size=3,
229
+ stride=2,
230
+ padding=1),
231
+ norm_layer(ngf * mult * 2), activation
232
+ ]
233
+
234
+ # resnet blocks
235
+ mult = 2**n_downsampling
236
+ for i in range(n_blocks):
237
+ model += [
238
+ ResnetBlock(ngf * mult,
239
+ padding_type=padding_type,
240
+ activation=activation,
241
+ norm_layer=norm_layer)
242
+ ]
243
+
244
+ # upsample
245
+ for i in range(n_downsampling):
246
+ mult = 2**(n_downsampling - i)
247
+ model += [
248
+ nn.ConvTranspose2d(ngf * mult,
249
+ int(ngf * mult / 2),
250
+ kernel_size=3,
251
+ stride=2,
252
+ padding=1,
253
+ output_padding=1),
254
+ norm_layer(int(ngf * mult / 2)), activation
255
+ ]
256
+ model += [
257
+ nn.ReflectionPad2d(3),
258
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
259
+ ]
260
+ if last_op is not None:
261
+ model += [last_op]
262
+ self.model = nn.Sequential(*model)
263
+
264
+ def forward(self, input):
265
+ return self.model(input)
266
+
267
+
268
+ # Define a resnet block
269
+ class ResnetBlock(pl.LightningModule):
270
+ def __init__(self,
271
+ dim,
272
+ padding_type,
273
+ norm_layer,
274
+ activation=nn.ReLU(True),
275
+ use_dropout=False):
276
+ super(ResnetBlock, self).__init__()
277
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
278
+ activation, use_dropout)
279
+
280
+ def build_conv_block(self, dim, padding_type, norm_layer, activation,
281
+ use_dropout):
282
+ conv_block = []
283
+ p = 0
284
+ if padding_type == 'reflect':
285
+ conv_block += [nn.ReflectionPad2d(1)]
286
+ elif padding_type == 'replicate':
287
+ conv_block += [nn.ReplicationPad2d(1)]
288
+ elif padding_type == 'zero':
289
+ p = 1
290
+ else:
291
+ raise NotImplementedError('padding [%s] is not implemented' %
292
+ padding_type)
293
+
294
+ conv_block += [
295
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
296
+ norm_layer(dim), activation
297
+ ]
298
+ if use_dropout:
299
+ conv_block += [nn.Dropout(0.5)]
300
+
301
+ p = 0
302
+ if padding_type == 'reflect':
303
+ conv_block += [nn.ReflectionPad2d(1)]
304
+ elif padding_type == 'replicate':
305
+ conv_block += [nn.ReplicationPad2d(1)]
306
+ elif padding_type == 'zero':
307
+ p = 1
308
+ else:
309
+ raise NotImplementedError('padding [%s] is not implemented' %
310
+ padding_type)
311
+ conv_block += [
312
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
313
+ norm_layer(dim)
314
+ ]
315
+
316
+ return nn.Sequential(*conv_block)
317
+
318
+ def forward(self, x):
319
+ out = x + self.conv_block(x)
320
+ return out
321
+
322
+
323
+ class Encoder(pl.LightningModule):
324
+ def __init__(self,
325
+ input_nc,
326
+ output_nc,
327
+ ngf=32,
328
+ n_downsampling=4,
329
+ norm_layer=nn.BatchNorm2d):
330
+ super(Encoder, self).__init__()
331
+ self.output_nc = output_nc
332
+
333
+ model = [
334
+ nn.ReflectionPad2d(3),
335
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
336
+ norm_layer(ngf),
337
+ nn.ReLU(True)
338
+ ]
339
+ # downsample
340
+ for i in range(n_downsampling):
341
+ mult = 2**i
342
+ model += [
343
+ nn.Conv2d(ngf * mult,
344
+ ngf * mult * 2,
345
+ kernel_size=3,
346
+ stride=2,
347
+ padding=1),
348
+ norm_layer(ngf * mult * 2),
349
+ nn.ReLU(True)
350
+ ]
351
+
352
+ # upsample
353
+ for i in range(n_downsampling):
354
+ mult = 2**(n_downsampling - i)
355
+ model += [
356
+ nn.ConvTranspose2d(ngf * mult,
357
+ int(ngf * mult / 2),
358
+ kernel_size=3,
359
+ stride=2,
360
+ padding=1,
361
+ output_padding=1),
362
+ norm_layer(int(ngf * mult / 2)),
363
+ nn.ReLU(True)
364
+ ]
365
+
366
+ model += [
367
+ nn.ReflectionPad2d(3),
368
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
369
+ nn.Tanh()
370
+ ]
371
+ self.model = nn.Sequential(*model)
372
+
373
+ def forward(self, input, inst):
374
+ outputs = self.model(input)
375
+
376
+ # instance-wise average pooling
377
+ outputs_mean = outputs.clone()
378
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
379
+ for i in inst_list:
380
+ for b in range(input.size()[0]):
381
+ indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
382
+ for j in range(self.output_nc):
383
+ output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
384
+ indices[:, 2], indices[:, 3]]
385
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
386
+ outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
387
+ indices[:, 2], indices[:, 3]] = mean_feat
388
+ return outputs_mean
lib/net/HGFilters.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from lib.net.net_util import *
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ class HourGlass(nn.Module):
24
+ def __init__(self, num_modules, depth, num_features, opt):
25
+ super(HourGlass, self).__init__()
26
+ self.num_modules = num_modules
27
+ self.depth = depth
28
+ self.features = num_features
29
+ self.opt = opt
30
+
31
+ self._generate_network(self.depth)
32
+
33
+ def _generate_network(self, level):
34
+ self.add_module('b1_' + str(level),
35
+ ConvBlock(self.features, self.features, self.opt))
36
+
37
+ self.add_module('b2_' + str(level),
38
+ ConvBlock(self.features, self.features, self.opt))
39
+
40
+ if level > 1:
41
+ self._generate_network(level - 1)
42
+ else:
43
+ self.add_module('b2_plus_' + str(level),
44
+ ConvBlock(self.features, self.features, self.opt))
45
+
46
+ self.add_module('b3_' + str(level),
47
+ ConvBlock(self.features, self.features, self.opt))
48
+
49
+ def _forward(self, level, inp):
50
+ # Upper branch
51
+ up1 = inp
52
+ up1 = self._modules['b1_' + str(level)](up1)
53
+
54
+ # Lower branch
55
+ low1 = F.avg_pool2d(inp, 2, stride=2)
56
+ low1 = self._modules['b2_' + str(level)](low1)
57
+
58
+ if level > 1:
59
+ low2 = self._forward(level - 1, low1)
60
+ else:
61
+ low2 = low1
62
+ low2 = self._modules['b2_plus_' + str(level)](low2)
63
+
64
+ low3 = low2
65
+ low3 = self._modules['b3_' + str(level)](low3)
66
+
67
+ # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
68
+ # if the pretrained model behaves weirdly, switch with the commented line.
69
+ # NOTE: I also found that "bicubic" works better.
70
+ up2 = F.interpolate(low3,
71
+ scale_factor=2,
72
+ mode='bicubic',
73
+ align_corners=True)
74
+ # up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
75
+
76
+ return up1 + up2
77
+
78
+ def forward(self, x):
79
+ return self._forward(self.depth, x)
80
+
81
+
82
+ class HGFilter(nn.Module):
83
+ def __init__(self, opt, num_modules, in_dim):
84
+ super(HGFilter, self).__init__()
85
+ self.num_modules = num_modules
86
+
87
+ self.opt = opt
88
+ [k, s, d, p] = self.opt.conv1
89
+
90
+ # self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3)
91
+ self.conv1 = nn.Conv2d(in_dim,
92
+ 64,
93
+ kernel_size=k,
94
+ stride=s,
95
+ dilation=d,
96
+ padding=p)
97
+
98
+ if self.opt.norm == 'batch':
99
+ self.bn1 = nn.BatchNorm2d(64)
100
+ elif self.opt.norm == 'group':
101
+ self.bn1 = nn.GroupNorm(32, 64)
102
+
103
+ if self.opt.hg_down == 'conv64':
104
+ self.conv2 = ConvBlock(64, 64, self.opt)
105
+ self.down_conv2 = nn.Conv2d(64,
106
+ 128,
107
+ kernel_size=3,
108
+ stride=2,
109
+ padding=1)
110
+ elif self.opt.hg_down == 'conv128':
111
+ self.conv2 = ConvBlock(64, 128, self.opt)
112
+ self.down_conv2 = nn.Conv2d(128,
113
+ 128,
114
+ kernel_size=3,
115
+ stride=2,
116
+ padding=1)
117
+ elif self.opt.hg_down == 'ave_pool':
118
+ self.conv2 = ConvBlock(64, 128, self.opt)
119
+ else:
120
+ raise NameError('Unknown Fan Filter setting!')
121
+
122
+ self.conv3 = ConvBlock(128, 128, self.opt)
123
+ self.conv4 = ConvBlock(128, 256, self.opt)
124
+
125
+ # Stacking part
126
+ for hg_module in range(self.num_modules):
127
+ self.add_module('m' + str(hg_module),
128
+ HourGlass(1, opt.num_hourglass, 256, self.opt))
129
+
130
+ self.add_module('top_m_' + str(hg_module),
131
+ ConvBlock(256, 256, self.opt))
132
+ self.add_module(
133
+ 'conv_last' + str(hg_module),
134
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
135
+ if self.opt.norm == 'batch':
136
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
137
+ elif self.opt.norm == 'group':
138
+ self.add_module('bn_end' + str(hg_module),
139
+ nn.GroupNorm(32, 256))
140
+
141
+ self.add_module(
142
+ 'l' + str(hg_module),
143
+ nn.Conv2d(256,
144
+ opt.hourglass_dim,
145
+ kernel_size=1,
146
+ stride=1,
147
+ padding=0))
148
+
149
+ if hg_module < self.num_modules - 1:
150
+ self.add_module(
151
+ 'bl' + str(hg_module),
152
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
153
+ self.add_module(
154
+ 'al' + str(hg_module),
155
+ nn.Conv2d(opt.hourglass_dim,
156
+ 256,
157
+ kernel_size=1,
158
+ stride=1,
159
+ padding=0))
160
+
161
+ def forward(self, x):
162
+ x = F.relu(self.bn1(self.conv1(x)), True)
163
+ tmpx = x
164
+ if self.opt.hg_down == 'ave_pool':
165
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
166
+ elif self.opt.hg_down in ['conv64', 'conv128']:
167
+ x = self.conv2(x)
168
+ x = self.down_conv2(x)
169
+ else:
170
+ raise NameError('Unknown Fan Filter setting!')
171
+
172
+ x = self.conv3(x)
173
+ x = self.conv4(x)
174
+
175
+ previous = x
176
+
177
+ outputs = []
178
+ for i in range(self.num_modules):
179
+ hg = self._modules['m' + str(i)](previous)
180
+
181
+ ll = hg
182
+ ll = self._modules['top_m_' + str(i)](ll)
183
+
184
+ ll = F.relu(
185
+ self._modules['bn_end' + str(i)](
186
+ self._modules['conv_last' + str(i)](ll)), True)
187
+
188
+ # Predict heatmaps
189
+ tmp_out = self._modules['l' + str(i)](ll)
190
+ outputs.append(tmp_out)
191
+
192
+ if i < self.num_modules - 1:
193
+ ll = self._modules['bl' + str(i)](ll)
194
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
195
+ previous = previous + ll + tmp_out_
196
+
197
+ return outputs
lib/net/HGPIFuNet.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from lib.net.voxelize import Voxelization
19
+ from lib.dataset.mesh_util import cal_sdf_batch, feat_select, read_smpl_constants
20
+ from lib.net.NormalNet import NormalNet
21
+ from lib.net.MLP import MLP
22
+ from lib.dataset.mesh_util import SMPLX
23
+ from lib.net.VE import VolumeEncoder
24
+ from lib.net.HGFilters import *
25
+ from termcolor import colored
26
+ from lib.net.BasePIFuNet import BasePIFuNet
27
+ import torch.nn as nn
28
+ import torch
29
+
30
+
31
+ maskout = False
32
+
33
+
34
+ class HGPIFuNet(BasePIFuNet):
35
+ '''
36
+ HG PIFu network uses Hourglass stacks as the image filter.
37
+ It does the following:
38
+ 1. Compute image feature stacks and store it in self.im_feat_list
39
+ self.im_feat_list[-1] is the last stack (output stack)
40
+ 2. Calculate calibration
41
+ 3. If training, it index on every intermediate stacks,
42
+ If testing, it index on the last stack.
43
+ 4. Classification.
44
+ 5. During training, error is calculated on all stacks.
45
+ '''
46
+
47
+ def __init__(self,
48
+ cfg,
49
+ projection_mode='orthogonal',
50
+ error_term=nn.MSELoss()):
51
+
52
+ super(HGPIFuNet, self).__init__(projection_mode=projection_mode,
53
+ error_term=error_term)
54
+
55
+ self.l1_loss = nn.SmoothL1Loss()
56
+ self.opt = cfg.net
57
+ self.root = cfg.root
58
+ self.overfit = cfg.overfit
59
+
60
+ channels_IF = self.opt.mlp_dim
61
+
62
+ self.use_filter = self.opt.use_filter
63
+ self.prior_type = self.opt.prior_type
64
+ self.smpl_feats = self.opt.smpl_feats
65
+
66
+ self.smpl_dim = self.opt.smpl_dim
67
+ self.voxel_dim = self.opt.voxel_dim
68
+ self.hourglass_dim = self.opt.hourglass_dim
69
+ self.sdf_clip = cfg.sdf_clip / 100.0
70
+
71
+ self.in_geo = [item[0] for item in self.opt.in_geo]
72
+ self.in_nml = [item[0] for item in self.opt.in_nml]
73
+
74
+ self.in_geo_dim = sum([item[1] for item in self.opt.in_geo])
75
+ self.in_nml_dim = sum([item[1] for item in self.opt.in_nml])
76
+
77
+ self.in_total = self.in_geo + self.in_nml
78
+ self.smpl_feat_dict = None
79
+ self.smplx_data = SMPLX()
80
+
81
+ if self.prior_type == 'icon':
82
+ if 'image' in self.in_geo:
83
+ self.channels_filter = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 6, 7, 8]]
84
+ else:
85
+ self.channels_filter = [[0, 1, 2], [3, 4, 5]]
86
+
87
+ else:
88
+ if 'image' in self.in_geo:
89
+ self.channels_filter = [[0, 1, 2, 3, 4, 5, 6, 7, 8]]
90
+ else:
91
+ self.channels_filter = [[0, 1, 2, 3, 4, 5]]
92
+
93
+ channels_IF[0] = self.hourglass_dim if self.use_filter else len(
94
+ self.channels_filter[0])
95
+
96
+ if self.prior_type == 'icon' and 'vis' not in self.smpl_feats:
97
+ if self.use_filter:
98
+ channels_IF[0] += self.hourglass_dim
99
+ else:
100
+ channels_IF[0] += len(self.channels_filter[0])
101
+
102
+ if self.prior_type == 'icon':
103
+ channels_IF[0] += self.smpl_dim
104
+ elif self.prior_type == 'pamir':
105
+ channels_IF[0] += self.voxel_dim
106
+ smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras = read_smpl_constants(
107
+ self.smplx_data.tedra_dir)
108
+ self.voxelization = Voxelization(
109
+ smpl_vertex_code,
110
+ smpl_face_code,
111
+ smpl_faces,
112
+ smpl_tetras,
113
+ volume_res=128,
114
+ sigma=0.05,
115
+ smooth_kernel_size=7,
116
+ batch_size=cfg.batch_size,
117
+ device=torch.device(f"cuda:{cfg.gpus[0]}"))
118
+ self.ve = VolumeEncoder(3, self.voxel_dim, self.opt.num_stack)
119
+ else:
120
+ channels_IF[0] += 1
121
+
122
+ self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
123
+ self.pamir_keys = [
124
+ "voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num"
125
+ ]
126
+
127
+ self.if_regressor = MLP(
128
+ filter_channels=channels_IF,
129
+ name='if',
130
+ res_layers=self.opt.res_layers,
131
+ norm=self.opt.norm_mlp,
132
+ last_op=nn.Sigmoid() if not cfg.test_mode else None)
133
+
134
+ # network
135
+ if self.use_filter:
136
+ if self.opt.gtype == "HGPIFuNet":
137
+ self.F_filter = HGFilter(self.opt, self.opt.num_stack,
138
+ len(self.channels_filter[0]))
139
+ else:
140
+ print(
141
+ colored(f"Backbone {self.opt.gtype} is unimplemented",
142
+ 'green'))
143
+
144
+ summary_log = f"{self.prior_type.upper()}:\n" + \
145
+ f"w/ Global Image Encoder: {self.use_filter}\n" + \
146
+ f"Image Features used by MLP: {self.in_geo}\n"
147
+
148
+ if self.prior_type == "icon":
149
+ summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n"
150
+ summary_log += f"Dim of Image Features (local): 6\n"
151
+ summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n"
152
+ elif self.prior_type == "pamir":
153
+ summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n"
154
+ summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n"
155
+ else:
156
+ summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n"
157
+ summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n"
158
+
159
+ summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n"
160
+
161
+ print(colored(summary_log, "yellow"))
162
+
163
+ self.normal_filter = NormalNet(cfg)
164
+ init_net(self)
165
+
166
+ def get_normal(self, in_tensor_dict):
167
+
168
+ # insert normal features
169
+ if (not self.training) and (not self.overfit):
170
+ # print(colored("infer normal","blue"))
171
+ with torch.no_grad():
172
+ feat_lst = []
173
+ if "image" in self.in_geo:
174
+ feat_lst.append(
175
+ in_tensor_dict['image']) # [1, 3, 512, 512]
176
+ if 'normal_F' in self.in_geo and 'normal_B' in self.in_geo:
177
+ if 'normal_F' not in in_tensor_dict.keys(
178
+ ) or 'normal_B' not in in_tensor_dict.keys():
179
+ (nmlF, nmlB) = self.normal_filter(in_tensor_dict)
180
+ else:
181
+ nmlF = in_tensor_dict['normal_F']
182
+ nmlB = in_tensor_dict['normal_B']
183
+ feat_lst.append(nmlF) # [1, 3, 512, 512]
184
+ feat_lst.append(nmlB) # [1, 3, 512, 512]
185
+ in_filter = torch.cat(feat_lst, dim=1)
186
+
187
+ else:
188
+ in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo],
189
+ dim=1)
190
+
191
+ return in_filter
192
+
193
+ def get_mask(self, in_filter, size=128):
194
+
195
+ mask = F.interpolate(in_filter[:, self.channels_filter[0]],
196
+ size=(size, size),
197
+ mode="bilinear",
198
+ align_corners=True).abs().sum(dim=1,
199
+ keepdim=True) != 0.0
200
+
201
+ return mask
202
+
203
+ def filter(self, in_tensor_dict, return_inter=False):
204
+ '''
205
+ Filter the input images
206
+ store all intermediate features.
207
+ :param images: [B, C, H, W] input images
208
+ '''
209
+
210
+ in_filter = self.get_normal(in_tensor_dict)
211
+
212
+ features_G = []
213
+
214
+ if self.prior_type == 'icon':
215
+ if self.use_filter:
216
+ features_F = self.F_filter(in_filter[:,
217
+ self.channels_filter[0]]
218
+ ) # [(B,hg_dim,128,128) * 4]
219
+ features_B = self.F_filter(in_filter[:,
220
+ self.channels_filter[1]]
221
+ ) # [(B,hg_dim,128,128) * 4]
222
+ else:
223
+ features_F = [in_filter[:, self.channels_filter[0]]]
224
+ features_B = [in_filter[:, self.channels_filter[1]]]
225
+ for idx in range(len(features_F)):
226
+ features_G.append(
227
+ torch.cat([features_F[idx], features_B[idx]], dim=1))
228
+ else:
229
+ if self.use_filter:
230
+ features_G = self.F_filter(in_filter[:,
231
+ self.channels_filter[0]])
232
+ else:
233
+ features_G = [in_filter[:, self.channels_filter[0]]]
234
+
235
+ if self.prior_type == 'icon':
236
+ self.smpl_feat_dict = {
237
+ k: in_tensor_dict[k]
238
+ for k in self.icon_keys
239
+ }
240
+ elif self.prior_type == "pamir":
241
+ self.smpl_feat_dict = {
242
+ k: in_tensor_dict[k]
243
+ for k in self.pamir_keys
244
+ }
245
+ else:
246
+ pass
247
+ # print(colored("use z rather than icon or pamir", "green"))
248
+
249
+ # If it is not in training, only produce the last im_feat
250
+ if not self.training:
251
+ features_out = [features_G[-1]]
252
+ else:
253
+ features_out = features_G
254
+
255
+ if maskout:
256
+ features_out_mask = []
257
+ for feat in features_out:
258
+ features_out_mask.append(
259
+ feat * self.get_mask(in_filter, size=feat.shape[2]))
260
+ features_out = features_out_mask
261
+
262
+ if return_inter:
263
+ return features_out, in_filter
264
+ else:
265
+ return features_out
266
+
267
+ def query(self, features, points, calibs, transforms=None, regressor=None):
268
+
269
+ xyz = self.projection(points, calibs, transforms)
270
+
271
+ (xy, z) = xyz.split([2, 1], dim=1)
272
+
273
+ in_cube = (xyz > -1.0) & (xyz < 1.0)
274
+ in_cube = in_cube.all(dim=1, keepdim=True).detach().float()
275
+
276
+ preds_list = []
277
+
278
+ if self.prior_type == 'icon':
279
+
280
+ # smpl_verts [B, N_vert, 3]
281
+ # smpl_faces [B, N_face, 3]
282
+ # points [B, 3, N]
283
+
284
+ smpl_sdf, smpl_norm, smpl_cmap, smpl_vis = cal_sdf_batch(
285
+ self.smpl_feat_dict['smpl_verts'],
286
+ self.smpl_feat_dict['smpl_faces'],
287
+ self.smpl_feat_dict['smpl_cmap'],
288
+ self.smpl_feat_dict['smpl_vis'],
289
+ xyz.permute(0, 2, 1).contiguous())
290
+
291
+ # smpl_sdf [B, N, 1]
292
+ # smpl_norm [B, N, 3]
293
+ # smpl_cmap [B, N, 3]
294
+ # smpl_vis [B, N, 1]
295
+
296
+ feat_lst = [smpl_sdf]
297
+ if 'cmap' in self.smpl_feats:
298
+ feat_lst.append(smpl_cmap)
299
+ if 'norm' in self.smpl_feats:
300
+ feat_lst.append(smpl_norm)
301
+ if 'vis' in self.smpl_feats:
302
+ feat_lst.append(smpl_vis)
303
+
304
+ smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1)
305
+ vol_feats = features
306
+
307
+ elif self.prior_type == "pamir":
308
+
309
+ voxel_verts = self.smpl_feat_dict[
310
+ 'voxel_verts'][:, :-self.smpl_feat_dict['pad_v_num'][0], :]
311
+ voxel_faces = self.smpl_feat_dict[
312
+ 'voxel_faces'][:, :-self.smpl_feat_dict['pad_f_num'][0], :]
313
+
314
+ self.voxelization.update_param(
315
+ batch_size=voxel_faces.shape[0],
316
+ smpl_tetra=voxel_faces[0].detach().cpu().numpy())
317
+ vol = self.voxelization(voxel_verts) # vol ~ [0,1]
318
+ vol_feats = self.ve(vol, intermediate_output=self.training)
319
+ else:
320
+ vol_feats = features
321
+
322
+ for im_feat, vol_feat in zip(features, vol_feats):
323
+
324
+ # [B, Feat_i + z, N]
325
+ # normal feature choice by smpl_vis
326
+ if self.prior_type == 'icon':
327
+ if 'vis' in self.smpl_feats:
328
+ point_local_feat = feat_select(self.index(im_feat, xy),
329
+ smpl_feat[:, [-1], :])
330
+ if maskout:
331
+ normal_mask = torch.tile(
332
+ point_local_feat.sum(dim=1, keepdims=True) == 0.0,
333
+ (1, smpl_feat.shape[1], 1))
334
+ normal_mask[:, 1:, :] = False
335
+ smpl_feat[normal_mask] = -1.0
336
+ point_feat_list = [point_local_feat, smpl_feat[:, :-1, :]]
337
+ else:
338
+ point_local_feat = self.index(im_feat, xy)
339
+ point_feat_list = [point_local_feat, smpl_feat[:, :, :]]
340
+
341
+ elif self.prior_type == 'pamir':
342
+ # im_feat [B, hg_dim, 128, 128]
343
+ # vol_feat [B, vol_dim, 32, 32, 32]
344
+ point_feat_list = [
345
+ self.index(im_feat, xy),
346
+ self.index(vol_feat, xyz)
347
+ ]
348
+
349
+ else:
350
+ point_feat_list = [self.index(im_feat, xy), z]
351
+
352
+ point_feat = torch.cat(point_feat_list, 1)
353
+
354
+ # out of image plane is always set to 0
355
+ preds = regressor(point_feat)
356
+ preds = in_cube * preds
357
+
358
+ preds_list.append(preds)
359
+
360
+ return preds_list
361
+
362
+ def get_error(self, preds_if_list, labels):
363
+ """calcaulate error
364
+
365
+ Args:
366
+ preds_list (list): list of torch.tensor(B, 3, N)
367
+ labels (torch.tensor): (B, N_knn, N)
368
+
369
+ Returns:
370
+ torch.tensor: error
371
+ """
372
+ error_if = 0
373
+
374
+ for pred_id in range(len(preds_if_list)):
375
+ pred_if = preds_if_list[pred_id]
376
+ error_if += self.error_term(pred_if, labels)
377
+
378
+ error_if /= len(preds_if_list)
379
+
380
+ return error_if
381
+
382
+ def forward(self, in_tensor_dict):
383
+ """
384
+ sample_tensor [B, 3, N]
385
+ calib_tensor [B, 4, 4]
386
+ label_tensor [B, 1, N]
387
+ smpl_feat_tensor [B, 59, N]
388
+ """
389
+
390
+ sample_tensor = in_tensor_dict['sample']
391
+ calib_tensor = in_tensor_dict['calib']
392
+ label_tensor = in_tensor_dict['label']
393
+
394
+ in_feat = self.filter(in_tensor_dict)
395
+
396
+ preds_if_list = self.query(in_feat,
397
+ sample_tensor,
398
+ calib_tensor,
399
+ regressor=self.if_regressor)
400
+
401
+ error = self.get_error(preds_if_list, label_tensor)
402
+
403
+ return preds_if_list[-1], error
lib/net/MLP.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import pytorch_lightning as pl
6
+
7
+
8
+ class MLP(pl.LightningModule):
9
+ def __init__(self,
10
+ filter_channels,
11
+ name=None,
12
+ res_layers=[],
13
+ norm='group',
14
+ last_op=None):
15
+
16
+ super(MLP, self).__init__()
17
+
18
+ self.filters = nn.ModuleList()
19
+ self.norms = nn.ModuleList()
20
+ self.res_layers = res_layers
21
+ self.norm = norm
22
+ self.last_op = last_op
23
+ self.name = name
24
+ self.activate = nn.LeakyReLU(inplace=True)
25
+
26
+ for l in range(0, len(filter_channels) - 1):
27
+ if l in self.res_layers:
28
+ self.filters.append(
29
+ nn.Conv1d(filter_channels[l] + filter_channels[0],
30
+ filter_channels[l + 1], 1))
31
+ else:
32
+ self.filters.append(
33
+ nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
34
+
35
+ if l != len(filter_channels) - 2:
36
+ if norm == 'group':
37
+ self.norms.append(nn.GroupNorm(32, filter_channels[l + 1]))
38
+ elif norm == 'batch':
39
+ self.norms.append(nn.BatchNorm1d(filter_channels[l + 1]))
40
+ elif norm == 'instance':
41
+ self.norms.append(nn.InstanceNorm1d(filter_channels[l +
42
+ 1]))
43
+ elif norm == 'weight':
44
+ self.filters[l] = nn.utils.weight_norm(self.filters[l],
45
+ name='weight')
46
+ # print(self.filters[l].weight_g.size(),
47
+ # self.filters[l].weight_v.size())
48
+
49
+ def forward(self, feature):
50
+ '''
51
+ feature may include multiple view inputs
52
+ args:
53
+ feature: [B, C_in, N]
54
+ return:
55
+ [B, C_out, N] prediction
56
+ '''
57
+ y = feature
58
+ tmpy = feature
59
+
60
+ for i, f in enumerate(self.filters):
61
+
62
+ y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1))
63
+ if i != len(self.filters) - 1:
64
+ if self.norm not in ['batch', 'group', 'instance']:
65
+ y = self.activate(y)
66
+ else:
67
+ y = self.activate(self.norms[i](y))
68
+
69
+ if self.last_op is not None:
70
+ y = self.last_op(y)
71
+
72
+ return y
lib/net/NormalNet.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from lib.net.FBNet import define_G
19
+ from lib.net.net_util import init_net, VGGLoss
20
+ from lib.net.HGFilters import *
21
+ from lib.net.BasePIFuNet import BasePIFuNet
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+
26
+ class NormalNet(BasePIFuNet):
27
+ '''
28
+ HG PIFu network uses Hourglass stacks as the image filter.
29
+ It does the following:
30
+ 1. Compute image feature stacks and store it in self.im_feat_list
31
+ self.im_feat_list[-1] is the last stack (output stack)
32
+ 2. Calculate calibration
33
+ 3. If training, it index on every intermediate stacks,
34
+ If testing, it index on the last stack.
35
+ 4. Classification.
36
+ 5. During training, error is calculated on all stacks.
37
+ '''
38
+
39
+ def __init__(self, cfg, error_term=nn.SmoothL1Loss()):
40
+
41
+ super(NormalNet, self).__init__(error_term=error_term)
42
+
43
+ self.l1_loss = nn.SmoothL1Loss()
44
+
45
+ self.opt = cfg.net
46
+
47
+ if self.training:
48
+ self.vgg_loss = [VGGLoss()]
49
+
50
+ self.in_nmlF = [
51
+ item[0] for item in self.opt.in_nml
52
+ if '_F' in item[0] or item[0] == 'image'
53
+ ]
54
+ self.in_nmlB = [
55
+ item[0] for item in self.opt.in_nml
56
+ if '_B' in item[0] or item[0] == 'image'
57
+ ]
58
+ self.in_nmlF_dim = sum([
59
+ item[1] for item in self.opt.in_nml
60
+ if '_F' in item[0] or item[0] == 'image'
61
+ ])
62
+ self.in_nmlB_dim = sum([
63
+ item[1] for item in self.opt.in_nml
64
+ if '_B' in item[0] or item[0] == 'image'
65
+ ])
66
+
67
+ self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3,
68
+ "instance")
69
+ self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3,
70
+ "instance")
71
+
72
+ init_net(self)
73
+
74
+ def forward(self, in_tensor):
75
+
76
+ inF_list = []
77
+ inB_list = []
78
+
79
+ for name in self.in_nmlF:
80
+ inF_list.append(in_tensor[name])
81
+ for name in self.in_nmlB:
82
+ inB_list.append(in_tensor[name])
83
+
84
+ nmlF = self.netF(torch.cat(inF_list, dim=1))
85
+ nmlB = self.netB(torch.cat(inB_list, dim=1))
86
+
87
+ # ||normal|| == 1
88
+ nmlF /= torch.norm(nmlF, dim=1)
89
+ nmlB /= torch.norm(nmlB, dim=1)
90
+
91
+ # output: float_arr [-1,1] with [B, C, H, W]
92
+
93
+ mask = (in_tensor['image'].abs().sum(dim=1, keepdim=True) !=
94
+ 0.0).detach().float()
95
+
96
+ nmlF = nmlF * mask
97
+ nmlB = nmlB * mask
98
+
99
+ return nmlF, nmlB
100
+
101
+ def get_norm_error(self, prd_F, prd_B, tgt):
102
+ """calculate normal loss
103
+
104
+ Args:
105
+ pred (torch.tensor): [B, 6, 512, 512]
106
+ tagt (torch.tensor): [B, 6, 512, 512]
107
+ """
108
+
109
+ tgt_F, tgt_B = tgt['normal_F'], tgt['normal_B']
110
+
111
+ l1_F_loss = self.l1_loss(prd_F, tgt_F)
112
+ l1_B_loss = self.l1_loss(prd_B, tgt_B)
113
+
114
+ with torch.no_grad():
115
+ vgg_F_loss = self.vgg_loss[0](prd_F, tgt_F)
116
+ vgg_B_loss = self.vgg_loss[0](prd_B, tgt_B)
117
+
118
+ total_loss = [
119
+ 5.0 * l1_F_loss + vgg_F_loss, 5.0 * l1_B_loss + vgg_B_loss
120
+ ]
121
+
122
+ return total_loss
lib/net/VE.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+
19
+ import torch.nn as nn
20
+ import pytorch_lightning as pl
21
+
22
+
23
+ class BaseNetwork(pl.LightningModule):
24
+ def __init__(self):
25
+ super(BaseNetwork, self).__init__()
26
+
27
+ def init_weights(self, init_type='xavier', gain=0.02):
28
+ '''
29
+ initializes network's weights
30
+ init_type: normal | xavier | kaiming | orthogonal
31
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
32
+ '''
33
+ def init_func(m):
34
+ classname = m.__class__.__name__
35
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1
36
+ or classname.find('Linear') != -1):
37
+ if init_type == 'normal':
38
+ nn.init.normal_(m.weight.data, 0.0, gain)
39
+ elif init_type == 'xavier':
40
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
41
+ elif init_type == 'kaiming':
42
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
43
+ elif init_type == 'orthogonal':
44
+ nn.init.orthogonal_(m.weight.data, gain=gain)
45
+
46
+ if hasattr(m, 'bias') and m.bias is not None:
47
+ nn.init.constant_(m.bias.data, 0.0)
48
+
49
+ elif classname.find('BatchNorm2d') != -1:
50
+ nn.init.normal_(m.weight.data, 1.0, gain)
51
+ nn.init.constant_(m.bias.data, 0.0)
52
+
53
+ self.apply(init_func)
54
+
55
+
56
+ class Residual3D(BaseNetwork):
57
+ def __init__(self, numIn, numOut):
58
+ super(Residual3D, self).__init__()
59
+ self.numIn = numIn
60
+ self.numOut = numOut
61
+ self.with_bias = True
62
+ # self.bn = nn.GroupNorm(4, self.numIn)
63
+ self.bn = nn.BatchNorm3d(self.numIn)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.conv1 = nn.Conv3d(self.numIn,
66
+ self.numOut,
67
+ bias=self.with_bias,
68
+ kernel_size=3,
69
+ stride=1,
70
+ padding=2,
71
+ dilation=2)
72
+ # self.bn1 = nn.GroupNorm(4, self.numOut)
73
+ self.bn1 = nn.BatchNorm3d(self.numOut)
74
+ self.conv2 = nn.Conv3d(self.numOut,
75
+ self.numOut,
76
+ bias=self.with_bias,
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=1)
80
+ # self.bn2 = nn.GroupNorm(4, self.numOut)
81
+ self.bn2 = nn.BatchNorm3d(self.numOut)
82
+ self.conv3 = nn.Conv3d(self.numOut,
83
+ self.numOut,
84
+ bias=self.with_bias,
85
+ kernel_size=3,
86
+ stride=1,
87
+ padding=1)
88
+
89
+ if self.numIn != self.numOut:
90
+ self.conv4 = nn.Conv3d(self.numIn,
91
+ self.numOut,
92
+ bias=self.with_bias,
93
+ kernel_size=1)
94
+ self.init_weights()
95
+
96
+ def forward(self, x):
97
+ residual = x
98
+ # out = self.bn(x)
99
+ # out = self.relu(out)
100
+ out = self.conv1(x)
101
+ out = self.bn1(out)
102
+ out = self.relu(out)
103
+ out = self.conv2(out)
104
+ out = self.bn2(out)
105
+ # out = self.conv3(out)
106
+ # out = self.relu(out)
107
+
108
+ if self.numIn != self.numOut:
109
+ residual = self.conv4(x)
110
+
111
+ return out + residual
112
+
113
+
114
+ class VolumeEncoder(BaseNetwork):
115
+ """CycleGan Encoder"""
116
+
117
+ def __init__(self, num_in=3, num_out=32, num_stacks=2):
118
+ super(VolumeEncoder, self).__init__()
119
+ self.num_in = num_in
120
+ self.num_out = num_out
121
+ self.num_inter = 8
122
+ self.num_stacks = num_stacks
123
+ self.with_bias = True
124
+
125
+ self.relu = nn.ReLU(inplace=True)
126
+ self.conv1 = nn.Conv3d(self.num_in,
127
+ self.num_inter,
128
+ bias=self.with_bias,
129
+ kernel_size=5,
130
+ stride=2,
131
+ padding=4,
132
+ dilation=2)
133
+ # self.bn1 = nn.GroupNorm(4, self.num_inter)
134
+ self.bn1 = nn.BatchNorm3d(self.num_inter)
135
+ self.conv2 = nn.Conv3d(self.num_inter,
136
+ self.num_out,
137
+ bias=self.with_bias,
138
+ kernel_size=5,
139
+ stride=2,
140
+ padding=4,
141
+ dilation=2)
142
+ # self.bn2 = nn.GroupNorm(4, self.num_out)
143
+ self.bn2 = nn.BatchNorm3d(self.num_out)
144
+
145
+ self.conv_out1 = nn.Conv3d(self.num_out,
146
+ self.num_out,
147
+ bias=self.with_bias,
148
+ kernel_size=3,
149
+ stride=1,
150
+ padding=1,
151
+ dilation=1)
152
+ self.conv_out2 = nn.Conv3d(self.num_out,
153
+ self.num_out,
154
+ bias=self.with_bias,
155
+ kernel_size=3,
156
+ stride=1,
157
+ padding=1,
158
+ dilation=1)
159
+
160
+ for idx in range(self.num_stacks):
161
+ self.add_module("res" + str(idx),
162
+ Residual3D(self.num_out, self.num_out))
163
+
164
+ self.init_weights()
165
+
166
+ def forward(self, x, intermediate_output=True):
167
+ out = self.conv1(x)
168
+ out = self.bn1(out)
169
+ out = self.relu(out)
170
+
171
+ out = self.conv2(out)
172
+ out = self.bn2(out)
173
+ out = self.relu(out)
174
+
175
+ out_lst = []
176
+ for idx in range(self.num_stacks):
177
+ out = self._modules["res" + str(idx)](out)
178
+ out_lst.append(out)
179
+
180
+ if intermediate_output:
181
+ return out_lst
182
+ else:
183
+ return [out_lst[-1]]
lib/net/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .BasePIFuNet import BasePIFuNet
2
+ from .HGPIFuNet import HGPIFuNet
3
+ from .NormalNet import NormalNet
4
+ from .VE import VolumeEncoder
lib/net/geometry.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch
19
+
20
+
21
+ def index(feat, uv):
22
+ '''
23
+ :param feat: [B, C, H, W] image features
24
+ :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
25
+ :return: [B, C, N] image features at the uv coordinates
26
+ '''
27
+ uv = uv.transpose(1, 2) # [B, N, 2]
28
+
29
+ (B, N, _) = uv.shape
30
+ C = feat.shape[1]
31
+
32
+ if uv.shape[-1] == 3:
33
+ # uv = uv[:,:,[2,1,0]]
34
+ # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
35
+ uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
36
+ else:
37
+ uv = uv.unsqueeze(2) # [B, N, 1, 2]
38
+
39
+ # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
40
+ # for old versions, simply remove the aligned_corners argument.
41
+ samples = torch.nn.functional.grid_sample(
42
+ feat, uv, align_corners=True) # [B, C, N, 1]
43
+ return samples.view(B, C, N) # [B, C, N]
44
+
45
+
46
+ def orthogonal(points, calibrations, transforms=None):
47
+ '''
48
+ Compute the orthogonal projections of 3D points into the image plane by given projection matrix
49
+ :param points: [B, 3, N] Tensor of 3D points
50
+ :param calibrations: [B, 3, 4] Tensor of projection matrix
51
+ :param transforms: [B, 2, 3] Tensor of image transform matrix
52
+ :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
53
+ '''
54
+ rot = calibrations[:, :3, :3]
55
+ trans = calibrations[:, :3, 3:4]
56
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
57
+ if transforms is not None:
58
+ scale = transforms[:2, :2]
59
+ shift = transforms[:2, 2:3]
60
+ pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
61
+ return pts
62
+
63
+
64
+ def perspective(points, calibrations, transforms=None):
65
+ '''
66
+ Compute the perspective projections of 3D points into the image plane by given projection matrix
67
+ :param points: [Bx3xN] Tensor of 3D points
68
+ :param calibrations: [Bx3x4] Tensor of projection matrix
69
+ :param transforms: [Bx2x3] Tensor of image transform matrix
70
+ :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
71
+ '''
72
+ rot = calibrations[:, :3, :3]
73
+ trans = calibrations[:, :3, 3:4]
74
+ homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
75
+ xy = homo[:, :2, :] / homo[:, 2:3, :]
76
+ if transforms is not None:
77
+ scale = transforms[:2, :2]
78
+ shift = transforms[:2, 2:3]
79
+ xy = torch.baddbmm(shift, scale, xy)
80
+
81
+ xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
82
+ return xyz
lib/net/net_util.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from torchvision import models
19
+ import torch
20
+ from torch.nn import init
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import functools
24
+ from torch.autograd import grad
25
+
26
+
27
+ def gradient(inputs, outputs):
28
+ d_points = torch.ones_like(outputs,
29
+ requires_grad=False,
30
+ device=outputs.device)
31
+ points_grad = grad(outputs=outputs,
32
+ inputs=inputs,
33
+ grad_outputs=d_points,
34
+ create_graph=True,
35
+ retain_graph=True,
36
+ only_inputs=True,
37
+ allow_unused=True)[0]
38
+ return points_grad
39
+
40
+
41
+ # def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
42
+ # "3x3 convolution with padding"
43
+ # return nn.Conv2d(in_planes, out_planes, kernel_size=3,
44
+ # stride=strd, padding=padding, bias=bias)
45
+
46
+
47
+ def conv3x3(in_planes,
48
+ out_planes,
49
+ kernel=3,
50
+ strd=1,
51
+ dilation=1,
52
+ padding=1,
53
+ bias=False):
54
+ "3x3 convolution with padding"
55
+ return nn.Conv2d(in_planes,
56
+ out_planes,
57
+ kernel_size=kernel,
58
+ dilation=dilation,
59
+ stride=strd,
60
+ padding=padding,
61
+ bias=bias)
62
+
63
+
64
+ def conv1x1(in_planes, out_planes, stride=1):
65
+ """1x1 convolution"""
66
+ return nn.Conv2d(in_planes,
67
+ out_planes,
68
+ kernel_size=1,
69
+ stride=stride,
70
+ bias=False)
71
+
72
+
73
+ def init_weights(net, init_type='normal', init_gain=0.02):
74
+ """Initialize network weights.
75
+
76
+ Parameters:
77
+ net (network) -- network to be initialized
78
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
79
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
80
+
81
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
82
+ work better for some applications. Feel free to try yourself.
83
+ """
84
+ def init_func(m): # define the initialization function
85
+ classname = m.__class__.__name__
86
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1
87
+ or classname.find('Linear') != -1):
88
+ if init_type == 'normal':
89
+ init.normal_(m.weight.data, 0.0, init_gain)
90
+ elif init_type == 'xavier':
91
+ init.xavier_normal_(m.weight.data, gain=init_gain)
92
+ elif init_type == 'kaiming':
93
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
94
+ elif init_type == 'orthogonal':
95
+ init.orthogonal_(m.weight.data, gain=init_gain)
96
+ else:
97
+ raise NotImplementedError(
98
+ 'initialization method [%s] is not implemented' %
99
+ init_type)
100
+ if hasattr(m, 'bias') and m.bias is not None:
101
+ init.constant_(m.bias.data, 0.0)
102
+ elif classname.find(
103
+ 'BatchNorm2d'
104
+ ) != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
105
+ init.normal_(m.weight.data, 1.0, init_gain)
106
+ init.constant_(m.bias.data, 0.0)
107
+
108
+ # print('initialize network with %s' % init_type)
109
+ net.apply(init_func) # apply the initialization function <init_func>
110
+
111
+
112
+ def init_net(net, init_type='xavier', init_gain=0.02, gpu_ids=[]):
113
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
114
+ Parameters:
115
+ net (network) -- the network to be initialized
116
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
117
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
118
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
119
+
120
+ Return an initialized network.
121
+ """
122
+ if len(gpu_ids) > 0:
123
+ assert (torch.cuda.is_available())
124
+ net = torch.nn.DataParallel(net) # multi-GPUs
125
+ init_weights(net, init_type, init_gain=init_gain)
126
+ return net
127
+
128
+
129
+ def imageSpaceRotation(xy, rot):
130
+ '''
131
+ args:
132
+ xy: (B, 2, N) input
133
+ rot: (B, 2) x,y axis rotation angles
134
+
135
+ rotation center will be always image center (other rotation center can be represented by additional z translation)
136
+ '''
137
+ disp = rot.unsqueeze(2).sin().expand_as(xy)
138
+ return (disp * xy).sum(dim=1)
139
+
140
+
141
+ def cal_gradient_penalty(netD,
142
+ real_data,
143
+ fake_data,
144
+ device,
145
+ type='mixed',
146
+ constant=1.0,
147
+ lambda_gp=10.0):
148
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
149
+
150
+ Arguments:
151
+ netD (network) -- discriminator network
152
+ real_data (tensor array) -- real images
153
+ fake_data (tensor array) -- generated images from the generator
154
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
155
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
156
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
157
+ lambda_gp (float) -- weight for this loss
158
+
159
+ Returns the gradient penalty loss
160
+ """
161
+ if lambda_gp > 0.0:
162
+ # either use real images, fake images, or a linear interpolation of two.
163
+ if type == 'real':
164
+ interpolatesv = real_data
165
+ elif type == 'fake':
166
+ interpolatesv = fake_data
167
+ elif type == 'mixed':
168
+ alpha = torch.rand(real_data.shape[0], 1)
169
+ alpha = alpha.expand(
170
+ real_data.shape[0],
171
+ real_data.nelement() //
172
+ real_data.shape[0]).contiguous().view(*real_data.shape)
173
+ alpha = alpha.to(device)
174
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
175
+ else:
176
+ raise NotImplementedError('{} not implemented'.format(type))
177
+ interpolatesv.requires_grad_(True)
178
+ disc_interpolates = netD(interpolatesv)
179
+ gradients = torch.autograd.grad(
180
+ outputs=disc_interpolates,
181
+ inputs=interpolatesv,
182
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
183
+ create_graph=True,
184
+ retain_graph=True,
185
+ only_inputs=True)
186
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
187
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) **
188
+ 2).mean() * lambda_gp # added eps
189
+ return gradient_penalty, gradients
190
+ else:
191
+ return 0.0, None
192
+
193
+
194
+ def get_norm_layer(norm_type='instance'):
195
+ """Return a normalization layer
196
+ Parameters:
197
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
198
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
199
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
200
+ """
201
+ if norm_type == 'batch':
202
+ norm_layer = functools.partial(nn.BatchNorm2d,
203
+ affine=True,
204
+ track_running_stats=True)
205
+ elif norm_type == 'instance':
206
+ norm_layer = functools.partial(nn.InstanceNorm2d,
207
+ affine=False,
208
+ track_running_stats=False)
209
+ elif norm_type == 'group':
210
+ norm_layer = functools.partial(nn.GroupNorm, 32)
211
+ elif norm_type == 'none':
212
+ norm_layer = None
213
+ else:
214
+ raise NotImplementedError('normalization layer [%s] is not found' %
215
+ norm_type)
216
+ return norm_layer
217
+
218
+
219
+ class Flatten(nn.Module):
220
+ def forward(self, input):
221
+ return input.view(input.size(0), -1)
222
+
223
+
224
+ class ConvBlock(nn.Module):
225
+ def __init__(self, in_planes, out_planes, opt):
226
+ super(ConvBlock, self).__init__()
227
+ [k, s, d, p] = opt.conv3x3
228
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p)
229
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d,
230
+ p)
231
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d,
232
+ p)
233
+
234
+ if opt.norm == 'batch':
235
+ self.bn1 = nn.BatchNorm2d(in_planes)
236
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
237
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
238
+ self.bn4 = nn.BatchNorm2d(in_planes)
239
+ elif opt.norm == 'group':
240
+ self.bn1 = nn.GroupNorm(32, in_planes)
241
+ self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
242
+ self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
243
+ self.bn4 = nn.GroupNorm(32, in_planes)
244
+
245
+ if in_planes != out_planes:
246
+ self.downsample = nn.Sequential(
247
+ self.bn4,
248
+ nn.ReLU(True),
249
+ nn.Conv2d(in_planes,
250
+ out_planes,
251
+ kernel_size=1,
252
+ stride=1,
253
+ bias=False),
254
+ )
255
+ else:
256
+ self.downsample = None
257
+
258
+ def forward(self, x):
259
+ residual = x
260
+
261
+ out1 = self.bn1(x)
262
+ out1 = F.relu(out1, True)
263
+ out1 = self.conv1(out1)
264
+
265
+ out2 = self.bn2(out1)
266
+ out2 = F.relu(out2, True)
267
+ out2 = self.conv2(out2)
268
+
269
+ out3 = self.bn3(out2)
270
+ out3 = F.relu(out3, True)
271
+ out3 = self.conv3(out3)
272
+
273
+ out3 = torch.cat((out1, out2, out3), 1)
274
+
275
+ if self.downsample is not None:
276
+ residual = self.downsample(residual)
277
+
278
+ out3 += residual
279
+
280
+ return out3
281
+
282
+
283
+ class Vgg19(torch.nn.Module):
284
+ def __init__(self, requires_grad=False):
285
+ super(Vgg19, self).__init__()
286
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
287
+ self.slice1 = torch.nn.Sequential()
288
+ self.slice2 = torch.nn.Sequential()
289
+ self.slice3 = torch.nn.Sequential()
290
+ self.slice4 = torch.nn.Sequential()
291
+ self.slice5 = torch.nn.Sequential()
292
+ for x in range(2):
293
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
294
+ for x in range(2, 7):
295
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
296
+ for x in range(7, 12):
297
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
298
+ for x in range(12, 21):
299
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
300
+ for x in range(21, 30):
301
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
302
+ if not requires_grad:
303
+ for param in self.parameters():
304
+ param.requires_grad = False
305
+
306
+ def forward(self, X):
307
+ h_relu1 = self.slice1(X)
308
+ h_relu2 = self.slice2(h_relu1)
309
+ h_relu3 = self.slice3(h_relu2)
310
+ h_relu4 = self.slice4(h_relu3)
311
+ h_relu5 = self.slice5(h_relu4)
312
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
313
+ return out
314
+
315
+
316
+ class VGGLoss(nn.Module):
317
+ def __init__(self):
318
+ super(VGGLoss, self).__init__()
319
+ self.vgg = Vgg19()
320
+ self.criterion = nn.L1Loss()
321
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
322
+
323
+ def forward(self, x, y):
324
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
325
+ loss = 0
326
+ for i in range(len(x_vgg)):
327
+ loss += self.weights[i] * self.criterion(x_vgg[i],
328
+ y_vgg[i].detach())
329
+ return loss
lib/net/voxelize.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division, print_function
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from torch.autograd import Function
7
+
8
+ import voxelize_cuda
9
+
10
+
11
+ class VoxelizationFunction(Function):
12
+ """
13
+ Definition of differentiable voxelization function
14
+ Currently implemented only for cuda Tensors
15
+ """
16
+ @staticmethod
17
+ def forward(ctx, smpl_vertices, smpl_face_center, smpl_face_normal,
18
+ smpl_vertex_code, smpl_face_code, smpl_tetrahedrons,
19
+ volume_res, sigma, smooth_kernel_size):
20
+ """
21
+ forward pass
22
+ Output format: (batch_size, z_dims, y_dims, x_dims, channel_num)
23
+ """
24
+ assert (smpl_vertices.size()[1] == smpl_vertex_code.size()[1])
25
+ assert (smpl_face_center.size()[1] == smpl_face_normal.size()[1])
26
+ assert (smpl_face_center.size()[1] == smpl_face_code.size()[1])
27
+ ctx.batch_size = smpl_vertices.size()[0]
28
+ ctx.volume_res = volume_res
29
+ ctx.sigma = sigma
30
+ ctx.smooth_kernel_size = smooth_kernel_size
31
+ ctx.smpl_vertex_num = smpl_vertices.size()[1]
32
+ ctx.device = smpl_vertices.device
33
+
34
+ smpl_vertices = smpl_vertices.contiguous()
35
+ smpl_face_center = smpl_face_center.contiguous()
36
+ smpl_face_normal = smpl_face_normal.contiguous()
37
+ smpl_vertex_code = smpl_vertex_code.contiguous()
38
+ smpl_face_code = smpl_face_code.contiguous()
39
+ smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
40
+
41
+ occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res,
42
+ ctx.volume_res,
43
+ ctx.volume_res).fill_(0.0)
44
+ semantic_volume = torch.cuda.FloatTensor(ctx.batch_size,
45
+ ctx.volume_res,
46
+ ctx.volume_res,
47
+ ctx.volume_res, 3).fill_(0.0)
48
+ weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size,
49
+ ctx.volume_res,
50
+ ctx.volume_res,
51
+ ctx.volume_res).fill_(1e-3)
52
+
53
+ # occ_volume [B, volume_res, volume_res, volume_res]
54
+ # semantic_volume [B, volume_res, volume_res, volume_res, 3]
55
+ # weight_sum_volume [B, volume_res, volume_res, volume_res]
56
+
57
+ occ_volume, semantic_volume, weight_sum_volume = voxelize_cuda.forward_semantic_voxelization(
58
+ smpl_vertices, smpl_vertex_code, smpl_tetrahedrons, occ_volume,
59
+ semantic_volume, weight_sum_volume, sigma)
60
+
61
+ return semantic_volume
62
+
63
+
64
+ class Voxelization(nn.Module):
65
+ """
66
+ Wrapper around the autograd function VoxelizationFunction
67
+ """
68
+
69
+ def __init__(self, smpl_vertex_code, smpl_face_code, smpl_face_indices,
70
+ smpl_tetraderon_indices, volume_res, sigma,
71
+ smooth_kernel_size, batch_size, device):
72
+ super(Voxelization, self).__init__()
73
+ assert (len(smpl_face_indices.shape) == 2)
74
+ assert (len(smpl_tetraderon_indices.shape) == 2)
75
+ assert (smpl_face_indices.shape[1] == 3)
76
+ assert (smpl_tetraderon_indices.shape[1] == 4)
77
+
78
+ self.volume_res = volume_res
79
+ self.sigma = sigma
80
+ self.smooth_kernel_size = smooth_kernel_size
81
+ self.batch_size = batch_size
82
+ self.device = device
83
+
84
+ self.smpl_vertex_code = smpl_vertex_code
85
+ self.smpl_face_code = smpl_face_code
86
+ self.smpl_face_indices = smpl_face_indices
87
+ self.smpl_tetraderon_indices = smpl_tetraderon_indices
88
+
89
+ def update_param(self, batch_size, smpl_tetra):
90
+
91
+ self.batch_size = batch_size
92
+ self.smpl_tetraderon_indices = smpl_tetra
93
+
94
+ smpl_vertex_code_batch = np.tile(self.smpl_vertex_code,
95
+ (self.batch_size, 1, 1))
96
+ smpl_face_code_batch = np.tile(self.smpl_face_code,
97
+ (self.batch_size, 1, 1))
98
+ smpl_face_indices_batch = np.tile(self.smpl_face_indices,
99
+ (self.batch_size, 1, 1))
100
+ smpl_tetraderon_indices_batch = np.tile(self.smpl_tetraderon_indices,
101
+ (self.batch_size, 1, 1))
102
+
103
+ smpl_vertex_code_batch = torch.from_numpy(
104
+ smpl_vertex_code_batch).contiguous().to(self.device)
105
+ smpl_face_code_batch = torch.from_numpy(
106
+ smpl_face_code_batch).contiguous().to(self.device)
107
+ smpl_face_indices_batch = torch.from_numpy(
108
+ smpl_face_indices_batch).contiguous().to(self.device)
109
+ smpl_tetraderon_indices_batch = torch.from_numpy(
110
+ smpl_tetraderon_indices_batch).contiguous().to(self.device)
111
+
112
+ self.register_buffer('smpl_vertex_code_batch', smpl_vertex_code_batch)
113
+ self.register_buffer('smpl_face_code_batch', smpl_face_code_batch)
114
+ self.register_buffer('smpl_face_indices_batch',
115
+ smpl_face_indices_batch)
116
+ self.register_buffer('smpl_tetraderon_indices_batch',
117
+ smpl_tetraderon_indices_batch)
118
+
119
+ def forward(self, smpl_vertices):
120
+ """
121
+ Generate semantic volumes from SMPL vertices
122
+ """
123
+ assert (smpl_vertices.size()[0] == self.batch_size)
124
+ self.check_input(smpl_vertices)
125
+ smpl_faces = self.vertices_to_faces(smpl_vertices)
126
+ smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices)
127
+ smpl_face_center = self.calc_face_centers(smpl_faces)
128
+ smpl_face_normal = self.calc_face_normals(smpl_faces)
129
+ smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1]
130
+ smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :]
131
+ vol = VoxelizationFunction.apply(smpl_vertices_surface,
132
+ smpl_face_center, smpl_face_normal,
133
+ self.smpl_vertex_code_batch,
134
+ self.smpl_face_code_batch,
135
+ smpl_tetrahedrons, self.volume_res,
136
+ self.sigma, self.smooth_kernel_size)
137
+ return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
138
+
139
+ def vertices_to_faces(self, vertices):
140
+ assert (vertices.ndimension() == 3)
141
+ bs, nv = vertices.shape[:2]
142
+ device = vertices.device
143
+ face = self.smpl_face_indices_batch + (
144
+ torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
145
+ vertices_ = vertices.reshape((bs * nv, 3))
146
+ return vertices_[face.long()]
147
+
148
+ def vertices_to_tetrahedrons(self, vertices):
149
+ assert (vertices.ndimension() == 3)
150
+ bs, nv = vertices.shape[:2]
151
+ device = vertices.device
152
+ tets = self.smpl_tetraderon_indices_batch + (
153
+ torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
154
+ vertices_ = vertices.reshape((bs * nv, 3))
155
+ return vertices_[tets.long()]
156
+
157
+ def calc_face_centers(self, face_verts):
158
+ assert len(face_verts.shape) == 4
159
+ assert face_verts.shape[2] == 3
160
+ assert face_verts.shape[3] == 3
161
+ bs, nf = face_verts.shape[:2]
162
+ face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] +
163
+ face_verts[:, :, 2, :]) / 3.0
164
+ face_centers = face_centers.reshape((bs, nf, 3))
165
+ return face_centers
166
+
167
+ def calc_face_normals(self, face_verts):
168
+ assert len(face_verts.shape) == 4
169
+ assert face_verts.shape[2] == 3
170
+ assert face_verts.shape[3] == 3
171
+ bs, nf = face_verts.shape[:2]
172
+ face_verts = face_verts.reshape((bs * nf, 3, 3))
173
+ v10 = face_verts[:, 0] - face_verts[:, 1]
174
+ v12 = face_verts[:, 2] - face_verts[:, 1]
175
+ normals = F.normalize(torch.cross(v10, v12), eps=1e-5)
176
+ normals = normals.reshape((bs, nf, 3))
177
+ return normals
178
+
179
+ def check_input(self, x):
180
+ if x.device == 'cpu':
181
+ raise TypeError('Voxelization module supports only cuda tensors')
182
+ if x.type() != 'torch.cuda.FloatTensor':
183
+ raise TypeError(
184
+ 'Voxelization module supports only float32 tensors')