Spaces:
Build error
Build error
Upload 11 files
Browse files- lib/net/BasePIFuNet.py +84 -0
- lib/net/FBNet.py +388 -0
- lib/net/HGFilters.py +197 -0
- lib/net/HGPIFuNet.py +403 -0
- lib/net/MLP.py +72 -0
- lib/net/NormalNet.py +122 -0
- lib/net/VE.py +183 -0
- lib/net/__init__.py +4 -0
- lib/net/geometry.py +82 -0
- lib/net/net_util.py +329 -0
- lib/net/voxelize.py +184 -0
lib/net/BasePIFuNet.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch.nn as nn
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
|
21 |
+
from .geometry import index, orthogonal, perspective
|
22 |
+
|
23 |
+
|
24 |
+
class BasePIFuNet(pl.LightningModule):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
projection_mode='orthogonal',
|
28 |
+
error_term=nn.MSELoss(),
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
:param projection_mode:
|
32 |
+
Either orthogonal or perspective.
|
33 |
+
It will call the corresponding function for projection.
|
34 |
+
:param error_term:
|
35 |
+
nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
|
36 |
+
"""
|
37 |
+
super(BasePIFuNet, self).__init__()
|
38 |
+
self.name = 'base'
|
39 |
+
|
40 |
+
self.error_term = error_term
|
41 |
+
|
42 |
+
self.index = index
|
43 |
+
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
|
44 |
+
|
45 |
+
def forward(self, points, images, calibs, transforms=None):
|
46 |
+
'''
|
47 |
+
:param points: [B, 3, N] world space coordinates of points
|
48 |
+
:param images: [B, C, H, W] input images
|
49 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
50 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
51 |
+
:return: [B, Res, N] predictions for each point
|
52 |
+
'''
|
53 |
+
features = self.filter(images)
|
54 |
+
preds = self.query(features, points, calibs, transforms)
|
55 |
+
return preds
|
56 |
+
|
57 |
+
def filter(self, images):
|
58 |
+
'''
|
59 |
+
Filter the input images
|
60 |
+
store all intermediate features.
|
61 |
+
:param images: [B, C, H, W] input images
|
62 |
+
'''
|
63 |
+
return None
|
64 |
+
|
65 |
+
def query(self, features, points, calibs, transforms=None):
|
66 |
+
'''
|
67 |
+
Given 3D points, query the network predictions for each point.
|
68 |
+
Image features should be pre-computed before this call.
|
69 |
+
store all intermediate features.
|
70 |
+
query() function may behave differently during training/testing.
|
71 |
+
:param points: [B, 3, N] world space coordinates of points
|
72 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
73 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
74 |
+
:param labels: Optional [B, Res, N] gt labeling
|
75 |
+
:return: [B, Res, N] predictions for each point
|
76 |
+
'''
|
77 |
+
return None
|
78 |
+
|
79 |
+
def get_error(self, preds, labels):
|
80 |
+
'''
|
81 |
+
Get the network loss from the last query
|
82 |
+
:return: loss term
|
83 |
+
'''
|
84 |
+
return self.error_term(preds, labels)
|
lib/net/FBNet.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
|
3 |
+
BSD License. All rights reserved.
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
* Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
|
17 |
+
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
|
18 |
+
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
19 |
+
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
|
20 |
+
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
21 |
+
'''
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import functools
|
25 |
+
import numpy as np
|
26 |
+
import pytorch_lightning as pl
|
27 |
+
|
28 |
+
|
29 |
+
###############################################################################
|
30 |
+
# Functions
|
31 |
+
###############################################################################
|
32 |
+
def weights_init(m):
|
33 |
+
classname = m.__class__.__name__
|
34 |
+
if classname.find('Conv') != -1:
|
35 |
+
m.weight.data.normal_(0.0, 0.02)
|
36 |
+
elif classname.find('BatchNorm2d') != -1:
|
37 |
+
m.weight.data.normal_(1.0, 0.02)
|
38 |
+
m.bias.data.fill_(0)
|
39 |
+
|
40 |
+
|
41 |
+
def get_norm_layer(norm_type='instance'):
|
42 |
+
if norm_type == 'batch':
|
43 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
44 |
+
elif norm_type == 'instance':
|
45 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError('normalization layer [%s] is not found' %
|
48 |
+
norm_type)
|
49 |
+
return norm_layer
|
50 |
+
|
51 |
+
|
52 |
+
def define_G(input_nc,
|
53 |
+
output_nc,
|
54 |
+
ngf,
|
55 |
+
netG,
|
56 |
+
n_downsample_global=3,
|
57 |
+
n_blocks_global=9,
|
58 |
+
n_local_enhancers=1,
|
59 |
+
n_blocks_local=3,
|
60 |
+
norm='instance',
|
61 |
+
gpu_ids=[],
|
62 |
+
last_op=nn.Tanh()):
|
63 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
64 |
+
if netG == 'global':
|
65 |
+
netG = GlobalGenerator(input_nc,
|
66 |
+
output_nc,
|
67 |
+
ngf,
|
68 |
+
n_downsample_global,
|
69 |
+
n_blocks_global,
|
70 |
+
norm_layer,
|
71 |
+
last_op=last_op)
|
72 |
+
elif netG == 'local':
|
73 |
+
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
|
74 |
+
n_blocks_global, n_local_enhancers,
|
75 |
+
n_blocks_local, norm_layer)
|
76 |
+
elif netG == 'encoder':
|
77 |
+
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
|
78 |
+
norm_layer)
|
79 |
+
else:
|
80 |
+
raise ('generator not implemented!')
|
81 |
+
# print(netG)
|
82 |
+
if len(gpu_ids) > 0:
|
83 |
+
assert (torch.cuda.is_available())
|
84 |
+
device=torch.device(f"cuda:{gpu_ids[0]}")
|
85 |
+
netG = netG.to(device)
|
86 |
+
netG.apply(weights_init)
|
87 |
+
return netG
|
88 |
+
|
89 |
+
|
90 |
+
def print_network(net):
|
91 |
+
if isinstance(net, list):
|
92 |
+
net = net[0]
|
93 |
+
num_params = 0
|
94 |
+
for param in net.parameters():
|
95 |
+
num_params += param.numel()
|
96 |
+
print(net)
|
97 |
+
print('Total number of parameters: %d' % num_params)
|
98 |
+
|
99 |
+
|
100 |
+
##############################################################################
|
101 |
+
# Generator
|
102 |
+
##############################################################################
|
103 |
+
class LocalEnhancer(pl.LightningModule):
|
104 |
+
def __init__(self,
|
105 |
+
input_nc,
|
106 |
+
output_nc,
|
107 |
+
ngf=32,
|
108 |
+
n_downsample_global=3,
|
109 |
+
n_blocks_global=9,
|
110 |
+
n_local_enhancers=1,
|
111 |
+
n_blocks_local=3,
|
112 |
+
norm_layer=nn.BatchNorm2d,
|
113 |
+
padding_type='reflect'):
|
114 |
+
super(LocalEnhancer, self).__init__()
|
115 |
+
self.n_local_enhancers = n_local_enhancers
|
116 |
+
|
117 |
+
###### global generator model #####
|
118 |
+
ngf_global = ngf * (2**n_local_enhancers)
|
119 |
+
model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
|
120 |
+
n_downsample_global, n_blocks_global,
|
121 |
+
norm_layer).model
|
122 |
+
model_global = [model_global[i] for i in range(len(model_global) - 3)
|
123 |
+
] # get rid of final convolution layers
|
124 |
+
self.model = nn.Sequential(*model_global)
|
125 |
+
|
126 |
+
###### local enhancer layers #####
|
127 |
+
for n in range(1, n_local_enhancers + 1):
|
128 |
+
# downsample
|
129 |
+
ngf_global = ngf * (2**(n_local_enhancers - n))
|
130 |
+
model_downsample = [
|
131 |
+
nn.ReflectionPad2d(3),
|
132 |
+
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
133 |
+
norm_layer(ngf_global),
|
134 |
+
nn.ReLU(True),
|
135 |
+
nn.Conv2d(ngf_global,
|
136 |
+
ngf_global * 2,
|
137 |
+
kernel_size=3,
|
138 |
+
stride=2,
|
139 |
+
padding=1),
|
140 |
+
norm_layer(ngf_global * 2),
|
141 |
+
nn.ReLU(True)
|
142 |
+
]
|
143 |
+
# residual blocks
|
144 |
+
model_upsample = []
|
145 |
+
for i in range(n_blocks_local):
|
146 |
+
model_upsample += [
|
147 |
+
ResnetBlock(ngf_global * 2,
|
148 |
+
padding_type=padding_type,
|
149 |
+
norm_layer=norm_layer)
|
150 |
+
]
|
151 |
+
|
152 |
+
# upsample
|
153 |
+
model_upsample += [
|
154 |
+
nn.ConvTranspose2d(ngf_global * 2,
|
155 |
+
ngf_global,
|
156 |
+
kernel_size=3,
|
157 |
+
stride=2,
|
158 |
+
padding=1,
|
159 |
+
output_padding=1),
|
160 |
+
norm_layer(ngf_global),
|
161 |
+
nn.ReLU(True)
|
162 |
+
]
|
163 |
+
|
164 |
+
# final convolution
|
165 |
+
if n == n_local_enhancers:
|
166 |
+
model_upsample += [
|
167 |
+
nn.ReflectionPad2d(3),
|
168 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
169 |
+
nn.Tanh()
|
170 |
+
]
|
171 |
+
|
172 |
+
setattr(self, 'model' + str(n) + '_1',
|
173 |
+
nn.Sequential(*model_downsample))
|
174 |
+
setattr(self, 'model' + str(n) + '_2',
|
175 |
+
nn.Sequential(*model_upsample))
|
176 |
+
|
177 |
+
self.downsample = nn.AvgPool2d(3,
|
178 |
+
stride=2,
|
179 |
+
padding=[1, 1],
|
180 |
+
count_include_pad=False)
|
181 |
+
|
182 |
+
def forward(self, input):
|
183 |
+
# create input pyramid
|
184 |
+
input_downsampled = [input]
|
185 |
+
for i in range(self.n_local_enhancers):
|
186 |
+
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
187 |
+
|
188 |
+
# output at coarest level
|
189 |
+
output_prev = self.model(input_downsampled[-1])
|
190 |
+
# build up one layer at a time
|
191 |
+
for n_local_enhancers in range(1, self.n_local_enhancers + 1):
|
192 |
+
model_downsample = getattr(self,
|
193 |
+
'model' + str(n_local_enhancers) + '_1')
|
194 |
+
model_upsample = getattr(self,
|
195 |
+
'model' + str(n_local_enhancers) + '_2')
|
196 |
+
input_i = input_downsampled[self.n_local_enhancers -
|
197 |
+
n_local_enhancers]
|
198 |
+
output_prev = model_upsample(
|
199 |
+
model_downsample(input_i) + output_prev)
|
200 |
+
return output_prev
|
201 |
+
|
202 |
+
|
203 |
+
class GlobalGenerator(pl.LightningModule):
|
204 |
+
def __init__(self,
|
205 |
+
input_nc,
|
206 |
+
output_nc,
|
207 |
+
ngf=64,
|
208 |
+
n_downsampling=3,
|
209 |
+
n_blocks=9,
|
210 |
+
norm_layer=nn.BatchNorm2d,
|
211 |
+
padding_type='reflect',
|
212 |
+
last_op=nn.Tanh()):
|
213 |
+
assert (n_blocks >= 0)
|
214 |
+
super(GlobalGenerator, self).__init__()
|
215 |
+
activation = nn.ReLU(True)
|
216 |
+
|
217 |
+
model = [
|
218 |
+
nn.ReflectionPad2d(3),
|
219 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
220 |
+
norm_layer(ngf), activation
|
221 |
+
]
|
222 |
+
# downsample
|
223 |
+
for i in range(n_downsampling):
|
224 |
+
mult = 2**i
|
225 |
+
model += [
|
226 |
+
nn.Conv2d(ngf * mult,
|
227 |
+
ngf * mult * 2,
|
228 |
+
kernel_size=3,
|
229 |
+
stride=2,
|
230 |
+
padding=1),
|
231 |
+
norm_layer(ngf * mult * 2), activation
|
232 |
+
]
|
233 |
+
|
234 |
+
# resnet blocks
|
235 |
+
mult = 2**n_downsampling
|
236 |
+
for i in range(n_blocks):
|
237 |
+
model += [
|
238 |
+
ResnetBlock(ngf * mult,
|
239 |
+
padding_type=padding_type,
|
240 |
+
activation=activation,
|
241 |
+
norm_layer=norm_layer)
|
242 |
+
]
|
243 |
+
|
244 |
+
# upsample
|
245 |
+
for i in range(n_downsampling):
|
246 |
+
mult = 2**(n_downsampling - i)
|
247 |
+
model += [
|
248 |
+
nn.ConvTranspose2d(ngf * mult,
|
249 |
+
int(ngf * mult / 2),
|
250 |
+
kernel_size=3,
|
251 |
+
stride=2,
|
252 |
+
padding=1,
|
253 |
+
output_padding=1),
|
254 |
+
norm_layer(int(ngf * mult / 2)), activation
|
255 |
+
]
|
256 |
+
model += [
|
257 |
+
nn.ReflectionPad2d(3),
|
258 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
|
259 |
+
]
|
260 |
+
if last_op is not None:
|
261 |
+
model += [last_op]
|
262 |
+
self.model = nn.Sequential(*model)
|
263 |
+
|
264 |
+
def forward(self, input):
|
265 |
+
return self.model(input)
|
266 |
+
|
267 |
+
|
268 |
+
# Define a resnet block
|
269 |
+
class ResnetBlock(pl.LightningModule):
|
270 |
+
def __init__(self,
|
271 |
+
dim,
|
272 |
+
padding_type,
|
273 |
+
norm_layer,
|
274 |
+
activation=nn.ReLU(True),
|
275 |
+
use_dropout=False):
|
276 |
+
super(ResnetBlock, self).__init__()
|
277 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
|
278 |
+
activation, use_dropout)
|
279 |
+
|
280 |
+
def build_conv_block(self, dim, padding_type, norm_layer, activation,
|
281 |
+
use_dropout):
|
282 |
+
conv_block = []
|
283 |
+
p = 0
|
284 |
+
if padding_type == 'reflect':
|
285 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
286 |
+
elif padding_type == 'replicate':
|
287 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
288 |
+
elif padding_type == 'zero':
|
289 |
+
p = 1
|
290 |
+
else:
|
291 |
+
raise NotImplementedError('padding [%s] is not implemented' %
|
292 |
+
padding_type)
|
293 |
+
|
294 |
+
conv_block += [
|
295 |
+
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
296 |
+
norm_layer(dim), activation
|
297 |
+
]
|
298 |
+
if use_dropout:
|
299 |
+
conv_block += [nn.Dropout(0.5)]
|
300 |
+
|
301 |
+
p = 0
|
302 |
+
if padding_type == 'reflect':
|
303 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
304 |
+
elif padding_type == 'replicate':
|
305 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
306 |
+
elif padding_type == 'zero':
|
307 |
+
p = 1
|
308 |
+
else:
|
309 |
+
raise NotImplementedError('padding [%s] is not implemented' %
|
310 |
+
padding_type)
|
311 |
+
conv_block += [
|
312 |
+
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
313 |
+
norm_layer(dim)
|
314 |
+
]
|
315 |
+
|
316 |
+
return nn.Sequential(*conv_block)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
out = x + self.conv_block(x)
|
320 |
+
return out
|
321 |
+
|
322 |
+
|
323 |
+
class Encoder(pl.LightningModule):
|
324 |
+
def __init__(self,
|
325 |
+
input_nc,
|
326 |
+
output_nc,
|
327 |
+
ngf=32,
|
328 |
+
n_downsampling=4,
|
329 |
+
norm_layer=nn.BatchNorm2d):
|
330 |
+
super(Encoder, self).__init__()
|
331 |
+
self.output_nc = output_nc
|
332 |
+
|
333 |
+
model = [
|
334 |
+
nn.ReflectionPad2d(3),
|
335 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
336 |
+
norm_layer(ngf),
|
337 |
+
nn.ReLU(True)
|
338 |
+
]
|
339 |
+
# downsample
|
340 |
+
for i in range(n_downsampling):
|
341 |
+
mult = 2**i
|
342 |
+
model += [
|
343 |
+
nn.Conv2d(ngf * mult,
|
344 |
+
ngf * mult * 2,
|
345 |
+
kernel_size=3,
|
346 |
+
stride=2,
|
347 |
+
padding=1),
|
348 |
+
norm_layer(ngf * mult * 2),
|
349 |
+
nn.ReLU(True)
|
350 |
+
]
|
351 |
+
|
352 |
+
# upsample
|
353 |
+
for i in range(n_downsampling):
|
354 |
+
mult = 2**(n_downsampling - i)
|
355 |
+
model += [
|
356 |
+
nn.ConvTranspose2d(ngf * mult,
|
357 |
+
int(ngf * mult / 2),
|
358 |
+
kernel_size=3,
|
359 |
+
stride=2,
|
360 |
+
padding=1,
|
361 |
+
output_padding=1),
|
362 |
+
norm_layer(int(ngf * mult / 2)),
|
363 |
+
nn.ReLU(True)
|
364 |
+
]
|
365 |
+
|
366 |
+
model += [
|
367 |
+
nn.ReflectionPad2d(3),
|
368 |
+
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
369 |
+
nn.Tanh()
|
370 |
+
]
|
371 |
+
self.model = nn.Sequential(*model)
|
372 |
+
|
373 |
+
def forward(self, input, inst):
|
374 |
+
outputs = self.model(input)
|
375 |
+
|
376 |
+
# instance-wise average pooling
|
377 |
+
outputs_mean = outputs.clone()
|
378 |
+
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
379 |
+
for i in inst_list:
|
380 |
+
for b in range(input.size()[0]):
|
381 |
+
indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
|
382 |
+
for j in range(self.output_nc):
|
383 |
+
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
|
384 |
+
indices[:, 2], indices[:, 3]]
|
385 |
+
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
386 |
+
outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
|
387 |
+
indices[:, 2], indices[:, 3]] = mean_feat
|
388 |
+
return outputs_mean
|
lib/net/HGFilters.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from lib.net.net_util import *
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
|
23 |
+
class HourGlass(nn.Module):
|
24 |
+
def __init__(self, num_modules, depth, num_features, opt):
|
25 |
+
super(HourGlass, self).__init__()
|
26 |
+
self.num_modules = num_modules
|
27 |
+
self.depth = depth
|
28 |
+
self.features = num_features
|
29 |
+
self.opt = opt
|
30 |
+
|
31 |
+
self._generate_network(self.depth)
|
32 |
+
|
33 |
+
def _generate_network(self, level):
|
34 |
+
self.add_module('b1_' + str(level),
|
35 |
+
ConvBlock(self.features, self.features, self.opt))
|
36 |
+
|
37 |
+
self.add_module('b2_' + str(level),
|
38 |
+
ConvBlock(self.features, self.features, self.opt))
|
39 |
+
|
40 |
+
if level > 1:
|
41 |
+
self._generate_network(level - 1)
|
42 |
+
else:
|
43 |
+
self.add_module('b2_plus_' + str(level),
|
44 |
+
ConvBlock(self.features, self.features, self.opt))
|
45 |
+
|
46 |
+
self.add_module('b3_' + str(level),
|
47 |
+
ConvBlock(self.features, self.features, self.opt))
|
48 |
+
|
49 |
+
def _forward(self, level, inp):
|
50 |
+
# Upper branch
|
51 |
+
up1 = inp
|
52 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
53 |
+
|
54 |
+
# Lower branch
|
55 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
56 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
57 |
+
|
58 |
+
if level > 1:
|
59 |
+
low2 = self._forward(level - 1, low1)
|
60 |
+
else:
|
61 |
+
low2 = low1
|
62 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
63 |
+
|
64 |
+
low3 = low2
|
65 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
66 |
+
|
67 |
+
# NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
|
68 |
+
# if the pretrained model behaves weirdly, switch with the commented line.
|
69 |
+
# NOTE: I also found that "bicubic" works better.
|
70 |
+
up2 = F.interpolate(low3,
|
71 |
+
scale_factor=2,
|
72 |
+
mode='bicubic',
|
73 |
+
align_corners=True)
|
74 |
+
# up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
|
75 |
+
|
76 |
+
return up1 + up2
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return self._forward(self.depth, x)
|
80 |
+
|
81 |
+
|
82 |
+
class HGFilter(nn.Module):
|
83 |
+
def __init__(self, opt, num_modules, in_dim):
|
84 |
+
super(HGFilter, self).__init__()
|
85 |
+
self.num_modules = num_modules
|
86 |
+
|
87 |
+
self.opt = opt
|
88 |
+
[k, s, d, p] = self.opt.conv1
|
89 |
+
|
90 |
+
# self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3)
|
91 |
+
self.conv1 = nn.Conv2d(in_dim,
|
92 |
+
64,
|
93 |
+
kernel_size=k,
|
94 |
+
stride=s,
|
95 |
+
dilation=d,
|
96 |
+
padding=p)
|
97 |
+
|
98 |
+
if self.opt.norm == 'batch':
|
99 |
+
self.bn1 = nn.BatchNorm2d(64)
|
100 |
+
elif self.opt.norm == 'group':
|
101 |
+
self.bn1 = nn.GroupNorm(32, 64)
|
102 |
+
|
103 |
+
if self.opt.hg_down == 'conv64':
|
104 |
+
self.conv2 = ConvBlock(64, 64, self.opt)
|
105 |
+
self.down_conv2 = nn.Conv2d(64,
|
106 |
+
128,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=2,
|
109 |
+
padding=1)
|
110 |
+
elif self.opt.hg_down == 'conv128':
|
111 |
+
self.conv2 = ConvBlock(64, 128, self.opt)
|
112 |
+
self.down_conv2 = nn.Conv2d(128,
|
113 |
+
128,
|
114 |
+
kernel_size=3,
|
115 |
+
stride=2,
|
116 |
+
padding=1)
|
117 |
+
elif self.opt.hg_down == 'ave_pool':
|
118 |
+
self.conv2 = ConvBlock(64, 128, self.opt)
|
119 |
+
else:
|
120 |
+
raise NameError('Unknown Fan Filter setting!')
|
121 |
+
|
122 |
+
self.conv3 = ConvBlock(128, 128, self.opt)
|
123 |
+
self.conv4 = ConvBlock(128, 256, self.opt)
|
124 |
+
|
125 |
+
# Stacking part
|
126 |
+
for hg_module in range(self.num_modules):
|
127 |
+
self.add_module('m' + str(hg_module),
|
128 |
+
HourGlass(1, opt.num_hourglass, 256, self.opt))
|
129 |
+
|
130 |
+
self.add_module('top_m_' + str(hg_module),
|
131 |
+
ConvBlock(256, 256, self.opt))
|
132 |
+
self.add_module(
|
133 |
+
'conv_last' + str(hg_module),
|
134 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
135 |
+
if self.opt.norm == 'batch':
|
136 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
137 |
+
elif self.opt.norm == 'group':
|
138 |
+
self.add_module('bn_end' + str(hg_module),
|
139 |
+
nn.GroupNorm(32, 256))
|
140 |
+
|
141 |
+
self.add_module(
|
142 |
+
'l' + str(hg_module),
|
143 |
+
nn.Conv2d(256,
|
144 |
+
opt.hourglass_dim,
|
145 |
+
kernel_size=1,
|
146 |
+
stride=1,
|
147 |
+
padding=0))
|
148 |
+
|
149 |
+
if hg_module < self.num_modules - 1:
|
150 |
+
self.add_module(
|
151 |
+
'bl' + str(hg_module),
|
152 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
153 |
+
self.add_module(
|
154 |
+
'al' + str(hg_module),
|
155 |
+
nn.Conv2d(opt.hourglass_dim,
|
156 |
+
256,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=1,
|
159 |
+
padding=0))
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
163 |
+
tmpx = x
|
164 |
+
if self.opt.hg_down == 'ave_pool':
|
165 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
166 |
+
elif self.opt.hg_down in ['conv64', 'conv128']:
|
167 |
+
x = self.conv2(x)
|
168 |
+
x = self.down_conv2(x)
|
169 |
+
else:
|
170 |
+
raise NameError('Unknown Fan Filter setting!')
|
171 |
+
|
172 |
+
x = self.conv3(x)
|
173 |
+
x = self.conv4(x)
|
174 |
+
|
175 |
+
previous = x
|
176 |
+
|
177 |
+
outputs = []
|
178 |
+
for i in range(self.num_modules):
|
179 |
+
hg = self._modules['m' + str(i)](previous)
|
180 |
+
|
181 |
+
ll = hg
|
182 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
183 |
+
|
184 |
+
ll = F.relu(
|
185 |
+
self._modules['bn_end' + str(i)](
|
186 |
+
self._modules['conv_last' + str(i)](ll)), True)
|
187 |
+
|
188 |
+
# Predict heatmaps
|
189 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
190 |
+
outputs.append(tmp_out)
|
191 |
+
|
192 |
+
if i < self.num_modules - 1:
|
193 |
+
ll = self._modules['bl' + str(i)](ll)
|
194 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
195 |
+
previous = previous + ll + tmp_out_
|
196 |
+
|
197 |
+
return outputs
|
lib/net/HGPIFuNet.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from lib.net.voxelize import Voxelization
|
19 |
+
from lib.dataset.mesh_util import cal_sdf_batch, feat_select, read_smpl_constants
|
20 |
+
from lib.net.NormalNet import NormalNet
|
21 |
+
from lib.net.MLP import MLP
|
22 |
+
from lib.dataset.mesh_util import SMPLX
|
23 |
+
from lib.net.VE import VolumeEncoder
|
24 |
+
from lib.net.HGFilters import *
|
25 |
+
from termcolor import colored
|
26 |
+
from lib.net.BasePIFuNet import BasePIFuNet
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch
|
29 |
+
|
30 |
+
|
31 |
+
maskout = False
|
32 |
+
|
33 |
+
|
34 |
+
class HGPIFuNet(BasePIFuNet):
|
35 |
+
'''
|
36 |
+
HG PIFu network uses Hourglass stacks as the image filter.
|
37 |
+
It does the following:
|
38 |
+
1. Compute image feature stacks and store it in self.im_feat_list
|
39 |
+
self.im_feat_list[-1] is the last stack (output stack)
|
40 |
+
2. Calculate calibration
|
41 |
+
3. If training, it index on every intermediate stacks,
|
42 |
+
If testing, it index on the last stack.
|
43 |
+
4. Classification.
|
44 |
+
5. During training, error is calculated on all stacks.
|
45 |
+
'''
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
cfg,
|
49 |
+
projection_mode='orthogonal',
|
50 |
+
error_term=nn.MSELoss()):
|
51 |
+
|
52 |
+
super(HGPIFuNet, self).__init__(projection_mode=projection_mode,
|
53 |
+
error_term=error_term)
|
54 |
+
|
55 |
+
self.l1_loss = nn.SmoothL1Loss()
|
56 |
+
self.opt = cfg.net
|
57 |
+
self.root = cfg.root
|
58 |
+
self.overfit = cfg.overfit
|
59 |
+
|
60 |
+
channels_IF = self.opt.mlp_dim
|
61 |
+
|
62 |
+
self.use_filter = self.opt.use_filter
|
63 |
+
self.prior_type = self.opt.prior_type
|
64 |
+
self.smpl_feats = self.opt.smpl_feats
|
65 |
+
|
66 |
+
self.smpl_dim = self.opt.smpl_dim
|
67 |
+
self.voxel_dim = self.opt.voxel_dim
|
68 |
+
self.hourglass_dim = self.opt.hourglass_dim
|
69 |
+
self.sdf_clip = cfg.sdf_clip / 100.0
|
70 |
+
|
71 |
+
self.in_geo = [item[0] for item in self.opt.in_geo]
|
72 |
+
self.in_nml = [item[0] for item in self.opt.in_nml]
|
73 |
+
|
74 |
+
self.in_geo_dim = sum([item[1] for item in self.opt.in_geo])
|
75 |
+
self.in_nml_dim = sum([item[1] for item in self.opt.in_nml])
|
76 |
+
|
77 |
+
self.in_total = self.in_geo + self.in_nml
|
78 |
+
self.smpl_feat_dict = None
|
79 |
+
self.smplx_data = SMPLX()
|
80 |
+
|
81 |
+
if self.prior_type == 'icon':
|
82 |
+
if 'image' in self.in_geo:
|
83 |
+
self.channels_filter = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 6, 7, 8]]
|
84 |
+
else:
|
85 |
+
self.channels_filter = [[0, 1, 2], [3, 4, 5]]
|
86 |
+
|
87 |
+
else:
|
88 |
+
if 'image' in self.in_geo:
|
89 |
+
self.channels_filter = [[0, 1, 2, 3, 4, 5, 6, 7, 8]]
|
90 |
+
else:
|
91 |
+
self.channels_filter = [[0, 1, 2, 3, 4, 5]]
|
92 |
+
|
93 |
+
channels_IF[0] = self.hourglass_dim if self.use_filter else len(
|
94 |
+
self.channels_filter[0])
|
95 |
+
|
96 |
+
if self.prior_type == 'icon' and 'vis' not in self.smpl_feats:
|
97 |
+
if self.use_filter:
|
98 |
+
channels_IF[0] += self.hourglass_dim
|
99 |
+
else:
|
100 |
+
channels_IF[0] += len(self.channels_filter[0])
|
101 |
+
|
102 |
+
if self.prior_type == 'icon':
|
103 |
+
channels_IF[0] += self.smpl_dim
|
104 |
+
elif self.prior_type == 'pamir':
|
105 |
+
channels_IF[0] += self.voxel_dim
|
106 |
+
smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras = read_smpl_constants(
|
107 |
+
self.smplx_data.tedra_dir)
|
108 |
+
self.voxelization = Voxelization(
|
109 |
+
smpl_vertex_code,
|
110 |
+
smpl_face_code,
|
111 |
+
smpl_faces,
|
112 |
+
smpl_tetras,
|
113 |
+
volume_res=128,
|
114 |
+
sigma=0.05,
|
115 |
+
smooth_kernel_size=7,
|
116 |
+
batch_size=cfg.batch_size,
|
117 |
+
device=torch.device(f"cuda:{cfg.gpus[0]}"))
|
118 |
+
self.ve = VolumeEncoder(3, self.voxel_dim, self.opt.num_stack)
|
119 |
+
else:
|
120 |
+
channels_IF[0] += 1
|
121 |
+
|
122 |
+
self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
|
123 |
+
self.pamir_keys = [
|
124 |
+
"voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num"
|
125 |
+
]
|
126 |
+
|
127 |
+
self.if_regressor = MLP(
|
128 |
+
filter_channels=channels_IF,
|
129 |
+
name='if',
|
130 |
+
res_layers=self.opt.res_layers,
|
131 |
+
norm=self.opt.norm_mlp,
|
132 |
+
last_op=nn.Sigmoid() if not cfg.test_mode else None)
|
133 |
+
|
134 |
+
# network
|
135 |
+
if self.use_filter:
|
136 |
+
if self.opt.gtype == "HGPIFuNet":
|
137 |
+
self.F_filter = HGFilter(self.opt, self.opt.num_stack,
|
138 |
+
len(self.channels_filter[0]))
|
139 |
+
else:
|
140 |
+
print(
|
141 |
+
colored(f"Backbone {self.opt.gtype} is unimplemented",
|
142 |
+
'green'))
|
143 |
+
|
144 |
+
summary_log = f"{self.prior_type.upper()}:\n" + \
|
145 |
+
f"w/ Global Image Encoder: {self.use_filter}\n" + \
|
146 |
+
f"Image Features used by MLP: {self.in_geo}\n"
|
147 |
+
|
148 |
+
if self.prior_type == "icon":
|
149 |
+
summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n"
|
150 |
+
summary_log += f"Dim of Image Features (local): 6\n"
|
151 |
+
summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n"
|
152 |
+
elif self.prior_type == "pamir":
|
153 |
+
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n"
|
154 |
+
summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n"
|
155 |
+
else:
|
156 |
+
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n"
|
157 |
+
summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n"
|
158 |
+
|
159 |
+
summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n"
|
160 |
+
|
161 |
+
print(colored(summary_log, "yellow"))
|
162 |
+
|
163 |
+
self.normal_filter = NormalNet(cfg)
|
164 |
+
init_net(self)
|
165 |
+
|
166 |
+
def get_normal(self, in_tensor_dict):
|
167 |
+
|
168 |
+
# insert normal features
|
169 |
+
if (not self.training) and (not self.overfit):
|
170 |
+
# print(colored("infer normal","blue"))
|
171 |
+
with torch.no_grad():
|
172 |
+
feat_lst = []
|
173 |
+
if "image" in self.in_geo:
|
174 |
+
feat_lst.append(
|
175 |
+
in_tensor_dict['image']) # [1, 3, 512, 512]
|
176 |
+
if 'normal_F' in self.in_geo and 'normal_B' in self.in_geo:
|
177 |
+
if 'normal_F' not in in_tensor_dict.keys(
|
178 |
+
) or 'normal_B' not in in_tensor_dict.keys():
|
179 |
+
(nmlF, nmlB) = self.normal_filter(in_tensor_dict)
|
180 |
+
else:
|
181 |
+
nmlF = in_tensor_dict['normal_F']
|
182 |
+
nmlB = in_tensor_dict['normal_B']
|
183 |
+
feat_lst.append(nmlF) # [1, 3, 512, 512]
|
184 |
+
feat_lst.append(nmlB) # [1, 3, 512, 512]
|
185 |
+
in_filter = torch.cat(feat_lst, dim=1)
|
186 |
+
|
187 |
+
else:
|
188 |
+
in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo],
|
189 |
+
dim=1)
|
190 |
+
|
191 |
+
return in_filter
|
192 |
+
|
193 |
+
def get_mask(self, in_filter, size=128):
|
194 |
+
|
195 |
+
mask = F.interpolate(in_filter[:, self.channels_filter[0]],
|
196 |
+
size=(size, size),
|
197 |
+
mode="bilinear",
|
198 |
+
align_corners=True).abs().sum(dim=1,
|
199 |
+
keepdim=True) != 0.0
|
200 |
+
|
201 |
+
return mask
|
202 |
+
|
203 |
+
def filter(self, in_tensor_dict, return_inter=False):
|
204 |
+
'''
|
205 |
+
Filter the input images
|
206 |
+
store all intermediate features.
|
207 |
+
:param images: [B, C, H, W] input images
|
208 |
+
'''
|
209 |
+
|
210 |
+
in_filter = self.get_normal(in_tensor_dict)
|
211 |
+
|
212 |
+
features_G = []
|
213 |
+
|
214 |
+
if self.prior_type == 'icon':
|
215 |
+
if self.use_filter:
|
216 |
+
features_F = self.F_filter(in_filter[:,
|
217 |
+
self.channels_filter[0]]
|
218 |
+
) # [(B,hg_dim,128,128) * 4]
|
219 |
+
features_B = self.F_filter(in_filter[:,
|
220 |
+
self.channels_filter[1]]
|
221 |
+
) # [(B,hg_dim,128,128) * 4]
|
222 |
+
else:
|
223 |
+
features_F = [in_filter[:, self.channels_filter[0]]]
|
224 |
+
features_B = [in_filter[:, self.channels_filter[1]]]
|
225 |
+
for idx in range(len(features_F)):
|
226 |
+
features_G.append(
|
227 |
+
torch.cat([features_F[idx], features_B[idx]], dim=1))
|
228 |
+
else:
|
229 |
+
if self.use_filter:
|
230 |
+
features_G = self.F_filter(in_filter[:,
|
231 |
+
self.channels_filter[0]])
|
232 |
+
else:
|
233 |
+
features_G = [in_filter[:, self.channels_filter[0]]]
|
234 |
+
|
235 |
+
if self.prior_type == 'icon':
|
236 |
+
self.smpl_feat_dict = {
|
237 |
+
k: in_tensor_dict[k]
|
238 |
+
for k in self.icon_keys
|
239 |
+
}
|
240 |
+
elif self.prior_type == "pamir":
|
241 |
+
self.smpl_feat_dict = {
|
242 |
+
k: in_tensor_dict[k]
|
243 |
+
for k in self.pamir_keys
|
244 |
+
}
|
245 |
+
else:
|
246 |
+
pass
|
247 |
+
# print(colored("use z rather than icon or pamir", "green"))
|
248 |
+
|
249 |
+
# If it is not in training, only produce the last im_feat
|
250 |
+
if not self.training:
|
251 |
+
features_out = [features_G[-1]]
|
252 |
+
else:
|
253 |
+
features_out = features_G
|
254 |
+
|
255 |
+
if maskout:
|
256 |
+
features_out_mask = []
|
257 |
+
for feat in features_out:
|
258 |
+
features_out_mask.append(
|
259 |
+
feat * self.get_mask(in_filter, size=feat.shape[2]))
|
260 |
+
features_out = features_out_mask
|
261 |
+
|
262 |
+
if return_inter:
|
263 |
+
return features_out, in_filter
|
264 |
+
else:
|
265 |
+
return features_out
|
266 |
+
|
267 |
+
def query(self, features, points, calibs, transforms=None, regressor=None):
|
268 |
+
|
269 |
+
xyz = self.projection(points, calibs, transforms)
|
270 |
+
|
271 |
+
(xy, z) = xyz.split([2, 1], dim=1)
|
272 |
+
|
273 |
+
in_cube = (xyz > -1.0) & (xyz < 1.0)
|
274 |
+
in_cube = in_cube.all(dim=1, keepdim=True).detach().float()
|
275 |
+
|
276 |
+
preds_list = []
|
277 |
+
|
278 |
+
if self.prior_type == 'icon':
|
279 |
+
|
280 |
+
# smpl_verts [B, N_vert, 3]
|
281 |
+
# smpl_faces [B, N_face, 3]
|
282 |
+
# points [B, 3, N]
|
283 |
+
|
284 |
+
smpl_sdf, smpl_norm, smpl_cmap, smpl_vis = cal_sdf_batch(
|
285 |
+
self.smpl_feat_dict['smpl_verts'],
|
286 |
+
self.smpl_feat_dict['smpl_faces'],
|
287 |
+
self.smpl_feat_dict['smpl_cmap'],
|
288 |
+
self.smpl_feat_dict['smpl_vis'],
|
289 |
+
xyz.permute(0, 2, 1).contiguous())
|
290 |
+
|
291 |
+
# smpl_sdf [B, N, 1]
|
292 |
+
# smpl_norm [B, N, 3]
|
293 |
+
# smpl_cmap [B, N, 3]
|
294 |
+
# smpl_vis [B, N, 1]
|
295 |
+
|
296 |
+
feat_lst = [smpl_sdf]
|
297 |
+
if 'cmap' in self.smpl_feats:
|
298 |
+
feat_lst.append(smpl_cmap)
|
299 |
+
if 'norm' in self.smpl_feats:
|
300 |
+
feat_lst.append(smpl_norm)
|
301 |
+
if 'vis' in self.smpl_feats:
|
302 |
+
feat_lst.append(smpl_vis)
|
303 |
+
|
304 |
+
smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1)
|
305 |
+
vol_feats = features
|
306 |
+
|
307 |
+
elif self.prior_type == "pamir":
|
308 |
+
|
309 |
+
voxel_verts = self.smpl_feat_dict[
|
310 |
+
'voxel_verts'][:, :-self.smpl_feat_dict['pad_v_num'][0], :]
|
311 |
+
voxel_faces = self.smpl_feat_dict[
|
312 |
+
'voxel_faces'][:, :-self.smpl_feat_dict['pad_f_num'][0], :]
|
313 |
+
|
314 |
+
self.voxelization.update_param(
|
315 |
+
batch_size=voxel_faces.shape[0],
|
316 |
+
smpl_tetra=voxel_faces[0].detach().cpu().numpy())
|
317 |
+
vol = self.voxelization(voxel_verts) # vol ~ [0,1]
|
318 |
+
vol_feats = self.ve(vol, intermediate_output=self.training)
|
319 |
+
else:
|
320 |
+
vol_feats = features
|
321 |
+
|
322 |
+
for im_feat, vol_feat in zip(features, vol_feats):
|
323 |
+
|
324 |
+
# [B, Feat_i + z, N]
|
325 |
+
# normal feature choice by smpl_vis
|
326 |
+
if self.prior_type == 'icon':
|
327 |
+
if 'vis' in self.smpl_feats:
|
328 |
+
point_local_feat = feat_select(self.index(im_feat, xy),
|
329 |
+
smpl_feat[:, [-1], :])
|
330 |
+
if maskout:
|
331 |
+
normal_mask = torch.tile(
|
332 |
+
point_local_feat.sum(dim=1, keepdims=True) == 0.0,
|
333 |
+
(1, smpl_feat.shape[1], 1))
|
334 |
+
normal_mask[:, 1:, :] = False
|
335 |
+
smpl_feat[normal_mask] = -1.0
|
336 |
+
point_feat_list = [point_local_feat, smpl_feat[:, :-1, :]]
|
337 |
+
else:
|
338 |
+
point_local_feat = self.index(im_feat, xy)
|
339 |
+
point_feat_list = [point_local_feat, smpl_feat[:, :, :]]
|
340 |
+
|
341 |
+
elif self.prior_type == 'pamir':
|
342 |
+
# im_feat [B, hg_dim, 128, 128]
|
343 |
+
# vol_feat [B, vol_dim, 32, 32, 32]
|
344 |
+
point_feat_list = [
|
345 |
+
self.index(im_feat, xy),
|
346 |
+
self.index(vol_feat, xyz)
|
347 |
+
]
|
348 |
+
|
349 |
+
else:
|
350 |
+
point_feat_list = [self.index(im_feat, xy), z]
|
351 |
+
|
352 |
+
point_feat = torch.cat(point_feat_list, 1)
|
353 |
+
|
354 |
+
# out of image plane is always set to 0
|
355 |
+
preds = regressor(point_feat)
|
356 |
+
preds = in_cube * preds
|
357 |
+
|
358 |
+
preds_list.append(preds)
|
359 |
+
|
360 |
+
return preds_list
|
361 |
+
|
362 |
+
def get_error(self, preds_if_list, labels):
|
363 |
+
"""calcaulate error
|
364 |
+
|
365 |
+
Args:
|
366 |
+
preds_list (list): list of torch.tensor(B, 3, N)
|
367 |
+
labels (torch.tensor): (B, N_knn, N)
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
torch.tensor: error
|
371 |
+
"""
|
372 |
+
error_if = 0
|
373 |
+
|
374 |
+
for pred_id in range(len(preds_if_list)):
|
375 |
+
pred_if = preds_if_list[pred_id]
|
376 |
+
error_if += self.error_term(pred_if, labels)
|
377 |
+
|
378 |
+
error_if /= len(preds_if_list)
|
379 |
+
|
380 |
+
return error_if
|
381 |
+
|
382 |
+
def forward(self, in_tensor_dict):
|
383 |
+
"""
|
384 |
+
sample_tensor [B, 3, N]
|
385 |
+
calib_tensor [B, 4, 4]
|
386 |
+
label_tensor [B, 1, N]
|
387 |
+
smpl_feat_tensor [B, 59, N]
|
388 |
+
"""
|
389 |
+
|
390 |
+
sample_tensor = in_tensor_dict['sample']
|
391 |
+
calib_tensor = in_tensor_dict['calib']
|
392 |
+
label_tensor = in_tensor_dict['label']
|
393 |
+
|
394 |
+
in_feat = self.filter(in_tensor_dict)
|
395 |
+
|
396 |
+
preds_if_list = self.query(in_feat,
|
397 |
+
sample_tensor,
|
398 |
+
calib_tensor,
|
399 |
+
regressor=self.if_regressor)
|
400 |
+
|
401 |
+
error = self.get_error(preds_if_list, label_tensor)
|
402 |
+
|
403 |
+
return preds_if_list[-1], error
|
lib/net/MLP.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
|
7 |
+
|
8 |
+
class MLP(pl.LightningModule):
|
9 |
+
def __init__(self,
|
10 |
+
filter_channels,
|
11 |
+
name=None,
|
12 |
+
res_layers=[],
|
13 |
+
norm='group',
|
14 |
+
last_op=None):
|
15 |
+
|
16 |
+
super(MLP, self).__init__()
|
17 |
+
|
18 |
+
self.filters = nn.ModuleList()
|
19 |
+
self.norms = nn.ModuleList()
|
20 |
+
self.res_layers = res_layers
|
21 |
+
self.norm = norm
|
22 |
+
self.last_op = last_op
|
23 |
+
self.name = name
|
24 |
+
self.activate = nn.LeakyReLU(inplace=True)
|
25 |
+
|
26 |
+
for l in range(0, len(filter_channels) - 1):
|
27 |
+
if l in self.res_layers:
|
28 |
+
self.filters.append(
|
29 |
+
nn.Conv1d(filter_channels[l] + filter_channels[0],
|
30 |
+
filter_channels[l + 1], 1))
|
31 |
+
else:
|
32 |
+
self.filters.append(
|
33 |
+
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1))
|
34 |
+
|
35 |
+
if l != len(filter_channels) - 2:
|
36 |
+
if norm == 'group':
|
37 |
+
self.norms.append(nn.GroupNorm(32, filter_channels[l + 1]))
|
38 |
+
elif norm == 'batch':
|
39 |
+
self.norms.append(nn.BatchNorm1d(filter_channels[l + 1]))
|
40 |
+
elif norm == 'instance':
|
41 |
+
self.norms.append(nn.InstanceNorm1d(filter_channels[l +
|
42 |
+
1]))
|
43 |
+
elif norm == 'weight':
|
44 |
+
self.filters[l] = nn.utils.weight_norm(self.filters[l],
|
45 |
+
name='weight')
|
46 |
+
# print(self.filters[l].weight_g.size(),
|
47 |
+
# self.filters[l].weight_v.size())
|
48 |
+
|
49 |
+
def forward(self, feature):
|
50 |
+
'''
|
51 |
+
feature may include multiple view inputs
|
52 |
+
args:
|
53 |
+
feature: [B, C_in, N]
|
54 |
+
return:
|
55 |
+
[B, C_out, N] prediction
|
56 |
+
'''
|
57 |
+
y = feature
|
58 |
+
tmpy = feature
|
59 |
+
|
60 |
+
for i, f in enumerate(self.filters):
|
61 |
+
|
62 |
+
y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1))
|
63 |
+
if i != len(self.filters) - 1:
|
64 |
+
if self.norm not in ['batch', 'group', 'instance']:
|
65 |
+
y = self.activate(y)
|
66 |
+
else:
|
67 |
+
y = self.activate(self.norms[i](y))
|
68 |
+
|
69 |
+
if self.last_op is not None:
|
70 |
+
y = self.last_op(y)
|
71 |
+
|
72 |
+
return y
|
lib/net/NormalNet.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from lib.net.FBNet import define_G
|
19 |
+
from lib.net.net_util import init_net, VGGLoss
|
20 |
+
from lib.net.HGFilters import *
|
21 |
+
from lib.net.BasePIFuNet import BasePIFuNet
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
|
26 |
+
class NormalNet(BasePIFuNet):
|
27 |
+
'''
|
28 |
+
HG PIFu network uses Hourglass stacks as the image filter.
|
29 |
+
It does the following:
|
30 |
+
1. Compute image feature stacks and store it in self.im_feat_list
|
31 |
+
self.im_feat_list[-1] is the last stack (output stack)
|
32 |
+
2. Calculate calibration
|
33 |
+
3. If training, it index on every intermediate stacks,
|
34 |
+
If testing, it index on the last stack.
|
35 |
+
4. Classification.
|
36 |
+
5. During training, error is calculated on all stacks.
|
37 |
+
'''
|
38 |
+
|
39 |
+
def __init__(self, cfg, error_term=nn.SmoothL1Loss()):
|
40 |
+
|
41 |
+
super(NormalNet, self).__init__(error_term=error_term)
|
42 |
+
|
43 |
+
self.l1_loss = nn.SmoothL1Loss()
|
44 |
+
|
45 |
+
self.opt = cfg.net
|
46 |
+
|
47 |
+
if self.training:
|
48 |
+
self.vgg_loss = [VGGLoss()]
|
49 |
+
|
50 |
+
self.in_nmlF = [
|
51 |
+
item[0] for item in self.opt.in_nml
|
52 |
+
if '_F' in item[0] or item[0] == 'image'
|
53 |
+
]
|
54 |
+
self.in_nmlB = [
|
55 |
+
item[0] for item in self.opt.in_nml
|
56 |
+
if '_B' in item[0] or item[0] == 'image'
|
57 |
+
]
|
58 |
+
self.in_nmlF_dim = sum([
|
59 |
+
item[1] for item in self.opt.in_nml
|
60 |
+
if '_F' in item[0] or item[0] == 'image'
|
61 |
+
])
|
62 |
+
self.in_nmlB_dim = sum([
|
63 |
+
item[1] for item in self.opt.in_nml
|
64 |
+
if '_B' in item[0] or item[0] == 'image'
|
65 |
+
])
|
66 |
+
|
67 |
+
self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3,
|
68 |
+
"instance")
|
69 |
+
self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3,
|
70 |
+
"instance")
|
71 |
+
|
72 |
+
init_net(self)
|
73 |
+
|
74 |
+
def forward(self, in_tensor):
|
75 |
+
|
76 |
+
inF_list = []
|
77 |
+
inB_list = []
|
78 |
+
|
79 |
+
for name in self.in_nmlF:
|
80 |
+
inF_list.append(in_tensor[name])
|
81 |
+
for name in self.in_nmlB:
|
82 |
+
inB_list.append(in_tensor[name])
|
83 |
+
|
84 |
+
nmlF = self.netF(torch.cat(inF_list, dim=1))
|
85 |
+
nmlB = self.netB(torch.cat(inB_list, dim=1))
|
86 |
+
|
87 |
+
# ||normal|| == 1
|
88 |
+
nmlF /= torch.norm(nmlF, dim=1)
|
89 |
+
nmlB /= torch.norm(nmlB, dim=1)
|
90 |
+
|
91 |
+
# output: float_arr [-1,1] with [B, C, H, W]
|
92 |
+
|
93 |
+
mask = (in_tensor['image'].abs().sum(dim=1, keepdim=True) !=
|
94 |
+
0.0).detach().float()
|
95 |
+
|
96 |
+
nmlF = nmlF * mask
|
97 |
+
nmlB = nmlB * mask
|
98 |
+
|
99 |
+
return nmlF, nmlB
|
100 |
+
|
101 |
+
def get_norm_error(self, prd_F, prd_B, tgt):
|
102 |
+
"""calculate normal loss
|
103 |
+
|
104 |
+
Args:
|
105 |
+
pred (torch.tensor): [B, 6, 512, 512]
|
106 |
+
tagt (torch.tensor): [B, 6, 512, 512]
|
107 |
+
"""
|
108 |
+
|
109 |
+
tgt_F, tgt_B = tgt['normal_F'], tgt['normal_B']
|
110 |
+
|
111 |
+
l1_F_loss = self.l1_loss(prd_F, tgt_F)
|
112 |
+
l1_B_loss = self.l1_loss(prd_B, tgt_B)
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
vgg_F_loss = self.vgg_loss[0](prd_F, tgt_F)
|
116 |
+
vgg_B_loss = self.vgg_loss[0](prd_B, tgt_B)
|
117 |
+
|
118 |
+
total_loss = [
|
119 |
+
5.0 * l1_F_loss + vgg_F_loss, 5.0 * l1_B_loss + vgg_B_loss
|
120 |
+
]
|
121 |
+
|
122 |
+
return total_loss
|
lib/net/VE.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
|
19 |
+
import torch.nn as nn
|
20 |
+
import pytorch_lightning as pl
|
21 |
+
|
22 |
+
|
23 |
+
class BaseNetwork(pl.LightningModule):
|
24 |
+
def __init__(self):
|
25 |
+
super(BaseNetwork, self).__init__()
|
26 |
+
|
27 |
+
def init_weights(self, init_type='xavier', gain=0.02):
|
28 |
+
'''
|
29 |
+
initializes network's weights
|
30 |
+
init_type: normal | xavier | kaiming | orthogonal
|
31 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
32 |
+
'''
|
33 |
+
def init_func(m):
|
34 |
+
classname = m.__class__.__name__
|
35 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1
|
36 |
+
or classname.find('Linear') != -1):
|
37 |
+
if init_type == 'normal':
|
38 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
39 |
+
elif init_type == 'xavier':
|
40 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
41 |
+
elif init_type == 'kaiming':
|
42 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
43 |
+
elif init_type == 'orthogonal':
|
44 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
45 |
+
|
46 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
47 |
+
nn.init.constant_(m.bias.data, 0.0)
|
48 |
+
|
49 |
+
elif classname.find('BatchNorm2d') != -1:
|
50 |
+
nn.init.normal_(m.weight.data, 1.0, gain)
|
51 |
+
nn.init.constant_(m.bias.data, 0.0)
|
52 |
+
|
53 |
+
self.apply(init_func)
|
54 |
+
|
55 |
+
|
56 |
+
class Residual3D(BaseNetwork):
|
57 |
+
def __init__(self, numIn, numOut):
|
58 |
+
super(Residual3D, self).__init__()
|
59 |
+
self.numIn = numIn
|
60 |
+
self.numOut = numOut
|
61 |
+
self.with_bias = True
|
62 |
+
# self.bn = nn.GroupNorm(4, self.numIn)
|
63 |
+
self.bn = nn.BatchNorm3d(self.numIn)
|
64 |
+
self.relu = nn.ReLU(inplace=True)
|
65 |
+
self.conv1 = nn.Conv3d(self.numIn,
|
66 |
+
self.numOut,
|
67 |
+
bias=self.with_bias,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
padding=2,
|
71 |
+
dilation=2)
|
72 |
+
# self.bn1 = nn.GroupNorm(4, self.numOut)
|
73 |
+
self.bn1 = nn.BatchNorm3d(self.numOut)
|
74 |
+
self.conv2 = nn.Conv3d(self.numOut,
|
75 |
+
self.numOut,
|
76 |
+
bias=self.with_bias,
|
77 |
+
kernel_size=3,
|
78 |
+
stride=1,
|
79 |
+
padding=1)
|
80 |
+
# self.bn2 = nn.GroupNorm(4, self.numOut)
|
81 |
+
self.bn2 = nn.BatchNorm3d(self.numOut)
|
82 |
+
self.conv3 = nn.Conv3d(self.numOut,
|
83 |
+
self.numOut,
|
84 |
+
bias=self.with_bias,
|
85 |
+
kernel_size=3,
|
86 |
+
stride=1,
|
87 |
+
padding=1)
|
88 |
+
|
89 |
+
if self.numIn != self.numOut:
|
90 |
+
self.conv4 = nn.Conv3d(self.numIn,
|
91 |
+
self.numOut,
|
92 |
+
bias=self.with_bias,
|
93 |
+
kernel_size=1)
|
94 |
+
self.init_weights()
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
residual = x
|
98 |
+
# out = self.bn(x)
|
99 |
+
# out = self.relu(out)
|
100 |
+
out = self.conv1(x)
|
101 |
+
out = self.bn1(out)
|
102 |
+
out = self.relu(out)
|
103 |
+
out = self.conv2(out)
|
104 |
+
out = self.bn2(out)
|
105 |
+
# out = self.conv3(out)
|
106 |
+
# out = self.relu(out)
|
107 |
+
|
108 |
+
if self.numIn != self.numOut:
|
109 |
+
residual = self.conv4(x)
|
110 |
+
|
111 |
+
return out + residual
|
112 |
+
|
113 |
+
|
114 |
+
class VolumeEncoder(BaseNetwork):
|
115 |
+
"""CycleGan Encoder"""
|
116 |
+
|
117 |
+
def __init__(self, num_in=3, num_out=32, num_stacks=2):
|
118 |
+
super(VolumeEncoder, self).__init__()
|
119 |
+
self.num_in = num_in
|
120 |
+
self.num_out = num_out
|
121 |
+
self.num_inter = 8
|
122 |
+
self.num_stacks = num_stacks
|
123 |
+
self.with_bias = True
|
124 |
+
|
125 |
+
self.relu = nn.ReLU(inplace=True)
|
126 |
+
self.conv1 = nn.Conv3d(self.num_in,
|
127 |
+
self.num_inter,
|
128 |
+
bias=self.with_bias,
|
129 |
+
kernel_size=5,
|
130 |
+
stride=2,
|
131 |
+
padding=4,
|
132 |
+
dilation=2)
|
133 |
+
# self.bn1 = nn.GroupNorm(4, self.num_inter)
|
134 |
+
self.bn1 = nn.BatchNorm3d(self.num_inter)
|
135 |
+
self.conv2 = nn.Conv3d(self.num_inter,
|
136 |
+
self.num_out,
|
137 |
+
bias=self.with_bias,
|
138 |
+
kernel_size=5,
|
139 |
+
stride=2,
|
140 |
+
padding=4,
|
141 |
+
dilation=2)
|
142 |
+
# self.bn2 = nn.GroupNorm(4, self.num_out)
|
143 |
+
self.bn2 = nn.BatchNorm3d(self.num_out)
|
144 |
+
|
145 |
+
self.conv_out1 = nn.Conv3d(self.num_out,
|
146 |
+
self.num_out,
|
147 |
+
bias=self.with_bias,
|
148 |
+
kernel_size=3,
|
149 |
+
stride=1,
|
150 |
+
padding=1,
|
151 |
+
dilation=1)
|
152 |
+
self.conv_out2 = nn.Conv3d(self.num_out,
|
153 |
+
self.num_out,
|
154 |
+
bias=self.with_bias,
|
155 |
+
kernel_size=3,
|
156 |
+
stride=1,
|
157 |
+
padding=1,
|
158 |
+
dilation=1)
|
159 |
+
|
160 |
+
for idx in range(self.num_stacks):
|
161 |
+
self.add_module("res" + str(idx),
|
162 |
+
Residual3D(self.num_out, self.num_out))
|
163 |
+
|
164 |
+
self.init_weights()
|
165 |
+
|
166 |
+
def forward(self, x, intermediate_output=True):
|
167 |
+
out = self.conv1(x)
|
168 |
+
out = self.bn1(out)
|
169 |
+
out = self.relu(out)
|
170 |
+
|
171 |
+
out = self.conv2(out)
|
172 |
+
out = self.bn2(out)
|
173 |
+
out = self.relu(out)
|
174 |
+
|
175 |
+
out_lst = []
|
176 |
+
for idx in range(self.num_stacks):
|
177 |
+
out = self._modules["res" + str(idx)](out)
|
178 |
+
out_lst.append(out)
|
179 |
+
|
180 |
+
if intermediate_output:
|
181 |
+
return out_lst
|
182 |
+
else:
|
183 |
+
return [out_lst[-1]]
|
lib/net/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .BasePIFuNet import BasePIFuNet
|
2 |
+
from .HGPIFuNet import HGPIFuNet
|
3 |
+
from .NormalNet import NormalNet
|
4 |
+
from .VE import VolumeEncoder
|
lib/net/geometry.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
def index(feat, uv):
|
22 |
+
'''
|
23 |
+
:param feat: [B, C, H, W] image features
|
24 |
+
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
|
25 |
+
:return: [B, C, N] image features at the uv coordinates
|
26 |
+
'''
|
27 |
+
uv = uv.transpose(1, 2) # [B, N, 2]
|
28 |
+
|
29 |
+
(B, N, _) = uv.shape
|
30 |
+
C = feat.shape[1]
|
31 |
+
|
32 |
+
if uv.shape[-1] == 3:
|
33 |
+
# uv = uv[:,:,[2,1,0]]
|
34 |
+
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
|
35 |
+
uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
|
36 |
+
else:
|
37 |
+
uv = uv.unsqueeze(2) # [B, N, 1, 2]
|
38 |
+
|
39 |
+
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
|
40 |
+
# for old versions, simply remove the aligned_corners argument.
|
41 |
+
samples = torch.nn.functional.grid_sample(
|
42 |
+
feat, uv, align_corners=True) # [B, C, N, 1]
|
43 |
+
return samples.view(B, C, N) # [B, C, N]
|
44 |
+
|
45 |
+
|
46 |
+
def orthogonal(points, calibrations, transforms=None):
|
47 |
+
'''
|
48 |
+
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
|
49 |
+
:param points: [B, 3, N] Tensor of 3D points
|
50 |
+
:param calibrations: [B, 3, 4] Tensor of projection matrix
|
51 |
+
:param transforms: [B, 2, 3] Tensor of image transform matrix
|
52 |
+
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
|
53 |
+
'''
|
54 |
+
rot = calibrations[:, :3, :3]
|
55 |
+
trans = calibrations[:, :3, 3:4]
|
56 |
+
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
57 |
+
if transforms is not None:
|
58 |
+
scale = transforms[:2, :2]
|
59 |
+
shift = transforms[:2, 2:3]
|
60 |
+
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
|
61 |
+
return pts
|
62 |
+
|
63 |
+
|
64 |
+
def perspective(points, calibrations, transforms=None):
|
65 |
+
'''
|
66 |
+
Compute the perspective projections of 3D points into the image plane by given projection matrix
|
67 |
+
:param points: [Bx3xN] Tensor of 3D points
|
68 |
+
:param calibrations: [Bx3x4] Tensor of projection matrix
|
69 |
+
:param transforms: [Bx2x3] Tensor of image transform matrix
|
70 |
+
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
|
71 |
+
'''
|
72 |
+
rot = calibrations[:, :3, :3]
|
73 |
+
trans = calibrations[:, :3, 3:4]
|
74 |
+
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
75 |
+
xy = homo[:, :2, :] / homo[:, 2:3, :]
|
76 |
+
if transforms is not None:
|
77 |
+
scale = transforms[:2, :2]
|
78 |
+
shift = transforms[:2, 2:3]
|
79 |
+
xy = torch.baddbmm(shift, scale, xy)
|
80 |
+
|
81 |
+
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
|
82 |
+
return xyz
|
lib/net/net_util.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from torchvision import models
|
19 |
+
import torch
|
20 |
+
from torch.nn import init
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import functools
|
24 |
+
from torch.autograd import grad
|
25 |
+
|
26 |
+
|
27 |
+
def gradient(inputs, outputs):
|
28 |
+
d_points = torch.ones_like(outputs,
|
29 |
+
requires_grad=False,
|
30 |
+
device=outputs.device)
|
31 |
+
points_grad = grad(outputs=outputs,
|
32 |
+
inputs=inputs,
|
33 |
+
grad_outputs=d_points,
|
34 |
+
create_graph=True,
|
35 |
+
retain_graph=True,
|
36 |
+
only_inputs=True,
|
37 |
+
allow_unused=True)[0]
|
38 |
+
return points_grad
|
39 |
+
|
40 |
+
|
41 |
+
# def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
42 |
+
# "3x3 convolution with padding"
|
43 |
+
# return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
44 |
+
# stride=strd, padding=padding, bias=bias)
|
45 |
+
|
46 |
+
|
47 |
+
def conv3x3(in_planes,
|
48 |
+
out_planes,
|
49 |
+
kernel=3,
|
50 |
+
strd=1,
|
51 |
+
dilation=1,
|
52 |
+
padding=1,
|
53 |
+
bias=False):
|
54 |
+
"3x3 convolution with padding"
|
55 |
+
return nn.Conv2d(in_planes,
|
56 |
+
out_planes,
|
57 |
+
kernel_size=kernel,
|
58 |
+
dilation=dilation,
|
59 |
+
stride=strd,
|
60 |
+
padding=padding,
|
61 |
+
bias=bias)
|
62 |
+
|
63 |
+
|
64 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
65 |
+
"""1x1 convolution"""
|
66 |
+
return nn.Conv2d(in_planes,
|
67 |
+
out_planes,
|
68 |
+
kernel_size=1,
|
69 |
+
stride=stride,
|
70 |
+
bias=False)
|
71 |
+
|
72 |
+
|
73 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
74 |
+
"""Initialize network weights.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
net (network) -- network to be initialized
|
78 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
79 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
80 |
+
|
81 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
82 |
+
work better for some applications. Feel free to try yourself.
|
83 |
+
"""
|
84 |
+
def init_func(m): # define the initialization function
|
85 |
+
classname = m.__class__.__name__
|
86 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1
|
87 |
+
or classname.find('Linear') != -1):
|
88 |
+
if init_type == 'normal':
|
89 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
90 |
+
elif init_type == 'xavier':
|
91 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
92 |
+
elif init_type == 'kaiming':
|
93 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
94 |
+
elif init_type == 'orthogonal':
|
95 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
96 |
+
else:
|
97 |
+
raise NotImplementedError(
|
98 |
+
'initialization method [%s] is not implemented' %
|
99 |
+
init_type)
|
100 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
101 |
+
init.constant_(m.bias.data, 0.0)
|
102 |
+
elif classname.find(
|
103 |
+
'BatchNorm2d'
|
104 |
+
) != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
105 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
106 |
+
init.constant_(m.bias.data, 0.0)
|
107 |
+
|
108 |
+
# print('initialize network with %s' % init_type)
|
109 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
110 |
+
|
111 |
+
|
112 |
+
def init_net(net, init_type='xavier', init_gain=0.02, gpu_ids=[]):
|
113 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
114 |
+
Parameters:
|
115 |
+
net (network) -- the network to be initialized
|
116 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
117 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
118 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
119 |
+
|
120 |
+
Return an initialized network.
|
121 |
+
"""
|
122 |
+
if len(gpu_ids) > 0:
|
123 |
+
assert (torch.cuda.is_available())
|
124 |
+
net = torch.nn.DataParallel(net) # multi-GPUs
|
125 |
+
init_weights(net, init_type, init_gain=init_gain)
|
126 |
+
return net
|
127 |
+
|
128 |
+
|
129 |
+
def imageSpaceRotation(xy, rot):
|
130 |
+
'''
|
131 |
+
args:
|
132 |
+
xy: (B, 2, N) input
|
133 |
+
rot: (B, 2) x,y axis rotation angles
|
134 |
+
|
135 |
+
rotation center will be always image center (other rotation center can be represented by additional z translation)
|
136 |
+
'''
|
137 |
+
disp = rot.unsqueeze(2).sin().expand_as(xy)
|
138 |
+
return (disp * xy).sum(dim=1)
|
139 |
+
|
140 |
+
|
141 |
+
def cal_gradient_penalty(netD,
|
142 |
+
real_data,
|
143 |
+
fake_data,
|
144 |
+
device,
|
145 |
+
type='mixed',
|
146 |
+
constant=1.0,
|
147 |
+
lambda_gp=10.0):
|
148 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
149 |
+
|
150 |
+
Arguments:
|
151 |
+
netD (network) -- discriminator network
|
152 |
+
real_data (tensor array) -- real images
|
153 |
+
fake_data (tensor array) -- generated images from the generator
|
154 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
155 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
156 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
157 |
+
lambda_gp (float) -- weight for this loss
|
158 |
+
|
159 |
+
Returns the gradient penalty loss
|
160 |
+
"""
|
161 |
+
if lambda_gp > 0.0:
|
162 |
+
# either use real images, fake images, or a linear interpolation of two.
|
163 |
+
if type == 'real':
|
164 |
+
interpolatesv = real_data
|
165 |
+
elif type == 'fake':
|
166 |
+
interpolatesv = fake_data
|
167 |
+
elif type == 'mixed':
|
168 |
+
alpha = torch.rand(real_data.shape[0], 1)
|
169 |
+
alpha = alpha.expand(
|
170 |
+
real_data.shape[0],
|
171 |
+
real_data.nelement() //
|
172 |
+
real_data.shape[0]).contiguous().view(*real_data.shape)
|
173 |
+
alpha = alpha.to(device)
|
174 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
175 |
+
else:
|
176 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
177 |
+
interpolatesv.requires_grad_(True)
|
178 |
+
disc_interpolates = netD(interpolatesv)
|
179 |
+
gradients = torch.autograd.grad(
|
180 |
+
outputs=disc_interpolates,
|
181 |
+
inputs=interpolatesv,
|
182 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
183 |
+
create_graph=True,
|
184 |
+
retain_graph=True,
|
185 |
+
only_inputs=True)
|
186 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
187 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) **
|
188 |
+
2).mean() * lambda_gp # added eps
|
189 |
+
return gradient_penalty, gradients
|
190 |
+
else:
|
191 |
+
return 0.0, None
|
192 |
+
|
193 |
+
|
194 |
+
def get_norm_layer(norm_type='instance'):
|
195 |
+
"""Return a normalization layer
|
196 |
+
Parameters:
|
197 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
198 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
199 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
200 |
+
"""
|
201 |
+
if norm_type == 'batch':
|
202 |
+
norm_layer = functools.partial(nn.BatchNorm2d,
|
203 |
+
affine=True,
|
204 |
+
track_running_stats=True)
|
205 |
+
elif norm_type == 'instance':
|
206 |
+
norm_layer = functools.partial(nn.InstanceNorm2d,
|
207 |
+
affine=False,
|
208 |
+
track_running_stats=False)
|
209 |
+
elif norm_type == 'group':
|
210 |
+
norm_layer = functools.partial(nn.GroupNorm, 32)
|
211 |
+
elif norm_type == 'none':
|
212 |
+
norm_layer = None
|
213 |
+
else:
|
214 |
+
raise NotImplementedError('normalization layer [%s] is not found' %
|
215 |
+
norm_type)
|
216 |
+
return norm_layer
|
217 |
+
|
218 |
+
|
219 |
+
class Flatten(nn.Module):
|
220 |
+
def forward(self, input):
|
221 |
+
return input.view(input.size(0), -1)
|
222 |
+
|
223 |
+
|
224 |
+
class ConvBlock(nn.Module):
|
225 |
+
def __init__(self, in_planes, out_planes, opt):
|
226 |
+
super(ConvBlock, self).__init__()
|
227 |
+
[k, s, d, p] = opt.conv3x3
|
228 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p)
|
229 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d,
|
230 |
+
p)
|
231 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d,
|
232 |
+
p)
|
233 |
+
|
234 |
+
if opt.norm == 'batch':
|
235 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
236 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
237 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
238 |
+
self.bn4 = nn.BatchNorm2d(in_planes)
|
239 |
+
elif opt.norm == 'group':
|
240 |
+
self.bn1 = nn.GroupNorm(32, in_planes)
|
241 |
+
self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
|
242 |
+
self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
|
243 |
+
self.bn4 = nn.GroupNorm(32, in_planes)
|
244 |
+
|
245 |
+
if in_planes != out_planes:
|
246 |
+
self.downsample = nn.Sequential(
|
247 |
+
self.bn4,
|
248 |
+
nn.ReLU(True),
|
249 |
+
nn.Conv2d(in_planes,
|
250 |
+
out_planes,
|
251 |
+
kernel_size=1,
|
252 |
+
stride=1,
|
253 |
+
bias=False),
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
self.downsample = None
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
residual = x
|
260 |
+
|
261 |
+
out1 = self.bn1(x)
|
262 |
+
out1 = F.relu(out1, True)
|
263 |
+
out1 = self.conv1(out1)
|
264 |
+
|
265 |
+
out2 = self.bn2(out1)
|
266 |
+
out2 = F.relu(out2, True)
|
267 |
+
out2 = self.conv2(out2)
|
268 |
+
|
269 |
+
out3 = self.bn3(out2)
|
270 |
+
out3 = F.relu(out3, True)
|
271 |
+
out3 = self.conv3(out3)
|
272 |
+
|
273 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
274 |
+
|
275 |
+
if self.downsample is not None:
|
276 |
+
residual = self.downsample(residual)
|
277 |
+
|
278 |
+
out3 += residual
|
279 |
+
|
280 |
+
return out3
|
281 |
+
|
282 |
+
|
283 |
+
class Vgg19(torch.nn.Module):
|
284 |
+
def __init__(self, requires_grad=False):
|
285 |
+
super(Vgg19, self).__init__()
|
286 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
287 |
+
self.slice1 = torch.nn.Sequential()
|
288 |
+
self.slice2 = torch.nn.Sequential()
|
289 |
+
self.slice3 = torch.nn.Sequential()
|
290 |
+
self.slice4 = torch.nn.Sequential()
|
291 |
+
self.slice5 = torch.nn.Sequential()
|
292 |
+
for x in range(2):
|
293 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
294 |
+
for x in range(2, 7):
|
295 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
296 |
+
for x in range(7, 12):
|
297 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
298 |
+
for x in range(12, 21):
|
299 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
300 |
+
for x in range(21, 30):
|
301 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
302 |
+
if not requires_grad:
|
303 |
+
for param in self.parameters():
|
304 |
+
param.requires_grad = False
|
305 |
+
|
306 |
+
def forward(self, X):
|
307 |
+
h_relu1 = self.slice1(X)
|
308 |
+
h_relu2 = self.slice2(h_relu1)
|
309 |
+
h_relu3 = self.slice3(h_relu2)
|
310 |
+
h_relu4 = self.slice4(h_relu3)
|
311 |
+
h_relu5 = self.slice5(h_relu4)
|
312 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
313 |
+
return out
|
314 |
+
|
315 |
+
|
316 |
+
class VGGLoss(nn.Module):
|
317 |
+
def __init__(self):
|
318 |
+
super(VGGLoss, self).__init__()
|
319 |
+
self.vgg = Vgg19()
|
320 |
+
self.criterion = nn.L1Loss()
|
321 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
322 |
+
|
323 |
+
def forward(self, x, y):
|
324 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
325 |
+
loss = 0
|
326 |
+
for i in range(len(x_vgg)):
|
327 |
+
loss += self.weights[i] * self.criterion(x_vgg[i],
|
328 |
+
y_vgg[i].detach())
|
329 |
+
return loss
|
lib/net/voxelize.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division, print_function
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from torch.autograd import Function
|
7 |
+
|
8 |
+
import voxelize_cuda
|
9 |
+
|
10 |
+
|
11 |
+
class VoxelizationFunction(Function):
|
12 |
+
"""
|
13 |
+
Definition of differentiable voxelization function
|
14 |
+
Currently implemented only for cuda Tensors
|
15 |
+
"""
|
16 |
+
@staticmethod
|
17 |
+
def forward(ctx, smpl_vertices, smpl_face_center, smpl_face_normal,
|
18 |
+
smpl_vertex_code, smpl_face_code, smpl_tetrahedrons,
|
19 |
+
volume_res, sigma, smooth_kernel_size):
|
20 |
+
"""
|
21 |
+
forward pass
|
22 |
+
Output format: (batch_size, z_dims, y_dims, x_dims, channel_num)
|
23 |
+
"""
|
24 |
+
assert (smpl_vertices.size()[1] == smpl_vertex_code.size()[1])
|
25 |
+
assert (smpl_face_center.size()[1] == smpl_face_normal.size()[1])
|
26 |
+
assert (smpl_face_center.size()[1] == smpl_face_code.size()[1])
|
27 |
+
ctx.batch_size = smpl_vertices.size()[0]
|
28 |
+
ctx.volume_res = volume_res
|
29 |
+
ctx.sigma = sigma
|
30 |
+
ctx.smooth_kernel_size = smooth_kernel_size
|
31 |
+
ctx.smpl_vertex_num = smpl_vertices.size()[1]
|
32 |
+
ctx.device = smpl_vertices.device
|
33 |
+
|
34 |
+
smpl_vertices = smpl_vertices.contiguous()
|
35 |
+
smpl_face_center = smpl_face_center.contiguous()
|
36 |
+
smpl_face_normal = smpl_face_normal.contiguous()
|
37 |
+
smpl_vertex_code = smpl_vertex_code.contiguous()
|
38 |
+
smpl_face_code = smpl_face_code.contiguous()
|
39 |
+
smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
|
40 |
+
|
41 |
+
occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res,
|
42 |
+
ctx.volume_res,
|
43 |
+
ctx.volume_res).fill_(0.0)
|
44 |
+
semantic_volume = torch.cuda.FloatTensor(ctx.batch_size,
|
45 |
+
ctx.volume_res,
|
46 |
+
ctx.volume_res,
|
47 |
+
ctx.volume_res, 3).fill_(0.0)
|
48 |
+
weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size,
|
49 |
+
ctx.volume_res,
|
50 |
+
ctx.volume_res,
|
51 |
+
ctx.volume_res).fill_(1e-3)
|
52 |
+
|
53 |
+
# occ_volume [B, volume_res, volume_res, volume_res]
|
54 |
+
# semantic_volume [B, volume_res, volume_res, volume_res, 3]
|
55 |
+
# weight_sum_volume [B, volume_res, volume_res, volume_res]
|
56 |
+
|
57 |
+
occ_volume, semantic_volume, weight_sum_volume = voxelize_cuda.forward_semantic_voxelization(
|
58 |
+
smpl_vertices, smpl_vertex_code, smpl_tetrahedrons, occ_volume,
|
59 |
+
semantic_volume, weight_sum_volume, sigma)
|
60 |
+
|
61 |
+
return semantic_volume
|
62 |
+
|
63 |
+
|
64 |
+
class Voxelization(nn.Module):
|
65 |
+
"""
|
66 |
+
Wrapper around the autograd function VoxelizationFunction
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, smpl_vertex_code, smpl_face_code, smpl_face_indices,
|
70 |
+
smpl_tetraderon_indices, volume_res, sigma,
|
71 |
+
smooth_kernel_size, batch_size, device):
|
72 |
+
super(Voxelization, self).__init__()
|
73 |
+
assert (len(smpl_face_indices.shape) == 2)
|
74 |
+
assert (len(smpl_tetraderon_indices.shape) == 2)
|
75 |
+
assert (smpl_face_indices.shape[1] == 3)
|
76 |
+
assert (smpl_tetraderon_indices.shape[1] == 4)
|
77 |
+
|
78 |
+
self.volume_res = volume_res
|
79 |
+
self.sigma = sigma
|
80 |
+
self.smooth_kernel_size = smooth_kernel_size
|
81 |
+
self.batch_size = batch_size
|
82 |
+
self.device = device
|
83 |
+
|
84 |
+
self.smpl_vertex_code = smpl_vertex_code
|
85 |
+
self.smpl_face_code = smpl_face_code
|
86 |
+
self.smpl_face_indices = smpl_face_indices
|
87 |
+
self.smpl_tetraderon_indices = smpl_tetraderon_indices
|
88 |
+
|
89 |
+
def update_param(self, batch_size, smpl_tetra):
|
90 |
+
|
91 |
+
self.batch_size = batch_size
|
92 |
+
self.smpl_tetraderon_indices = smpl_tetra
|
93 |
+
|
94 |
+
smpl_vertex_code_batch = np.tile(self.smpl_vertex_code,
|
95 |
+
(self.batch_size, 1, 1))
|
96 |
+
smpl_face_code_batch = np.tile(self.smpl_face_code,
|
97 |
+
(self.batch_size, 1, 1))
|
98 |
+
smpl_face_indices_batch = np.tile(self.smpl_face_indices,
|
99 |
+
(self.batch_size, 1, 1))
|
100 |
+
smpl_tetraderon_indices_batch = np.tile(self.smpl_tetraderon_indices,
|
101 |
+
(self.batch_size, 1, 1))
|
102 |
+
|
103 |
+
smpl_vertex_code_batch = torch.from_numpy(
|
104 |
+
smpl_vertex_code_batch).contiguous().to(self.device)
|
105 |
+
smpl_face_code_batch = torch.from_numpy(
|
106 |
+
smpl_face_code_batch).contiguous().to(self.device)
|
107 |
+
smpl_face_indices_batch = torch.from_numpy(
|
108 |
+
smpl_face_indices_batch).contiguous().to(self.device)
|
109 |
+
smpl_tetraderon_indices_batch = torch.from_numpy(
|
110 |
+
smpl_tetraderon_indices_batch).contiguous().to(self.device)
|
111 |
+
|
112 |
+
self.register_buffer('smpl_vertex_code_batch', smpl_vertex_code_batch)
|
113 |
+
self.register_buffer('smpl_face_code_batch', smpl_face_code_batch)
|
114 |
+
self.register_buffer('smpl_face_indices_batch',
|
115 |
+
smpl_face_indices_batch)
|
116 |
+
self.register_buffer('smpl_tetraderon_indices_batch',
|
117 |
+
smpl_tetraderon_indices_batch)
|
118 |
+
|
119 |
+
def forward(self, smpl_vertices):
|
120 |
+
"""
|
121 |
+
Generate semantic volumes from SMPL vertices
|
122 |
+
"""
|
123 |
+
assert (smpl_vertices.size()[0] == self.batch_size)
|
124 |
+
self.check_input(smpl_vertices)
|
125 |
+
smpl_faces = self.vertices_to_faces(smpl_vertices)
|
126 |
+
smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices)
|
127 |
+
smpl_face_center = self.calc_face_centers(smpl_faces)
|
128 |
+
smpl_face_normal = self.calc_face_normals(smpl_faces)
|
129 |
+
smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1]
|
130 |
+
smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :]
|
131 |
+
vol = VoxelizationFunction.apply(smpl_vertices_surface,
|
132 |
+
smpl_face_center, smpl_face_normal,
|
133 |
+
self.smpl_vertex_code_batch,
|
134 |
+
self.smpl_face_code_batch,
|
135 |
+
smpl_tetrahedrons, self.volume_res,
|
136 |
+
self.sigma, self.smooth_kernel_size)
|
137 |
+
return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
|
138 |
+
|
139 |
+
def vertices_to_faces(self, vertices):
|
140 |
+
assert (vertices.ndimension() == 3)
|
141 |
+
bs, nv = vertices.shape[:2]
|
142 |
+
device = vertices.device
|
143 |
+
face = self.smpl_face_indices_batch + (
|
144 |
+
torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
|
145 |
+
vertices_ = vertices.reshape((bs * nv, 3))
|
146 |
+
return vertices_[face.long()]
|
147 |
+
|
148 |
+
def vertices_to_tetrahedrons(self, vertices):
|
149 |
+
assert (vertices.ndimension() == 3)
|
150 |
+
bs, nv = vertices.shape[:2]
|
151 |
+
device = vertices.device
|
152 |
+
tets = self.smpl_tetraderon_indices_batch + (
|
153 |
+
torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
|
154 |
+
vertices_ = vertices.reshape((bs * nv, 3))
|
155 |
+
return vertices_[tets.long()]
|
156 |
+
|
157 |
+
def calc_face_centers(self, face_verts):
|
158 |
+
assert len(face_verts.shape) == 4
|
159 |
+
assert face_verts.shape[2] == 3
|
160 |
+
assert face_verts.shape[3] == 3
|
161 |
+
bs, nf = face_verts.shape[:2]
|
162 |
+
face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] +
|
163 |
+
face_verts[:, :, 2, :]) / 3.0
|
164 |
+
face_centers = face_centers.reshape((bs, nf, 3))
|
165 |
+
return face_centers
|
166 |
+
|
167 |
+
def calc_face_normals(self, face_verts):
|
168 |
+
assert len(face_verts.shape) == 4
|
169 |
+
assert face_verts.shape[2] == 3
|
170 |
+
assert face_verts.shape[3] == 3
|
171 |
+
bs, nf = face_verts.shape[:2]
|
172 |
+
face_verts = face_verts.reshape((bs * nf, 3, 3))
|
173 |
+
v10 = face_verts[:, 0] - face_verts[:, 1]
|
174 |
+
v12 = face_verts[:, 2] - face_verts[:, 1]
|
175 |
+
normals = F.normalize(torch.cross(v10, v12), eps=1e-5)
|
176 |
+
normals = normals.reshape((bs, nf, 3))
|
177 |
+
return normals
|
178 |
+
|
179 |
+
def check_input(self, x):
|
180 |
+
if x.device == 'cpu':
|
181 |
+
raise TypeError('Voxelization module supports only cuda tensors')
|
182 |
+
if x.type() != 'torch.cuda.FloatTensor':
|
183 |
+
raise TypeError(
|
184 |
+
'Voxelization module supports only float32 tensors')
|