YoonaAI commited on
Commit
b72eefa
·
1 Parent(s): 456161a

Create pymaf/utils/imutils.py

Browse files
Files changed (1) hide show
  1. lib/pymaf/utils/imutils.py +491 -0
lib/pymaf/utils/imutils.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains functions that are used to perform data augmentation.
3
+ """
4
+ import cv2
5
+ import io
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ from rembg import remove
10
+ from rembg.session_factory import new_session
11
+ from torchvision.models import detection
12
+
13
+ from lib.pymaf.core import constants
14
+ from lib.pymaf.utils.streamer import aug_matrix
15
+ from lib.common.cloth_extraction import load_segmentation
16
+ from torchvision import transforms
17
+
18
+
19
+ def load_img(img_file):
20
+
21
+ img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
22
+ if len(img.shape) == 2:
23
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24
+
25
+ if not img_file.endswith("png"):
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ else:
28
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
29
+
30
+ return img
31
+
32
+
33
+ def get_bbox(img, det):
34
+
35
+ input = np.float32(img)
36
+ input = (input / 255.0 -
37
+ (0.5, 0.5, 0.5)) / (0.5, 0.5, 0.5) # TO [-1.0, 1.0]
38
+ input = input.transpose(2, 0, 1) # TO [3 x H x W]
39
+ bboxes, probs = det(torch.from_numpy(input).float().unsqueeze(0))
40
+
41
+ probs = probs.unsqueeze(3)
42
+ bboxes = (bboxes * probs).sum(dim=1, keepdim=True) / probs.sum(
43
+ dim=1, keepdim=True)
44
+ bbox = bboxes[0, 0, 0].cpu().numpy()
45
+
46
+ return bbox
47
+ # Michael Black is
48
+
49
+
50
+ def get_transformer(input_res):
51
+
52
+ image_to_tensor = transforms.Compose([
53
+ transforms.Resize(input_res),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
56
+ ])
57
+
58
+ mask_to_tensor = transforms.Compose([
59
+ transforms.Resize(input_res),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize((0.0, ), (1.0, ))
62
+ ])
63
+
64
+ image_to_pymaf_tensor = transforms.Compose([
65
+ transforms.Resize(size=224),
66
+ transforms.Normalize(mean=constants.IMG_NORM_MEAN,
67
+ std=constants.IMG_NORM_STD)
68
+ ])
69
+
70
+ image_to_pixie_tensor = transforms.Compose([
71
+ transforms.Resize(224)
72
+ ])
73
+
74
+ def image_to_hybrik_tensor(img):
75
+ # mean
76
+ img[0].add_(-0.406)
77
+ img[1].add_(-0.457)
78
+ img[2].add_(-0.480)
79
+
80
+ # std
81
+ img[0].div_(0.225)
82
+ img[1].div_(0.224)
83
+ img[2].div_(0.229)
84
+ return img
85
+
86
+ return [image_to_tensor, mask_to_tensor, image_to_pymaf_tensor, image_to_pixie_tensor, image_to_hybrik_tensor]
87
+
88
+
89
+ def process_image(img_file, hps_type, input_res=512, device=None, seg_path=None):
90
+ """Read image, do preprocessing and possibly crop it according to the bounding box.
91
+ If there are bounding box annotations, use them to crop the image.
92
+ If no bounding box is specified but openpose detections are available, use them to get the bounding box.
93
+ """
94
+
95
+ [image_to_tensor, mask_to_tensor, image_to_pymaf_tensor,
96
+ image_to_pixie_tensor, image_to_hybrik_tensor] = get_transformer(input_res)
97
+
98
+ img_ori = load_img(img_file)
99
+
100
+ in_height, in_width, _ = img_ori.shape
101
+ M = aug_matrix(in_width, in_height, input_res*2, input_res*2)
102
+
103
+ # from rectangle to square
104
+ img_for_crop = cv2.warpAffine(img_ori, M[0:2, :],
105
+ (input_res*2, input_res*2), flags=cv2.INTER_CUBIC)
106
+
107
+ # detection for bbox
108
+ detector = detection.maskrcnn_resnet50_fpn(pretrained=True)
109
+ detector.eval()
110
+ predictions = detector(
111
+ [torch.from_numpy(img_for_crop).permute(2, 0, 1) / 255.])[0]
112
+ human_ids = torch.where(
113
+ predictions["scores"] == predictions["scores"][predictions['labels'] == 1].max())
114
+ bbox = predictions["boxes"][human_ids, :].flatten().detach().cpu().numpy()
115
+
116
+ width = bbox[2] - bbox[0]
117
+ height = bbox[3] - bbox[1]
118
+ center = np.array([(bbox[0] + bbox[2]) / 2.0,
119
+ (bbox[1] + bbox[3]) / 2.0])
120
+
121
+ scale = max(height, width) / 180
122
+
123
+ if hps_type == 'hybrik':
124
+ img_np = crop_for_hybrik(img_for_crop, center,
125
+ np.array([scale * 180, scale * 180]))
126
+ else:
127
+ img_np, cropping_parameters = crop(
128
+ img_for_crop, center, scale, (input_res, input_res))
129
+
130
+ img_pil = Image.fromarray(remove(img_np, post_process_mask=True, session=new_session("u2net")))
131
+
132
+ # for icon
133
+ img_rgb = image_to_tensor(img_pil.convert("RGB"))
134
+ img_mask = torch.tensor(1.0) - (mask_to_tensor(img_pil.split()[-1]) <
135
+ torch.tensor(0.5)).float()
136
+ img_tensor = img_rgb * img_mask
137
+
138
+ # for hps
139
+ img_hps = img_np.astype(np.float32) / 255.
140
+ img_hps = torch.from_numpy(img_hps).permute(2, 0, 1)
141
+
142
+ if hps_type == 'bev':
143
+ img_hps = img_np[:, :, [2, 1, 0]]
144
+ elif hps_type == 'hybrik':
145
+ img_hps = image_to_hybrik_tensor(img_hps).unsqueeze(0).to(device)
146
+ elif hps_type != 'pixie':
147
+ img_hps = image_to_pymaf_tensor(img_hps).unsqueeze(0).to(device)
148
+ else:
149
+ img_hps = image_to_pixie_tensor(img_hps).unsqueeze(0).to(device)
150
+
151
+ # uncrop params
152
+ uncrop_param = {'center': center,
153
+ 'scale': scale,
154
+ 'ori_shape': img_ori.shape,
155
+ 'box_shape': img_np.shape,
156
+ 'crop_shape': img_for_crop.shape,
157
+ 'M': M}
158
+
159
+ if not (seg_path is None):
160
+ segmentations = load_segmentation(seg_path, (in_height, in_width))
161
+ seg_coord_normalized = []
162
+ for seg in segmentations:
163
+ coord_normalized = []
164
+ for xy in seg['coordinates']:
165
+ xy_h = np.vstack((xy[:, 0], xy[:, 1], np.ones(len(xy)))).T
166
+ warped_indeces = M[0:2, :] @ xy_h[:, :, None]
167
+ warped_indeces = np.array(warped_indeces).astype(int)
168
+ warped_indeces.resize((warped_indeces.shape[:2]))
169
+
170
+ # cropped_indeces = crop_segmentation(warped_indeces, center, scale, (input_res, input_res), img_np.shape)
171
+ cropped_indeces = crop_segmentation(
172
+ warped_indeces, (input_res, input_res), cropping_parameters)
173
+
174
+ indices = np.vstack(
175
+ (cropped_indeces[:, 0], cropped_indeces[:, 1])).T
176
+
177
+ # Convert to NDC coordinates
178
+ seg_cropped_normalized = 2*(indices / input_res) - 1
179
+ # Don't know why we need to divide by 50 but it works ¯\_(ツ)_/¯ (probably some scaling factor somewhere)
180
+ # Divide only by 45 on the horizontal axis to take the curve of the human body into account
181
+ seg_cropped_normalized[:, 0] = (
182
+ 1/40) * seg_cropped_normalized[:, 0]
183
+ seg_cropped_normalized[:, 1] = (
184
+ 1/50) * seg_cropped_normalized[:, 1]
185
+ coord_normalized.append(seg_cropped_normalized)
186
+
187
+ seg['coord_normalized'] = coord_normalized
188
+ seg_coord_normalized.append(seg)
189
+
190
+ return img_tensor, img_hps, img_ori, img_mask, uncrop_param, seg_coord_normalized
191
+
192
+ return img_tensor, img_hps, img_ori, img_mask, uncrop_param
193
+
194
+
195
+ def get_transform(center, scale, res):
196
+ """Generate transformation matrix."""
197
+ h = 200 * scale
198
+ t = np.zeros((3, 3))
199
+ t[0, 0] = float(res[1]) / h
200
+ t[1, 1] = float(res[0]) / h
201
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
202
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
203
+ t[2, 2] = 1
204
+
205
+ return t
206
+
207
+
208
+ def transform(pt, center, scale, res, invert=0):
209
+ """Transform pixel location to different reference."""
210
+ t = get_transform(center, scale, res)
211
+ if invert:
212
+ t = np.linalg.inv(t)
213
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
214
+ new_pt = np.dot(t, new_pt)
215
+ return np.around(new_pt[:2]).astype(np.int16)
216
+
217
+
218
+ def crop(img, center, scale, res):
219
+ """Crop image according to the supplied bounding box."""
220
+
221
+ # Upper left point
222
+ ul = np.array(transform([0, 0], center, scale, res, invert=1))
223
+
224
+ # Bottom right point
225
+ br = np.array(transform(res, center, scale, res, invert=1))
226
+
227
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
228
+ if len(img.shape) > 2:
229
+ new_shape += [img.shape[2]]
230
+ new_img = np.zeros(new_shape)
231
+
232
+ # Range to fill new array
233
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
234
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
235
+
236
+ # Range to sample from original image
237
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
238
+ old_y = max(0, ul[1]), min(len(img), br[1])
239
+
240
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]
241
+ ] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
242
+ if len(img.shape) == 2:
243
+ new_img = np.array(Image.fromarray(new_img).resize(res))
244
+ else:
245
+ new_img = np.array(Image.fromarray(
246
+ new_img.astype(np.uint8)).resize(res))
247
+
248
+ return new_img, (old_x, new_x, old_y, new_y, new_shape)
249
+
250
+
251
+ def crop_segmentation(org_coord, res, cropping_parameters):
252
+ old_x, new_x, old_y, new_y, new_shape = cropping_parameters
253
+
254
+ new_coord = np.zeros((org_coord.shape))
255
+ new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0])
256
+ new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0])
257
+
258
+ new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1])
259
+ new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0])
260
+
261
+ return new_coord
262
+
263
+
264
+ def crop_for_hybrik(img, center, scale):
265
+ inp_h, inp_w = (256, 256)
266
+ trans = get_affine_transform(center, scale, 0, [inp_w, inp_h])
267
+ new_img = cv2.warpAffine(
268
+ img, trans, (int(inp_w), int(inp_h)), flags=cv2.INTER_LINEAR)
269
+ return new_img
270
+
271
+
272
+ def get_affine_transform(center,
273
+ scale,
274
+ rot,
275
+ output_size,
276
+ shift=np.array([0, 0], dtype=np.float32),
277
+ inv=0):
278
+
279
+ def get_dir(src_point, rot_rad):
280
+ """Rotate the point by `rot_rad` degree."""
281
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
282
+
283
+ src_result = [0, 0]
284
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
285
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
286
+
287
+ return src_result
288
+
289
+ def get_3rd_point(a, b):
290
+ """Return vector c that perpendicular to (a - b)."""
291
+ direct = a - b
292
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
293
+
294
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
295
+ scale = np.array([scale, scale])
296
+
297
+ scale_tmp = scale
298
+ src_w = scale_tmp[0]
299
+ dst_w = output_size[0]
300
+ dst_h = output_size[1]
301
+
302
+ rot_rad = np.pi * rot / 180
303
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
304
+ dst_dir = np.array([0, dst_w * -0.5], np.float32)
305
+
306
+ src = np.zeros((3, 2), dtype=np.float32)
307
+ dst = np.zeros((3, 2), dtype=np.float32)
308
+ src[0, :] = center + scale_tmp * shift
309
+ src[1, :] = center + src_dir + scale_tmp * shift
310
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
311
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
312
+
313
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
314
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
315
+
316
+ if inv:
317
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
318
+ else:
319
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
320
+
321
+ return trans
322
+
323
+
324
+ def corner_align(ul, br):
325
+
326
+ if ul[1]-ul[0] != br[1]-br[0]:
327
+ ul[1] = ul[0]+br[1]-br[0]
328
+
329
+ return ul, br
330
+
331
+
332
+ def uncrop(img, center, scale, orig_shape):
333
+ """'Undo' the image cropping/resizing.
334
+ This function is used when evaluating mask/part segmentation.
335
+ """
336
+
337
+ res = img.shape[:2]
338
+
339
+ # Upper left point
340
+ ul = np.array(transform([0, 0], center, scale, res, invert=1))
341
+ # Bottom right point
342
+ br = np.array(transform(res, center, scale, res, invert=1))
343
+
344
+ # quick fix
345
+ ul, br = corner_align(ul, br)
346
+
347
+ # size of cropped image
348
+ crop_shape = [br[1] - ul[1], br[0] - ul[0]]
349
+ new_img = np.zeros(orig_shape, dtype=np.uint8)
350
+
351
+ # Range to fill new array
352
+ new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
353
+ new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
354
+
355
+ # Range to sample from original image
356
+ old_x = max(0, ul[0]), min(orig_shape[1], br[0])
357
+ old_y = max(0, ul[1]), min(orig_shape[0], br[1])
358
+
359
+ img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape))
360
+
361
+ new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]
362
+ ] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
363
+
364
+ return new_img
365
+
366
+
367
+ def rot_aa(aa, rot):
368
+ """Rotate axis angle parameters."""
369
+ # pose parameters
370
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
371
+ [np.sin(np.deg2rad(-rot)),
372
+ np.cos(np.deg2rad(-rot)), 0], [0, 0, 1]])
373
+ # find the rotation of the body in camera frame
374
+ per_rdg, _ = cv2.Rodrigues(aa)
375
+ # apply the global rotation to the global orientation
376
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
377
+ aa = (resrot.T)[0]
378
+ return aa
379
+
380
+
381
+ def flip_img(img):
382
+ """Flip rgb images or masks.
383
+ channels come last, e.g. (256,256,3).
384
+ """
385
+ img = np.fliplr(img)
386
+ return img
387
+
388
+
389
+ def flip_kp(kp, is_smpl=False):
390
+ """Flip keypoints."""
391
+ if len(kp) == 24:
392
+ if is_smpl:
393
+ flipped_parts = constants.SMPL_JOINTS_FLIP_PERM
394
+ else:
395
+ flipped_parts = constants.J24_FLIP_PERM
396
+ elif len(kp) == 49:
397
+ if is_smpl:
398
+ flipped_parts = constants.SMPL_J49_FLIP_PERM
399
+ else:
400
+ flipped_parts = constants.J49_FLIP_PERM
401
+ kp = kp[flipped_parts]
402
+ kp[:, 0] = -kp[:, 0]
403
+ return kp
404
+
405
+
406
+ def flip_pose(pose):
407
+ """Flip pose.
408
+ The flipping is based on SMPL parameters.
409
+ """
410
+ flipped_parts = constants.SMPL_POSE_FLIP_PERM
411
+ pose = pose[flipped_parts]
412
+ # we also negate the second and the third dimension of the axis-angle
413
+ pose[1::3] = -pose[1::3]
414
+ pose[2::3] = -pose[2::3]
415
+ return pose
416
+
417
+
418
+ def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
419
+ # Normalize keypoints between -1, 1
420
+ if not inv:
421
+ ratio = 1.0 / crop_size
422
+ kp_2d = 2.0 * kp_2d * ratio - 1.0
423
+ else:
424
+ ratio = 1.0 / crop_size
425
+ kp_2d = (kp_2d + 1.0) / (2 * ratio)
426
+
427
+ return kp_2d
428
+
429
+
430
+ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
431
+ '''
432
+ param joints: [num_joints, 3]
433
+ param joints_vis: [num_joints, 3]
434
+ return: target, target_weight(1: visible, 0: invisible)
435
+ '''
436
+ num_joints = joints.shape[0]
437
+ device = joints.device
438
+ cur_device = torch.device(device.type, device.index)
439
+ if not hasattr(heatmap_size, '__len__'):
440
+ # width height
441
+ heatmap_size = [heatmap_size, heatmap_size]
442
+ assert len(heatmap_size) == 2
443
+ target_weight = np.ones((num_joints, 1), dtype=np.float32)
444
+ if joints_vis is not None:
445
+ target_weight[:, 0] = joints_vis[:, 0]
446
+ target = torch.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
447
+ dtype=torch.float32,
448
+ device=cur_device)
449
+
450
+ tmp_size = sigma * 3
451
+
452
+ for joint_id in range(num_joints):
453
+ mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
454
+ mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
455
+ # Check that any part of the gaussian is in-bounds
456
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
457
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
458
+ if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
459
+ or br[0] < 0 or br[1] < 0:
460
+ # If not, just return the image as is
461
+ target_weight[joint_id] = 0
462
+ continue
463
+
464
+ # # Generate gaussian
465
+ size = 2 * tmp_size + 1
466
+ # x = np.arange(0, size, 1, np.float32)
467
+ # y = x[:, np.newaxis]
468
+ # x0 = y0 = size // 2
469
+ # # The gaussian is not normalized, we want the center value to equal 1
470
+ # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
471
+ # g = torch.from_numpy(g.astype(np.float32))
472
+
473
+ x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
474
+ y = x.unsqueeze(-1)
475
+ x0 = y0 = size // 2
476
+ # The gaussian is not normalized, we want the center value to equal 1
477
+ g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
478
+
479
+ # Usable gaussian range
480
+ g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
481
+ g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
482
+ # Image range
483
+ img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
484
+ img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
485
+
486
+ v = target_weight[joint_id]
487
+ if v > 0.5:
488
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
489
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
490
+
491
+ return target, target_weight