alexnasa commited on
Commit
de56840
·
verified ·
1 Parent(s): c22a956

Upload run_facer_segmentation.py

Browse files
src/pixel3dmm/run_facer_segmentation.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ from math import ceil
6
+
7
+ import PIL.Image
8
+ import torch
9
+ import distinctipy
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+ import numpy as np
13
+ import tyro
14
+ import facer
15
+
16
+ from pixel3dmm import env_paths
17
+
18
+ colors = distinctipy.get_colors(22, rng=0)
19
+
20
+
21
+ def viz_results(img, seq_classes, n_classes, suppress_plot = False):
22
+
23
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
24
+ #distinctipy.color_swatch(colors)
25
+ bad_indices = [
26
+ 0, # background,
27
+ 1, # neck
28
+ # 2, skin
29
+ 3, # cloth
30
+ 4, # ear_r (images-space r)
31
+ 5, # ear_l
32
+ # 6 brow_r
33
+ # 7 brow_l
34
+ # 8, # eye_r
35
+ # 9, # eye_l
36
+ # 10 noise
37
+ # 11 mouth
38
+ # 12 lower_lip
39
+ # 13 upper_lip
40
+ 14, # hair,
41
+ # 15, glasses
42
+ 16, # ??
43
+ 17, # earring_r
44
+ 18, # ?
45
+ ]
46
+ bad_indices = []
47
+
48
+ for i in range(n_classes):
49
+ if i not in bad_indices:
50
+ seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
51
+
52
+ if not suppress_plot:
53
+ plt.imshow(seg_img.astype(np.uint(8)))
54
+ plt.show()
55
+ return Image.fromarray(seg_img.astype(np.uint8))
56
+
57
+ def get_color_seg(img, seq_classes, n_classes):
58
+
59
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
60
+ colors = distinctipy.get_colors(n_classes+1, rng=0)
61
+ #distinctipy.color_swatch(colors)
62
+ bad_indices = [
63
+ 0, # background,
64
+ 1, # neck
65
+ # 2, skin
66
+ 3, # cloth
67
+ 4, # ear_r (images-space r)
68
+ 5, # ear_l
69
+ # 6 brow_r
70
+ # 7 brow_l
71
+ # 8, # eye_r
72
+ # 9, # eye_l
73
+ # 10 noise
74
+ # 11 mouth
75
+ # 12 lower_lip
76
+ # 13 upper_lip
77
+ 14, # hair,
78
+ # 15, glasses
79
+ 16, # ??
80
+ 17, # earring_r
81
+ 18, # ?
82
+ ]
83
+
84
+ for i in range(n_classes):
85
+ if i not in bad_indices:
86
+ seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255
87
+
88
+
89
+ return Image.fromarray(seg_img.astype(np.uint8))
90
+
91
+
92
+ def crop_gt_img(img, seq_classes, n_classes):
93
+
94
+ seg_img = np.zeros([img.shape[-2], img.shape[-1], 3])
95
+ colors = distinctipy.get_colors(n_classes+1, rng=0)
96
+ #distinctipy.color_swatch(colors)
97
+ bad_indices = [
98
+ 0, # background,
99
+ 1, # neck
100
+ # 2, skin
101
+ 3, # cloth
102
+ 4, #ear_r (images-space r)
103
+ 5, #ear_l
104
+ # 6 brow_r
105
+ # 7 brow_l
106
+ #8, # eye_r
107
+ #9, # eye_l
108
+ # 10 noise
109
+ # 11 mouth
110
+ # 12 lower_lip
111
+ # 13 upper_lip
112
+ 14, # hair,
113
+ # 15, glasses
114
+ 16, # ??
115
+ 17, # earring_r
116
+ 18, # ?
117
+ ]
118
+
119
+ for i in range(n_classes):
120
+ if i in bad_indices:
121
+ img[seq_classes[0, :, :] == i] = 0
122
+
123
+
124
+ #plt.imshow(img.astype(np.uint(8)))
125
+ #plt.show()
126
+ return img.astype(np.uint8)
127
+
128
+
129
+ def segment(video_name : str, face_detector, face_parser):
130
+
131
+
132
+ out = f'{env_paths.PREPROCESSED_DATA}/{video_name}'
133
+ out_seg = f'{out}/seg_og/'
134
+ out_seg_annot = f'{out}/seg_non_crop_annotations/'
135
+ os.makedirs(out_seg, exist_ok=True)
136
+ os.makedirs(out_seg_annot, exist_ok=True)
137
+ folder = f'{out}/cropped/' # '/home/giebenhain/GTA/data_kinect/color/'
138
+
139
+
140
+
141
+
142
+
143
+ frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')]
144
+
145
+ frames.sort()
146
+
147
+ if len(os.listdir(out_seg)) == len(frames):
148
+ print(f'''
149
+ <<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>>
150
+ ''')
151
+ return
152
+
153
+ #for file in frames:
154
+ batch_size = 1
155
+
156
+ for i in range(len(frames)//batch_size):
157
+ image_stack = []
158
+ frame_stack = []
159
+ original_shapes = []
160
+ for j in range(batch_size):
161
+ file = frames[i * batch_size + j]
162
+
163
+ if os.path.exists(f'{out_seg_annot}/color_{file}.png'):
164
+ print('DONE')
165
+ continue
166
+ img = Image.open(f'{folder}/{file}')#.resize((512, 512))
167
+
168
+ og_size = img.size
169
+
170
+ image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device="cuda") # image: 1 x 3 x h x w
171
+ image_stack.append(image)
172
+ frame_stack.append(file[:-4])
173
+
174
+ for batch_idx in range(ceil(len(image_stack)/batch_size)):
175
+ image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0)
176
+ frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size]
177
+ og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size]
178
+
179
+ #if True:
180
+ try:
181
+ with torch.inference_mode():
182
+ faces = face_detector(image_batch)
183
+ torch.cuda.empty_cache()
184
+ faces = face_parser(image_batch, faces, bbox_scale_factor=1.25)
185
+ torch.cuda.empty_cache()
186
+
187
+ seg_logits = faces['seg']['logits']
188
+ back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy()
189
+ seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
190
+ seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8)
191
+ seg_classes[back_ground] = seg_probs.shape[1] + 1
192
+
193
+
194
+ for _iidx in range(seg_probs.shape[0]):
195
+ frame = frame_idx_batch[_iidx]
196
+ iidx = faces['image_ids'][_iidx].item()
197
+ try:
198
+ I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True)
199
+ I_color.save(f'{out_seg_annot}/color_{frame}.png')
200
+ except Exception as ex:
201
+ pass
202
+ I = Image.fromarray(seg_classes[_iidx])
203
+ I.save(f'{out_seg}/{frame}.png')
204
+ torch.cuda.empty_cache()
205
+ except Exception as exx:
206
+ traceback.print_exc()
207
+ continue
208
+