YoonaAI commited on
Commit
7853fb6
·
1 Parent(s): b1f0cf3

Upload cloth_extraction.py

Browse files
Files changed (1) hide show
  1. lib/common/cloth_extraction.py +170 -0
lib/common/cloth_extraction.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import itertools
5
+ import trimesh
6
+ from matplotlib.path import Path
7
+ from collections import Counter
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+
10
+
11
+ def load_segmentation(path, shape):
12
+ """
13
+ Get a segmentation mask for a given image
14
+ Arguments:
15
+ path: path to the segmentation json file
16
+ shape: shape of the output mask
17
+ Returns:
18
+ Returns a segmentation mask
19
+ """
20
+ with open(path) as json_file:
21
+ dict = json.load(json_file)
22
+ segmentations = []
23
+ for key, val in dict.items():
24
+ if not key.startswith('item'):
25
+ continue
26
+
27
+ # Each item can have multiple polygons. Combine them to one
28
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
29
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
30
+
31
+ coordinates = []
32
+ for segmentation_coord in val['segmentation']:
33
+ # The format before is [x1,y1, x2, y2, ....]
34
+ x = segmentation_coord[::2]
35
+ y = segmentation_coord[1::2]
36
+ xy = np.vstack((x, y)).T
37
+ coordinates.append(xy)
38
+
39
+ segmentations.append(
40
+ {'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
41
+
42
+ return segmentations
43
+
44
+
45
+ def smpl_to_recon_labels(recon, smpl, k=1):
46
+ """
47
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
48
+ Arguments:
49
+ recon: trimesh object (fully clothed model)
50
+ shape: trimesh object (smpl model)
51
+ k: number of nearest neighbours to use
52
+ Returns:
53
+ Returns a dictionary containing the bodypart and the corresponding indices
54
+ """
55
+ smpl_vert_segmentation = json.load(
56
+ open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
57
+ n = smpl.vertices.shape[0]
58
+ y = np.array([None] * n)
59
+ for key, val in smpl_vert_segmentation.items():
60
+ y[val] = key
61
+
62
+ classifier = KNeighborsClassifier(n_neighbors=1)
63
+ classifier.fit(smpl.vertices, y)
64
+
65
+ y_pred = classifier.predict(recon.vertices)
66
+
67
+ recon_labels = {}
68
+ for key in smpl_vert_segmentation.keys():
69
+ recon_labels[key] = list(np.argwhere(
70
+ y_pred == key).flatten().astype(int))
71
+
72
+ return recon_labels
73
+
74
+
75
+ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
76
+ """
77
+ Extract a portion of a mesh using 2d segmentation coordinates
78
+ Arguments:
79
+ recon: fully clothed mesh
80
+ seg_coord: segmentation coordinates in 2D (NDC)
81
+ K: intrinsic matrix of the projection
82
+ R: rotation matrix of the projection
83
+ t: translation vector of the projection
84
+ Returns:
85
+ Returns a submesh using the segmentation coordinates
86
+ """
87
+ seg_coord = segmentation['coord_normalized']
88
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
89
+ extrinsic = np.zeros((3, 4))
90
+ extrinsic[:3, :3] = R
91
+ extrinsic[:, 3] = t
92
+ P = K[:3, :3] @ extrinsic
93
+
94
+ P_inv = np.linalg.pinv(P)
95
+
96
+ # Each segmentation can contain multiple polygons
97
+ # We need to check them separately
98
+ points_so_far = []
99
+ faces = recon.faces
100
+ for polygon in seg_coord:
101
+ n = len(polygon)
102
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
103
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
104
+ XYZ = P_inv @ coords_h[:, :, None]
105
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
106
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
107
+
108
+ p = Path(XYZ[:, :2])
109
+
110
+ grid = p.contains_points(recon.vertices[:, :2])
111
+ indeces = np.argwhere(grid == True)
112
+ points_so_far += list(indeces.flatten())
113
+
114
+ if smpl is not None:
115
+ num_verts = recon.vertices.shape[0]
116
+ recon_labels = smpl_to_recon_labels(recon, smpl)
117
+ body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
118
+ 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
119
+ type = segmentation['type_id']
120
+
121
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
122
+ # https://github.com/switchablenorms/DeepFashion2
123
+ # Short sleeve clothes
124
+ if type == 1 or type == 3 or type == 10:
125
+ body_parts_to_remove += ['leftForeArm', 'rightForeArm']
126
+ # No sleeves at all or lower body clothes
127
+ elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
128
+ body_parts_to_remove += ['leftForeArm',
129
+ 'rightForeArm', 'leftArm', 'rightArm']
130
+ # Shorts
131
+ elif type == 7:
132
+ body_parts_to_remove += ['leftLeg', 'rightLeg',
133
+ 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
134
+
135
+ verts_to_remove = list(itertools.chain.from_iterable(
136
+ [recon_labels[part] for part in body_parts_to_remove]))
137
+
138
+ label_mask = np.zeros(num_verts, dtype=bool)
139
+ label_mask[verts_to_remove] = True
140
+
141
+ seg_mask = np.zeros(num_verts, dtype=bool)
142
+ seg_mask[points_so_far] = True
143
+
144
+ # Remove points that belong to other bodyparts
145
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
146
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
147
+
148
+ combine_mask = np.zeros(num_verts, dtype=bool)
149
+ combine_mask[points_so_far] = True
150
+ combine_mask[extra_verts_to_remove] = False
151
+
152
+ all_indices = np.argwhere(combine_mask == True).flatten()
153
+
154
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
155
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
156
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
157
+
158
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
159
+ mask = np.zeros(len(recon.faces), dtype=bool)
160
+ if len(faces_to_keep) > 0:
161
+ mask[faces_to_keep] = True
162
+
163
+ mesh.update_faces(mask)
164
+ mesh.remove_unreferenced_vertices()
165
+
166
+ # mesh.rezero()
167
+
168
+ return mesh
169
+
170
+ return None