Spaces:
Build error
Build error
Upload cloth_extraction.py
Browse files- lib/common/cloth_extraction.py +170 -0
lib/common/cloth_extraction.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import itertools
|
5 |
+
import trimesh
|
6 |
+
from matplotlib.path import Path
|
7 |
+
from collections import Counter
|
8 |
+
from sklearn.neighbors import KNeighborsClassifier
|
9 |
+
|
10 |
+
|
11 |
+
def load_segmentation(path, shape):
|
12 |
+
"""
|
13 |
+
Get a segmentation mask for a given image
|
14 |
+
Arguments:
|
15 |
+
path: path to the segmentation json file
|
16 |
+
shape: shape of the output mask
|
17 |
+
Returns:
|
18 |
+
Returns a segmentation mask
|
19 |
+
"""
|
20 |
+
with open(path) as json_file:
|
21 |
+
dict = json.load(json_file)
|
22 |
+
segmentations = []
|
23 |
+
for key, val in dict.items():
|
24 |
+
if not key.startswith('item'):
|
25 |
+
continue
|
26 |
+
|
27 |
+
# Each item can have multiple polygons. Combine them to one
|
28 |
+
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
|
29 |
+
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
|
30 |
+
|
31 |
+
coordinates = []
|
32 |
+
for segmentation_coord in val['segmentation']:
|
33 |
+
# The format before is [x1,y1, x2, y2, ....]
|
34 |
+
x = segmentation_coord[::2]
|
35 |
+
y = segmentation_coord[1::2]
|
36 |
+
xy = np.vstack((x, y)).T
|
37 |
+
coordinates.append(xy)
|
38 |
+
|
39 |
+
segmentations.append(
|
40 |
+
{'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
|
41 |
+
|
42 |
+
return segmentations
|
43 |
+
|
44 |
+
|
45 |
+
def smpl_to_recon_labels(recon, smpl, k=1):
|
46 |
+
"""
|
47 |
+
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
|
48 |
+
Arguments:
|
49 |
+
recon: trimesh object (fully clothed model)
|
50 |
+
shape: trimesh object (smpl model)
|
51 |
+
k: number of nearest neighbours to use
|
52 |
+
Returns:
|
53 |
+
Returns a dictionary containing the bodypart and the corresponding indices
|
54 |
+
"""
|
55 |
+
smpl_vert_segmentation = json.load(
|
56 |
+
open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
|
57 |
+
n = smpl.vertices.shape[0]
|
58 |
+
y = np.array([None] * n)
|
59 |
+
for key, val in smpl_vert_segmentation.items():
|
60 |
+
y[val] = key
|
61 |
+
|
62 |
+
classifier = KNeighborsClassifier(n_neighbors=1)
|
63 |
+
classifier.fit(smpl.vertices, y)
|
64 |
+
|
65 |
+
y_pred = classifier.predict(recon.vertices)
|
66 |
+
|
67 |
+
recon_labels = {}
|
68 |
+
for key in smpl_vert_segmentation.keys():
|
69 |
+
recon_labels[key] = list(np.argwhere(
|
70 |
+
y_pred == key).flatten().astype(int))
|
71 |
+
|
72 |
+
return recon_labels
|
73 |
+
|
74 |
+
|
75 |
+
def extract_cloth(recon, segmentation, K, R, t, smpl=None):
|
76 |
+
"""
|
77 |
+
Extract a portion of a mesh using 2d segmentation coordinates
|
78 |
+
Arguments:
|
79 |
+
recon: fully clothed mesh
|
80 |
+
seg_coord: segmentation coordinates in 2D (NDC)
|
81 |
+
K: intrinsic matrix of the projection
|
82 |
+
R: rotation matrix of the projection
|
83 |
+
t: translation vector of the projection
|
84 |
+
Returns:
|
85 |
+
Returns a submesh using the segmentation coordinates
|
86 |
+
"""
|
87 |
+
seg_coord = segmentation['coord_normalized']
|
88 |
+
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
|
89 |
+
extrinsic = np.zeros((3, 4))
|
90 |
+
extrinsic[:3, :3] = R
|
91 |
+
extrinsic[:, 3] = t
|
92 |
+
P = K[:3, :3] @ extrinsic
|
93 |
+
|
94 |
+
P_inv = np.linalg.pinv(P)
|
95 |
+
|
96 |
+
# Each segmentation can contain multiple polygons
|
97 |
+
# We need to check them separately
|
98 |
+
points_so_far = []
|
99 |
+
faces = recon.faces
|
100 |
+
for polygon in seg_coord:
|
101 |
+
n = len(polygon)
|
102 |
+
coords_h = np.hstack((polygon, np.ones((n, 1))))
|
103 |
+
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
|
104 |
+
XYZ = P_inv @ coords_h[:, :, None]
|
105 |
+
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
|
106 |
+
XYZ = XYZ[:, :3] / XYZ[:, 3, None]
|
107 |
+
|
108 |
+
p = Path(XYZ[:, :2])
|
109 |
+
|
110 |
+
grid = p.contains_points(recon.vertices[:, :2])
|
111 |
+
indeces = np.argwhere(grid == True)
|
112 |
+
points_so_far += list(indeces.flatten())
|
113 |
+
|
114 |
+
if smpl is not None:
|
115 |
+
num_verts = recon.vertices.shape[0]
|
116 |
+
recon_labels = smpl_to_recon_labels(recon, smpl)
|
117 |
+
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
|
118 |
+
'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
|
119 |
+
type = segmentation['type_id']
|
120 |
+
|
121 |
+
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
|
122 |
+
# https://github.com/switchablenorms/DeepFashion2
|
123 |
+
# Short sleeve clothes
|
124 |
+
if type == 1 or type == 3 or type == 10:
|
125 |
+
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
|
126 |
+
# No sleeves at all or lower body clothes
|
127 |
+
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
|
128 |
+
body_parts_to_remove += ['leftForeArm',
|
129 |
+
'rightForeArm', 'leftArm', 'rightArm']
|
130 |
+
# Shorts
|
131 |
+
elif type == 7:
|
132 |
+
body_parts_to_remove += ['leftLeg', 'rightLeg',
|
133 |
+
'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
|
134 |
+
|
135 |
+
verts_to_remove = list(itertools.chain.from_iterable(
|
136 |
+
[recon_labels[part] for part in body_parts_to_remove]))
|
137 |
+
|
138 |
+
label_mask = np.zeros(num_verts, dtype=bool)
|
139 |
+
label_mask[verts_to_remove] = True
|
140 |
+
|
141 |
+
seg_mask = np.zeros(num_verts, dtype=bool)
|
142 |
+
seg_mask[points_so_far] = True
|
143 |
+
|
144 |
+
# Remove points that belong to other bodyparts
|
145 |
+
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
|
146 |
+
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
|
147 |
+
|
148 |
+
combine_mask = np.zeros(num_verts, dtype=bool)
|
149 |
+
combine_mask[points_so_far] = True
|
150 |
+
combine_mask[extra_verts_to_remove] = False
|
151 |
+
|
152 |
+
all_indices = np.argwhere(combine_mask == True).flatten()
|
153 |
+
|
154 |
+
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
|
155 |
+
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
|
156 |
+
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
|
157 |
+
|
158 |
+
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
|
159 |
+
mask = np.zeros(len(recon.faces), dtype=bool)
|
160 |
+
if len(faces_to_keep) > 0:
|
161 |
+
mask[faces_to_keep] = True
|
162 |
+
|
163 |
+
mesh.update_faces(mask)
|
164 |
+
mesh.remove_unreferenced_vertices()
|
165 |
+
|
166 |
+
# mesh.rezero()
|
167 |
+
|
168 |
+
return mesh
|
169 |
+
|
170 |
+
return None
|