YoonaAI commited on
Commit
456161a
·
1 Parent(s): 2b839ad

Create dataset/mesh_util.py

Browse files
Files changed (1) hide show
  1. lib/dataset/mesh_util.py +911 -0
lib/dataset/mesh_util.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ import cv2
20
+ import pymeshlab
21
+ import torch
22
+ import torchvision
23
+ import trimesh
24
+ from pytorch3d.io import load_obj
25
+ from termcolor import colored
26
+ from scipy.spatial import cKDTree
27
+
28
+ from pytorch3d.structures import Meshes
29
+ import torch.nn.functional as F
30
+
31
+ import os
32
+ from lib.pymaf.utils.imutils import uncrop
33
+ from lib.common.render_utils import Pytorch3dRasterizer, face_vertices
34
+
35
+ from pytorch3d.renderer.mesh import rasterize_meshes
36
+ from PIL import Image, ImageFont, ImageDraw
37
+ from kaolin.ops.mesh import check_sign
38
+ from kaolin.metrics.trianglemesh import point_to_mesh_distance
39
+
40
+ from pytorch3d.loss import (
41
+ mesh_laplacian_smoothing,
42
+ mesh_normal_consistency
43
+ )
44
+
45
+ from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
46
+
47
+ def rot6d_to_rotmat(x):
48
+ """Convert 6D rotation representation to 3x3 rotation matrix.
49
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
50
+ Input:
51
+ (B,6) Batch of 6-D rotation representations
52
+ Output:
53
+ (B,3,3) Batch of corresponding rotation matrices
54
+ """
55
+ x = x.view(-1, 3, 2)
56
+ a1 = x[:, :, 0]
57
+ a2 = x[:, :, 1]
58
+ b1 = F.normalize(a1)
59
+ b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
60
+ b3 = torch.cross(b1, b2)
61
+ return torch.stack((b1, b2, b3), dim=-1)
62
+
63
+
64
+ def tensor2variable(tensor, device):
65
+ # [1,23,3,3]
66
+ return torch.tensor(tensor, device=device, requires_grad=True)
67
+
68
+
69
+ def normal_loss(vec1, vec2):
70
+
71
+ # vec1_mask = vec1.sum(dim=1) != 0.0
72
+ # vec2_mask = vec2.sum(dim=1) != 0.0
73
+ # union_mask = vec1_mask * vec2_mask
74
+ vec_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(vec1, vec2)
75
+ # vec_diff = ((vec_sim-1.0)**2)[union_mask].mean()
76
+ vec_diff = ((vec_sim-1.0)**2).mean()
77
+
78
+ return vec_diff
79
+
80
+
81
+ class GMoF(torch.nn.Module):
82
+ def __init__(self, rho=1):
83
+ super(GMoF, self).__init__()
84
+ self.rho = rho
85
+
86
+ def extra_repr(self):
87
+ return 'rho = {}'.format(self.rho)
88
+
89
+ def forward(self, residual):
90
+ dist = torch.div(residual, residual + self.rho ** 2)
91
+ return self.rho ** 2 * dist
92
+
93
+
94
+ def mesh_edge_loss(meshes, target_length: float = 0.0):
95
+ """
96
+ Computes mesh edge length regularization loss averaged across all meshes
97
+ in a batch. Each mesh contributes equally to the final loss, regardless of
98
+ the number of edges per mesh in the batch by weighting each mesh with the
99
+ inverse number of edges. For example, if mesh 3 (out of N) has only E=4
100
+ edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
101
+ contribute to the final loss.
102
+
103
+ Args:
104
+ meshes: Meshes object with a batch of meshes.
105
+ target_length: Resting value for the edge length.
106
+
107
+ Returns:
108
+ loss: Average loss across the batch. Returns 0 if meshes contains
109
+ no meshes or all empty meshes.
110
+ """
111
+ if meshes.isempty():
112
+ return torch.tensor(
113
+ [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
114
+ )
115
+
116
+ N = len(meshes)
117
+ edges_packed = meshes.edges_packed() # (sum(E_n), 3)
118
+ verts_packed = meshes.verts_packed() # (sum(V_n), 3)
119
+ edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
120
+ num_edges_per_mesh = meshes.num_edges_per_mesh() # N
121
+
122
+ # Determine the weight for each edge based on the number of edges in the
123
+ # mesh it corresponds to.
124
+ # TODO (nikhilar) Find a faster way of computing the weights for each edge
125
+ # as this is currently a bottleneck for meshes with a large number of faces.
126
+ weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
127
+ weights = 1.0 / weights.float()
128
+
129
+ verts_edges = verts_packed[edges_packed]
130
+ v0, v1 = verts_edges.unbind(1)
131
+ loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
132
+ loss_vertex = loss * weights
133
+ # loss_outlier = torch.topk(loss, 100)[0].mean()
134
+ # loss_all = (loss_vertex.sum() + loss_outlier.mean()) / N
135
+ loss_all = loss_vertex.sum() / N
136
+
137
+ return loss_all
138
+
139
+
140
+ def remesh(obj_path, perc, device):
141
+
142
+ ms = pymeshlab.MeshSet()
143
+ ms.load_new_mesh(obj_path)
144
+ ms.laplacian_smooth()
145
+ ms.remeshing_isotropic_explicit_remeshing(
146
+ targetlen=pymeshlab.Percentage(perc), adaptive=True)
147
+ ms.save_current_mesh(obj_path.replace("recon", "remesh"))
148
+ polished_mesh = trimesh.load_mesh(obj_path.replace("recon", "remesh"))
149
+ verts_pr = torch.tensor(polished_mesh.vertices).float().unsqueeze(0).to(device)
150
+ faces_pr = torch.tensor(polished_mesh.faces).long().unsqueeze(0).to(device)
151
+
152
+ return verts_pr, faces_pr
153
+
154
+
155
+ def possion(mesh, obj_path):
156
+
157
+ mesh.export(obj_path)
158
+ ms = pymeshlab.MeshSet()
159
+ ms.load_new_mesh(obj_path)
160
+ ms.surface_reconstruction_screened_poisson(depth=10)
161
+ ms.set_current_mesh(1)
162
+ ms.save_current_mesh(obj_path)
163
+
164
+ return trimesh.load(obj_path)
165
+
166
+
167
+ def get_mask(tensor, dim):
168
+
169
+ mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
170
+ mask = mask.type_as(tensor)
171
+
172
+ return mask
173
+
174
+
175
+ def blend_rgb_norm(rgb, norm, mask):
176
+
177
+ # [0,0,0] or [127,127,127] should be marked as mask
178
+ final = rgb * (1-mask) + norm * (mask)
179
+
180
+ return final.astype(np.uint8)
181
+
182
+
183
+ def unwrap(image, data):
184
+
185
+ img_uncrop = uncrop(np.array(Image.fromarray(image).resize(data['uncrop_param']['box_shape'][:2])),
186
+ data['uncrop_param']['center'],
187
+ data['uncrop_param']['scale'],
188
+ data['uncrop_param']['crop_shape'])
189
+
190
+ img_orig = cv2.warpAffine(img_uncrop,
191
+ np.linalg.inv(data['uncrop_param']['M'])[:2, :],
192
+ data['uncrop_param']['ori_shape'][::-1][1:],
193
+ flags=cv2.INTER_CUBIC)
194
+
195
+ return img_orig
196
+
197
+
198
+ # Losses to smooth / regularize the mesh shape
199
+ def update_mesh_shape_prior_losses(mesh, losses):
200
+
201
+ # and (b) the edge length of the predicted mesh
202
+ losses["edge"]['value'] = mesh_edge_loss(mesh)
203
+ # mesh normal consistency
204
+ losses["nc"]['value'] = mesh_normal_consistency(mesh)
205
+ # mesh laplacian smoothing
206
+ losses["laplacian"]['value'] = mesh_laplacian_smoothing(
207
+ mesh, method="uniform")
208
+
209
+
210
+ def rename(old_dict, old_name, new_name):
211
+ new_dict = {}
212
+ for key, value in zip(old_dict.keys(), old_dict.values()):
213
+ new_key = key if key != old_name else new_name
214
+ new_dict[new_key] = old_dict[key]
215
+ return new_dict
216
+
217
+
218
+ def load_checkpoint(model, cfg):
219
+
220
+ model_dict = model.state_dict()
221
+ main_dict = {}
222
+ normal_dict = {}
223
+
224
+ device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
225
+
226
+ main_dict = torch.load(cached_download(cfg.resume_path, use_auth_token=os.environ['ICON']),
227
+ map_location=device)['state_dict']
228
+
229
+ main_dict = {
230
+ k: v
231
+ for k, v in main_dict.items()
232
+ if k in model_dict and v.shape == model_dict[k].shape and (
233
+ 'reconEngine' not in k) and ("normal_filter" not in k) and (
234
+ 'voxelization' not in k)
235
+ }
236
+ print(colored(f"Resume MLP weights from {cfg.resume_path}", 'green'))
237
+
238
+ normal_dict = torch.load(cached_download(cfg.normal_path, use_auth_token=os.environ['ICON']),
239
+ map_location=device)['state_dict']
240
+
241
+ for key in normal_dict.keys():
242
+ normal_dict = rename(normal_dict, key,
243
+ key.replace("netG", "netG.normal_filter"))
244
+
245
+ normal_dict = {
246
+ k: v
247
+ for k, v in normal_dict.items()
248
+ if k in model_dict and v.shape == model_dict[k].shape
249
+ }
250
+ print(colored(f"Resume normal model from {cfg.normal_path}", 'green'))
251
+
252
+ model_dict.update(main_dict)
253
+ model_dict.update(normal_dict)
254
+ model.load_state_dict(model_dict)
255
+
256
+ model.netG = model.netG.to(device)
257
+ model.reconEngine = model.reconEngine.to(device)
258
+
259
+ model.netG.training = False
260
+ model.netG.eval()
261
+
262
+ del main_dict
263
+ del normal_dict
264
+ del model_dict
265
+
266
+ return model
267
+
268
+
269
+ def read_smpl_constants(folder):
270
+ """Load smpl vertex code"""
271
+ smpl_vtx_std = np.loadtxt(cached_download(os.path.join(folder, 'vertices.txt'), use_auth_token=os.environ['ICON']))
272
+ min_x = np.min(smpl_vtx_std[:, 0])
273
+ max_x = np.max(smpl_vtx_std[:, 0])
274
+ min_y = np.min(smpl_vtx_std[:, 1])
275
+ max_y = np.max(smpl_vtx_std[:, 1])
276
+ min_z = np.min(smpl_vtx_std[:, 2])
277
+ max_z = np.max(smpl_vtx_std[:, 2])
278
+
279
+ smpl_vtx_std[:, 0] = (smpl_vtx_std[:, 0] - min_x) / (max_x - min_x)
280
+ smpl_vtx_std[:, 1] = (smpl_vtx_std[:, 1] - min_y) / (max_y - min_y)
281
+ smpl_vtx_std[:, 2] = (smpl_vtx_std[:, 2] - min_z) / (max_z - min_z)
282
+ smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
283
+ """Load smpl faces & tetrahedrons"""
284
+ smpl_faces = np.loadtxt(cached_download(os.path.join(folder, 'faces.txt'), use_auth_token=os.environ['ICON']),
285
+ dtype=np.int32) - 1
286
+ smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] +
287
+ smpl_vertex_code[smpl_faces[:, 1]] +
288
+ smpl_vertex_code[smpl_faces[:, 2]]) / 3.0
289
+ smpl_tetras = np.loadtxt(cached_download(os.path.join(folder, 'tetrahedrons.txt'), use_auth_token=os.environ['ICON']),
290
+ dtype=np.int32) - 1
291
+
292
+ return smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras
293
+
294
+
295
+ def feat_select(feat, select):
296
+
297
+ # feat [B, featx2, N]
298
+ # select [B, 1, N]
299
+ # return [B, feat, N]
300
+
301
+ dim = feat.shape[1] // 2
302
+ idx = torch.tile((1-select), (1, dim, 1))*dim + \
303
+ torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
304
+ feat_select = torch.gather(feat, 1, idx.long())
305
+
306
+ return feat_select
307
+
308
+
309
+ def get_visibility(xy, z, faces):
310
+ """get the visibility of vertices
311
+
312
+ Args:
313
+ xy (torch.tensor): [N,2]
314
+ z (torch.tensor): [N,1]
315
+ faces (torch.tensor): [N,3]
316
+ size (int): resolution of rendered image
317
+ """
318
+
319
+ xyz = torch.cat((xy, -z), dim=1)
320
+ xyz = (xyz + 1.0) / 2.0
321
+ faces = faces.long()
322
+
323
+ rasterizer = Pytorch3dRasterizer(image_size=2**12)
324
+ meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
325
+ raster_settings = rasterizer.raster_settings
326
+
327
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
328
+ meshes_screen,
329
+ image_size=raster_settings.image_size,
330
+ blur_radius=raster_settings.blur_radius,
331
+ faces_per_pixel=raster_settings.faces_per_pixel,
332
+ bin_size=raster_settings.bin_size,
333
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
334
+ perspective_correct=raster_settings.perspective_correct,
335
+ cull_backfaces=raster_settings.cull_backfaces,
336
+ )
337
+
338
+ vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
339
+ vis_mask = torch.zeros(size=(z.shape[0], 1))
340
+ vis_mask[vis_vertices_id] = 1.0
341
+
342
+ # print("------------------------\n")
343
+ # print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
344
+
345
+ return vis_mask
346
+
347
+
348
+ def barycentric_coordinates_of_projection(points, vertices):
349
+ ''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py
350
+ '''
351
+ """Given a point, gives projected coords of that point to a triangle
352
+ in barycentric coordinates.
353
+ See
354
+ **Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05
355
+ at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf
356
+
357
+ :param p: point to project. [B, 3]
358
+ :param v0: first vertex of triangles. [B, 3]
359
+ :returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v``
360
+ vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN``
361
+ """
362
+ #(p, q, u, v)
363
+ v0, v1, v2 = vertices[:, 0], vertices[:, 1], vertices[:, 2]
364
+ p = points
365
+
366
+ q = v0
367
+ u = v1 - v0
368
+ v = v2 - v0
369
+ n = torch.cross(u, v)
370
+ s = torch.sum(n * n, dim=1)
371
+ # If the triangle edges are collinear, cross-product is zero,
372
+ # which makes "s" 0, which gives us divide by zero. So we
373
+ # make the arbitrary choice to set s to epsv (=numpy.spacing(1)),
374
+ # the closest thing to zero
375
+ s[s == 0] = 1e-6
376
+ oneOver4ASquared = 1.0 / s
377
+ w = p - q
378
+ b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared
379
+ b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared
380
+ weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1)
381
+ # check barycenric weights
382
+ # p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3]
383
+ return weights
384
+
385
+
386
+ def cal_sdf_batch(verts, faces, cmaps, vis, points):
387
+
388
+ # verts [B, N_vert, 3]
389
+ # faces [B, N_face, 3]
390
+ # triangles [B, N_face, 3, 3]
391
+ # points [B, N_point, 3]
392
+ # cmaps [B, N_vert, 3]
393
+
394
+ Bsize = points.shape[0]
395
+
396
+ normals = Meshes(verts, faces).verts_normals_padded()
397
+
398
+ triangles = face_vertices(verts, faces)
399
+ normals = face_vertices(normals, faces)
400
+ cmaps = face_vertices(cmaps, faces)
401
+ vis = face_vertices(vis, faces)
402
+
403
+ residues, pts_ind, _ = point_to_mesh_distance(points, triangles)
404
+ closest_triangles = torch.gather(
405
+ triangles, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
406
+ closest_normals = torch.gather(
407
+ normals, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
408
+ closest_cmaps = torch.gather(
409
+ cmaps, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
410
+ closest_vis = torch.gather(
411
+ vis, 1, pts_ind[:, :, None, None].expand(-1, -1, 3, 1)).view(-1, 3, 1)
412
+ bary_weights = barycentric_coordinates_of_projection(
413
+ points.view(-1, 3), closest_triangles)
414
+
415
+ pts_cmap = (closest_cmaps*bary_weights[:, :, None]).sum(1).unsqueeze(0).clamp_(min=0.0, max=1.0)
416
+ pts_vis = (closest_vis*bary_weights[:,
417
+ :, None]).sum(1).unsqueeze(0).ge(1e-1)
418
+ pts_norm = (closest_normals*bary_weights[:, :, None]).sum(
419
+ 1).unsqueeze(0) * torch.tensor([-1.0, 1.0, -1.0]).type_as(normals)
420
+ pts_norm = F.normalize(pts_norm, dim=2)
421
+ pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3))
422
+
423
+ pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5)
424
+ pts_sdf = (pts_dist * pts_signs).unsqueeze(-1)
425
+
426
+ return pts_sdf.view(Bsize, -1, 1), pts_norm.view(Bsize, -1, 3), pts_cmap.view(Bsize, -1, 3), pts_vis.view(Bsize, -1, 1)
427
+
428
+
429
+ def orthogonal(points, calibrations, transforms=None):
430
+ '''
431
+ Compute the orthogonal projections of 3D points into the image plane by given projection matrix
432
+ :param points: [B, 3, N] Tensor of 3D points
433
+ :param calibrations: [B, 3, 4] Tensor of projection matrix
434
+ :param transforms: [B, 2, 3] Tensor of image transform matrix
435
+ :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
436
+ '''
437
+ rot = calibrations[:, :3, :3]
438
+ trans = calibrations[:, :3, 3:4]
439
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
440
+ if transforms is not None:
441
+ scale = transforms[:2, :2]
442
+ shift = transforms[:2, 2:3]
443
+ pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
444
+ return pts
445
+
446
+
447
+ def projection(points, calib, format='numpy'):
448
+ if format == 'tensor':
449
+ return torch.mm(calib[:3, :3], points.T).T + calib[:3, 3]
450
+ else:
451
+ return np.matmul(calib[:3, :3], points.T).T + calib[:3, 3]
452
+
453
+
454
+ def load_calib(calib_path):
455
+ calib_data = np.loadtxt(calib_path, dtype=float)
456
+ extrinsic = calib_data[:4, :4]
457
+ intrinsic = calib_data[4:8, :4]
458
+ calib_mat = np.matmul(intrinsic, extrinsic)
459
+ calib_mat = torch.from_numpy(calib_mat).float()
460
+ return calib_mat
461
+
462
+
463
+ def load_obj_mesh_for_Hoppe(mesh_file):
464
+ vertex_data = []
465
+ face_data = []
466
+
467
+ if isinstance(mesh_file, str):
468
+ f = open(mesh_file, "r")
469
+ else:
470
+ f = mesh_file
471
+ for line in f:
472
+ if isinstance(line, bytes):
473
+ line = line.decode("utf-8")
474
+ if line.startswith('#'):
475
+ continue
476
+ values = line.split()
477
+ if not values:
478
+ continue
479
+
480
+ if values[0] == 'v':
481
+ v = list(map(float, values[1:4]))
482
+ vertex_data.append(v)
483
+
484
+ elif values[0] == 'f':
485
+ # quad mesh
486
+ if len(values) > 4:
487
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
488
+ face_data.append(f)
489
+ f = list(
490
+ map(lambda x: int(x.split('/')[0]),
491
+ [values[3], values[4], values[1]]))
492
+ face_data.append(f)
493
+ # tri mesh
494
+ else:
495
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
496
+ face_data.append(f)
497
+
498
+ vertices = np.array(vertex_data)
499
+ faces = np.array(face_data)
500
+ faces[faces > 0] -= 1
501
+
502
+ normals, _ = compute_normal(vertices, faces)
503
+
504
+ return vertices, normals, faces
505
+
506
+
507
+ def load_obj_mesh_with_color(mesh_file):
508
+ vertex_data = []
509
+ color_data = []
510
+ face_data = []
511
+
512
+ if isinstance(mesh_file, str):
513
+ f = open(mesh_file, "r")
514
+ else:
515
+ f = mesh_file
516
+ for line in f:
517
+ if isinstance(line, bytes):
518
+ line = line.decode("utf-8")
519
+ if line.startswith('#'):
520
+ continue
521
+ values = line.split()
522
+ if not values:
523
+ continue
524
+
525
+ if values[0] == 'v':
526
+ v = list(map(float, values[1:4]))
527
+ vertex_data.append(v)
528
+ c = list(map(float, values[4:7]))
529
+ color_data.append(c)
530
+
531
+ elif values[0] == 'f':
532
+ # quad mesh
533
+ if len(values) > 4:
534
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
535
+ face_data.append(f)
536
+ f = list(
537
+ map(lambda x: int(x.split('/')[0]),
538
+ [values[3], values[4], values[1]]))
539
+ face_data.append(f)
540
+ # tri mesh
541
+ else:
542
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
543
+ face_data.append(f)
544
+
545
+ vertices = np.array(vertex_data)
546
+ colors = np.array(color_data)
547
+ faces = np.array(face_data)
548
+ faces[faces > 0] -= 1
549
+
550
+ return vertices, colors, faces
551
+
552
+
553
+ def load_obj_mesh(mesh_file, with_normal=False, with_texture=False):
554
+ vertex_data = []
555
+ norm_data = []
556
+ uv_data = []
557
+
558
+ face_data = []
559
+ face_norm_data = []
560
+ face_uv_data = []
561
+
562
+ if isinstance(mesh_file, str):
563
+ f = open(mesh_file, "r")
564
+ else:
565
+ f = mesh_file
566
+ for line in f:
567
+ if isinstance(line, bytes):
568
+ line = line.decode("utf-8")
569
+ if line.startswith('#'):
570
+ continue
571
+ values = line.split()
572
+ if not values:
573
+ continue
574
+
575
+ if values[0] == 'v':
576
+ v = list(map(float, values[1:4]))
577
+ vertex_data.append(v)
578
+ elif values[0] == 'vn':
579
+ vn = list(map(float, values[1:4]))
580
+ norm_data.append(vn)
581
+ elif values[0] == 'vt':
582
+ vt = list(map(float, values[1:3]))
583
+ uv_data.append(vt)
584
+
585
+ elif values[0] == 'f':
586
+ # quad mesh
587
+ if len(values) > 4:
588
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
589
+ face_data.append(f)
590
+ f = list(
591
+ map(lambda x: int(x.split('/')[0]),
592
+ [values[3], values[4], values[1]]))
593
+ face_data.append(f)
594
+ # tri mesh
595
+ else:
596
+ f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
597
+ face_data.append(f)
598
+
599
+ # deal with texture
600
+ if len(values[1].split('/')) >= 2:
601
+ # quad mesh
602
+ if len(values) > 4:
603
+ f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
604
+ face_uv_data.append(f)
605
+ f = list(
606
+ map(lambda x: int(x.split('/')[1]),
607
+ [values[3], values[4], values[1]]))
608
+ face_uv_data.append(f)
609
+ # tri mesh
610
+ elif len(values[1].split('/')[1]) != 0:
611
+ f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
612
+ face_uv_data.append(f)
613
+ # deal with normal
614
+ if len(values[1].split('/')) == 3:
615
+ # quad mesh
616
+ if len(values) > 4:
617
+ f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
618
+ face_norm_data.append(f)
619
+ f = list(
620
+ map(lambda x: int(x.split('/')[2]),
621
+ [values[3], values[4], values[1]]))
622
+ face_norm_data.append(f)
623
+ # tri mesh
624
+ elif len(values[1].split('/')[2]) != 0:
625
+ f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
626
+ face_norm_data.append(f)
627
+
628
+ vertices = np.array(vertex_data)
629
+ faces = np.array(face_data)
630
+ faces[faces > 0] -= 1
631
+
632
+ if with_texture and with_normal:
633
+ uvs = np.array(uv_data)
634
+ face_uvs = np.array(face_uv_data)
635
+ face_uvs[face_uvs > 0] -= 1
636
+ norms = np.array(norm_data)
637
+ if norms.shape[0] == 0:
638
+ norms, _ = compute_normal(vertices, faces)
639
+ face_normals = faces
640
+ else:
641
+ norms = normalize_v3(norms)
642
+ face_normals = np.array(face_norm_data)
643
+ face_normals[face_normals > 0] -= 1
644
+ return vertices, faces, norms, face_normals, uvs, face_uvs
645
+
646
+ if with_texture:
647
+ uvs = np.array(uv_data)
648
+ face_uvs = np.array(face_uv_data) - 1
649
+ return vertices, faces, uvs, face_uvs
650
+
651
+ if with_normal:
652
+ norms = np.array(norm_data)
653
+ norms = normalize_v3(norms)
654
+ face_normals = np.array(face_norm_data) - 1
655
+ return vertices, faces, norms, face_normals
656
+
657
+ return vertices, faces
658
+
659
+
660
+ def normalize_v3(arr):
661
+ ''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
662
+ lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2)
663
+ eps = 0.00000001
664
+ lens[lens < eps] = eps
665
+ arr[:, 0] /= lens
666
+ arr[:, 1] /= lens
667
+ arr[:, 2] /= lens
668
+ return arr
669
+
670
+
671
+ def compute_normal(vertices, faces):
672
+ # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal
673
+ vert_norms = np.zeros(vertices.shape, dtype=vertices.dtype)
674
+ # Create an indexed view into the vertex array using the array of three indices for triangles
675
+ tris = vertices[faces]
676
+ # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle
677
+ face_norms = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
678
+ # n is now an array of normals per triangle. The length of each normal is dependent the vertices,
679
+ # we need to normalize these, so that our next step weights each normal equally.
680
+ normalize_v3(face_norms)
681
+ # now we have a normalized array of normals, one per triangle, i.e., per triangle normals.
682
+ # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle,
683
+ # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards.
684
+ # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array
685
+ vert_norms[faces[:, 0]] += face_norms
686
+ vert_norms[faces[:, 1]] += face_norms
687
+ vert_norms[faces[:, 2]] += face_norms
688
+ normalize_v3(vert_norms)
689
+
690
+ return vert_norms, face_norms
691
+
692
+
693
+ def save_obj_mesh(mesh_path, verts, faces):
694
+ file = open(mesh_path, 'w')
695
+ for v in verts:
696
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
697
+ for f in faces:
698
+ f_plus = f + 1
699
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
700
+ file.close()
701
+
702
+
703
+ def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
704
+ file = open(mesh_path, 'w')
705
+
706
+ for idx, v in enumerate(verts):
707
+ c = colors[idx]
708
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
709
+ (v[0], v[1], v[2], c[0], c[1], c[2]))
710
+ for f in faces:
711
+ f_plus = f + 1
712
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
713
+ file.close()
714
+
715
+
716
+ def calculate_mIoU(outputs, labels):
717
+
718
+ SMOOTH = 1e-6
719
+
720
+ outputs = outputs.int()
721
+ labels = labels.int()
722
+
723
+ intersection = (
724
+ outputs
725
+ & labels).float().sum() # Will be zero if Truth=0 or Prediction=0
726
+ union = (outputs | labels).float().sum() # Will be zzero if both are 0
727
+
728
+ iou = (intersection + SMOOTH) / (union + SMOOTH
729
+ ) # We smooth our devision to avoid 0/0
730
+
731
+ thresholded = torch.clamp(
732
+ 20 * (iou - 0.5), 0,
733
+ 10).ceil() / 10 # This is equal to comparing with thresolds
734
+
735
+ return thresholded.mean().detach().cpu().numpy(
736
+ ) # Or thresholded.mean() if you are interested in average across the batch
737
+
738
+
739
+ def mask_filter(mask, number=1000):
740
+ """only keep {number} True items within a mask
741
+
742
+ Args:
743
+ mask (bool array): [N, ]
744
+ number (int, optional): total True item. Defaults to 1000.
745
+ """
746
+ true_ids = np.where(mask)[0]
747
+ keep_ids = np.random.choice(true_ids, size=number)
748
+ filter_mask = np.isin(np.arange(len(mask)), keep_ids)
749
+
750
+ return filter_mask
751
+
752
+
753
+ def query_mesh(path):
754
+
755
+ verts, faces_idx, _ = load_obj(path)
756
+
757
+ return verts, faces_idx.verts_idx
758
+
759
+
760
+ def add_alpha(colors, alpha=0.7):
761
+
762
+ colors_pad = np.pad(colors, ((0, 0), (0, 1)),
763
+ mode='constant',
764
+ constant_values=alpha)
765
+
766
+ return colors_pad
767
+
768
+
769
+ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type='smpl'):
770
+
771
+ font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
772
+ font = ImageFont.truetype(font_path, 30)
773
+ grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0),
774
+ nrow=nrow)
775
+ grid_img = Image.fromarray(
776
+ ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 *
777
+ 255.0).astype(np.uint8))
778
+
779
+ # add text
780
+ draw = ImageDraw.Draw(grid_img)
781
+ grid_size = 512
782
+ if loss is not None:
783
+ draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
784
+
785
+ if type == 'smpl':
786
+ for col_id, col_txt in enumerate(
787
+ ['image', 'smpl-norm(render)', 'cloth-norm(pred)', 'diff-norm', 'diff-mask']):
788
+ draw.text((10+(col_id*grid_size), 5),
789
+ col_txt, (255, 0, 0), font=font)
790
+ elif type == 'cloth':
791
+ for col_id, col_txt in enumerate(
792
+ ['image', 'cloth-norm(recon)', 'cloth-norm(pred)', 'diff-norm']):
793
+ draw.text((10+(col_id*grid_size), 5),
794
+ col_txt, (255, 0, 0), font=font)
795
+ for col_id, col_txt in enumerate(
796
+ ['0', '90', '180', '270']):
797
+ draw.text((10+(col_id*grid_size), grid_size*2+5),
798
+ col_txt, (255, 0, 0), font=font)
799
+ else:
800
+ print(f"{type} should be 'smpl' or 'cloth'")
801
+
802
+ grid_img = grid_img.resize((grid_img.size[0], grid_img.size[1]),
803
+ Image.ANTIALIAS)
804
+
805
+ return grid_img
806
+
807
+
808
+ def clean_mesh(verts, faces):
809
+
810
+ device = verts.device
811
+
812
+ mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(),
813
+ faces.detach().cpu().numpy())
814
+ mesh_lst = mesh_lst.split(only_watertight=False)
815
+ comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst]
816
+ mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
817
+
818
+ final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
819
+ final_faces = torch.as_tensor(mesh_clean.faces).int().to(device)
820
+
821
+ return final_verts, final_faces
822
+
823
+
824
+ def merge_mesh(verts_A, faces_A, verts_B, faces_B, color=False):
825
+
826
+ sep_mesh = trimesh.Trimesh(np.concatenate([verts_A, verts_B], axis=0),
827
+ np.concatenate(
828
+ [faces_A, faces_B + faces_A.max() + 1],
829
+ axis=0),
830
+ maintain_order=True,
831
+ process=False)
832
+ if color:
833
+ colors = np.ones_like(sep_mesh.vertices)
834
+ colors[:verts_A.shape[0]] *= np.array([255.0, 0.0, 0.0])
835
+ colors[verts_A.shape[0]:] *= np.array([0.0, 255.0, 0.0])
836
+ sep_mesh.visual.vertex_colors = colors
837
+
838
+ # union_mesh = trimesh.boolean.union([trimesh.Trimesh(verts_A, faces_A),
839
+ # trimesh.Trimesh(verts_B, faces_B)], engine='blender')
840
+
841
+ return sep_mesh
842
+
843
+
844
+ def mesh_move(mesh_lst, step, scale=1.0):
845
+
846
+ trans = np.array([1.0, 0.0, 0.0]) * step
847
+
848
+ resize_matrix = trimesh.transformations.scale_and_translate(
849
+ scale=(scale), translate=trans)
850
+
851
+ results = []
852
+
853
+ for mesh in mesh_lst:
854
+ mesh.apply_transform(resize_matrix)
855
+ results.append(mesh)
856
+
857
+ return results
858
+
859
+
860
+ class SMPLX():
861
+ def __init__(self):
862
+
863
+ REPO_ID = "Yuliang/SMPL"
864
+
865
+ self.smpl_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smpl_verts.npy', use_auth_token=os.environ['ICON'])
866
+ self.smplx_verts_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_verts.npy', use_auth_token=os.environ['ICON'])
867
+ self.faces_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_faces.npy', use_auth_token=os.environ['ICON'])
868
+ self.cmap_vert_path = hf_hub_download(REPO_ID, filename='smpl_data/smplx_cmap.npy', use_auth_token=os.environ['ICON'])
869
+
870
+ self.faces = np.load(self.faces_path)
871
+ self.verts = np.load(self.smplx_verts_path)
872
+ self.smpl_verts = np.load(self.smpl_verts_path)
873
+
874
+ self.model_dir = hf_hub_url(REPO_ID, filename='models')
875
+ self.tedra_dir = hf_hub_url(REPO_ID, filename='tedra_data')
876
+
877
+ def get_smpl_mat(self, vert_ids):
878
+
879
+ mat = torch.as_tensor(np.load(self.cmap_vert_path)).float()
880
+ return mat[vert_ids, :]
881
+
882
+ def smpl2smplx(self, vert_ids=None):
883
+ """convert vert_ids in smpl to vert_ids in smplx
884
+
885
+ Args:
886
+ vert_ids ([int.array]): [n, knn_num]
887
+ """
888
+ smplx_tree = cKDTree(self.verts, leafsize=1)
889
+ _, ind = smplx_tree.query(self.smpl_verts, k=1) # ind: [smpl_num, 1]
890
+
891
+ if vert_ids is not None:
892
+ smplx_vert_ids = ind[vert_ids]
893
+ else:
894
+ smplx_vert_ids = ind
895
+
896
+ return smplx_vert_ids
897
+
898
+ def smplx2smpl(self, vert_ids=None):
899
+ """convert vert_ids in smplx to vert_ids in smpl
900
+
901
+ Args:
902
+ vert_ids ([int.array]): [n, knn_num]
903
+ """
904
+ smpl_tree = cKDTree(self.smpl_verts, leafsize=1)
905
+ _, ind = smpl_tree.query(self.verts, k=1) # ind: [smplx_num, 1]
906
+ if vert_ids is not None:
907
+ smpl_vert_ids = ind[vert_ids]
908
+ else:
909
+ smpl_vert_ids = ind
910
+
911
+ return smpl_vert_ids